Reverse Distillation#

Anomaly Detection via Reverse Distillation from One-Class Embedding.

This module implements the Reverse Distillation model for anomaly detection as described in Deng et al. (2022).

The model consists of: - A pre-trained encoder (e.g. ResNet) that extracts multi-scale features - A bottleneck layer that compresses features into a compact representation - A decoder that reconstructs features back to the original feature space - A scoring mechanism based on reconstruction error

Example

>>> from anomalib.models import ReverseDistillation
>>> from anomalib.data import MVTec
>>> from anomalib.engine import Engine
>>> # Initialize model and data
>>> datamodule = MVTec()
>>> model = ReverseDistillation(
...     backbone="wide_resnet50_2",
...     layers=["layer1", "layer2", "layer3"]
... )
>>> # Train using the Engine
>>> engine = Engine()
>>> engine.fit(model=model, datamodule=datamodule)
>>> # Get predictions
>>> predictions = engine.predict(model=model, datamodule=datamodule)

See also

  • ReverseDistillation: Lightning implementation of the model

  • ReverseDistillationModel: PyTorch implementation of the model

  • ReverseDistillationLoss: Loss function for training

class anomalib.models.image.reverse_distillation.lightning_model.ReverseDistillation(backbone='wide_resnet50_2', layers=('layer1', 'layer2', 'layer3'), anomaly_map_mode=AnomalyMapGenerationMode.ADD, pre_trained=True, pre_processor=True, post_processor=True, evaluator=True, visualizer=True)#

Bases: AnomalibModule

PL Lightning Module for Reverse Distillation Algorithm.

Parameters:
  • backbone (str) – Backbone of CNN network Defaults to wide_resnet50_2.

  • layers (list[str]) – Layers to extract features from the backbone CNN Defaults to ["layer1", "layer2", "layer3"].

  • anomaly_map_mode (AnomalyMapGenerationMode, optional) – Mode to generate anomaly map. Defaults to AnomalyMapGenerationMode.ADD.

  • pre_trained (bool, optional) – Boolean to check whether to use a pre_trained backbone. Defaults to True.

  • pre_processor (PreProcessor, optional) – Pre-processor for the model. This is used to pre-process the input data before it is passed to the model. Defaults to None.

configure_optimizers()#

Configure optimizers for decoder and bottleneck.

Returns:

Adam optimizer for each decoder

Return type:

Optimizer

property learning_type: LearningType#

Return the learning type of the model.

Returns:

Learning type of the model.

Return type:

LearningType

property trainer_arguments: dict[str, Any]#

Return Reverse Distillation trainer arguments.

training_step(batch, *args, **kwargs)#

Perform a training step of Reverse Distillation Model.

Features are extracted from three layers of the Encoder model. These are passed to the bottleneck layer that are passed to the decoder network. The loss is then calculated based on the cosine similarity between the encoder and decoder features.

Parameters:
  • (batch (batch) – Batch): Input batch

  • args – Additional arguments.

  • kwargs – Additional keyword arguments.

Return type:

Union[Tensor, Mapping[str, Any], None]

Returns:

Feature Map

validation_step(batch, *args, **kwargs)#

Perform a validation step of Reverse Distillation Model.

Similar to the training step, encoder/decoder features are extracted from the CNN for each batch, and anomaly map is computed.

Parameters:
  • batch (Batch) – Input batch

  • args – Additional arguments.

  • kwargs – Additional keyword arguments.

Return type:

Union[Tensor, Mapping[str, Any], None]

Returns:

Dictionary containing images, anomaly maps, true labels and masks. These are required in validation_epoch_end for feature concatenation.

PyTorch model implementation for Reverse Distillation.

This module implements the core PyTorch model architecture for the Reverse Distillation anomaly detection method as described in Deng et al. (2022).

The model consists of: - A pre-trained encoder (e.g. ResNet) that extracts multi-scale features - A bottleneck layer that compresses features into a compact representation - A decoder that reconstructs features back to the original feature space - A scoring mechanism based on reconstruction error

