"""Helper functions for CFlow implementation."""# Copyright (C) 2020 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions# and limitations under the License.importloggingimportmathimportnumpyasnpimporttorchfromtorchimportnnfromanomalib.models.components.freia.frameworkimportSequenceINNfromanomalib.models.components.freia.modulesimportAllInOneBlock
[docs]defget_logp(dim_feature_vector:int,p_u:torch.Tensor,logdet_j:torch.Tensor)->torch.Tensor:"""Returns the log likelihood estimation. Args: dim_feature_vector (int): Dimensions of the condition vector p_u (torch.Tensor): Random variable u logdet_j (torch.Tensor): log of determinant of jacobian returned from the invertable decoder Returns: torch.Tensor: Log probability """ln_sqrt_2pi=-np.log(np.sqrt(2*np.pi))# ln(sqrt(2*pi))logp=dim_feature_vector*ln_sqrt_2pi-0.5*torch.sum(p_u**2,1)+logdet_jreturnlogp
[docs]defpositional_encoding_2d(condition_vector:int,height:int,width:int)->torch.Tensor:"""Creates embedding to store relative position of the feature vector using sine and cosine functions. Args: condition_vector (int): Length of the condition vector height (int): H of the positions width (int): W of the positions Raises: ValueError: Cannot generate encoding with conditional vector length not as multiple of 4 Returns: torch.Tensor: condition_vector x HEIGHT x WIDTH position matrix """ifcondition_vector%4!=0:raiseValueError(f"Cannot use sin/cos positional encoding with odd dimension (got dim={condition_vector})")pos_encoding=torch.zeros(condition_vector,height,width)# Each dimension use half of condition_vectorcondition_vector=condition_vector//2div_term=torch.exp(torch.arange(0.0,condition_vector,2)*-(math.log(1e4)/condition_vector))pos_w=torch.arange(0.0,width).unsqueeze(1)pos_h=torch.arange(0.0,height).unsqueeze(1)pos_encoding[0:condition_vector:2,:,:]=(torch.sin(pos_w*div_term).transpose(0,1).unsqueeze(1).repeat(1,height,1))pos_encoding[1:condition_vector:2,:,:]=(torch.cos(pos_w*div_term).transpose(0,1).unsqueeze(1).repeat(1,height,1))pos_encoding[condition_vector::2,:,:]=(torch.sin(pos_h*div_term).transpose(0,1).unsqueeze(2).repeat(1,1,width))pos_encoding[condition_vector+1::2,:,:]=(torch.cos(pos_h*div_term).transpose(0,1).unsqueeze(2).repeat(1,1,width))returnpos_encoding
[docs]defsubnet_fc(dims_in:int,dims_out:int):"""Subnetwork which predicts the affine coefficients. Args: dims_in (int): input dimensions dims_out (int): output dimensions Returns: nn.Sequential: Feed-forward subnetwork """returnnn.Sequential(nn.Linear(dims_in,2*dims_in),nn.ReLU(),nn.Linear(2*dims_in,dims_out))
[docs]defcflow_head(condition_vector:int,coupling_blocks:int,clamp_alpha:float,n_features:int,permute_soft:bool=False)->SequenceINN:"""Create invertible decoder network. Args: condition_vector (int): length of the condition vector coupling_blocks (int): number of coupling blocks to build the decoder clamp_alpha (float): clamping value to avoid exploding values n_features (int): number of decoder features permute_soft (bool): Whether to sample the permutation matrix :math:`R` from :math:`SO(N)`, or to use hard permutations instead. Note, ``permute_soft=True`` is very slow when working with >512 dimensions. Returns: SequenceINN: decoder network block """coder=SequenceINN(n_features)logger.info("CNF coder: %d",n_features)for_inrange(coupling_blocks):coder.append(AllInOneBlock,cond=0,cond_shape=(condition_vector,),subnet_constructor=subnet_fc,affine_clamping=clamp_alpha,global_affine_type="SOFTPLUS",permute_soft=permute_soft,)returncoder