Source code for anomalib.models.components.base.anomaly_module
"""Base Anomaly Module for Training Task."""# Copyright (C) 2020 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions# and limitations under the License.importloggingfromabcimportABCfromtypingimportAny,List,Optionalimportpytorch_lightningasplfrompytorch_lightning.callbacks.baseimportCallbackfromtorchimportTensor,nnfromanomalib.utils.metricsimport(AdaptiveThreshold,AnomalibMetricCollection,AnomalyScoreDistribution,MinMax,)
[docs]classAnomalyModule(pl.LightningModule,ABC):"""AnomalyModule to train, validate, predict and test images. Acts as a base class for all the Anomaly Modules in the library. """def__init__(self):super().__init__()logger.info("Initializing %s model.",self.__class__.__name__)self.save_hyperparameters()self.model:nn.Moduleself.loss:Tensorself.callbacks:List[Callback]self.adaptive_threshold:boolself.image_threshold=AdaptiveThreshold().cpu()self.pixel_threshold=AdaptiveThreshold().cpu()self.training_distribution=AnomalyScoreDistribution().cpu()self.min_max=MinMax().cpu()# Create placeholders for image and pixel metrics.# If set from the config file, MetricsConfigurationCallback will# create the metric collections upon setup.self.image_metrics:AnomalibMetricCollectionself.pixel_metrics:AnomalibMetricCollection
[docs]defforward(self,batch):# pylint: disable=arguments-differ"""Forward-pass input tensor to the module. Args: batch (Tensor): Input Tensor Returns: Tensor: Output tensor from the model. """returnself.model(batch)
[docs]defvalidation_step(self,batch,batch_idx)->dict:# type: ignore # pylint: disable=arguments-differ"""To be implemented in the subclasses."""raiseNotImplementedError
[docs]defpredict_step(self,batch:Any,batch_idx:int,_dataloader_idx:Optional[int]=None)->Any:"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. Override to add any processing logic. Args: batch (Tensor): Current batch batch_idx (int): Index of current batch _dataloader_idx (int): Index of the current dataloader Return: Predicted output """outputs=self.validation_step(batch,batch_idx)self._post_process(outputs)outputs["pred_labels"]=outputs["pred_scores"]>=self.image_threshold.valueif"anomaly_maps"inoutputs.keys():outputs["pred_masks"]=outputs["anomaly_maps"]>=self.pixel_threshold.valuereturnoutputs
[docs]deftest_step(self,batch,_):# pylint: disable=arguments-differ"""Calls validation_step for anomaly map/score calculation. Args: batch (Tensor): Input batch _: Index of the batch. Returns: Dictionary containing images, features, true labels and masks. These are required in `validation_epoch_end` for feature concatenation. """returnself.predict_step(batch,_)
[docs]defvalidation_step_end(self,val_step_outputs):# pylint: disable=arguments-differ"""Called at the end of each validation step."""self._outputs_to_cpu(val_step_outputs)self._post_process(val_step_outputs)returnval_step_outputs
[docs]deftest_step_end(self,test_step_outputs):# pylint: disable=arguments-differ"""Called at the end of each test step."""self._outputs_to_cpu(test_step_outputs)self._post_process(test_step_outputs)returntest_step_outputs
[docs]defvalidation_epoch_end(self,outputs):"""Compute threshold and performance metrics. Args: outputs: Batch of outputs from the validation step """ifself.adaptive_threshold:self._compute_adaptive_threshold(outputs)self._collect_outputs(self.image_metrics,self.pixel_metrics,outputs)self._log_metrics()
[docs]deftest_epoch_end(self,outputs):"""Compute and save anomaly scores of the test set. Args: outputs: Batch of outputs from the validation step """self._collect_outputs(self.image_metrics,self.pixel_metrics,outputs)self._log_metrics()
[docs]def_post_process(self,outputs):"""Compute labels based on model predictions."""if"pred_scores"notinoutputsand"anomaly_maps"inoutputs:outputs["pred_scores"]=(outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0],-1).max(dim=1).values
)
[docs]def_outputs_to_cpu(self,output):# for output in outputs:forkey,valueinoutput.items():ifisinstance(value,Tensor):output[key]=value.cpu()