Source code for anomalib.models.patchcore.lightning_model

"""Towards Total Recall in Industrial Anomaly Detection.

Paper https://arxiv.org/abs/2106.08265.
"""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import List, Tuple, Union

import torch
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
from torch import Tensor

from anomalib.models.components import AnomalyModule
from anomalib.models.patchcore.torch_model import PatchcoreModel

[docs]logger = logging.getLogger(__name__)
@MODEL_REGISTRY
[docs]class Patchcore(AnomalyModule): """PatchcoreLightning Module to train PatchCore algorithm. Args: input_size (Tuple[int, int]): Size of the model input. backbone (str): Backbone CNN network layers (List[str]): Layers to extract features from the backbone CNN pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. coreset_sampling_ratio (float, optional): Coreset sampling ratio to subsample embedding. Defaults to 0.1. num_neighbors (int, optional): Number of nearest neighbors. Defaults to 9. """ def __init__( self, input_size: Tuple[int, int], backbone: str, layers: List[str], pre_trained: bool = True, coreset_sampling_ratio: float = 0.1, num_neighbors: int = 9, ) -> None: super().__init__() self.model: PatchcoreModel = PatchcoreModel( input_size=input_size, backbone=backbone, pre_trained=pre_trained, layers=layers, num_neighbors=num_neighbors, ) self.coreset_sampling_ratio = coreset_sampling_ratio self.embeddings: List[Tensor] = []
[docs] def configure_optimizers(self) -> None: """Configure optimizers. Returns: None: Do not set optimizers by returning None. """ return None
[docs] def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ """Generate feature embedding of the batch. Args: batch (Dict[str, Any]): Batch containing image filename, image, label and mask _batch_idx (int): Batch Index Returns: Dict[str, np.ndarray]: Embedding Vector """ self.model.feature_extractor.eval() embedding = self.model(batch["image"]) # NOTE: `self.embedding` appends each batch embedding to # store the training set embedding. We manually append these # values mainly due to the new order of hooks introduced after PL v1.4.0 # https://github.com/PyTorchLightning/pytorch-lightning/pull/7357 self.embeddings.append(embedding)
[docs] def on_validation_start(self) -> None: """Apply subsampling to the embedding collected from the training set.""" # NOTE: Previous anomalib versions fit subsampling at the end of the epoch. # This is not possible anymore with PyTorch Lightning v1.4.0 since validation # is run within train epoch. logger.info("Aggregating the embedding extracted from the training set.") embeddings = torch.vstack(self.embeddings) logger.info("Applying core-set subsampling to get the embedding.") self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio)
[docs] def validation_step(self, batch, _): # pylint: disable=arguments-differ """Get batch of anomaly maps from input image batch. Args: batch (Dict[str, Any]): Batch containing image filename, image, label and mask _ (int): Batch Index Returns: Dict[str, Any]: Image filenames, test images, GT and predicted label/masks """ anomaly_maps, anomaly_score = self.model(batch["image"]) batch["anomaly_maps"] = anomaly_maps batch["pred_scores"] = anomaly_score.unsqueeze(0) return batch
[docs]class PatchcoreLightning(Patchcore): """PatchcoreLightning Module to train PatchCore algorithm. Args: hparams (Union[DictConfig, ListConfig]): Model params """ def __init__(self, hparams) -> None: super().__init__( input_size=hparams.model.input_size, backbone=hparams.model.backbone, layers=hparams.model.layers, coreset_sampling_ratio=hparams.model.coreset_sampling_ratio, num_neighbors=hparams.model.num_neighbors, ) self.hparams: Union[DictConfig, ListConfig] # type: ignore self.save_hyperparameters(hparams)