|
r""" Helper functions """ |
|
import random |
|
|
|
import torch |
|
import numpy as np |
|
|
|
|
|
def fix_randseed(seed): |
|
r""" Set random seeds for reproducibility """ |
|
if seed is None: |
|
seed = int(random.random() * 1e5) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
torch.backends.cudnn.benchmark = False |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
def mean(x): |
|
return sum(x) / len(x) if len(x) > 0 else 0.0 |
|
|
|
|
|
def to_cuda(batch): |
|
for key, value in batch.items(): |
|
if isinstance(value, torch.Tensor): |
|
batch[key] = value.cuda() |
|
return batch |
|
|
|
|
|
def to_cpu(tensor): |
|
return tensor.detach().clone().cpu() |
|
|