Source code for anomalib.utils.metrics.anomaly_score_distribution

"""Module that computes the parameters of the normal data distribution of the training set."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Tuple

import torch
from torch import Tensor
from torchmetrics import Metric


[docs]class AnomalyScoreDistribution(Metric): """Mean and standard deviation of the anomaly scores of normal training data.""" def __init__(self, **kwargs): super().__init__(**kwargs) self.anomaly_maps = [] self.anomaly_scores = [] self.add_state("image_mean", torch.empty(0), persistent=True) self.add_state("image_std", torch.empty(0), persistent=True) self.add_state("pixel_mean", torch.empty(0), persistent=True) self.add_state("pixel_std", torch.empty(0), persistent=True) self.image_mean = torch.empty(0) self.image_std = torch.empty(0) self.pixel_mean = torch.empty(0) self.pixel_std = torch.empty(0) # pylint: disable=arguments-differ
[docs] def update( # type: ignore self, anomaly_scores: Optional[Tensor] = None, anomaly_maps: Optional[Tensor] = None ) -> None: """Update the precision-recall curve metric.""" if anomaly_maps is not None: self.anomaly_maps.append(anomaly_maps) if anomaly_scores is not None: self.anomaly_scores.append(anomaly_scores)
[docs] def compute(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Compute stats.""" anomaly_scores = torch.hstack(self.anomaly_scores) anomaly_scores = torch.log(anomaly_scores) self.image_mean = anomaly_scores.mean() self.image_std = anomaly_scores.std() if self.anomaly_maps: anomaly_maps = torch.vstack(self.anomaly_maps) anomaly_maps = torch.log(anomaly_maps).cpu() self.pixel_mean = anomaly_maps.mean(dim=0).squeeze() self.pixel_std = anomaly_maps.std(dim=0).squeeze() return self.image_mean, self.image_std, self.pixel_mean, self.pixel_std