"""PyTorch model for the PaDiM model implementation."""# Copyright (C) 2020 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions# and limitations under the License.fromrandomimportsamplefromtypingimportDict,List,Optional,Tupleimporttorchimporttorch.nn.functionalasFimporttorchvisionfromtorchimportTensor,nnfromanomalib.models.componentsimportFeatureExtractor,MultiVariateGaussianfromanomalib.models.padim.anomaly_mapimportAnomalyMapGeneratorfromanomalib.pre_processingimportTiler
[docs]classPadimModel(nn.Module):"""Padim Module. Args: input_size (Tuple[int, int]): Input size for the model. layers (List[str]): Layers used for feature extraction backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". """def__init__(self,input_size:Tuple[int,int],layers:List[str],backbone:str="resnet18",):super().__init__()self.tiler:Optional[Tiler]=Noneself.backbone=getattr(torchvision.models,backbone)self.layers=layersself.feature_extractor=FeatureExtractor(backbone=self.backbone(pretrained=True),layers=self.layers)self.dims=DIMS[backbone]# pylint: disable=not-callable# Since idx is randomly selected, save it with model to get same resultsself.register_buffer("idx",torch.tensor(sample(range(0,DIMS[backbone]["orig_dims"]),DIMS[backbone]["reduced_dims"])),)self.idx:Tensorself.loss=Noneself.anomaly_map_generator=AnomalyMapGenerator(image_size=input_size)n_features=DIMS[backbone]["reduced_dims"]patches_dims=torch.tensor(input_size)/DIMS[backbone]["emb_scale"]n_patches=patches_dims.ceil().prod().int().item()self.gaussian=MultiVariateGaussian(n_features,n_patches)
[docs]defforward(self,input_tensor:Tensor)->Tensor:"""Forward-pass image-batch (N, C, H, W) into model to extract features. Args: input_tensor: Image-batch (N, C, H, W) input_tensor: Tensor: Returns: Features from single/multiple layers. Example: >>> x = torch.randn(32, 3, 224, 224) >>> features = self.extract_features(input_tensor) >>> features.keys() dict_keys(['layer1', 'layer2', 'layer3']) >>> [v.shape for v in features.values()] [torch.Size([32, 64, 56, 56]), torch.Size([32, 128, 28, 28]), torch.Size([32, 256, 14, 14])] """ifself.tiler:input_tensor=self.tiler.tile(input_tensor)withtorch.no_grad():features=self.feature_extractor(input_tensor)embeddings=self.generate_embedding(features)ifself.tiler:embeddings=self.tiler.untile(embeddings)ifself.training:output=embeddingselse:output=self.anomaly_map_generator(embedding=embeddings,mean=self.gaussian.mean,inv_covariance=self.gaussian.inv_covariance)returnoutput
[docs]defgenerate_embedding(self,features:Dict[str,Tensor])->Tensor:"""Generate embedding from hierarchical feature map. Args: features (Dict[str, Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet) Returns: Embedding vector """embeddings=features[self.layers[0]]forlayerinself.layers[1:]:layer_embedding=features[layer]layer_embedding=F.interpolate(layer_embedding,size=embeddings.shape[-2:],mode="nearest")embeddings=torch.cat((embeddings,layer_embedding),1)# subsample embeddingsidx=self.idx.to(embeddings.device)embeddings=torch.index_select(embeddings,1,idx)returnembeddings