File size: 536 Bytes
2568013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from multiprocessing import RLock
import torch
from jaxtyping import Int64
from torch import Tensor
from torch.multiprocessing import Manager
class StepTracker:
lock: RLock
step: Int64[Tensor, ""]
def __init__(self):
self.lock = Manager().RLock()
self.step = torch.tensor(0, dtype=torch.int64).share_memory_()
def set_step(self, step: int) -> None:
with self.lock:
self.step.fill_(step)
def get_step(self) -> int:
with self.lock:
return self.step.item()
|