anomalib.models.draem.loss

Loss function for the DRAEM model implementation.

Module Contents

Classes

DraemLoss

Overall loss function of the DRAEM model.

class anomalib.models.draem.loss.DraemLoss[source]

Bases: torch.nn.Module

Overall loss function of the DRAEM model.

The total loss consists of the sum of the L2 loss and Focal loss between the reconstructed image and the input image, and the Structural Similarity loss between the predicted and GT anomaly masks.

forward(input_image, reconstruction, anomaly_mask, prediction)[source]

Compute the loss over a batch for the DRAEM model.