Source code for anomalib.models.patchcore.torch_model
"""PyTorch model for the PatchCore model implementation."""# Copyright (C) 2022 Intel Corporation# SPDX-License-Identifier: Apache-2.0fromtypingimportDict,List,Optional,Tuple,Unionimporttorchimporttorch.nn.functionalasFfromtorchimportTensor,nnfromanomalib.models.componentsimport(DynamicBufferModule,FeatureExtractor,KCenterGreedy,)fromanomalib.models.patchcore.anomaly_mapimportAnomalyMapGeneratorfromanomalib.pre_processingimportTiler
[docs]defforward(self,input_tensor:Tensor)->Union[torch.Tensor,Tuple[torch.Tensor,torch.Tensor]]:"""Return Embedding during training, or a tuple of anomaly map and anomaly score during testing. Steps performed: 1. Get features from a CNN. 2. Generate embedding based on the features. 3. Compute anomaly map in test mode. Args: input_tensor (Tensor): Input tensor Returns: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Embedding for training, anomaly map and anomaly score for testing. """ifself.tiler:input_tensor=self.tiler.tile(input_tensor)withtorch.no_grad():features=self.feature_extractor(input_tensor)features={layer:self.feature_pooler(feature)forlayer,featureinfeatures.items()}embedding=self.generate_embedding(features)ifself.tiler:embedding=self.tiler.untile(embedding)feature_map_shape=embedding.shape[-2:]embedding=self.reshape_embedding(embedding)ifself.training:output=embeddingelse:patch_scores=self.nearest_neighbors(embedding=embedding,n_neighbors=self.num_neighbors)anomaly_map,anomaly_score=self.anomaly_map_generator(patch_scores=patch_scores,feature_map_shape=feature_map_shape)output=(anomaly_map,anomaly_score)returnoutput
[docs]defgenerate_embedding(self,features:Dict[str,Tensor])->torch.Tensor:"""Generate embedding from hierarchical feature map. Args: features: Hierarchical feature map from a CNN (ResNet18 or WideResnet) features: Dict[str:Tensor]: Returns: Embedding vector """embeddings=features[self.layers[0]]forlayerinself.layers[1:]:layer_embedding=features[layer]layer_embedding=F.interpolate(layer_embedding,size=embeddings.shape[-2:],mode="nearest")embeddings=torch.cat((embeddings,layer_embedding),1)returnembeddings
@staticmethod
[docs]defreshape_embedding(embedding:Tensor)->Tensor:"""Reshape Embedding. Reshapes Embedding to the following format: [Batch, Embedding, Patch, Patch] to [Batch*Patch*Patch, Embedding] Args: embedding (Tensor): Embedding tensor extracted from CNN features. Returns: Tensor: Reshaped embedding tensor. """embedding_size=embedding.size(1)embedding=embedding.permute(0,2,3,1).reshape(-1,embedding_size)returnembedding
[docs]defsubsample_embedding(self,embedding:torch.Tensor,sampling_ratio:float)->None:"""Subsample embedding based on coreset sampling and store to memory. Args: embedding (np.ndarray): Embedding tensor from the CNN sampling_ratio (float): Coreset sampling ratio """# Coreset Subsamplingsampler=KCenterGreedy(embedding=embedding,sampling_ratio=sampling_ratio)coreset=sampler.sample_coreset()self.memory_bank=coreset
[docs]defnearest_neighbors(self,embedding:Tensor,n_neighbors:int=9)->Tensor:"""Nearest Neighbours using brute force method and euclidean norm. Args: embedding (Tensor): Features to compare the distance with the memory bank. n_neighbors (int): Number of neighbors to look at Returns: Tensor: Patch scores. """distances=torch.cdist(embedding,self.memory_bank,p=2.0)# euclidean normpatch_scores,_=distances.topk(k=n_neighbors,largest=False,dim=1)returnpatch_scores