DRAEM#

DRÆM.

A discriminatively trained reconstruction embedding for surface anomaly detection.

Paper https://arxiv.org/abs/2108.07610

This module implements the DRÆM model for surface anomaly detection. DRÆM uses a discriminatively trained reconstruction embedding approach to detect anomalies by comparing input images with their reconstructions.

class anomalib.models.image.draem.lightning_model.Draem(dtd_dir='./datasets/dtd', enable_sspcab=False, sspcab_lambda=0.1, beta=(0.1, 1.0), pre_processor=True, post_processor=True, evaluator=True, visualizer=True)#

Bases: AnomalibModule

DRÆM.

A discriminatively trained reconstruction embedding for surface anomaly detection.

The model consists of two main components: 1. A reconstruction network that learns to reconstruct normal images 2. A discriminative network that learns to identify anomalous regions

Parameters:
  • dtd_dir (Path | str) – Directory path for the DTD dataset for anomaly deneration. Defaults to ./datasets/dtd.

  • enable_sspcab (bool, optional) – Enable SSPCAB training. Defaults to False.

  • sspcab_lambda (float, optional) – Weight factor for SSPCAB loss. Defaults to 0.1.

  • anomaly_source_path (str | None, optional) – Path to directory containing anomaly source images. If None, random noise is used. Defaults to None.

  • beta (float | tuple[float, float], optional) – Blend factor for anomaly generation. If tuple, represents range for random sampling. Defaults to (0.1, 1.0).

  • pre_processor (PreProcessor | bool, optional) – Pre-processor instance or flag to use default. Defaults to True.

  • post_processor (PostProcessor | bool, optional) – Post-processor instance or flag to use default. Defaults to True.

  • evaluator (Evaluator | bool, optional) – Evaluator instance or flag to use default. Defaults to True.

  • visualizer (Visualizer | bool, optional) – Visualizer instance or flag to use default. Defaults to True.

configure_optimizers()#

Configure optimizer and learning rate scheduler.

Returns:

Tuple containing optimizer and

scheduler lists.

Return type:

tuple[list[Adam], list[MultiStepLR]]

classmethod configure_pre_processor(image_size=None)#

Configure default pre-processor for DRÆM.

Note

Imagenet normalization is not used in this model.

Parameters:

image_size (tuple[int, int] | None, optional) – Target image size. Defaults to (256, 256).

Returns:

Configured pre-processor with resize transform.

Return type:

PreProcessor

property learning_type: LearningType#

Get the learning type of the model.

Returns:

The learning type (LearningType.ONE_CLASS).

Return type:

LearningType

on_train_start()#

Validates transforms before training begins.

Raises:

ValueError – If transforms contain normalization.

Return type:

None

setup_sspcab()#

Set up SSPCAB forward hooks.

Prepares the model for SSPCAB training by adding forward hooks to capture layer activations from specific points in the network.

Return type:

None

property trainer_arguments: dict[str, Any]#

Get DRÆM-specific trainer arguments.

Returns:

Dictionary containing trainer arguments:
  • gradient_clip_val: 0

  • num_sanity_val_steps: 0

Return type:

dict[str, Any]

training_step(batch, *args, **kwargs)#

Perform training step for DRAEM.

The step consists of: 1. Generating simulated anomalies 2. Computing reconstructions and predictions 3. Calculating the loss

Parameters:
  • batch (Batch) – Input batch containing images and metadata.

  • args – Additional positional arguments (unused).

  • kwargs – Additional keyword arguments (unused).

Returns:

Dictionary containing the training loss.

Return type:

STEP_OUTPUT

validation_step(batch, *args, **kwargs)#

Perform validation step for DRAEM.

Uses softmax predictions of the anomalous class as anomaly maps.

Parameters:
  • batch (Batch) – Input batch containing images and metadata.

  • args – Additional positional arguments (unused).

  • kwargs – Additional keyword arguments (unused).

Returns:

Dictionary containing predictions and metadata.

Return type:

STEP_OUTPUT

PyTorch model for the DRAEM model implementation.

The DRAEM model consists of two sub-networks: 1. A reconstructive sub-network that learns to reconstruct input images 2. A discriminative sub-network that detects anomalies by comparing original and

reconstructed images

class anomalib.models.image.draem.torch_model.DraemModel(sspcab=False)#

Bases: Module

DRAEM PyTorch model with reconstructive and discriminative sub-networks.

Parameters:

sspcab (bool, optional) – Enable SSPCAB training. Defaults to False.

Example

>>> model = DraemModel(sspcab=True)
>>> input_tensor = torch.randn(32, 3, 256, 256)
>>> reconstruction, prediction = model(input_tensor)
forward(batch)#

Forward pass through both sub-networks.

Parameters:

batch (torch.Tensor) – Input batch of images of shape (batch_size, channels, height, width)

Returns:

tuple: Tuple containing:
  • Reconstructed images

  • Predicted anomaly masks

During inference:

InferenceBatch: Contains anomaly map and prediction score

Return type:

During training

Example

>>> model = DraemModel()
>>> batch = torch.randn(32, 3, 256, 256)
>>> reconstruction, prediction = model(batch)  # Training mode
>>> model.eval()
>>> output = model(batch)  # Inference mode
>>> assert isinstance(output, InferenceBatch)

Loss function for the DRAEM model implementation.

This module implements the loss function used to train the DRAEM model for anomaly detection. The loss combines L2 reconstruction loss, focal loss for anomaly segmentation, and structural similarity (SSIM) loss.

Example

>>> import torch
>>> from anomalib.models.image.draem.loss import DraemLoss
>>> criterion = DraemLoss()
>>> input_image = torch.randn(8, 3, 256, 256)
>>> reconstruction = torch.randn(8, 3, 256, 256)
>>> anomaly_mask = torch.randint(0, 2, (8, 1, 256, 256))
>>> prediction = torch.randn(8, 2, 256, 256)
>>> loss = criterion(input_image, reconstruction, anomaly_mask, prediction)
class anomalib.models.image.draem.loss.DraemLoss#

Bases: Module

Overall loss function of the DRAEM model.

The total loss consists of three components: 1. L2 loss between the reconstructed and input images 2. Focal loss between predicted and ground truth anomaly masks 3. Structural Similarity (SSIM) loss between reconstructed and input images

The final loss is computed as: loss = l2_loss + ssim_loss + focal_loss

Example

>>> criterion = DraemLoss()
>>> loss = criterion(input_image, reconstruction, anomaly_mask, prediction)
forward(input_image, reconstruction, anomaly_mask, prediction)#

Compute the combined loss over a batch for the DRAEM model.

Parameters:
  • input_image (Tensor) – Original input images of shape (batch_size, num_channels, height, width)

  • reconstruction (Tensor) – Reconstructed images from the model of shape (batch_size, num_channels, height, width)

  • anomaly_mask (Tensor) – Ground truth anomaly masks of shape (batch_size, 1, height, width)

  • prediction (Tensor) – Model predictions of shape (batch_size, num_classes, height, width)

Returns:

Combined loss value

Return type:

torch.Tensor