Feature Extractors#
Feature extractors for deep learning models.
This module provides feature extraction utilities and classes for extracting features from images using various backbone architectures.
- Classes:
TimmFeatureExtractor: Feature extractor using timm models. TorchFXFeatureExtractor: Feature extractor using TorchFX for graph capture. BackboneParams: Configuration parameters for backbone models.
- Functions:
dryrun_find_featuremap_dims: Utility to find feature map dimensions.
Example
>>> from anomalib.models.components.feature_extractors import (
... TimmFeatureExtractor
... )
>>> # Create feature extractor
>>> feature_extractor = TimmFeatureExtractor(
... backbone="resnet18",
... layers=['layer1', 'layer2']
... )
>>> # Extract features
>>> features = feature_extractor(images)
- class anomalib.models.components.feature_extractors.BackboneParams(class_path, init_args=<factory>)#
Bases:
object
Used for serializing the backbone.
- class anomalib.models.components.feature_extractors.TimmFeatureExtractor(backbone, layers, pre_trained=True, requires_grad=False)#
Bases:
Module
Extract intermediate features from timm models.
- Parameters:
backbone (str) – Name of the timm model architecture to use as backbone. Can include custom weights URI in format
name__AT__uri
.layers (Sequence[str]) – Names of layers from which to extract features.
pre_trained (bool, optional) – Whether to use pre-trained weights. Defaults to
True
.requires_grad (bool, optional) – Whether to compute gradients for the backbone. Required for training models like STFPM. Defaults to
False
.
- feature_extractor#
The underlying timm model.
- Type:
nn.Module
Example
>>> import torch >>> from anomalib.models.components.feature_extractors import ( ... TimmFeatureExtractor ... ) >>> # Create extractor >>> model = TimmFeatureExtractor( ... backbone="resnet18", ... layers=["layer1", "layer2"] ... ) >>> # Extract features >>> inputs = torch.randn(1, 3, 224, 224) >>> features = model(inputs) >>> # Print shapes >>> for name, feat in features.items(): ... print(f"{name}: {feat.shape}") layer1: torch.Size([1, 64, 56, 56]) layer2: torch.Size([1, 128, 28, 28])
- forward(inputs)#
Extract features from the input tensor.
- Parameters:
inputs (torch.Tensor) – Input tensor of shape
(batch_size, channels, height, width)
.- Returns:
Dictionary mapping layer names to their feature tensors.
- Return type:
Example
>>> import torch >>> from anomalib.models.components.feature_extractors import ( ... TimmFeatureExtractor ... ) >>> model = TimmFeatureExtractor( ... backbone="resnet18", ... layers=["layer1"] ... ) >>> inputs = torch.randn(1, 3, 224, 224) >>> features = model(inputs) >>> features["layer1"].shape torch.Size([1, 64, 56, 56])
- class anomalib.models.components.feature_extractors.TorchFXFeatureExtractor(backbone, return_nodes, weights=None, requires_grad=False, tracer_kwargs=None)#
Bases:
Module
Extract features from a CNN using TorchFX.
- Parameters:
backbone (str | BackboneParams | dict | nn.Module) – The backbone to which the feature extraction hooks are attached. If a string name is provided, the model is loaded from torchvision. Otherwise, the model class can be provided and it will try to load the weights from the provided weights file. Last, an instance of nn.Module can also be passed directly.
return_nodes (list[str]) – List of layer names of the backbone to which the hooks are attached. You can find the names of these nodes by using
get_graph_node_names
function.weights (str | WeightsEnum | None) – Weights enum to use for the model. Torchvision models require
WeightsEnum
. These enums are defined intorchvision.models.<model>
. You can pass the weights path for custom models. Defaults toNone
.requires_grad (bool) – Models like
stfpm
use the feature extractor for training. In such cases we should setrequires_grad
toTrue
. Defaults toFalse
.tracer_kwargs (dict | None) – Dictionary of keyword arguments for NodePathTracer (which passes them onto it’s parent class torch.fx.Tracer). Can be used to allow not tracing through a list of problematic modules, by passing a list of
leaf_modules
as one of thetracer_kwargs
. Defaults toNone
.
- feature_extractor#
The TorchFX feature extractor module.
- Type:
GraphModule
Example
>>> import torch >>> from anomalib.models.components.feature_extractors import ( ... TorchFXFeatureExtractor ... ) >>> # Initialize with torchvision model >>> extractor = TorchFXFeatureExtractor( ... backbone="resnet18", ... return_nodes=["layer1", "layer2"] ... ) >>> # Extract features >>> inputs = torch.randn(1, 3, 224, 224) >>> features = extractor(inputs) >>> # Access features by layer name >>> print(features["layer1"].shape) torch.Size([1, 64, 56, 56])
- forward(inputs)#
Extract features from the input.
- Parameters:
inputs (torch.Tensor) – Input tensor.
- Returns:
- Dictionary mapping layer names to their
feature tensors.
- Return type:
- initialize_feature_extractor(backbone, return_nodes, weights=None, requires_grad=False, tracer_kwargs=None)#
Initialize the feature extractor.
- Parameters:
backbone (BackboneParams | nn.Module) – The backbone to which the feature extraction hooks are attached.
return_nodes (list[str]) – List of layer names to extract features from.
weights (str | WeightsEnum | None) – Model weights specification. Defaults to
None
.requires_grad (bool) – Whether to compute gradients. Defaults to
False
.tracer_kwargs (dict | None) – Additional arguments for the tracer. Defaults to
None
.
- Returns:
Initialized feature extractor.
- Return type:
GraphModule
- Raises:
TypeError – If weights format is invalid.
- anomalib.models.components.feature_extractors.dryrun_find_featuremap_dims(feature_extractor, input_size, layers)#
Get feature map dimensions by running an empty tensor through the model.
Performs a forward pass with an empty tensor to determine the output dimensions of specified feature maps.
- Parameters:
feature_extractor (
TimmFeatureExtractor
|GraphModule
) – Feature extraction model, either aTimmFeatureExtractor
orGraphModule
.input_size (
tuple
[int
,int
]) – Tuple of(height, width)
specifying input image dimensions.layers (
list
[str
]) – List of layer names from which to extract features.
- Return type:
- Returns:
Dictionary mapping layer names to dimension information. For each layer, returns a dictionary with:
num_features
: Number of feature channels (int)resolution
: Spatial dimensions as(height, width)
tuple
Example
>>> extractor = TimmFeatureExtractor("resnet18", layers=["layer1"]) >>> dims = dryrun_find_featuremap_dims( ... extractor, ... input_size=(256, 256), ... layers=["layer1"] ... ) >>> print(dims["layer1"]["num_features"]) # channels 64 >>> print(dims["layer1"]["resolution"]) # (height, width) (64, 64)