Source code for anomalib.models.draem.utils.augmenter

"""Augmenter module to generates out-of-distribution samples for the DRAEM implementation."""

# Original Code
# Copyright (c) 2021 VitjanZ
# https://github.com/VitjanZ/DRAEM.
# SPDX-License-Identifier: MIT
#
# Modified
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import glob
import random
from typing import Optional, Tuple

import cv2
import imgaug.augmenters as iaa
import numpy as np
import torch
from torch import Tensor
from torchvision.datasets.folder import IMG_EXTENSIONS

from anomalib.models.draem.utils.perlin import rand_perlin_2d_np


[docs]class Augmenter: """Class that generates noisy augmentations of input images. Args: anomaly_source_path (Optional[str]): Path to a folder of images that will be used as source of the anomalous noise. If not specified, random noise will be used instead. """ def __init__(self, anomaly_source_path: Optional[str] = None): self.anomaly_source_paths = [] if anomaly_source_path is not None: for img_ext in IMG_EXTENSIONS: self.anomaly_source_paths.extend(glob.glob(anomaly_source_path + "/**/*" + img_ext, recursive=True)) self.augmenters = [ iaa.GammaContrast((0.5, 2.0), per_channel=True), iaa.MultiplyAndAddToBrightness(mul=(0.8, 1.2), add=(-30, 30)), iaa.pillike.EnhanceSharpness(), iaa.AddToHueAndSaturation((-50, 50), per_channel=True), iaa.Solarize(0.5, threshold=(32, 128)), iaa.Posterize(), iaa.Invert(), iaa.pillike.Autocontrast(), iaa.pillike.Equalize(), iaa.Affine(rotate=(-45, 45)), ] self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])
[docs] def rand_augmenter(self) -> iaa.Sequential: """Selects 3 random transforms that will be applied to the anomaly source images. Returns: A selection of 3 transforms. """ aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False) aug = iaa.Sequential([self.augmenters[aug_ind[0]], self.augmenters[aug_ind[1]], self.augmenters[aug_ind[2]]]) return aug
[docs] def generate_perturbation( self, height: int, width: int, anomaly_source_path: Optional[str] ) -> Tuple[np.ndarray, np.ndarray]: """Generate an image containing a random anomalous perturbation using a source image. Args: height (int): height of the generated image. width: (int): width of the generated image. anomaly_source_path (Optional[str]): Path to an image file. If not provided, random noise will be used instead. Returns: Image containing a random anomalous perturbation, and the corresponding ground truth anomaly mask. """ # Generate random perlin noise perlin_scale = 6 min_perlin_scale = 0 perlin_scalex = 2 ** random.randint(min_perlin_scale, perlin_scale) perlin_scaley = 2 ** random.randint(min_perlin_scale, perlin_scale) perlin_noise = rand_perlin_2d_np((height, width), (perlin_scalex, perlin_scaley)) perlin_noise = self.rot(image=perlin_noise) # Create mask from perlin noise mask = np.where(perlin_noise > 0.5, np.ones_like(perlin_noise), np.zeros_like(perlin_noise)) mask = np.expand_dims(mask, axis=2).astype(np.float32) # Load anomaly source image if anomaly_source_path: anomaly_source_img = cv2.imread(anomaly_source_path) anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(width, height)) else: # if no anomaly source is specified, we use the perlin noise as anomalous source anomaly_source_img = np.expand_dims(perlin_noise, 2).repeat(3, 2) anomaly_source_img = (anomaly_source_img * 255).astype(np.uint8) # Augment anomaly source image aug = self.rand_augmenter() anomaly_img_augmented = aug(image=anomaly_source_img) # Create anomalous perturbation that we will apply to the image perturbation = anomaly_img_augmented.astype(np.float32) * mask / 255.0 return perturbation, mask
[docs] def augment_batch(self, batch: Tensor) -> Tuple[Tensor, Tensor]: """Generate anomalous augmentations for a batch of input images. Args: batch (Tensor): Batch of input images Returns: - Augmented image to which anomalous perturbations have been added. - Ground truth masks corresponding to the anomalous perturbations. """ batch_size, channels, height, width = batch.shape # Collect perturbations perturbations_list = [] masks_list = [] for _ in range(batch_size): if random.random() > 0.5: # include 50% normal samples perturbations_list.append(torch.zeros((channels, height, width))) masks_list.append(torch.zeros((1, height, width))) else: anomaly_source_path = ( random.sample(self.anomaly_source_paths, 1)[0] if len(self.anomaly_source_paths) > 0 else None ) perturbation, mask = self.generate_perturbation(height, width, anomaly_source_path) perturbations_list.append(Tensor(perturbation).permute((2, 0, 1))) masks_list.append(Tensor(mask).permute((2, 0, 1))) perturbations = torch.stack(perturbations_list).to(batch.device) masks = torch.stack(masks_list).to(batch.device) # Apply perturbations batch wise beta = torch.rand(batch_size) * 0.8 beta = beta.view(batch_size, 1, 1, 1).expand_as(batch).to(batch.device) augmented_batch = batch * (1 - masks) + (1 - beta) * perturbations + beta * batch * (masks) return augmented_batch, masks