DRAEM#
DRÆM.
A discriminatively trained reconstruction embedding for surface anomaly detection.
Paper https://arxiv.org/abs/2108.07610
This module implements the DRÆM model for surface anomaly detection. DRÆM uses a discriminatively trained reconstruction embedding approach to detect anomalies by comparing input images with their reconstructions.
- class anomalib.models.image.draem.lightning_model.Draem(enable_sspcab=False, sspcab_lambda=0.1, anomaly_source_path=None, beta=(0.1, 1.0), pre_processor=True, post_processor=True, evaluator=True, visualizer=True)#
Bases:
AnomalibModule
DRÆM.
A discriminatively trained reconstruction embedding for surface anomaly detection.
The model consists of two main components: 1. A reconstruction network that learns to reconstruct normal images 2. A discriminative network that learns to identify anomalous regions
- Parameters:
enable_sspcab (bool, optional) – Enable SSPCAB training. Defaults to
False
.sspcab_lambda (float, optional) – Weight factor for SSPCAB loss. Defaults to
0.1
.anomaly_source_path (str | None, optional) – Path to directory containing anomaly source images. If
None
, random noise is used. Defaults toNone
.beta (float | tuple[float, float], optional) – Blend factor for anomaly generation. If tuple, represents range for random sampling. Defaults to
(0.1, 1.0)
.pre_processor (PreProcessor | bool, optional) – Pre-processor instance or flag to use default. Defaults to
True
.post_processor (PostProcessor | bool, optional) – Post-processor instance or flag to use default. Defaults to
True
.evaluator (Evaluator | bool, optional) – Evaluator instance or flag to use default. Defaults to
True
.visualizer (Visualizer | bool, optional) – Visualizer instance or flag to use default. Defaults to
True
.
- configure_optimizers()#
Configure optimizer and learning rate scheduler.
- static configure_transforms(image_size=None)#
Configure default transforms for DRAEM.
Note
Normalization is not needed as images are scaled to [0, 1] in Dataset.
- property learning_type: LearningType#
Get the learning type of the model.
- Returns:
The learning type (
LearningType.ONE_CLASS
).- Return type:
LearningType
- setup_sspcab()#
Set up SSPCAB forward hooks.
Prepares the model for SSPCAB training by adding forward hooks to capture layer activations from specific points in the network.
- Return type:
- training_step(batch, *args, **kwargs)#
Perform training step for DRAEM.
The step consists of: 1. Generating simulated anomalies 2. Computing reconstructions and predictions 3. Calculating the loss
- Parameters:
batch (Batch) – Input batch containing images and metadata.
args – Additional positional arguments (unused).
kwargs – Additional keyword arguments (unused).
- Returns:
Dictionary containing the training loss.
- Return type:
STEP_OUTPUT
- validation_step(batch, *args, **kwargs)#
Perform validation step for DRAEM.
Uses softmax predictions of the anomalous class as anomaly maps.
- Parameters:
batch (Batch) – Input batch containing images and metadata.
args – Additional positional arguments (unused).
kwargs – Additional keyword arguments (unused).
- Returns:
Dictionary containing predictions and metadata.
- Return type:
STEP_OUTPUT
PyTorch model for the DRAEM model implementation.
The DRAEM model consists of two sub-networks: 1. A reconstructive sub-network that learns to reconstruct input images 2. A discriminative sub-network that detects anomalies by comparing original and
reconstructed images
- class anomalib.models.image.draem.torch_model.DraemModel(sspcab=False)#
Bases:
Module
DRAEM PyTorch model with reconstructive and discriminative sub-networks.
- Parameters:
sspcab (bool, optional) – Enable SSPCAB training. Defaults to
False
.
Example
>>> model = DraemModel(sspcab=True) >>> input_tensor = torch.randn(32, 3, 256, 256) >>> reconstruction, prediction = model(input_tensor)
- forward(batch)#
Forward pass through both sub-networks.
- Parameters:
batch (torch.Tensor) – Input batch of images of shape
(batch_size, channels, height, width)
- Returns:
- tuple: Tuple containing:
Reconstructed images
Predicted anomaly masks
- During inference:
InferenceBatch: Contains anomaly map and prediction score
- Return type:
During training
Example
>>> model = DraemModel() >>> batch = torch.randn(32, 3, 256, 256) >>> reconstruction, prediction = model(batch) # Training mode >>> model.eval() >>> output = model(batch) # Inference mode >>> assert isinstance(output, InferenceBatch)
Loss function for the DRAEM model implementation.
This module implements the loss function used to train the DRAEM model for anomaly detection. The loss combines L2 reconstruction loss, focal loss for anomaly segmentation, and structural similarity (SSIM) loss.
Example
>>> import torch
>>> from anomalib.models.image.draem.loss import DraemLoss
>>> criterion = DraemLoss()
>>> input_image = torch.randn(8, 3, 256, 256)
>>> reconstruction = torch.randn(8, 3, 256, 256)
>>> anomaly_mask = torch.randint(0, 2, (8, 1, 256, 256))
>>> prediction = torch.randn(8, 2, 256, 256)
>>> loss = criterion(input_image, reconstruction, anomaly_mask, prediction)
- class anomalib.models.image.draem.loss.DraemLoss#
Bases:
Module
Overall loss function of the DRAEM model.
The total loss consists of three components: 1. L2 loss between the reconstructed and input images 2. Focal loss between predicted and ground truth anomaly masks 3. Structural Similarity (SSIM) loss between reconstructed and input images
The final loss is computed as:
loss = l2_loss + ssim_loss + focal_loss
Example
>>> criterion = DraemLoss() >>> loss = criterion(input_image, reconstruction, anomaly_mask, prediction)
- forward(input_image, reconstruction, anomaly_mask, prediction)#
Compute the combined loss over a batch for the DRAEM model.
- Parameters:
input_image (
Tensor
) – Original input images of shape(batch_size, num_channels, height, width)
reconstruction (
Tensor
) – Reconstructed images from the model of shape(batch_size, num_channels, height, width)
anomaly_mask (
Tensor
) – Ground truth anomaly masks of shape(batch_size, 1, height, width)
prediction (
Tensor
) – Model predictions of shape(batch_size, num_classes, height, width)
- Returns:
Combined loss value
- Return type: