Source code for anomalib.utils.callbacks.graph

"""Log model graph to respective logger."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY

from anomalib.utils.loggers import AnomalibTensorBoardLogger, AnomalibWandbLogger


@CALLBACK_REGISTRY
[docs]class GraphLogger(Callback): """Log model graph to respective logger."""
[docs] def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: """Log model graph to respective logger. Args: trainer: Trainer object which contans reference to loggers. pl_module: LightningModule object which is logged. """ for logger in trainer.loggers: if isinstance(logger, AnomalibWandbLogger): # NOTE: log graph gets populated only after one backward pass. This won't work for models which do not # require training such as Padim logger.watch(pl_module, log_graph=True, log="all") break
[docs] def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Unwatch model if configured for wandb and log it model graph in Tensorboard if specified. Args: trainer: Trainer object which contans reference to loggers. pl_module: LightningModule object which is logged. """ for logger in trainer.loggers: if isinstance(logger, AnomalibTensorBoardLogger): logger.log_graph(pl_module, input_array=torch.ones((1, 3, 256, 256))) elif isinstance(logger, AnomalibWandbLogger): logger.unwatch(pl_module) # type: ignore