"""PyTorch model for the PaDiM model implementation."""# Copyright (C) 2022 Intel Corporation# SPDX-License-Identifier: Apache-2.0fromrandomimportsamplefromtypingimportDict,List,Optional,Tupleimporttorchimporttorch.nn.functionalasFfromtorchimportTensor,nnfromanomalib.models.componentsimportFeatureExtractor,MultiVariateGaussianfromanomalib.models.padim.anomaly_mapimportAnomalyMapGeneratorfromanomalib.pre_processingimportTiler
[docs]classPadimModel(nn.Module):"""Padim Module. Args: input_size (Tuple[int, int]): Input size for the model. layers (List[str]): Layers used for feature extraction backbone (str, optional): Pre-trained model backbone. Defaults to "resnet18". pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. """def__init__(self,input_size:Tuple[int,int],layers:List[str],backbone:str="resnet18",pre_trained:bool=True,):super().__init__()self.tiler:Optional[Tiler]=Noneself.backbone=backboneself.layers=layersself.feature_extractor=FeatureExtractor(backbone=self.backbone,layers=layers,pre_trained=pre_trained)self.dims=DIMS[backbone]# pylint: disable=not-callable# Since idx is randomly selected, save it with model to get same resultsself.register_buffer("idx",torch.tensor(sample(range(0,DIMS[backbone]["orig_dims"]),DIMS[backbone]["reduced_dims"])),)self.idx:Tensorself.loss=Noneself.anomaly_map_generator=AnomalyMapGenerator(image_size=input_size)n_features=DIMS[backbone]["reduced_dims"]patches_dims=torch.tensor(input_size)/DIMS[backbone]["emb_scale"]n_patches=patches_dims.ceil().prod().int().item()self.gaussian=MultiVariateGaussian(n_features,n_patches)
[docs]defforward(self,input_tensor:Tensor)->Tensor:"""Forward-pass image-batch (N, C, H, W) into model to extract features. Args: input_tensor: Image-batch (N, C, H, W) input_tensor: Tensor: Returns: Features from single/multiple layers. Example: >>> x = torch.randn(32, 3, 224, 224) >>> features = self.extract_features(input_tensor) >>> features.keys() dict_keys(['layer1', 'layer2', 'layer3']) >>> [v.shape for v in features.values()] [torch.Size([32, 64, 56, 56]), torch.Size([32, 128, 28, 28]), torch.Size([32, 256, 14, 14])] """ifself.tiler:input_tensor=self.tiler.tile(input_tensor)withtorch.no_grad():features=self.feature_extractor(input_tensor)embeddings=self.generate_embedding(features)ifself.tiler:embeddings=self.tiler.untile(embeddings)ifself.training:output=embeddingselse:output=self.anomaly_map_generator(embedding=embeddings,mean=self.gaussian.mean,inv_covariance=self.gaussian.inv_covariance)returnoutput
[docs]defgenerate_embedding(self,features:Dict[str,Tensor])->Tensor:"""Generate embedding from hierarchical feature map. Args: features (Dict[str, Tensor]): Hierarchical feature map from a CNN (ResNet18 or WideResnet) Returns: Embedding vector """embeddings=features[self.layers[0]]forlayerinself.layers[1:]:layer_embedding=features[layer]layer_embedding=F.interpolate(layer_embedding,size=embeddings.shape[-2:],mode="nearest")embeddings=torch.cat((embeddings,layer_embedding),1)# subsample embeddingsidx=self.idx.to(embeddings.device)embeddings=torch.index_select(embeddings,1,idx)returnembeddings