Example

>>> from anomalib.models.image.reverse_distillation.torch_model import (
...     ReverseDistillationModel
... )
>>> model = ReverseDistillationModel(
...     backbone="wide_resnet50_2",
...     input_size=(256, 256),
...     layers=["layer1", "layer2", "layer3"],
...     anomaly_map_mode="multiply"
... )
>>> features = model(torch.randn(1, 3, 256, 256))

See also

  • ReverseDistillationModel: Main PyTorch model implementation

  • ReverseDistillationLoss: Loss function for training

  • AnomalyMapGenerator: Anomaly map generation from features

class anomalib.models.image.reverse_distillation.torch_model.ReverseDistillationModel(backbone, input_size, layers, anomaly_map_mode, pre_trained=True)#

Bases: Module

PyTorch implementation of the Reverse Distillation model.

The model consists of an encoder-decoder architecture where the encoder extracts multi-scale features and the decoder reconstructs them back to the original feature space. The reconstruction error is used to detect anomalies.

Parameters:
  • backbone (str) – Name of the backbone CNN architecture used for encoder and decoder. Supported backbones can be found in timm library.

  • input_size (tuple[int, int]) – Size of input images in format (H, W).

  • layers (Sequence[str]) – Names of layers from which to extract features. For example ["layer1", "layer2", "layer3"].

  • anomaly_map_mode (AnomalyMapGenerationMode) – Mode used to generate anomaly map. Options are "multiply" or "add".

  • pre_trained (bool, optional) – Whether to use pre-trained weights for the encoder backbone. Defaults to True.

Example

>>> import torch
>>> from anomalib.models.image.reverse_distillation.torch_model import (
...     ReverseDistillationModel
... )
>>> model = ReverseDistillationModel(
...     backbone="wide_resnet50_2",
...     input_size=(256, 256),
...     layers=["layer1", "layer2", "layer3"],
...     anomaly_map_mode="multiply"
... )
>>> input_tensor = torch.randn(1, 3, 256, 256)
>>> features = model(input_tensor)

Note

The original paper uses torchvision’s pre-trained wide_resnet50_2 as the encoder backbone.

tiler#

Optional tiler for processing large images in patches.

Type:

Tiler | None

encoder#

Feature extraction backbone.

Type:

TimmFeatureExtractor

bottleneck#

Bottleneck layer to compress features.

Type:

nn.Module

decoder#

Decoder network to reconstruct features.

Type:

nn.Module

anomaly_map_generator#

Module to generate anomaly maps from features.

Type:

AnomalyMapGenerator

forward(images)#

Forward pass through the model.

The behavior differs between training and evaluation modes: - Training: Returns encoder and decoder features for computing loss - Evaluation: Returns anomaly maps and scores

Parameters:

images (torch.Tensor) – Input tensor of shape (N, C, H, W) where N is batch size, C is number of channels, H and W are height and width.

Returns:

  • In training mode: Tuple of lists containing encoder and decoder features

  • In evaluation mode: InferenceBatch containing anomaly maps and scores

Return type:

tuple[list[torch.Tensor], list[torch.Tensor]] | InferenceBatch

Example

>>> import torch
>>> model = ReverseDistillationModel(
...     backbone="wide_resnet50_2",
...     input_size=(256, 256),
...     layers=["layer1", "layer2", "layer3"],
...     anomaly_map_mode="multiply"
... )
>>> input_tensor = torch.randn(1, 3, 256, 256)
>>> # Training mode
>>> model.train()
>>> encoder_features, decoder_features = model(input_tensor)
>>> # Evaluation mode
>>> model.eval()
>>> predictions = model(input_tensor)

Loss function for Reverse Distillation model.

This module implements the loss function used to train the Reverse Distillation model for anomaly detection. The loss is based on cosine similarity between encoder and decoder features.

The loss function: 1. Takes encoder and decoder feature maps as input 2. Flattens the spatial dimensions of each feature map 3. Computes cosine similarity between corresponding encoder-decoder pairs 4. Averages the similarities across spatial dimensions and feature pairs

