Source code for anomalib.models.draem.utils.augmenter
"""Augmenter module to generates out-of-distribution samples for the DRAEM implementation."""# Original Code# Copyright (c) 2021 VitjanZ# https://github.com/VitjanZ/DRAEM.# SPDX-License-Identifier: MIT## Modified# Copyright (C) 2022 Intel Corporation# SPDX-License-Identifier: Apache-2.0importglobimportrandomfromtypingimportOptional,Tupleimportcv2importimgaug.augmentersasiaaimportnumpyasnpimporttorchfromtorchimportTensorfromtorchvision.datasets.folderimportIMG_EXTENSIONSfromanomalib.models.draem.utils.perlinimportrand_perlin_2d_np
[docs]classAugmenter:"""Class that generates noisy augmentations of input images. Args: anomaly_source_path (Optional[str]): Path to a folder of images that will be used as source of the anomalous noise. If not specified, random noise will be used instead. """def__init__(self,anomaly_source_path:Optional[str]=None):self.anomaly_source_paths=[]ifanomaly_source_pathisnotNone:forimg_extinIMG_EXTENSIONS:self.anomaly_source_paths.extend(glob.glob(anomaly_source_path+"/**/*"+img_ext,recursive=True))self.augmenters=[iaa.GammaContrast((0.5,2.0),per_channel=True),iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)),iaa.pillike.EnhanceSharpness(),iaa.AddToHueAndSaturation((-50,50),per_channel=True),iaa.Solarize(0.5,threshold=(32,128)),iaa.Posterize(),iaa.Invert(),iaa.pillike.Autocontrast(),iaa.pillike.Equalize(),iaa.Affine(rotate=(-45,45)),]self.rot=iaa.Sequential([iaa.Affine(rotate=(-90,90))])
[docs]defrand_augmenter(self)->iaa.Sequential:"""Selects 3 random transforms that will be applied to the anomaly source images. Returns: A selection of 3 transforms. """aug_ind=np.random.choice(np.arange(len(self.augmenters)),3,replace=False)aug=iaa.Sequential([self.augmenters[aug_ind[0]],self.augmenters[aug_ind[1]],self.augmenters[aug_ind[2]]])returnaug
[docs]defgenerate_perturbation(self,height:int,width:int,anomaly_source_path:Optional[str])->Tuple[np.ndarray,np.ndarray]:"""Generate an image containing a random anomalous perturbation using a source image. Args: height (int): height of the generated image. width: (int): width of the generated image. anomaly_source_path (Optional[str]): Path to an image file. If not provided, random noise will be used instead. Returns: Image containing a random anomalous perturbation, and the corresponding ground truth anomaly mask. """# Generate random perlin noiseperlin_scale=6min_perlin_scale=0perlin_scalex=2**random.randint(min_perlin_scale,perlin_scale)perlin_scaley=2**random.randint(min_perlin_scale,perlin_scale)perlin_noise=rand_perlin_2d_np((height,width),(perlin_scalex,perlin_scaley))perlin_noise=self.rot(image=perlin_noise)# Create mask from perlin noisemask=np.where(perlin_noise>0.5,np.ones_like(perlin_noise),np.zeros_like(perlin_noise))mask=np.expand_dims(mask,axis=2).astype(np.float32)# Load anomaly source imageifanomaly_source_path:anomaly_source_img=cv2.imread(anomaly_source_path)anomaly_source_img=cv2.resize(anomaly_source_img,dsize=(width,height))else:# if no anomaly source is specified, we use the perlin noise as anomalous sourceanomaly_source_img=np.expand_dims(perlin_noise,2).repeat(3,2)anomaly_source_img=(anomaly_source_img*255).astype(np.uint8)# Augment anomaly source imageaug=self.rand_augmenter()anomaly_img_augmented=aug(image=anomaly_source_img)# Create anomalous perturbation that we will apply to the imageperturbation=anomaly_img_augmented.astype(np.float32)*mask/255.0returnperturbation,mask
[docs]defaugment_batch(self,batch:Tensor)->Tuple[Tensor,Tensor]:"""Generate anomalous augmentations for a batch of input images. Args: batch (Tensor): Batch of input images Returns: - Augmented image to which anomalous perturbations have been added. - Ground truth masks corresponding to the anomalous perturbations. """batch_size,channels,height,width=batch.shape# Collect perturbationsperturbations_list=[]masks_list=[]for_inrange(batch_size):ifrandom.random()>0.5:# include 50% normal samplesperturbations_list.append(torch.zeros((channels,height,width)))masks_list.append(torch.zeros((1,height,width)))else:anomaly_source_path=(random.sample(self.anomaly_source_paths,1)[0]iflen(self.anomaly_source_paths)>0elseNone)perturbation,mask=self.generate_perturbation(height,width,anomaly_source_path)perturbations_list.append(Tensor(perturbation).permute((2,0,1)))masks_list.append(Tensor(mask).permute((2,0,1)))perturbations=torch.stack(perturbations_list).to(batch.device)masks=torch.stack(masks_list).to(batch.device)# Apply perturbations batch wisebeta=torch.rand(batch_size)*0.8beta=beta.view(batch_size,1,1,1).expand_as(batch).to(batch.device)augmented_batch=batch*(1-masks)+(1-beta)*perturbations+beta*batch*(masks)returnaugmented_batch,masks