julien.blanchon
add app
c8c12e9
"""Callback to measure training and testing time of a PyTorch Lightning module."""
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import logging
import time
from pytorch_lightning import Callback, LightningModule, Trainer
logger = logging.getLogger(__name__)
class TimerCallback(Callback):
"""Callback that measures the training and testing time of a PyTorch Lightning module."""
# pylint: disable=unused-argument
def __init__(self):
self.start: float
self.num_images: int = 0
def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit begins.
Sets the start time to the time training started.
Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.
Returns:
None
"""
self.start = time.time()
def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when fit ends.
Prints the time taken for training.
Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.
Returns:
None
"""
logger.info("Training took %5.2f seconds", (time.time() - self.start))
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test begins.
Sets the start time to the time testing started.
Goes over all the test dataloaders and adds the number of images in each.
Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.
Returns:
None
"""
self.start = time.time()
self.num_images = 0
if trainer.test_dataloaders is not None: # Check to placate Mypy.
for dataloader in trainer.test_dataloaders:
self.num_images += len(dataloader.dataset)
def on_test_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
"""Call when the test ends.
Prints the time taken for testing and the throughput in frames per second.
Args:
trainer (Trainer): PyTorch Lightning trainer.
pl_module (LightningModule): Current training module.
Returns:
None
"""
testing_time = time.time() - self.start
output = f"Testing took {testing_time} seconds\nThroughput "
if trainer.test_dataloaders is not None:
output += f"(batch_size={trainer.test_dataloaders[0].batch_size})"
output += f" : {self.num_images/testing_time} FPS"
logger.info(output)