Callbacks#

Model Checkpoint

Save and manage model checkpoints during training.

Model Checkpoint
Graph Logger

Log model computation graphs for visualization.

Graph Logger
Load Model

Load pre-trained models and weights.

Load Model
Tile Configuration

Configure and manage image tiling settings.

Tile Configuration
Timer

Track and measure execution times during training.

Timer

Model Checkpoint#

Anomalib Model Checkpoint Callback.

This module provides the ModelCheckpoint callback that extends PyTorch Lightning’s ModelCheckpoint to support zero-shot and few-shot learning scenarios.

The callback enables checkpoint saving without requiring training steps, which is particularly useful for zero-shot and few-shot learning models where the training process may only involve validation.

Example

Create and use a checkpoint callback:

>>> from anomalib.callbacks import ModelCheckpoint
>>> checkpoint_callback = ModelCheckpoint(
...     dirpath="checkpoints",
...     filename="best",
...     monitor="val_loss"
... )
>>> from lightning.pytorch import Trainer
>>> trainer = Trainer(callbacks=[checkpoint_callback])

Note

This callback is particularly important for zero-shot and few-shot models where traditional training-based checkpoint saving strategies may not be appropriate.

class anomalib.callbacks.checkpoint.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, enable_version_counter=True)#

Bases: ModelCheckpoint

Custom ModelCheckpoint callback for Anomalib.

This callback extends PyTorch Lightning’s ModelCheckpoint to enable checkpoint saving without requiring training steps. This is particularly useful for zero-shot and few-shot learning models where the training process may only involve validation.

The callback overrides two key methods from the parent class:

  1. _should_save_on_train_epoch_end(): Controls whether checkpoints are saved at the end of training epochs or validation sequences. For zero-shot and few-shot models, it defaults to saving at validation end unless explicitly configured otherwise.

  2. _should_skip_saving_checkpoint(): Determines if checkpoint saving should be skipped. Modified to:

    • Allow saving during both FITTING and VALIDATING states

    • Permit saving even when global step hasn’t changed (for zero-shot/few-shot models)

    • Maintain standard checkpoint skipping conditions (fast_dev_run, sanity checking)

Example

Create and use a checkpoint callback:

>>> from anomalib.callbacks import ModelCheckpoint
>>> # Create a checkpoint callback
>>> checkpoint_callback = ModelCheckpoint(
...     dirpath="checkpoints",
...     filename="best",
...     monitor="val_loss"
... )
>>> # Use it with Lightning Trainer
>>> from lightning.pytorch import Trainer
>>> trainer = Trainer(callbacks=[checkpoint_callback])

Note

All arguments from PyTorch Lightning’s ModelCheckpoint are supported. See ModelCheckpoint for details.

Graph Logger#

Graph logging callback for model visualization.

This module provides the GraphLogger callback for visualizing model architectures in various logging backends. The callback supports TensorBoard, Comet, and Weights & Biases (W&B) logging.

The callback automatically detects which logger is being used and handles the graph logging appropriately for each backend.

Example

Log model graph to TensorBoard:

>>> from anomalib.callbacks import GraphLogger
>>> from anomalib.loggers import AnomalibTensorBoardLogger
>>> from anomalib.engine import Engine
>>> logger = AnomalibTensorBoardLogger()
>>> callbacks = [GraphLogger()]
>>> engine = Engine(logger=logger, callbacks=callbacks)

Log model graph to Comet:

>>> from anomalib.callbacks import GraphLogger
>>> from anomalib.loggers import AnomalibCometLogger
>>> from anomalib.engine import Engine
>>> logger = AnomalibCometLogger()
>>> callbacks = [GraphLogger()]
>>> engine = Engine(logger=logger, callbacks=callbacks)

Note

For TensorBoard and Comet, the graph is logged at the end of training. For W&B, the graph is logged at the start of training but requires one backward pass to be populated. This means it may not work for models that don’t require training (e.g., PaDiM).

class anomalib.callbacks.graph.GraphLogger#

Bases: Callback

Log model graph to respective logger.

This callback logs the model architecture graph to the configured logger. It supports multiple logging backends including TensorBoard, Comet, and Weights & Biases (W&B).

The callback automatically detects which logger is being used and handles the graph logging appropriately for each backend.

Example

Create and use a graph logger:

>>> from anomalib.callbacks import GraphLogger
>>> from anomalib.loggers import AnomalibTensorBoardLogger
>>> from lightning.pytorch import Trainer
>>> logger = AnomalibTensorBoardLogger()
>>> graph_logger = GraphLogger()
>>> trainer = Trainer(logger=logger, callbacks=[graph_logger])

Note

  • For TensorBoard and Comet, the graph is logged at the end of training

  • For W&B, the graph is logged at the start of training but requires one backward pass to be populated. This means it may not work for models that don’t require training (e.g., PaDiM)

static on_train_end(trainer, pl_module)#

Log model graph at training end and cleanup.

This method is called automatically at the end of training. It: - Logs the model graph for TensorBoard and Comet loggers - Unwatches the model for W&B logger

Parameters:
  • trainer (Trainer) – PyTorch Lightning trainer instance containing logger references.

  • pl_module (LightningModule) – Lightning module instance to be logged.

Return type:

None

Example

>>> from anomalib.callbacks import GraphLogger
>>> callback = GraphLogger()
>>> # Called automatically by trainer
>>> # callback.on_train_end(trainer, model)
static on_train_start(trainer, pl_module)#

