"""Anomaly Map Generator for the STFPM model implementation."""# Copyright (C) 2022 Intel Corporation# SPDX-License-Identifier: Apache-2.0fromtypingimportDict,Tuple,Unionimporttorchimporttorch.nn.functionalasFfromomegaconfimportListConfigfromtorchimportTensor,nn
[docs]defcompute_layer_map(self,teacher_features:Tensor,student_features:Tensor)->Tensor:"""Compute the layer map based on cosine similarity. Args: teacher_features (Tensor): Teacher features student_features (Tensor): Student features Returns: Anomaly score based on cosine similarity. """norm_teacher_features=F.normalize(teacher_features)norm_student_features=F.normalize(student_features)layer_map=0.5*torch.norm(norm_teacher_features-norm_student_features,p=2,dim=-3,keepdim=True)**2layer_map=F.interpolate(layer_map,size=self.image_size,align_corners=False,mode="bilinear")returnlayer_map
[docs]defcompute_anomaly_map(self,teacher_features:Dict[str,Tensor],student_features:Dict[str,Tensor])->torch.Tensor:"""Compute the overall anomaly map via element-wise production the interpolated anomaly maps. Args: teacher_features (Dict[str, Tensor]): Teacher features student_features (Dict[str, Tensor]): Student features Returns: Final anomaly map """batch_size=list(teacher_features.values())[0].shape[0]anomaly_map=torch.ones(batch_size,1,self.image_size[0],self.image_size[1])forlayerinteacher_features.keys():layer_map=self.compute_layer_map(teacher_features[layer],student_features[layer])anomaly_map=anomaly_map.to(layer_map.device)anomaly_map*=layer_mapreturnanomaly_map
[docs]defforward(self,**kwargs:Dict[str,Tensor])->torch.Tensor:"""Returns anomaly map. Expects `teach_features` and `student_features` keywords to be passed explicitly. Example: >>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size)) >>> output = self.anomaly_map_generator( teacher_features=teacher_features, student_features=student_features ) Raises: ValueError: `teach_features` and `student_features` keys are not found Returns: torch.Tensor: anomaly map """ifnot("teacher_features"inkwargsand"student_features"inkwargs):raiseValueError(f"Expected keys `teacher_features` and `student_features. Found {kwargs.keys()}")teacher_features:Dict[str,Tensor]=kwargs["teacher_features"]student_features:Dict[str,Tensor]=kwargs["student_features"]returnself.compute_anomaly_map(teacher_features,student_features)