[docs]classMinMaxNormalizationCallback(Callback):"""Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization."""# pylint: disable=unused-argument
[docs]defsetup(self,trainer:pl.Trainer,pl_module:AnomalyModule,stage:Optional[str]=None)->None:"""Adds min_max metrics to normalization metrics."""ifnothasattr(pl_module,"normalization_metrics"):pl_module.normalization_metrics=MinMax().cpu()elifnotisinstance(pl_module.normalization_metrics,MinMax):raiseAttributeError(f"Expected normalization_metrics to be of type MinMax, got {type(pl_module.normalization_metrics)}"
)# pylint: disable=unused-argument
[docs]defon_test_start(self,trainer:pl.Trainer,pl_module:AnomalyModule)->None:"""Called when the test begins."""formetricin(pl_module.image_metrics,pl_module.pixel_metrics):ifmetricisnotNone:metric.set_threshold(0.5)
[docs]defon_validation_batch_end(self,_trainer:pl.Trainer,pl_module:AnomalyModule,outputs:STEP_OUTPUT,_batch:Any,_batch_idx:int,_dataloader_idx:int,)->None:"""Called when the validation batch ends, update the min and max observed values."""if"anomaly_maps"inoutputs.keys():pl_module.normalization_metrics(outputs["anomaly_maps"])else:pl_module.normalization_metrics(outputs["pred_scores"])
[docs]defon_test_batch_end(self,_trainer:pl.Trainer,pl_module:AnomalyModule,outputs:STEP_OUTPUT,_batch:Any,_batch_idx:int,_dataloader_idx:int,)->None:"""Called when the test batch ends, normalizes the predicted scores and anomaly maps."""self._normalize_batch(outputs,pl_module)
[docs]defon_predict_batch_end(self,_trainer:pl.Trainer,pl_module:AnomalyModule,outputs:Dict,_batch:Any,_batch_idx:int,_dataloader_idx:int,)->None:"""Called when the predict batch ends, normalizes the predicted scores and anomaly maps."""self._normalize_batch(outputs,pl_module)
@staticmethod
[docs]def_normalize_batch(outputs,pl_module):"""Normalize a batch of predictions."""stats=pl_module.normalization_metrics.cpu()outputs["pred_scores"]=normalize(outputs["pred_scores"],pl_module.image_threshold.value.cpu(),stats.min,stats.max)if"anomaly_maps"inoutputs.keys():outputs["anomaly_maps"]=normalize(outputs["anomaly_maps"],pl_module.pixel_threshold.value.cpu(),stats.min,stats.max