jbilcke-hf's picture
jbilcke-hf HF Staff
we are going to hack into finetrainers
9fd1204
import time
from dataclasses import dataclass
from enum import Enum
import torch
from finetrainers.constants import FINETRAINERS_ENABLE_TIMING
from finetrainers.logging import get_logger
logger = get_logger()
class TimerDevice(str, Enum):
CPU = "cpu"
CUDA = "cuda"
@dataclass
class TimerData:
name: str
device: TimerDevice
start_time: float = 0.0
end_time: float = 0.0
class Timer:
def __init__(self, name: str, device: TimerDevice, device_sync: bool = False):
self.data = TimerData(name=name, device=device)
self._device_sync = device_sync
self._start_event = None
self._end_event = None
self._active = False
self._enabled = FINETRAINERS_ENABLE_TIMING
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end()
return False
def start(self):
if self._active:
logger.warning(f"Timer {self.data.name} is already running. Please stop it before starting again.")
return
self._active = True
if not self._enabled:
return
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
self._start_cuda()
else:
self._start_cpu()
if not self.data.device == TimerDevice.CPU:
logger.warning(
f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU."
)
def end(self):
if not self._active:
logger.warning(f"Timer {self.data.name} is not running. Please start it before stopping.")
return
self._active = False
if not self._enabled:
return
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
self._end_cuda()
else:
self._end_cpu()
if not self.data.device == TimerDevice.CPU:
logger.warning(
f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU."
)
@property
def elapsed_time(self) -> float:
if self._active:
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
premature_end_event = torch.cuda.Event(enable_timing=True)
premature_end_event.record()
premature_end_event.synchronize()
return self._start_event.elapsed_time(premature_end_event) / 1000.0
else:
return time.time() - self.data.start_time
else:
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
return self._start_event.elapsed_time(self._end_event) / 1000.0
else:
return self.data.end_time - self.data.start_time
def _start_cpu(self):
self.data.start_time = time.time()
def _start_cuda(self):
torch.cuda.synchronize()
self._start_event = torch.cuda.Event(enable_timing=True)
self._end_event = torch.cuda.Event(enable_timing=True)
self._start_event.record()
def _end_cpu(self):
self.data.end_time = time.time()
def _end_cuda(self):
if self._device_sync:
torch.cuda.synchronize()
self._end_event.record()