Log model graph to respective logger at training start.

This method is called automatically at the start of training. For W&B logger, it sets up model watching with graph logging enabled.

Parameters:
  • trainer (Trainer) – PyTorch Lightning trainer instance containing logger references.

  • pl_module (LightningModule) – Lightning module instance to be logged.

Return type:

None

Example

>>> from anomalib.callbacks import GraphLogger
>>> callback = GraphLogger()
>>> # Called automatically by trainer
>>> # callback.on_train_start(trainer, model)

Load Model#

Model loader callback.

This module provides the LoadModelCallback for loading pre-trained model weights from a state dict.

The callback loads model weights from a specified path when inference begins. This is useful for loading pre-trained models for inference or fine-tuning.

Example

Load pre-trained weights and create a trainer:

>>> from anomalib.callbacks import LoadModelCallback
>>> from anomalib.engine import Engine
>>> from anomalib.models import Padim
>>> model = Padim()
>>> callbacks = [LoadModelCallback(weights_path="path/to/weights.pt")]
>>> engine = Engine(model=model, callbacks=callbacks)

Note

The weights file should be a PyTorch state dict saved with either a .pt or .pth extension. The state dict should contain a "state_dict" key with the model weights.

class anomalib.callbacks.model_loader.LoadModelCallback(weights_path)#

Bases: Callback

Callback that loads model weights from a state dict.

This callback loads pre-trained model weights from a specified path when inference begins. The weights are loaded into the model’s state dict using the device specified by the model.

Parameters:

weights_path (str) – Path to the model weights file (.pt or .pth). The file should contain a state dict with a "state_dict" key.

Examples

Create a callback and use it with a trainer:

>>> from anomalib.callbacks import LoadModelCallback
>>> from anomalib.engine import Engine
>>> from anomalib.models import Padim
>>> model = Padim()
>>> # Create callback with path to weights
>>> callback = LoadModelCallback(weights_path="path/to/weights.pt")
>>> # Use callback with engine
>>> engine = Engine(model=model, callbacks=[callback])

Note

The callback automatically handles device mapping when loading weights.

setup(trainer, pl_module, stage=None)#

Call when inference begins.

This method is called by PyTorch Lightning when inference begins. It loads the model weights from the specified path into the module’s state dict.

Parameters:
  • trainer (Trainer) – PyTorch Lightning trainer instance.

  • pl_module (AnomalibModule) – The module to load weights into.

  • stage (str | None, optional) – Current stage of execution. Defaults to None.

Return type:

None

Note

The weights are loaded using torch.load with automatic device mapping based on the module’s device. The state dict is expected to have a "state_dict" key containing the model weights.

Tile Configuration#

Timer#

Timer callback.

This module provides the TimerCallback for measuring training and testing time of Anomalib models. The callback tracks execution time and calculates throughput metrics.

Example

Add timer callback to track performance:

>>> from anomalib.callbacks import TimerCallback
>>> from lightning.pytorch import Trainer
>>> callback = TimerCallback()
>>> trainer = Trainer(callbacks=[callback])

The callback will automatically log: - Total training time when training completes - Total testing time and throughput (FPS) when testing completes

Note

  • The callback handles both single and multiple test dataloaders

  • Throughput is calculated as total number of images / total testing time

  • Batch size is included in throughput logging for reference

class anomalib.callbacks.timer.TimerCallback#

Bases: Callback

Callback for measuring model training and testing time.

This callback tracks execution time metrics: - Training time: Total time taken for model training - Testing time: Total time taken for model testing - Testing throughput: Images processed per second during testing

Example

Add timer to track performance:

>>> from anomalib.callbacks import TimerCallback
>>> from lightning.pytorch import Trainer
>>> callback = TimerCallback()
>>> trainer = Trainer(callbacks=[callback])

Note

  • The callback automatically handles both single and multiple test dataloaders

  • Throughput is calculated as: num_test_images / testing_time

  • All metrics are logged using the logger specified in the trainer

on_fit_end(trainer, pl_module)#

Called when fit ends.

Calculates and logs the total training time.

Parameters:
  • trainer (Trainer) – PyTorch Lightning trainer instance

  • pl_module (LightningModule) – The current training module

Return type:

None

Note

The trainer and module arguments are not used but kept for callback signature compatibility

on_fit_start(trainer, pl_module)#

Called when fit begins.

Records the start time of the training process.

Parameters:
  • trainer (Trainer) – PyTorch Lightning trainer instance

  • pl_module (LightningModule) – The current training module

Return type:

None

Note

The trainer and module arguments are not used but kept for callback signature compatibility

on_test_end(trainer, pl_module)#

Called when test ends.

Calculates and logs testing time and throughput metrics.

Parameters:
  • trainer (Trainer) – PyTorch Lightning trainer instance

  • pl_module (LightningModule) – The current training module

Return type:

None

Note

  • Calculates total testing time

  • Computes throughput in frames per second (FPS)

  • Logs batch size along with throughput for reference

  • The module argument is not used but kept for callback signature compatibility

on_test_start(trainer, pl_module)#

Called when test begins.

Records test start time and counts total number of test images.

Parameters:
  • trainer (Trainer) – PyTorch Lightning trainer instance

  • pl_module (LightningModule) – The current training module

Return type:

None

Note

  • Records start timestamp for testing phase

  • Counts total images across all test dataloaders if multiple are present

  • The module argument is not used but kept for callback signature compatibility