Source code for anomalib.utils.callbacks.visualizer.visualizer_metric

"""Metric Visualizer Callback."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from pathlib import Path

import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY

from anomalib.models.components import AnomalyModule

from .visualizer_base import BaseVisualizerCallback


@CALLBACK_REGISTRY
[docs]class MetricVisualizerCallback(BaseVisualizerCallback): """Callback that visualizes the metric results of a model by plotting the corresponding curves. To save the images to the filesystem, add the 'local' keyword to the `project.log_images_to` parameter in the config.yaml file. """
[docs] def on_test_end(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None: """Log images of the metrics contained in pl_module. In order to also plot custom metrics, they need to have implemented a `generate_figure` function that returns Tuple[matplotlib.figure.Figure, str]. Args: trainer (pl.Trainer): pytorch lightning trainer. pl_module (AnomalyModule): pytorch lightning module. """ if self.save_images or self.log_images: for metrics in (pl_module.image_metrics, pl_module.pixel_metrics): for metric in metrics.values(): # `generate_figure` needs to be defined for every metric that should be plotted automatically if hasattr(metric, "generate_figure"): fig, log_name = metric.generate_figure() file_name = f"{metrics.prefix}{log_name}" if self.log_images: self._add_to_logger(fig, pl_module, trainer, file_name) if self.save_images: fig.canvas.draw() # convert figure to np.ndarray for saving via visualizer img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) self.visualizer.save(Path(self.image_save_path.joinpath(f"{file_name}.png")), img) plt.close(fig) super().on_test_end(trainer, pl_module)