[docs]classAnomalyModule(pl.LightningModule,ABC):"""AnomalyModule to train, validate, predict and test images. Acts as a base class for all the Anomaly Modules in the library. """def__init__(self):super().__init__()logger.info("Initializing %s model.",self.__class__.__name__)self.save_hyperparameters()self.model:nn.Moduleself.loss:Tensorself.callbacks:List[Callback]self.adaptive_threshold:boolself.image_threshold=AdaptiveThreshold().cpu()self.pixel_threshold=AdaptiveThreshold().cpu()self.normalization_metrics:Metricself.image_metrics:AnomalibMetricCollectionself.pixel_metrics:AnomalibMetricCollection
[docs]defforward(self,batch):# pylint: disable=arguments-differ"""Forward-pass input tensor to the module. Args: batch (Tensor): Input Tensor Returns: Tensor: Output tensor from the model. """returnself.model(batch)
[docs]defvalidation_step(self,batch,batch_idx)->dict:# type: ignore # pylint: disable=arguments-differ"""To be implemented in the subclasses."""raiseNotImplementedError
[docs]defpredict_step(self,batch:Any,batch_idx:int,_dataloader_idx:Optional[int]=None)->Any:"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. Override to add any processing logic. Args: batch (Tensor): Current batch batch_idx (int): Index of current batch _dataloader_idx (int): Index of the current dataloader Return: Predicted output """outputs=self.validation_step(batch,batch_idx)self._post_process(outputs)outputs["pred_labels"]=outputs["pred_scores"]>=self.image_threshold.valueif"anomaly_maps"inoutputs.keys():outputs["pred_masks"]=outputs["anomaly_maps"]>=self.pixel_threshold.valuereturnoutputs
[docs]deftest_step(self,batch,_):# pylint: disable=arguments-differ"""Calls validation_step for anomaly map/score calculation. Args: batch (Tensor): Input batch _: Index of the batch. Returns: Dictionary containing images, features, true labels and masks. These are required in `validation_epoch_end` for feature concatenation. """returnself.predict_step(batch,_)
[docs]defvalidation_step_end(self,val_step_outputs):# pylint: disable=arguments-differ"""Called at the end of each validation step."""self._outputs_to_cpu(val_step_outputs)self._post_process(val_step_outputs)returnval_step_outputs
[docs]deftest_step_end(self,test_step_outputs):# pylint: disable=arguments-differ"""Called at the end of each test step."""self._outputs_to_cpu(test_step_outputs)self._post_process(test_step_outputs)returntest_step_outputs
[docs]defvalidation_epoch_end(self,outputs):"""Compute threshold and performance metrics. Args: outputs: Batch of outputs from the validation step """ifself.adaptive_threshold:self._compute_adaptive_threshold(outputs)self._collect_outputs(self.image_metrics,self.pixel_metrics,outputs)self._log_metrics()
[docs]deftest_epoch_end(self,outputs):"""Compute and save anomaly scores of the test set. Args: outputs: Batch of outputs from the validation step """self._collect_outputs(self.image_metrics,self.pixel_metrics,outputs)self._log_metrics()
[docs]def_post_process(outputs):"""Compute labels based on model predictions."""if"pred_scores"notinoutputsand"anomaly_maps"inoutputs:outputs["pred_scores"]=(outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0],-1).max(dim=1).values
[docs]def_load_normalization_class(self,state_dict:OrderedDict[str,Tensor]):"""Assigns the normalization method to use."""if"normalization_metrics.max"instate_dict.keys():self.normalization_metrics=MinMax()elif"normalization_metrics.image_mean"instate_dict.keys():self.normalization_metrics=AnomalyScoreDistribution()else:warn("No known normalization found in model weights.")
[docs]defload_state_dict(self,state_dict:OrderedDict[str,Tensor],strict:bool=True):"""Load state dict from checkpoint. Ensures that normalization and thresholding attributes is properly setup before model is loaded. """# Used to load missing normalization and threshold parametersself._load_normalization_class(state_dict)super().load_state_dict(state_dict,strict=strict)