Source code for anomalib.models.components.feature_extractors.feature_extractor
"""Feature Extractor.
This script extracts features from a CNN network
"""
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import warnings
from typing import Dict, List
import timm
import torch
from torch import Tensor, nn
[docs]class FeatureExtractor(nn.Module):
"""Extract features from a CNN.
Args:
backbone (nn.Module): The backbone to which the feature extraction hooks are attached.
layers (Iterable[str]): List of layer names of the backbone to which the hooks are attached.
Example:
>>> import torch
>>> from anomalib.core.model.feature_extractor import FeatureExtractor
>>> model = FeatureExtractor(model="resnet18", layers=['layer1', 'layer2', 'layer3'])
>>> input = torch.rand((32, 3, 256, 256))
>>> features = model(input)
>>> [layer for layer in features.keys()]
['layer1', 'layer2', 'layer3']
>>> [feature.shape for feature in features.values()]
[torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])]
"""
def __init__(self, backbone: str, layers: List[str], pre_trained: bool = True):
super().__init__()
self.backbone = backbone
self.layers = layers
self.idx = self._map_layer_to_idx()
self.feature_extractor = timm.create_model(
backbone,
pretrained=pre_trained,
features_only=True,
exportable=True,
out_indices=self.idx,
)
self.out_dims = self.feature_extractor.feature_info.channels()
self._features = {layer: torch.empty(0) for layer in self.layers}
[docs] def _map_layer_to_idx(self, offset: int = 3) -> List[int]:
"""Maps set of layer names to indices of model.
Args:
offset (int) `timm` ignores the first few layers when indexing please update offset based on need
Returns:
Feature map extracted from the CNN
"""
idx = []
features = timm.create_model(
self.backbone,
pretrained=False,
features_only=False,
exportable=True,
)
for i in self.layers:
try:
idx.append(list(dict(features.named_children()).keys()).index(i) - offset)
except ValueError:
warnings.warn(f"Layer {i} not found in model {self.backbone}")
# Remove unfound key from layer dict
self.layers.remove(i)
return idx
[docs] def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]:
"""Forward-pass input tensor into the CNN.
Args:
input_tensor (Tensor): Input tensor
Returns:
Feature map extracted from the CNN
"""
features = dict(zip(self.layers, self.feature_extractor(input_tensor)))
return features