Spaces:
Running
Running
import time | |
from dataclasses import dataclass | |
from enum import Enum | |
import torch | |
from finetrainers.constants import FINETRAINERS_ENABLE_TIMING | |
from finetrainers.logging import get_logger | |
logger = get_logger() | |
class TimerDevice(str, Enum): | |
CPU = "cpu" | |
CUDA = "cuda" | |
class TimerData: | |
name: str | |
device: TimerDevice | |
start_time: float = 0.0 | |
end_time: float = 0.0 | |
class Timer: | |
def __init__(self, name: str, device: TimerDevice, device_sync: bool = False): | |
self.data = TimerData(name=name, device=device) | |
self._device_sync = device_sync | |
self._start_event = None | |
self._end_event = None | |
self._active = False | |
self._enabled = FINETRAINERS_ENABLE_TIMING | |
def __enter__(self): | |
self.start() | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
self.end() | |
return False | |
def start(self): | |
if self._active: | |
logger.warning(f"Timer {self.data.name} is already running. Please stop it before starting again.") | |
return | |
self._active = True | |
if not self._enabled: | |
return | |
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): | |
self._start_cuda() | |
else: | |
self._start_cpu() | |
if not self.data.device == TimerDevice.CPU: | |
logger.warning( | |
f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU." | |
) | |
def end(self): | |
if not self._active: | |
logger.warning(f"Timer {self.data.name} is not running. Please start it before stopping.") | |
return | |
self._active = False | |
if not self._enabled: | |
return | |
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): | |
self._end_cuda() | |
else: | |
self._end_cpu() | |
if not self.data.device == TimerDevice.CPU: | |
logger.warning( | |
f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU." | |
) | |
def elapsed_time(self) -> float: | |
if self._active: | |
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): | |
premature_end_event = torch.cuda.Event(enable_timing=True) | |
premature_end_event.record() | |
premature_end_event.synchronize() | |
return self._start_event.elapsed_time(premature_end_event) / 1000.0 | |
else: | |
return time.time() - self.data.start_time | |
else: | |
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): | |
return self._start_event.elapsed_time(self._end_event) / 1000.0 | |
else: | |
return self.data.end_time - self.data.start_time | |
def _start_cpu(self): | |
self.data.start_time = time.time() | |
def _start_cuda(self): | |
torch.cuda.synchronize() | |
self._start_event = torch.cuda.Event(enable_timing=True) | |
self._end_event = torch.cuda.Event(enable_timing=True) | |
self._start_event.record() | |
def _end_cpu(self): | |
self.data.end_time = time.time() | |
def _end_cuda(self): | |
if self._device_sync: | |
torch.cuda.synchronize() | |
self._end_event.record() | |