Example

>>> import torch
>>> from anomalib.models.image.reverse_distillation.loss import (
...     ReverseDistillationLoss
... )
>>> criterion = ReverseDistillationLoss()
>>> encoder_features = [torch.randn(2, 64, 32, 32)]
>>> decoder_features = [torch.randn(2, 64, 32, 32)]
>>> loss = criterion(encoder_features, decoder_features)

See also

  • ReverseDistillationLoss: Main loss class implementation

  • ReverseDistillation: Lightning implementation of the full model

class anomalib.models.image.reverse_distillation.loss.ReverseDistillationLoss(*args, **kwargs)#

Bases: Module

Loss function for Reverse Distillation model.

This class implements the cosine similarity loss used to train the Reverse Distillation model. The loss measures the dissimilarity between encoder and decoder feature maps.

The loss computation involves: 1. Flattening the spatial dimensions of encoder and decoder feature maps 2. Computing cosine similarity between corresponding encoder-decoder pairs 3. Subtracting similarities from 1 to get a dissimilarity measure 4. Taking mean across spatial dimensions and feature pairs

Example

>>> import torch
>>> from anomalib.models.image.reverse_distillation.loss import (
...     ReverseDistillationLoss
... )
>>> criterion = ReverseDistillationLoss()
>>> encoder_features = [torch.randn(2, 64, 32, 32)]
>>> decoder_features = [torch.randn(2, 64, 32, 32)]
>>> loss = criterion(encoder_features, decoder_features)

References

static forward(encoder_features, decoder_features)#

Compute cosine similarity loss between encoder and decoder features.

Parameters:
  • encoder_features (list[torch.Tensor]) – List of feature tensors from the encoder network. Each tensor has shape (B, C, H, W) where B is batch size, C is channels, H and W are spatial dimensions.

  • decoder_features (list[torch.Tensor]) – List of feature tensors from the decoder network. Must match encoder features in length and shapes.

Returns:

Scalar loss value computed as mean of (1 - cosine

similarity) across all feature pairs.

Return type:

torch.Tensor

Anomaly map computation for Reverse Distillation model.

This module implements functionality to generate anomaly heatmaps from the feature reconstruction errors of the Reverse Distillation model.

The anomaly maps are generated by: 1. Computing reconstruction error between original and reconstructed features 2. Upscaling the error maps to original image size 3. Optional smoothing via Gaussian blur 4. Combining multiple scale errors via addition or multiplication

Example

>>> from anomalib.models.image.reverse_distillation.anomaly_map import (
...     AnomalyMapGenerator
... )
>>> generator = AnomalyMapGenerator(image_size=(256, 256))
>>> features = [torch.randn(1, 64, 32, 32), torch.randn(1, 128, 16, 16)]
>>> anomaly_map = generator(features)

See also

class anomalib.models.image.reverse_distillation.anomaly_map.AnomalyMapGenerationMode(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)#

Bases: str, Enum

Type of mode when generating anomaly imape.

class anomalib.models.image.reverse_distillation.anomaly_map.AnomalyMapGenerator(image_size, sigma=4, mode=AnomalyMapGenerationMode.MULTIPLY)#

Bases: Module

Generate Anomaly Heatmap.

Parameters:
  • image_size (ListConfig, tuple) – Size of original image used for upscaling the anomaly map.

  • sigma (int) – Standard deviation of the gaussian kernel used to smooth anomaly map. Defaults to 4.

  • mode (AnomalyMapGenerationMode, optional) – Operation used to generate anomaly map. Options are AnomalyMapGenerationMode.ADD and AnomalyMapGenerationMode.MULTIPLY. Defaults to AnomalyMapGenerationMode.MULTIPLY.

Raises:

ValueError – In case modes other than multiply and add are passed.

forward(student_features, teacher_features)#

Compute anomaly map given encoder and decoder features.

Parameters:
Returns:

Anomaly maps of length batch.

Return type:

Tensor