DRAEM#

DRÆM.

A discriminatively trained reconstruction embedding for surface anomaly detection.

Paper https://arxiv.org/abs/2108.07610

This module implements the DRÆM model for surface anomaly detection. DRÆM uses a discriminatively trained reconstruction embedding approach to detect anomalies by comparing input images with their reconstructions.

class anomalib.models.image.draem.lightning_model.Draem(enable_sspcab=False, sspcab_lambda=0.1, anomaly_source_path=None, beta=(0.1, 1.0), pre_processor=True, post_processor=True, evaluator=True, visualizer=True)#

Bases: AnomalibModule

DRÆM.

A discriminatively trained reconstruction embedding for surface anomaly detection.

The model consists of two main components: 1. A reconstruction network that learns to reconstruct normal images 2. A discriminative network that learns to identify anomalous regions

Parameters:
  • enable_sspcab (bool, optional) – Enable SSPCAB training. Defaults to False.

  • sspcab_lambda (float, optional) – Weight factor for SSPCAB loss. Defaults to 0.1.

  • anomaly_source_path (str | None, optional) – Path to directory containing anomaly source images. If None, random noise is used. Defaults to None.

  • beta (float | tuple[float, float], optional) – Blend factor for anomaly generation. If tuple, represents range for random sampling. Defaults to (0.1, 1.0).

  • pre_processor (PreProcessor | bool, optional) – Pre-processor instance or flag to use default. Defaults to True.

  • post_processor (PostProcessor | bool, optional) – Post-processor instance or flag to use default. Defaults to True.

  • evaluator (Evaluator | bool, optional) – Evaluator instance or flag to use default. Defaults to True.

  • visualizer (Visualizer | bool, optional) – Visualizer instance or flag to use default. Defaults to True.

configure_optimizers()#

Configure optimizer and learning rate scheduler.

Returns:

Tuple containing optimizer and

scheduler lists.

Return type:

tuple[list[Adam], list[MultiStepLR]]

static configure_transforms(image_size=None)#

Configure default transforms for DRAEM.

Note

Normalization is not needed as images are scaled to [0, 1] in Dataset.

Parameters:

image_size (tuple[int, int] | None, optional) – Target size for image resizing. Defaults to (256, 256).

Returns:

Composed transform including resizing.

Return type:

Transform

property learning_type: LearningType#

Get the learning type of the model.

Returns:

The learning type (LearningType.ONE_CLASS).

Return type:

LearningType

setup_sspcab()#

Set up SSPCAB forward hooks.

Prepares the model for SSPCAB training by adding forward hooks to capture layer activations from specific points in the network.

Return type:

None

property trainer_arguments: dict[str, Any]#

Get DRÆM-specific trainer arguments.

Returns:

Dictionary containing trainer arguments:
  • gradient_clip_val: 0

  • num_sanity_val_steps: 0

Return type:

dict[str, Any]

training_step(batch, *args, **kwargs)#

Perform training step for DRAEM.

The step consists of: 1. Generating simulated anomalies 2. Computing reconstructions and predictions 3. Calculating the loss

Parameters:
  • batch (Batch) – Input batch containing images and metadata.

  • args – Additional positional arguments (unused).

  • kwargs – Additional keyword arguments (unused).

Returns:

Dictionary containing the training loss.

Return type:

STEP_OUTPUT

validation_step(batch, *args, **kwargs)#

Perform validation step for DRAEM.

Uses softmax predictions of the anomalous class as anomaly maps.

Parameters:
  • batch (Batch) – Input batch containing images and metadata.

  • args – Additional positional arguments (unused).

  • kwargs – Additional keyword arguments (unused).

Returns:

Dictionary containing predictions and metadata.

Return type:

STEP_OUTPUT

PyTorch model for the DRAEM model implementation.

The DRAEM model consists of two sub-networks: 1. A reconstructive sub-network that learns to reconstruct input images 2. A discriminative sub-network that detects anomalies by comparing original and

reconstructed images

class anomalib.models.image.draem.torch_model.DraemModel(sspcab=False)#

Bases: Module

DRAEM PyTorch model with reconstructive and discriminative sub-networks.

Parameters:

sspcab (bool, optional) – Enable SSPCAB training. Defaults to False.

Example

>>> model = DraemModel(sspcab=True)
>>> input_tensor = torch.randn(32, 3, 256, 256)
>>> reconstruction, prediction = model(input_tensor)
forward(batch)#

Forward pass through both sub-networks.

Parameters:

batch (torch.Tensor) – Input batch of images of shape (batch_size, channels, height, width)

Returns:

tuple: Tuple containing:
  • Reconstructed images

  • Predicted anomaly masks

During inference:

InferenceBatch: Contains anomaly map and prediction score

Return type:

During training

Example

>>> model = DraemModel()
>>> batch = torch.randn(32, 3, 256, 256)
>>> reconstruction, prediction = model(batch)  # Training mode
>>> model.eval()
>>> output = model(batch)  # Inference mode
>>> assert isinstance(output, InferenceBatch)

Loss function for the DRAEM model implementation.

This module implements the loss function used to train the DRAEM model for anomaly detection. The loss combines L2 reconstruction loss, focal loss for anomaly segmentation, and structural similarity (SSIM) loss.

Example

>>> import torch
>>> from anomalib.models.image.draem.loss import DraemLoss
>>> criterion = DraemLoss()
>>> input_image = torch.randn(8, 3, 256, 256)
>>> reconstruction = torch.randn(8, 3, 256, 256)
>>> anomaly_mask = torch.randint(0, 2, (8, 1, 256, 256))
>>> prediction = torch.randn(8, 2, 256, 256)
>>> loss = criterion(input_image, reconstruction, anomaly_mask, prediction)
class anomalib.models.image.draem.loss.DraemLoss#

Bases: Module

Overall loss function of the DRAEM model.

The total loss consists of three components: 1. L2 loss between the reconstructed and input images 2. Focal loss between predicted and ground truth anomaly masks 3. Structural Similarity (SSIM) loss between reconstructed and input images

The final loss is computed as: loss = l2_loss + ssim_loss + focal_loss

Example

>>> criterion = DraemLoss()
>>> loss = criterion(input_image, reconstruction, anomaly_mask, prediction)
forward(input_image, reconstruction, anomaly_mask, prediction)#

Compute the combined loss over a batch for the DRAEM model.

Parameters:
  • input_image (Tensor) – Original input images of shape (batch_size, num_channels, height, width)

  • reconstruction (Tensor) – Reconstructed images from the model of shape (batch_size, num_channels, height, width)

  • anomaly_mask (Tensor) – Ground truth anomaly masks of shape (batch_size, 1, height, width)

  • prediction (Tensor) – Model predictions of shape (batch_size, num_classes, height, width)

Returns:

Combined loss value

Return type:

torch.Tensor