"""Inference Dataset."""# Copyright (C) 2020 Intel Corporation## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing,# software distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions# and limitations under the License.frompathlibimportPathfromtypingimportAny,Optional,Tuple,UnionimportalbumentationsasAfromtorch.utils.data.datasetimportDatasetfromanomalib.data.utilsimportget_image_filenames,read_imagefromanomalib.pre_processingimportPreProcessor
[docs]classInferenceDataset(Dataset):"""Inference Dataset to perform prediction."""def__init__(self,path:Union[str,Path],pre_process:Optional[PreProcessor]=None,image_size:Optional[Union[int,Tuple[int,int]]]=None,transform_config:Optional[Union[str,A.Compose]]=None,)->None:"""Inference Dataset to perform prediction. Args: path (Union[str, Path]): Path to an image or image-folder. pre_process (Optional[PreProcessor], optional): Pre-Processing transforms to pre-process the input dataset. Defaults to None. image_size (Optional[Union[int, Tuple[int, int]]], optional): Target image size to resize the original image. Defaults to None. transform_config (Optional[Union[str, A.Compose]], optional): Configuration file parse the albumentation transforms. Defaults to None. """super().__init__()self.image_filenames=get_image_filenames(path)ifpre_processisNone:self.pre_process=PreProcessor(transform_config,image_size)else:self.pre_process=pre_process
[docs]def__len__(self)->int:"""Get the number of images in the given path."""returnlen(self.image_filenames)
[docs]def__getitem__(self,index:int)->Any:"""Get the image based on the `index`."""image_filename=self.image_filenames[index]image=read_image(path=image_filename)pre_processed=self.pre_process(image=image)pre_processed["image_path"]=str(image_filename)returnpre_processed