"""Utilities for optimization and OpenVINO conversion."""# Copyright (C) 2022 Intel Corporation# SPDX-License-Identifier: Apache-2.0importjsonimportosfrompathlibimportPathfromtypingimportDict,List,Optional,Tuple,UnionimportnumpyasnpimporttorchfromtorchimportTensorfromtorch.typesimportNumberfromanomalib.models.componentsimportAnomalyModule
[docs]defget_model_metadata(model:AnomalyModule)->Dict[str,Tensor]:"""Get meta data related to normalization from model. Args: model (AnomalyModule): Anomaly model which contains metadata related to normalization. Returns: Dict[str, Tensor]: metadata """meta_data={}cached_meta_data:Dict[str,Union[Number,Tensor]]={"image_threshold":model.image_threshold.cpu().value.item(),"pixel_threshold":model.pixel_threshold.cpu().value.item(),}ifhasattr(model,"normalization_metrics")andmodel.normalization_metrics.state_dict()isnotNone:forkey,valueinmodel.normalization_metrics.state_dict().items():cached_meta_data[key]=value.cpu()# Remove undefined values by copying in a new dictforkey,valincached_meta_data.items():ifnotnp.isinf(val).all():meta_data[key]=valdelcached_meta_datareturnmeta_data
[docs]defexport_convert(model:AnomalyModule,input_size:Union[List[int],Tuple[int,int]],export_mode:str,export_path:Optional[Union[str,Path]]=None,):"""Export the model to onnx format and convert to OpenVINO IR. Args: model (AnomalyModule): Model to convert. input_size (Union[List[int], Tuple[int, int]]): Image size used as the input for onnx converter. export_path (Union[str, Path]): Path to exported OpenVINO IR. export_mode (str): Mode to export onnx or openvino """height,width=input_sizeonnx_path=os.path.join(str(export_path),"model.onnx")torch.onnx.export(model.model,torch.zeros((1,3,height,width)).to(model.device),onnx_path,opset_version=11,input_names=["input"],output_names=["output"],)ifexport_mode=="openvino":export_path=os.path.join(str(export_path),"openvino")optimize_command="mo --input_model "+str(onnx_path)+" --output_dir "+str(export_path)os.system(optimize_command)withopen(Path(export_path)/"meta_data.json","w",encoding="utf-8")asmetadata_file:meta_data=get_model_metadata(model)# Convert metadata from torchforkey,valueinmeta_data.items():ifisinstance(value,Tensor):meta_data[key]=value.numpy().tolist()json.dump(meta_data,metadata_file,ensure_ascii=False,indent=4)