anomalib.models.stfpm.loss¶
Loss function for the STFPM Model Implementation.
Module Contents¶
Classes¶
Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper. |
- class anomalib.models.stfpm.loss.STFPMLoss[source]¶
Bases:
torch.nn.ModuleFeature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper.
Example
>>> from anomalib.models.components.feature_extractors.feature_extractor import FeatureExtractor >>> from anomalib.models.stfpm.loss import STFPMLoss >>> from torchvision.models import resnet18
>>> layers = ['layer1', 'layer2', 'layer3'] >>> teacher_model = FeatureExtractor(model=resnet18(pretrained=True), layers=layers) >>> student_model = FeatureExtractor(model=resnet18(pretrained=False), layers=layers) >>> loss = Loss()
>>> inp = torch.rand((4, 3, 256, 256)) >>> teacher_features = teacher_model(inp) >>> student_features = student_model(inp) >>> loss(student_features, teacher_features) tensor(51.2015, grad_fn=<SumBackward0>)
- compute_layer_loss(self, teacher_feats: torch.Tensor, student_feats: torch.Tensor) torch.Tensor[source]¶
Compute layer loss based on Equation (1) in Section 3.2 of the paper.
- Parameters
teacher_feats (Tensor) – Teacher features
student_feats (Tensor) – Student features
- Returns
L2 distance between teacher and student features.
- forward(self, teacher_features: Dict[str, torch.Tensor], student_features: Dict[str, torch.Tensor]) torch.Tensor[source]¶
Compute the overall loss via the weighted average of the layer losses computed by the cosine similarity.
- Parameters
teacher_features (Dict[str, Tensor]) – Teacher features
student_features (Dict[str, Tensor]) – Student features
- Returns
Total loss, which is the weighted average of the layer losses.