from multiprocessing import RLock import torch from jaxtyping import Int64 from torch import Tensor from torch.multiprocessing import Manager class StepTracker: lock: RLock step: Int64[Tensor, ""] def __init__(self): self.lock = Manager().RLock() self.step = torch.tensor(0, dtype=torch.int64).share_memory_() def set_step(self, step: int) -> None: with self.lock: self.step.fill_(step) def get_step(self) -> int: with self.lock: return self.step.item()