"""PyTorch model for CFlow model implementation."""# Copyright (C) 2020 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions# and limitations under the License.fromtypingimportList,Tupleimporteinopsimporttorchimporttorchvisionfromtorchimportnnfromanomalib.models.cflow.anomaly_mapimportAnomalyMapGeneratorfromanomalib.models.cflow.utilsimportcflow_head,get_logp,positional_encoding_2dfromanomalib.models.componentsimportFeatureExtractor
[docs]classCflowModel(nn.Module):"""CFLOW: Conditional Normalizing Flows."""def__init__(self,input_size:Tuple[int,int],backbone:str,layers:List[str],fiber_batch_size:int=64,decoder:str="freia-cflow",condition_vector:int=128,coupling_blocks:int=8,clamp_alpha:float=1.9,permute_soft:bool=False,):super().__init__()self.backbone=getattr(torchvision.models,backbone)self.fiber_batch_size=fiber_batch_sizeself.condition_vector:int=condition_vectorself.dec_arch=decoderself.pool_layers=layersself.encoder=FeatureExtractor(backbone=self.backbone(pretrained=True),layers=self.pool_layers)self.pool_dims=self.encoder.out_dimsself.decoders=nn.ModuleList([cflow_head(condition_vector=self.condition_vector,coupling_blocks=coupling_blocks,clamp_alpha=clamp_alpha,n_features=pool_dim,permute_soft=permute_soft,)forpool_diminself.pool_dims])# encoder model is fixedforparametersinself.encoder.parameters():parameters.requires_grad=Falseself.anomaly_map_generator=AnomalyMapGenerator(image_size=tuple(input_size),pool_layers=self.pool_layers)
[docs]defforward(self,images):"""Forward-pass images into the network to extract encoder features and compute probability. Args: images: Batch of images. Returns: Predicted anomaly maps. """self.encoder.eval()self.decoders.eval()withtorch.no_grad():activation=self.encoder(images)distribution=[torch.Tensor(0).to(images.device)for_inself.pool_layers]height:List[int]=[]width:List[int]=[]forlayer_idx,layerinenumerate(self.pool_layers):encoder_activations=activation[layer]# BxCxHxWbatch_size,dim_feature_vector,im_height,im_width=encoder_activations.size()image_size=im_height*im_widthembedding_length=batch_size*image_size# number of rows in the conditional vectorheight.append(im_height)width.append(im_width)# repeats positional encoding for the entire batch 1 C H W to B C H Wpos_encoding=einops.repeat(positional_encoding_2d(self.condition_vector,im_height,im_width).unsqueeze(0),"b c h w-> (tile b) c h w",tile=batch_size,).to(images.device)c_r=einops.rearrange(pos_encoding,"b c h w -> (b h w) c")# BHWxPe_r=einops.rearrange(encoder_activations,"b c h w -> (b h w) c")# BHWxCdecoder=self.decoders[layer_idx].to(images.device)# Sometimes during validation, the last batch E / N is not a whole number. Hence we need to add 1.# It is assumed that during training that E / N is a whole number as no errors were discovered during# testing. In case it is observed in the future, we can use only this line and ensure that FIB is at# least 1 or set `drop_last` in the dataloader to drop the last non-full batch.fiber_batches=embedding_length//self.fiber_batch_size+int(embedding_length%self.fiber_batch_size>0)forbatch_numinrange(fiber_batches):# per-fiber processingifbatch_num<(fiber_batches-1):idx=torch.arange(batch_num*self.fiber_batch_size,(batch_num+1)*self.fiber_batch_size)else:# When non-full batch is encountered batch_num+1 * N will go out of boundsidx=torch.arange(batch_num*self.fiber_batch_size,embedding_length)c_p=c_r[idx]# NxPe_p=e_r[idx]# NxC# decoder returns the transformed variable z and the log Jacobian determinantwithtorch.no_grad():p_u,log_jac_det=decoder(e_p,[c_p])#decoder_log_prob=get_logp(dim_feature_vector,p_u,log_jac_det)log_prob=decoder_log_prob/dim_feature_vector# likelihood per dimdistribution[layer_idx]=torch.cat((distribution[layer_idx],log_prob))output=self.anomaly_map_generator(distribution=distribution,height=height,width=width)self.decoders.train()returnoutput.to(images.device)