File size: 343 Bytes
b73936d
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from typing import TypedDict, Dict, Any
import torch


class TrainState(TypedDict):
    dataloader: torch.utils.data.DataLoader
    optimizer: Dict[str, Any]
    scheduler: Dict[str, Any]
    sampler: Any  # Changed from torch.utils.data.sampler to Any
    profiler: bool
    epoch: int
    iteration: int
    loss: float
    wandb_state: int