[docs]classTorchInferencer(Inferencer):"""PyTorch implementation for the inference. Args: config (Union[str, Path, DictConfig, ListConfig]): Configurable parameters that are used during the training stage. model_source (Union[str, Path, AnomalyModule]): Path to the model ckpt file or the Anomaly model. meta_data_path (Union[str, Path], optional): Path to metadata file. If none, it tries to load the params from the model state_dict. Defaults to None. """def__init__(self,config:Union[str,Path,DictConfig,ListConfig],model_source:Union[str,Path,AnomalyModule],meta_data_path:Union[str,Path]=None,):# Check and load the configurationifisinstance(config,(str,Path)):self.config=get_configurable_parameters(config_path=config)elifisinstance(config,(DictConfig,ListConfig)):self.config=configelse:raiseValueError(f"Unknown config type {type(config)}")# Check and load the model weights.ifisinstance(model_source,AnomalyModule):self.model=model_sourceelse:self.model=self.load_model(model_source)self.meta_data=self._load_meta_data(meta_data_path)
[docs]def_load_meta_data(self,path:Optional[Union[str,Path]]=None)->Union[Dict,DictConfig]:"""Load metadata from file or from model state dict. Args: path (Optional[Union[str, Path]], optional): Path to metadata file. If none, it tries to load the params from the model state_dict. Defaults to None. Returns: Dict: Dictionary containing the meta_data. """meta_data:Union[DictConfig,Dict[str,Union[float,Tensor,np.ndarray]]]ifpathisNone:meta_data=get_model_metadata(self.model)else:meta_data=super()._load_meta_data(path)returnmeta_data
[docs]defload_model(self,path:Union[str,Path])->AnomalyModule:"""Load the PyTorch model. Args: path (Union[str, Path]): Path to model ckpt file. Returns: (AnomalyModule): PyTorch Lightning model. """model=get_model(self.config)model.load_state_dict(torch.load(path)["state_dict"])model.eval()returnmodel
[docs]defpre_process(self,image:np.ndarray)->Tensor:"""Pre process the input image by applying transformations. Args: image (np.ndarray): Input image Returns: Tensor: pre-processed image. """transform_config=(self.config.dataset.transform_config.valif"transform_config"inself.config.dataset.keys()elseNone)image_size=tuple(self.config.dataset.image_size)pre_processor=PreProcessor(transform_config,image_size)processed_image=pre_processor(image=image)["image"]iflen(processed_image)==3:processed_image=processed_image.unsqueeze(0)returnprocessed_image
[docs]defforward(self,image:Tensor)->Tensor:"""Forward-Pass input tensor to the model. Args: image (Tensor): Input tensor. Returns: Tensor: Output predictions. """returnself.model(image)
[docs]defpost_process(self,predictions:Tensor,meta_data:Optional[Union[Dict,DictConfig]]=None)->Dict[str,Any]:"""Post process the output predictions. Args: predictions (Tensor): Raw output predicted by the model. meta_data (Dict, optional): Meta data. Post-processing step sometimes requires additional meta data such as image shape. This variable comprises such info. Defaults to None. Returns: Dict[str, Union[str, float, np.ndarray]]: Post processed prediction results. """ifmeta_dataisNone:meta_data=self.meta_dataifisinstance(predictions,Tensor):anomaly_map=predictions.detach().cpu().numpy()pred_score=anomaly_map.reshape(-1).max()else:# NOTE: Patchcore `forward`` returns heatmap and score.# We need to add the following check to ensure the variables# are properly assigned. Without this check, the code# throws an error regarding type mismatch torch vs np.ifisinstance(predictions[1],(Tensor)):anomaly_map,pred_score=predictionsanomaly_map=anomaly_map.detach().cpu().numpy()pred_score=pred_score.detach().cpu().numpy()else:anomaly_map,pred_score=predictionspred_score=pred_score.detach()# Common practice in anomaly detection is to assign anomalous# label to the prediction if the prediction score is greater# than the image threshold.pred_label:Optional[str]=Noneif"image_threshold"inmeta_data:pred_idx=pred_score>=meta_data["image_threshold"]pred_label="Anomalous"ifpred_idxelse"Normal"pred_mask:Optional[np.ndarray]=Noneif"pixel_threshold"inmeta_data:pred_mask=(anomaly_map>=meta_data["pixel_threshold"]).squeeze().astype(np.uint8)anomaly_map=anomaly_map.squeeze()anomaly_map,pred_score=self._normalize(anomaly_map,pred_score,meta_data)ifisinstance(anomaly_map,Tensor):anomaly_map=anomaly_map.detach().cpu().numpy()if"image_shape"inmeta_dataandanomaly_map.shape!=meta_data["image_shape"]:image_height=meta_data["image_shape"][0]image_width=meta_data["image_shape"][1]anomaly_map=cv2.resize(anomaly_map,(image_width,image_height))ifpred_maskisnotNone:pred_mask=cv2.resize(pred_mask,(image_width,image_height))return{"anomaly_map":anomaly_map,"pred_label":pred_label,"pred_score":pred_score,"pred_mask":pred_mask,