Source code for anomalib.models.stfpm.lightning_model
"""STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection.https://arxiv.org/abs/2103.04257"""# Copyright (C) 2020 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions# and limitations under the License.fromtypingimportList,Tuple,UnionimporttorchfromomegaconfimportDictConfig,ListConfigfrompytorch_lightning.callbacksimportEarlyStoppingfrompytorch_lightning.utilities.cliimportMODEL_REGISTRYfromtorchimportoptimfromanomalib.models.componentsimportAnomalyModulefromanomalib.models.stfpm.lossimportSTFPMLossfromanomalib.models.stfpm.torch_modelimportSTFPMModel__all__=["StfpmLightning"]@MODEL_REGISTRY
[docs]classStfpm(AnomalyModule):"""PL Lightning Module for the STFPM algorithm. Args: input_size (Tuple[int, int]): Size of the model input. backbone (str): Backbone CNN network layers (List[str]): Layers to extract features from the backbone CNN """def__init__(self,input_size:Tuple[int,int],backbone:str,layers:List[str],):super().__init__()self.model=STFPMModel(input_size=input_size,backbone=backbone,layers=layers,)self.loss=STFPMLoss()deftraining_step(self,batch,_):# pylint: disable=arguments-differ"""Training Step of STFPM. For each batch, teacher and student and teacher features are extracted from the CNN. Args: batch (Tensor): Input batch _: Index of the batch. Returns: Hierarchical feature map """self.model.teacher_model.eval()teacher_features,student_features=self.model.forward(batch["image"])loss=self.loss(teacher_features,student_features)return{"loss":loss}defvalidation_step(self,batch,_):# pylint: disable=arguments-differ"""Validation Step of STFPM. Similar to the training step, student/teacher features are extracted from the CNN for each batch, and anomaly map is computed. Args: batch (Tensor): Input batch _: Index of the batch. Returns: Dictionary containing images, anomaly maps, true labels and masks. These are required in `validation_epoch_end` for feature concatenation. """batch["anomaly_maps"]=self.model(batch["image"])returnbatch
[docs]classStfpmLightning(Stfpm):"""PL Lightning Module for the STFPM algorithm. Args: hparams (Union[DictConfig, ListConfig]): Model params """def__init__(self,hparams:Union[DictConfig,ListConfig])->None:super().__init__(input_size=hparams.model.input_size,backbone=hparams.model.backbone,layers=hparams.model.layers,)self.hparams:Union[DictConfig,ListConfig]# type: ignoreself.save_hyperparameters(hparams)
[docs]defconfigure_callbacks(self):"""Configure model-specific callbacks. Note: This method is used for the existing CLI. When PL CLI is introduced, configure callback method will be deprecated, and callbacks will be configured from either config.yaml file or from CLI. """early_stopping=EarlyStopping(monitor=self.hparams.model.early_stopping.metric,patience=self.hparams.model.early_stopping.patience,mode=self.hparams.model.early_stopping.mode,)return[early_stopping]
[docs]defconfigure_optimizers(self)->torch.optim.Optimizer:"""Configures optimizers for each decoder. Note: This method is used for the existing CLI. When PL CLI is introduced, configure optimizers method will be deprecated, and optimizers will be configured from either config.yaml file or from CLI. Returns: Optimizer: Adam optimizer for each decoder """returnoptim.SGD(params=self.model.student_model.parameters(),lr=self.hparams.model.lr,momentum=self.hparams.model.momentum,weight_decay=self.hparams.model.weight_decay,