Source code for anomalib.models.components.feature_extractors.feature_extractor
"""Feature Extractor.
This script extracts features from a CNN network
"""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from typing import Callable, Dict, Iterable
import torch
from torch import Tensor, nn
[docs]class FeatureExtractor(nn.Module):
"""Extract features from a CNN.
Args:
backbone (nn.Module): The backbone to which the feature extraction hooks are attached.
layers (Iterable[str]): List of layer names of the backbone to which the hooks are attached.
Example:
>>> import torch
>>> import torchvision
>>> from anomalib.core.model.feature_extractor import FeatureExtractor
>>> model = FeatureExtractor(model=torchvision.models.resnet18(), layers=['layer1', 'layer2', 'layer3'])
>>> input = torch.rand((32, 3, 256, 256))
>>> features = model(input)
>>> [layer for layer in features.keys()]
['layer1', 'layer2', 'layer3']
>>> [feature.shape for feature in features.values()]
[torch.Size([32, 64, 64, 64]), torch.Size([32, 128, 32, 32]), torch.Size([32, 256, 16, 16])]
"""
def __init__(self, backbone: nn.Module, layers: Iterable[str]):
super().__init__()
self.backbone = backbone
self.layers = layers
self._features = {layer: torch.empty(0) for layer in self.layers}
self.out_dims = []
for layer_id in layers:
layer = dict([*self.backbone.named_modules()])[layer_id]
layer.register_forward_hook(self.get_features(layer_id))
# get output dimension of features if available
layer_modules = [*layer.modules()]
for idx in reversed(range(len(layer_modules))):
if hasattr(layer_modules[idx], "out_channels"):
self.out_dims.append(layer_modules[idx].out_channels)
break
[docs] def get_features(self, layer_id: str) -> Callable:
"""Get layer features.
Args:
layer_id (str): Layer ID
Returns:
Layer features
"""
def hook(_, __, output):
"""Hook to extract features via a forward-pass.
Args:
output: Feature map collected after the forward-pass.
"""
self._features[layer_id] = output
return hook
[docs] def forward(self, input_tensor: Tensor) -> Dict[str, Tensor]:
"""Forward-pass input tensor into the CNN.
Args:
input_tensor (Tensor): Input tensor
Returns:
Feature map extracted from the CNN
"""
self._features = {layer: torch.empty(0) for layer in self.layers}
_ = self.backbone(input_tensor)
return self._features