anomalib.models.components.feature_extractors.feature_extractor

Feature Extractor.

This script extracts features from a CNN network

Module Contents

Classes

FeatureExtractor

Extract features from a CNN.

class anomalib.models.components.feature_extractors.feature_extractor.FeatureExtractor(backbone: torch.nn.Module, layers: Iterable[str])[source]

Bases: torch.nn.Module

Extract features from a CNN.

Parameters
  • backbone (nn.Module) – The backbone to which the feature extraction hooks are attached.

  • layers (Iterable[str]) – List of layer names of the backbone to which the hooks are attached.

Example

>>> import torch
>>> import torchvision
>>> from anomalib.core.model.feature_extractor import FeatureExtractor
>>> model = FeatureExtractor(model=torchvision.models.resnet18(), layers=['layer1', 'layer2', 'layer3'])
>>> input = torch.rand((32, 3, 256, 256))
>>> features = model(input)
>>> [layer for layer in features.keys()]
    ['layer1', 'layer2', 'layer3']
>>> [feature.shape for feature in features.values()]
    [torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])]
get_features(self, layer_id: str) Callable[source]

Get layer features.

Parameters

layer_id (str) – Layer ID

Returns

Layer features

forward(self, input_tensor: torch.Tensor) Dict[str, torch.Tensor][source]

Forward-pass input tensor into the CNN.

Parameters

input_tensor (Tensor) – Input tensor

Returns

Feature map extracted from the CNN