File size: 710 Bytes
57746f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
r""" Helper functions """
import random
import torch
import numpy as np
def fix_randseed(seed):
r""" Set random seeds for reproducibility """
if seed is None:
seed = int(random.random() * 1e5)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def mean(x):
return sum(x) / len(x) if len(x) > 0 else 0.0
def to_cuda(batch):
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch[key] = value.cuda()
return batch
def to_cpu(tensor):
return tensor.detach().clone().cpu()
|