AnySplat / src /misc /step_tracker.py
alexnasa's picture
Upload 243 files
2568013 verified
raw
history blame contribute delete
536 Bytes
from multiprocessing import RLock
import torch
from jaxtyping import Int64
from torch import Tensor
from torch.multiprocessing import Manager
class StepTracker:
lock: RLock
step: Int64[Tensor, ""]
def __init__(self):
self.lock = Manager().RLock()
self.step = torch.tensor(0, dtype=torch.int64).share_memory_()
def set_step(self, step: int) -> None:
with self.lock:
self.step.fill_(step)
def get_step(self) -> int:
with self.lock:
return self.step.item()