Source code for anomalib.models.cflow.lightning_model
"""CFLOW: Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows.https://arxiv.org/pdf/2107.12571v1.pdf"""# Copyright (C) 2020 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions# and limitations under the License.fromtypingimportList,Tuple,Unionimporteinopsimporttorchimporttorch.nn.functionalasFfromomegaconfimportDictConfig,ListConfigfrompytorch_lightning.callbacksimportEarlyStoppingfrompytorch_lightning.utilities.cliimportMODEL_REGISTRYfromtorchimportoptimfromanomalib.models.cflow.torch_modelimportCflowModelfromanomalib.models.cflow.utilsimportget_logp,positional_encoding_2dfromanomalib.models.componentsimportAnomalyModule__all__=["Cflow","CflowLightning"]@MODEL_REGISTRY
[docs]classCflow(AnomalyModule):"""PL Lightning Module for the CFLOW algorithm."""def__init__(self,input_size:Tuple[int,int],backbone:str,layers:List[str],pre_trained:bool=True,fiber_batch_size:int=64,decoder:str="freia-cflow",condition_vector:int=128,coupling_blocks:int=8,clamp_alpha:float=1.9,permute_soft:bool=False,lr:float=0.0001,):super().__init__()self.model:CflowModel=CflowModel(input_size=input_size,backbone=backbone,pre_trained=pre_trained,layers=layers,fiber_batch_size=fiber_batch_size,decoder=decoder,condition_vector=condition_vector,coupling_blocks=coupling_blocks,clamp_alpha=clamp_alpha,permute_soft=permute_soft,)self.automatic_optimization=False# TODO: LR should be part of optimizer in config.yaml! Since cflow has custom# optimizer this is to be addressed later.self.learning_rate=lr
[docs]defconfigure_optimizers(self)->torch.optim.Optimizer:"""Configures optimizers for each decoder. Note: This method is used for the existing CLI. When PL CLI is introduced, configure optimizers method will be deprecated, and optimizers will be configured from either config.yaml file or from CLI. Returns: Optimizer: Adam optimizer for each decoder """decoders_parameters=[]fordecoder_idxinrange(len(self.model.pool_layers)):decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))optimizer=optim.Adam(params=decoders_parameters,lr=self.learning_rate,)returnoptimizer
[docs]deftraining_step(self,batch,_):# pylint: disable=arguments-differ"""Training Step of CFLOW. For each batch, decoder layers are trained with a dynamic fiber batch size. Training step is performed manually as multiple training steps are involved per batch of input images Args: batch: Input batch _: Index of the batch. Returns: Loss value for the batch """opt=self.optimizers()self.model.encoder.eval()images=batch["image"]activation=self.model.encoder(images)avg_loss=torch.zeros([1],dtype=torch.float64).to(images.device)height=[]width=[]forlayer_idx,layerinenumerate(self.model.pool_layers):encoder_activations=activation[layer].detach()# BxCxHxWbatch_size,dim_feature_vector,im_height,im_width=encoder_activations.size()image_size=im_height*im_widthembedding_length=batch_size*image_size# number of rows in the conditional vectorheight.append(im_height)width.append(im_width)# repeats positional encoding for the entire batch 1 C H W to B C H Wpos_encoding=einops.repeat(positional_encoding_2d(self.model.condition_vector,im_height,im_width).unsqueeze(0),"b c h w-> (tile b) c h w",tile=batch_size,).to(images.device)c_r=einops.rearrange(pos_encoding,"b c h w -> (b h w) c")# BHWxPe_r=einops.rearrange(encoder_activations,"b c h w -> (b h w) c")# BHWxCperm=torch.randperm(embedding_length)# BHWdecoder=self.model.decoders[layer_idx].to(images.device)fiber_batches=embedding_length//self.model.fiber_batch_size# number of fiber batchesassertfiber_batches>0,"Make sure we have enough fibers, otherwise decrease N or batch-size!"forbatch_numinrange(fiber_batches):# per-fiber processingopt.zero_grad()ifbatch_num<(fiber_batches-1):idx=torch.arange(batch_num*self.model.fiber_batch_size,(batch_num+1)*self.model.fiber_batch_size)else:# When non-full batch is encountered batch_num * N will go out of boundsidx=torch.arange(batch_num*self.model.fiber_batch_size,embedding_length)# get random vectorsc_p=c_r[perm[idx]]# NxPe_p=e_r[perm[idx]]# NxC# decoder returns the transformed variable z and the log Jacobian determinantp_u,log_jac_det=decoder(e_p,[c_p])#decoder_log_prob=get_logp(dim_feature_vector,p_u,log_jac_det)log_prob=decoder_log_prob/dim_feature_vector# likelihood per dimloss=-F.logsigmoid(log_prob)self.manual_backward(loss.mean())opt.step()avg_loss+=loss.sum()return{"loss":avg_loss}
[docs]defvalidation_step(self,batch,_):# pylint: disable=arguments-differ"""Validation Step of CFLOW. Similar to the training step, encoder features are extracted from the CNN for each batch, and anomaly map is computed. Args: batch: Input batch _: Index of the batch. Returns: Dictionary containing images, anomaly maps, true labels and masks. These are required in `validation_epoch_end` for feature concatenation. """batch["anomaly_maps"]=self.model(batch["image"])returnbatch
[docs]classCflowLightning(Cflow):"""PL Lightning Module for the CFLOW algorithm. Args: hparams (Union[DictConfig, ListConfig]): Model params """def__init__(self,hparams:Union[DictConfig,ListConfig])->None:super().__init__(input_size=hparams.model.input_size,backbone=hparams.model.backbone,layers=hparams.model.layers,fiber_batch_size=hparams.dataset.fiber_batch_size,decoder=hparams.model.decoder,condition_vector=hparams.model.condition_vector,coupling_blocks=hparams.model.coupling_blocks,clamp_alpha=hparams.model.clamp_alpha,permute_soft=hparams.model.soft_permutation,)self.hparams:Union[DictConfig,ListConfig]# type: ignoreself.save_hyperparameters(hparams)
[docs]defconfigure_callbacks(self):"""Configure model-specific callbacks. Note: This method is used for the existing CLI. When PL CLI is introduced, configure callback method will be deprecated, and callbacks will be configured from either config.yaml file or from CLI. """early_stopping=EarlyStopping(monitor=self.hparams.model.early_stopping.metric,patience=self.hparams.model.early_stopping.patience,mode=self.hparams.model.early_stopping.mode,)return[early_stopping]