Spaces:
No application file
No application file
| import torch | |
| from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler | |
| from contextlib import contextmanager | |
| from pytorch_lightning.utilities import rank_zero_only | |
| class InferenceProfiler(SimpleProfiler): | |
| """ | |
| This profiler records duration of actions with cuda.synchronize() | |
| Use this in test time. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| self.start = rank_zero_only(self.start) | |
| self.stop = rank_zero_only(self.stop) | |
| self.summary = rank_zero_only(self.summary) | |
| def profile(self, action_name: str) -> None: | |
| try: | |
| torch.cuda.synchronize() | |
| self.start(action_name) | |
| yield action_name | |
| finally: | |
| torch.cuda.synchronize() | |
| self.stop(action_name) | |
| def build_profiler(name): | |
| if name == "inference": | |
| return InferenceProfiler() | |
| elif name == "pytorch": | |
| from pytorch_lightning.profiler import PyTorchProfiler | |
| return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) | |
| elif name is None: | |
| return PassThroughProfiler() | |
| else: | |
| raise ValueError(f"Invalid profiler: {name}") | |