Data Transforms#

This tutorial will show how Anomalib applies transforms to the input images, and how these transforms can be configured. Anomalib uses the Torchvision Transforms v2 API to apply transforms to the input images.

Common transforms are the Resize transform, which is used to resize the input images to a fixed width and height, and the Normalize transform, which normalizes the pixel values of the input images to a pre-determined range. The normalization statistics are usually chosen to correspond to the pre-training characteristics of the model’s backbone. For example, when the backbone of the model was pre-trained on ImageNet dataset, it is usually recommended to normalize the model’s input images to the mean and standard deviation of the pixel values of ImageNet. In addition, there are many other transforms which could be useful to achieve the desired pre-processing of the input images and to apply data augmentations during training.

Using custom transforms for training and evaluation#

When we create a new datamodule, it will not be equipped with any transforms by default. When we load an image from the datamodule, it will have the same shape and pixel values as the original image from the file system.

from anomalib.data import MVTec

datamodule = MVTec()
datamodule.prepare_data()
datamodule.setup()

next(iter(datamodule.train_data))["image"].shape
# torch.Size([3, 900, 900])
next(iter(datamodule.test_data))["image"].shape
# torch.Size([3, 900, 900])

Now let’s create another datamodule, this time passing a simple resize transform to the datamodule using the transform argument.

from torchvision.transforms.v2 import Resize

transform = Resize((256, 256))
datamodule = MVTec(transform=transform)

datamodule.prepare_data()
datamodule.setup()

datamodule.train_transform
# Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=warn)
datamodule.eval_transform
# Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=warn)

next(iter(datamodule.train_data))["image"].shape
# torch.Size([3, 256, 256])
next(iter(datamodule.test_data))["image"].shape
# torch.Size([3, 256, 256])

In the CLI, we can specify a custom transforms by providing the class path and init args of the Torchvision transforms class:

class_path: anomalib.data.MVTec
init_args:
  root: ./datasets/MVTec
  category: bottle
  image_size: [256, 256]
  train_batch_size: 32
  eval_batch_size: 32
  num_workers: 8
  task: segmentation
  test_split_mode: from_dir
  test_split_ratio: 0.2
  val_split_mode: same_as_test
  val_split_ratio: 0.5
  seed: null
  transform:
    - class_path: torchvision.transforms.v2.Resize
      init_args:
        size: [256, 256]

As we can see, the datamodule now applies the custom transform when loading the images, resizing both training and test data to the specified shape.

In the above example, we used the transform argument to assign a single set of transforms to be used both in the training and in the evaluation subsets. In some cases, we might want to apply distinct sets of transforms between training and evaluation. This can be useful, for example, when we want to apply random data augmentations during training to improve generalization of our model. Using different transforms for training and evaluation can be done easily by specifying different values for the train_transform and eval_transform arguments. The train transforms will be applied to the images in the training subset, while the eval transforms will be applied to images in the validation, testing and prediction subsets.

from torchvision.transforms.v2 import Compose, RandomAdjustSharpness, RandomHorizontalFlip, Resize

train_transform = Compose(
    [
        RandomAdjustSharpness(sharpness_factor=0.7, p=0.5),
        RandomHorizontalFlip(p=0.5),
        Resize((256, 256), antialias=True),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ],
)
eval_transform = Compose(
    [
        Resize((256, 256), antialias=True),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ],
)

datamodule = MVTec(train_transform=train_transform, eval_transform=eval_transform)
datamodule.prepare_data()
datamodule.setup()

datamodule.train_transform
# Compose(
#       RandomAdjustSharpness(p=0.5, sharpness_factor=0.7)
#       RandomHorizontalFlip(p=0.5)
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )
datamodule.eval_transform
# Compose(
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )

train_transform and eval_transform can also be set separately from CLI. Note that the CLI also supports stacking multiple transforms using a Compose object.

class_path: anomalib.data.MVTec
init_args:
  root: ./datasets/MVTec
  category: bottle
  train_batch_size: 32
  eval_batch_size: 32
  num_workers: 8
  task: segmentation
  test_split_mode: from_dir
  test_split_ratio: 0.2
  val_split_mode: same_as_test
  val_split_ratio: 0.5
  seed: null
  train_transform:
    class_path: torchvision.transforms.v2.Compose
    init_args:
      transforms:
        - class_path: torchvision.transforms.v2.RandomAdjustSharpness
          init_args:
            sharpness_factor: 0.7
            p: 0.5
        - class_path: torchvision.transforms.v2.RandomHorizontalFlip
          init_args:
            p: 0.5
        - class_path: torchvision.transforms.v2.Resize
          init_args:
            size: [256, 256]
        - class_path: torchvision.transforms.v2.Normalize
          init_args:
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]
  eval_transform:
    class_path: torchvision.transforms.v2.Compose
    init_args:
      transforms:
        - class_path: torchvision.transforms.v2.Resize
          init_args:
            size: [256, 256]
        - class_path: torchvision.transforms.v2.Normalize
          init_args:
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]

