Source code for anomalib.models.patchcore.lightning_model
"""Towards Total Recall in Industrial Anomaly Detection.Paper https://arxiv.org/abs/2106.08265."""# Copyright (C) 2020 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions# and limitations under the License.importloggingfromtypingimportList,Tuple,UnionimporttorchfromomegaconfimportDictConfig,ListConfigfrompytorch_lightning.utilities.cliimportMODEL_REGISTRYfromtorchimportTensorfromanomalib.models.componentsimportAnomalyModulefromanomalib.models.patchcore.torch_modelimportPatchcoreModel
[docs]classPatchcore(AnomalyModule):"""PatchcoreLightning Module to train PatchCore algorithm. Args: input_size (Tuple[int, int]): Size of the model input. backbone (str): Backbone CNN network layers (List[str]): Layers to extract features from the backbone CNN pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. coreset_sampling_ratio (float, optional): Coreset sampling ratio to subsample embedding. Defaults to 0.1. num_neighbors (int, optional): Number of nearest neighbors. Defaults to 9. """def__init__(self,input_size:Tuple[int,int],backbone:str,layers:List[str],pre_trained:bool=True,coreset_sampling_ratio:float=0.1,num_neighbors:int=9,)->None:super().__init__()self.model:PatchcoreModel=PatchcoreModel(input_size=input_size,backbone=backbone,pre_trained=pre_trained,layers=layers,num_neighbors=num_neighbors,)self.coreset_sampling_ratio=coreset_sampling_ratioself.embeddings:List[Tensor]=[]
[docs]defconfigure_optimizers(self)->None:"""Configure optimizers. Returns: None: Do not set optimizers by returning None. """returnNone
[docs]deftraining_step(self,batch,_batch_idx):# pylint: disable=arguments-differ"""Generate feature embedding of the batch. Args: batch (Dict[str, Any]): Batch containing image filename, image, label and mask _batch_idx (int): Batch Index Returns: Dict[str, np.ndarray]: Embedding Vector """self.model.feature_extractor.eval()embedding=self.model(batch["image"])# NOTE: `self.embedding` appends each batch embedding to# store the training set embedding. We manually append these# values mainly due to the new order of hooks introduced after PL v1.4.0# https://github.com/PyTorchLightning/pytorch-lightning/pull/7357self.embeddings.append(embedding)
[docs]defon_validation_start(self)->None:"""Apply subsampling to the embedding collected from the training set."""# NOTE: Previous anomalib versions fit subsampling at the end of the epoch.# This is not possible anymore with PyTorch Lightning v1.4.0 since validation# is run within train epoch.logger.info("Aggregating the embedding extracted from the training set.")embeddings=torch.vstack(self.embeddings)logger.info("Applying core-set subsampling to get the embedding.")self.model.subsample_embedding(embeddings,self.coreset_sampling_ratio)
[docs]defvalidation_step(self,batch,_):# pylint: disable=arguments-differ"""Get batch of anomaly maps from input image batch. Args: batch (Dict[str, Any]): Batch containing image filename, image, label and mask _ (int): Batch Index Returns: Dict[str, Any]: Image filenames, test images, GT and predicted label/masks """anomaly_maps,anomaly_score=self.model(batch["image"])batch["anomaly_maps"]=anomaly_mapsbatch["pred_scores"]=anomaly_score.unsqueeze(0)returnbatch
[docs]classPatchcoreLightning(Patchcore):"""PatchcoreLightning Module to train PatchCore algorithm. Args: hparams (Union[DictConfig, ListConfig]): Model params """def__init__(self,hparams)->None:super().__init__(input_size=hparams.model.input_size,backbone=hparams.model.backbone,layers=hparams.model.layers,coreset_sampling_ratio=hparams.model.coreset_sampling_ratio,num_neighbors=hparams.model.num_neighbors,)self.hparams:Union[DictConfig,ListConfig]# type: ignoreself.save_hyperparameters(hparams)