Callbacks#
Save and manage model checkpoints during training.
Log model computation graphs for visualization.
Load pre-trained models and weights.
Configure and manage image tiling settings.
Track and measure execution times during training.
Model Checkpoint#
Anomalib Model Checkpoint Callback.
This module provides the ModelCheckpoint
callback that extends PyTorch Lightning’s
ModelCheckpoint
to support zero-shot and few-shot learning scenarios.
The callback enables checkpoint saving without requiring training steps, which is particularly useful for zero-shot and few-shot learning models where the training process may only involve validation.
Example
Create and use a checkpoint callback:
>>> from anomalib.callbacks import ModelCheckpoint
>>> checkpoint_callback = ModelCheckpoint(
... dirpath="checkpoints",
... filename="best",
... monitor="val_loss"
... )
>>> from lightning.pytorch import Trainer
>>> trainer = Trainer(callbacks=[checkpoint_callback])
Note
This callback is particularly important for zero-shot and few-shot models where traditional training-based checkpoint saving strategies may not be appropriate.
- class anomalib.callbacks.checkpoint.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, enable_version_counter=True)#
Bases:
ModelCheckpoint
Custom ModelCheckpoint callback for Anomalib.
This callback extends PyTorch Lightning’s
ModelCheckpoint
to enable checkpoint saving without requiring training steps. This is particularly useful for zero-shot and few-shot learning models where the training process may only involve validation.The callback overrides two key methods from the parent class:
_should_save_on_train_epoch_end()
: Controls whether checkpoints are saved at the end of training epochs or validation sequences. For zero-shot and few-shot models, it defaults to saving at validation end unless explicitly configured otherwise._should_skip_saving_checkpoint()
: Determines if checkpoint saving should be skipped. Modified to:Allow saving during both
FITTING
andVALIDATING
statesPermit saving even when global step hasn’t changed (for zero-shot/few-shot models)
Maintain standard checkpoint skipping conditions (
fast_dev_run
, sanity checking)
Example
Create and use a checkpoint callback:
>>> from anomalib.callbacks import ModelCheckpoint >>> # Create a checkpoint callback >>> checkpoint_callback = ModelCheckpoint( ... dirpath="checkpoints", ... filename="best", ... monitor="val_loss" ... ) >>> # Use it with Lightning Trainer >>> from lightning.pytorch import Trainer >>> trainer = Trainer(callbacks=[checkpoint_callback])
Note
All arguments from PyTorch Lightning’s
ModelCheckpoint
are supported. SeeModelCheckpoint
for details.
Graph Logger#
Graph logging callback for model visualization.
This module provides the GraphLogger
callback for visualizing model architectures in various logging backends.
The callback supports TensorBoard, Comet, and Weights & Biases (W&B) logging.
The callback automatically detects which logger is being used and handles the graph logging appropriately for each backend.
Example
Log model graph to TensorBoard:
>>> from anomalib.callbacks import GraphLogger
>>> from anomalib.loggers import AnomalibTensorBoardLogger
>>> from anomalib.engine import Engine
>>> logger = AnomalibTensorBoardLogger()
>>> callbacks = [GraphLogger()]
>>> engine = Engine(logger=logger, callbacks=callbacks)
Log model graph to Comet:
>>> from anomalib.callbacks import GraphLogger
>>> from anomalib.loggers import AnomalibCometLogger
>>> from anomalib.engine import Engine
>>> logger = AnomalibCometLogger()
>>> callbacks = [GraphLogger()]
>>> engine = Engine(logger=logger, callbacks=callbacks)
Note
For TensorBoard and Comet, the graph is logged at the end of training.
For W&B, the graph is logged at the start of training but requires one backward pass
to be populated. This means it may not work for models that don’t require training
(e.g., PaDiM
).
- class anomalib.callbacks.graph.GraphLogger#
Bases:
Callback
Log model graph to respective logger.
This callback logs the model architecture graph to the configured logger. It supports multiple logging backends including TensorBoard, Comet, and Weights & Biases (W&B).
The callback automatically detects which logger is being used and handles the graph logging appropriately for each backend.
Example
Create and use a graph logger:
>>> from anomalib.callbacks import GraphLogger >>> from anomalib.loggers import AnomalibTensorBoardLogger >>> from lightning.pytorch import Trainer >>> logger = AnomalibTensorBoardLogger() >>> graph_logger = GraphLogger() >>> trainer = Trainer(logger=logger, callbacks=[graph_logger])
Note
For TensorBoard and Comet, the graph is logged at the end of training
For W&B, the graph is logged at the start of training but requires one backward pass to be populated. This means it may not work for models that don’t require training (e.g.,
PaDiM
)
- static on_train_end(trainer, pl_module)#
Log model graph at training end and cleanup.
This method is called automatically at the end of training. It: - Logs the model graph for TensorBoard and Comet loggers - Unwatches the model for W&B logger
- Parameters:
trainer (Trainer) – PyTorch Lightning trainer instance containing logger references.
pl_module (LightningModule) – Lightning module instance to be logged.
- Return type:
Example
>>> from anomalib.callbacks import GraphLogger >>> callback = GraphLogger() >>> # Called automatically by trainer >>> # callback.on_train_end(trainer, model)
- static on_train_start(trainer, pl_module)#
Log model graph to respective logger at training start.
This method is called automatically at the start of training. For W&B logger, it sets up model watching with graph logging enabled.
- Parameters:
trainer (Trainer) – PyTorch Lightning trainer instance containing logger references.
pl_module (LightningModule) – Lightning module instance to be logged.
- Return type:
Example
>>> from anomalib.callbacks import GraphLogger >>> callback = GraphLogger() >>> # Called automatically by trainer >>> # callback.on_train_start(trainer, model)
Load Model#
Model loader callback.
This module provides the LoadModelCallback
for loading pre-trained model weights from a state dict.
The callback loads model weights from a specified path when inference begins. This is useful for loading pre-trained models for inference or fine-tuning.
Example
Load pre-trained weights and create a trainer:
>>> from anomalib.callbacks import LoadModelCallback
>>> from anomalib.engine import Engine
>>> from anomalib.models import Padim
>>> model = Padim()
>>> callbacks = [LoadModelCallback(weights_path="path/to/weights.pt")]
>>> engine = Engine(model=model, callbacks=callbacks)
Note
The weights file should be a PyTorch state dict saved with either a .pt
or .pth
extension.
The state dict should contain a "state_dict"
key with the model weights.
- class anomalib.callbacks.model_loader.LoadModelCallback(weights_path)#
Bases:
Callback
Callback that loads model weights from a state dict.
This callback loads pre-trained model weights from a specified path when inference begins. The weights are loaded into the model’s state dict using the device specified by the model.
- Parameters:
weights_path (str) – Path to the model weights file (
.pt
or.pth
). The file should contain a state dict with a"state_dict"
key.
Examples
Create a callback and use it with a trainer:
>>> from anomalib.callbacks import LoadModelCallback >>> from anomalib.engine import Engine >>> from anomalib.models import Padim >>> model = Padim() >>> # Create callback with path to weights >>> callback = LoadModelCallback(weights_path="path/to/weights.pt") >>> # Use callback with engine >>> engine = Engine(model=model, callbacks=[callback])
Note
The callback automatically handles device mapping when loading weights.
- setup(trainer, pl_module, stage=None)#
Call when inference begins.
This method is called by PyTorch Lightning when inference begins. It loads the model weights from the specified path into the module’s state dict.
- Parameters:
trainer (Trainer) – PyTorch Lightning trainer instance.
pl_module (AnomalibModule) – The module to load weights into.
stage (str | None, optional) – Current stage of execution. Defaults to
None
.
- Return type:
Note
The weights are loaded using
torch.load
with automatic device mapping based on the module’s device. The state dict is expected to have a"state_dict"
key containing the model weights.
Tile Configuration#
Timer#
Timer callback.
This module provides the TimerCallback
for measuring training and testing time of
Anomalib models. The callback tracks execution time and calculates throughput metrics.
Example
Add timer callback to track performance:
>>> from anomalib.callbacks import TimerCallback
>>> from lightning.pytorch import Trainer
>>> callback = TimerCallback()
>>> trainer = Trainer(callbacks=[callback])
The callback will automatically log: - Total training time when training completes - Total testing time and throughput (FPS) when testing completes
Note
The callback handles both single and multiple test dataloaders
Throughput is calculated as total number of images / total testing time
Batch size is included in throughput logging for reference
- class anomalib.callbacks.timer.TimerCallback#
Bases:
Callback
Callback for measuring model training and testing time.
This callback tracks execution time metrics: - Training time: Total time taken for model training - Testing time: Total time taken for model testing - Testing throughput: Images processed per second during testing
Example
Add timer to track performance:
>>> from anomalib.callbacks import TimerCallback >>> from lightning.pytorch import Trainer >>> callback = TimerCallback() >>> trainer = Trainer(callbacks=[callback])
Note
The callback automatically handles both single and multiple test dataloaders
Throughput is calculated as:
num_test_images / testing_time
All metrics are logged using the logger specified in the trainer
- on_fit_end(trainer, pl_module)#
Called when fit ends.
Calculates and logs the total training time.
- Parameters:
trainer (Trainer) – PyTorch Lightning trainer instance
pl_module (LightningModule) – The current training module
- Return type:
Note
The trainer and module arguments are not used but kept for callback signature compatibility
- on_fit_start(trainer, pl_module)#
Called when fit begins.
Records the start time of the training process.
- Parameters:
trainer (Trainer) – PyTorch Lightning trainer instance
pl_module (LightningModule) – The current training module
- Return type:
Note
The trainer and module arguments are not used but kept for callback signature compatibility
- on_test_end(trainer, pl_module)#
Called when test ends.
Calculates and logs testing time and throughput metrics.
- Parameters:
trainer (Trainer) – PyTorch Lightning trainer instance
pl_module (LightningModule) – The current training module
- Return type:
Note
Calculates total testing time
Computes throughput in frames per second (FPS)
Logs batch size along with throughput for reference
The module argument is not used but kept for callback signature compatibility
- on_test_start(trainer, pl_module)#
Called when test begins.
Records test start time and counts total number of test images.
- Parameters:
trainer (Trainer) – PyTorch Lightning trainer instance
pl_module (LightningModule) – The current training module
- Return type:
Note
Records start timestamp for testing phase
Counts total images across all test dataloaders if multiple are present
The module argument is not used but kept for callback signature compatibility