Note

Please note that it is not recommended to pass only one of train_transform and eval_transform while keeping the other parameter empty. This could lead to unexpected behaviour, as it might lead to a mismatch between the training and testing subsets in terms of image shape and normalization characteristics.

Model-specific transforms#

Each Anomalib model defines a default set of transforms, that will be applied to the input data when the user does not specify any custom transforms. The default transforms of a model can be inspected using the configure_transforms method, for example:

from anomalib.models import Patchcore

model = Patchcore()
model.configure_transforms()
# Compose(
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       CenterCrop(size=(224, 224))
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )

As shown in the example, the default transforms for PatchCore consist of resizing the image to 256x256 pixels, followed by center cropping to an image size of 224x224. Finally, the pixel values are normalized to the mean and standard deviation of the ImageNet dataset. These transforms correspond to the recommended pre-processing steps described in the original PatchCore paper.

The use of these model-specific transforms ensures that Anomalib automatically applies the right transforms when no custom transforms are passed to the datamodule by the user. When no user-defined transforms are passed to the datamodule, Anomalib’s engine assigns the model’s default transform to the train_transform and eval_transform of the datamodule at the start of the fit/val/test sequence:

from anomalib.engine import Engine

# instantiate the datamodule without passing custom transforms
datamodule = MVTec()
# initially, the datamodule will not have any transforms defined
datamodule.train_transform is None
# True

engine = Engine()
engine.fit(model, datamodule=datamodule)

# after running fit, the engine will have injected the model's default transform into the datamodule
datamodule.train_transform
# Compose(
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       CenterCrop(size=(224, 224))
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )
datamodule.eval_transform
# Compose(
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       CenterCrop(size=(224, 224))
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )

Since the CLI uses the Anomalib engine under the hood, the same principles concerning model-specific transforms apply when running a model from the CI. Hence, the following command will ensure that Patchcore’s model-specific default transform is used when fitting the model.

anomalib fit --model Patchcore --data MVTec

Transforms during inference#

To ensure consistent transforms between training and inference, Anomalib includes the eval transform in the exported model. During inference, the transforms are infused in the model’s forward pass which ensures that the transforms are always applied. The following example illustrates how Anomalib’s torch inferencer automatically applies the transforms stored in the model. The same principles apply to both Lightning inference and OpenVINO inference.

import torch

from anomalib.deploy import ExportType, TorchInferencer

engine = Engine()
model = Patchcore()

train_transform = Compose(
    [
        RandomHorizontalFlip(p=0.5),
        Resize((256, 256)),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ],
)

eval_transform = Compose(
    [
        Resize((256, 256)),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ],
)

datamodule = MVTec(train_transform=train_transform, eval_transform=eval_transform)

engine.fit(model, datamodule=datamodule)

# after running fit, the used eval_transform will be stored in the model
model.transform
# Compose(
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )

# when exporting the trained model, the transforms are added in the export
engine.export(model, export_type=ExportType.TORCH, export_root="./export_folder")

inferencer = TorchInferencer("export_folder/weights/torch/model.pt")
inferencer.model.transform
# Compose(
#       Resize(size=[256, 256], interpolation=InterpolationMode.BILINEAR, antialias=True)
#       Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
# )

# since the transforms are included in the forward pass, they are applied automatically
image = torch.rand((3, 900, 900))
prediction = inferencer.predict(image)
prediction.pred_label
# <LabelName.ABNORMAL: 1>

The CLI behaviour is equivalent to that of the API. When a model is trained with a custom eval_transform like in the example below, the eval_transform is included both in the saved lightning model as in the exported torch model.

class_path: anomalib.data.MVTec
init_args:
  root: ./datasets/MVTec
  category: bottle
  image_size: [256, 256]
  train_batch_size: 32
  eval_batch_size: 32
  num_workers: 8
  task: segmentation
  test_split_mode: from_dir
  test_split_ratio: 0.2
  val_split_mode: same_as_test
  val_split_ratio: 0.5
  seed: null
  eval_transform:
    class_path: torchvision.transforms.v2.Compose
    init_args:
      transforms:
        - class_path: torchvision.transforms.v2.RandomHorizontalFlip
          init_args:
            p: 0.5
        - class_path: torchvision.transforms.v2.Resize
          init_args:
            size: [256, 256]
        - class_path: torchvision.transforms.v2.Normalize
          init_args:
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]
  train_transform:
    class_path: torchvision.transforms.v2.Compose
    init_args:
      transforms:
        - class_path: torchvision.transforms.v2.Resize
          init_args:
            size: [256, 256]
        - class_path: torchvision.transforms.v2.Normalize
          init_args:
            mean: [0.485, 0.456, 0.406]
            std: [0.229, 0.224, 0.225]
anomalib fit --model Patchcore --data mvtec.yaml --default_root_dir export_path
anomalib export --model Patchcore --export_type TORCH --ckpt_path export_path/Patchcore/MVTec/bottle/latest/weights/lightning/model.ckpt