from multiprocessing import RLock | |
import torch | |
from jaxtyping import Int64 | |
from torch import Tensor | |
from torch.multiprocessing import Manager | |
class StepTracker: | |
lock: RLock | |
step: Int64[Tensor, ""] | |
def __init__(self): | |
self.lock = Manager().RLock() | |
self.step = torch.tensor(0, dtype=torch.int64).share_memory_() | |
def set_step(self, step: int) -> None: | |
with self.lock: | |
self.step.fill_(step) | |
def get_step(self) -> int: | |
with self.lock: | |
return self.step.item() | |