anomalib.models.stfpm.loss

Loss function for the STFPM Model Implementation.

Module Contents

Classes

STFPMLoss

Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper.

class anomalib.models.stfpm.loss.STFPMLoss[source]

Bases: torch.nn.Module

Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper.

Example

>>> from anomalib.models.components.feature_extractors.feature_extractor import FeatureExtractor
>>> from anomalib.models.stfpm.loss import STFPMLoss
>>> from torchvision.models import resnet18
>>> layers = ['layer1', 'layer2', 'layer3']
>>> teacher_model = FeatureExtractor(model=resnet18(pretrained=True), layers=layers)
>>> student_model = FeatureExtractor(model=resnet18(pretrained=False), layers=layers)
>>> loss = Loss()
>>> inp = torch.rand((4, 3, 256, 256))
>>> teacher_features = teacher_model(inp)
>>> student_features = student_model(inp)
>>> loss(student_features, teacher_features)
    tensor(51.2015, grad_fn=<SumBackward0>)
compute_layer_loss(self, teacher_feats: torch.Tensor, student_feats: torch.Tensor) torch.Tensor[source]

Compute layer loss based on Equation (1) in Section 3.2 of the paper.

Parameters
  • teacher_feats (Tensor) – Teacher features

  • student_feats (Tensor) – Student features

Returns

L2 distance between teacher and student features.

forward(self, teacher_features: Dict[str, torch.Tensor], student_features: Dict[str, torch.Tensor]) torch.Tensor[source]

Compute the overall loss via the weighted average of the layer losses computed by the cosine similarity.

Parameters
  • teacher_features (Dict[str, Tensor]) – Teacher features

  • student_features (Dict[str, Tensor]) – Student features

Returns

Total loss, which is the weighted average of the layer losses.