kairunwen's picture
Update Code
57746f1
raw
history blame contribute delete
710 Bytes
r""" Helper functions """
import random
import torch
import numpy as np
def fix_randseed(seed):
r""" Set random seeds for reproducibility """
if seed is None:
seed = int(random.random() * 1e5)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def mean(x):
return sum(x) / len(x) if len(x) > 0 else 0.0
def to_cuda(batch):
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.cuda()
return batch
def to_cpu(tensor):
return tensor.detach().clone().cpu()