Source code for anomalib.utils.callbacks.timer

"""Callback to measure training and testing time of a PyTorch Lightning module."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
import time

from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY

[docs]logger = logging.getLogger(__name__)
@CALLBACK_REGISTRY
[docs]class TimerCallback(Callback): """Callback that measures the training and testing time of a PyTorch Lightning module.""" # pylint: disable=unused-argument def __init__(self): self.start: float self.num_images: int = 0
[docs] def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 """Call when fit begins. Sets the start time to the time training started. Args: trainer (Trainer): PyTorch Lightning trainer. pl_module (LightningModule): Current training module. Returns: None """ self.start = time.time()
[docs] def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 """Call when fit ends. Prints the time taken for training. Args: trainer (Trainer): PyTorch Lightning trainer. pl_module (LightningModule): Current training module. Returns: None """ logger.info("Training took %5.2f seconds", (time.time() - self.start))
[docs] def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 """Call when the test begins. Sets the start time to the time testing started. Goes over all the test dataloaders and adds the number of images in each. Args: trainer (Trainer): PyTorch Lightning trainer. pl_module (LightningModule): Current training module. Returns: None """ self.start = time.time() self.num_images = 0 if trainer.test_dataloaders is not None: # Check to placate Mypy. for dataloader in trainer.test_dataloaders: self.num_images += len(dataloader.dataset)
[docs] def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613 """Call when the test ends. Prints the time taken for testing and the throughput in frames per second. Args: trainer (Trainer): PyTorch Lightning trainer. pl_module (LightningModule): Current training module. Returns: None """ testing_time = time.time() - self.start output = f"Testing took {testing_time} seconds\nThroughput " if trainer.test_dataloaders is not None: output += f"(batch_size={trainer.test_dataloaders[0].batch_size})" output += f" : {self.num_images/testing_time} FPS" logger.info(output)