Source code for anomalib.models.stfpm.lightning_model

"""STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection.

https://arxiv.org/abs/2103.04257
"""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import List, Tuple, Union

import torch
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
from torch import optim

from anomalib.models.components import AnomalyModule
from anomalib.models.stfpm.loss import STFPMLoss
from anomalib.models.stfpm.torch_model import STFPMModel

__all__ = ["StfpmLightning"]


@MODEL_REGISTRY
[docs]class Stfpm(AnomalyModule): """PL Lightning Module for the STFPM algorithm. Args: input_size (Tuple[int, int]): Size of the model input. backbone (str): Backbone CNN network layers (List[str]): Layers to extract features from the backbone CNN """ def __init__( self, input_size: Tuple[int, int], backbone: str, layers: List[str], ): super().__init__() self.model = STFPMModel( input_size=input_size, backbone=backbone, layers=layers, ) self.loss = STFPMLoss() def training_step(self, batch, _): # pylint: disable=arguments-differ """Training Step of STFPM. For each batch, teacher and student and teacher features are extracted from the CNN. Args: batch (Tensor): Input batch _: Index of the batch. Returns: Hierarchical feature map """ self.model.teacher_model.eval() teacher_features, student_features = self.model.forward(batch["image"]) loss = self.loss(teacher_features, student_features) return {"loss": loss} def validation_step(self, batch, _): # pylint: disable=arguments-differ """Validation Step of STFPM. Similar to the training step, student/teacher features are extracted from the CNN for each batch, and anomaly map is computed. Args: batch (Tensor): Input batch _: Index of the batch. Returns: Dictionary containing images, anomaly maps, true labels and masks. These are required in `validation_epoch_end` for feature concatenation. """ batch["anomaly_maps"] = self.model(batch["image"]) return batch
[docs]class StfpmLightning(Stfpm): """PL Lightning Module for the STFPM algorithm. Args: hparams (Union[DictConfig, ListConfig]): Model params """ def __init__(self, hparams: Union[DictConfig, ListConfig]) -> None: super().__init__( input_size=hparams.model.input_size, backbone=hparams.model.backbone, layers=hparams.model.layers, ) self.hparams: Union[DictConfig, ListConfig] # type: ignore self.save_hyperparameters(hparams)
[docs] def configure_callbacks(self): """Configure model-specific callbacks. Note: This method is used for the existing CLI. When PL CLI is introduced, configure callback method will be deprecated, and callbacks will be configured from either config.yaml file or from CLI. """ early_stopping = EarlyStopping( monitor=self.hparams.model.early_stopping.metric, patience=self.hparams.model.early_stopping.patience, mode=self.hparams.model.early_stopping.mode, ) return [early_stopping]
[docs] def configure_optimizers(self) -> torch.optim.Optimizer: """Configures optimizers for each decoder. Note: This method is used for the existing CLI. When PL CLI is introduced, configure optimizers method will be deprecated, and optimizers will be configured from either config.yaml file or from CLI. Returns: Optimizer: Adam optimizer for each decoder """ return optim.SGD( params=self.model.student_model.parameters(), lr=self.hparams.model.lr, momentum=self.hparams.model.momentum, weight_decay=self.hparams.model.weight_decay,
)