Datamodules#

This guide explains how Lightning DataModules work in Anomalib and how they integrate with datasets and dataclasses.

Overview#

DataModules encapsulate all the steps needed to process data:

  • Download/prepare the data

  • Set up train/val/test datasets

  • Apply transforms

  • Create data loaders

Basic Structure#

A typical Anomalib DataModule follows this structure:

from lightning.pytorch import LightningDataModule
from anomalib.data.datasets.base.image import AnomalibDataset
from torch.utils.data import DataLoader

class AnomalibDataModule(LightningDataModule):
    def __init__(
        self,
        root: str = "./datasets",
        category: str = "bottle",
        image_size: tuple[int, int] = (256, 256),
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        num_workers: int = 8,
        transform = None,
    ):
        super().__init__()
        self.root = root
        self.category = category
        self.image_size = image_size
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.num_workers = num_workers
        self.transform = transform

Integration with Datasets#

DataModules create and manage dataset instances:

def setup(self, stage: str | None = None):
    """Set up train, validation and test datasets."""
    if stage == "fit" or stage is None:
        self.train_dataset = AnomalibDataset(
            root=self.root,
            category=self.category,
            transform=self.transform,
            split="train"
        )

        self.val_dataset = AnomalibDataset(
            root=self.root,
            category=self.category,
            transform=self.transform,
            split="val"
        )

    if stage == "test" or stage is None:
        self.test_dataset = AnomalibDataset(
            root=self.root,
            category=self.category,
            transform=self.transform,
            split="test"
        )

Integration with Dataclasses#

DataModules use DataLoaders to convert dataset items into batches:

def train_dataloader(self) -> DataLoader:
    """Create the train dataloader."""
    return DataLoader(
        dataset=self.train_dataset,
        batch_size=self.train_batch_size,
        shuffle=True,
        num_workers=self.num_workers,
        collate_fn=ImageBatch.collate    # Converts list of ImageItems to ImageBatch
    )

The data flow is:

  1. Dataset returns ImageItem objects

  2. DataLoader collates them into ImageBatch objects

  3. Model receives ImageBatch for training/inference

Example DataModules#

1. Image DataModule#

from anomalib.data import MVTec

datamodule = MVTec(
    root="./datasets/MVTec",
    category="bottle",
    train_batch_size=32,
    eval_batch_size=32,
    num_workers=8
)

# Setup creates the datasets
datamodule.setup()

# Get train dataloader
train_loader = datamodule.train_dataloader()

# Access batches
for batch in train_loader:
    print(batch.image.shape)      # torch.Size([32, 3, 256, 256])
    print(batch.gt_label.shape)   # torch.Size([32])

2. Video DataModule#

from anomalib.data import Avenue

datamodule = Avenue(
    clip_length_in_frames=2,
    frames_between_clips=1,
    target_frame="last",
)
datamodule.setup()
i, data = next(enumerate(datamodule.train_dataloader()))
data["image"].shape
# torch.Size([32, 2, 3, 256, 256])

3. Depth DataModule#

from anomalib.data import MVTec3D

datamodule = MVTec3D(
    root="./datasets/MVTec3D",
    category="bagel",
    train_batch_size=32,
)

# Access RGB-D batches
i, data = next(enumerate(datamodule.train_dataloader()))
data["image"].shape
# torch.Size([32, 3, 256, 256])
data["depth_map"].shape
# torch.Size([32, 1, 256, 256])

Creating Custom DataModules#

To create a custom DataModule:

from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from anomalib.data.dataclasses import ImageBatch

class CustomDataModule(LightningDataModule):
    def __init__(
        self,
        root: str,
        category: str,
        train_batch_size: int = 32,
        **kwargs
    ):
        super().__init__()
        self.root = root
        self.category = category
        self.image_size = image_size
        self.train_batch_size = train_batch_size

    def setup(self, stage: str | None = None):
        """Initialize datasets."""
        if stage == "fit" or stage is None:
            self.train_dataset = CustomDataset(
                root=self.root,
                category=self.category,
                split="train"
            )

    def train_dataloader(self) -> DataLoader:
        """Create train dataloader."""
        return DataLoader(
            dataset=self.train_dataset,
            batch_size=self.train_batch_size,
            shuffle=True,
            collate_fn=ImageBatch.collate
        )

Best Practices#

  1. Data Organization:

    • Keep dataset creation in setup()

    • Use appropriate batch sizes for train/eval

    • Handle multi-GPU scenarios

  2. Memory Management:

    • Use appropriate number of workers

    • Clear cache between epochs if needed

    • Handle GPU memory efficiently

  3. Transforms:

    • Apply consistent transforms across splits

    • Use torchvision.transforms.v2

    • Handle different input modalities

  4. Validation:

    • Verify data shapes and types

    • Check batch size consistency

    • Validate paths and parameters

See also