Spaces:
Running
Running
File size: 1,318 Bytes
406f22d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
import torch.nn as nn
def pad_x_to_y(x, y, axis: int = -1):
if axis != -1:
raise NotImplementedError
inp_len = y.shape[axis]
output_len = x.shape[axis]
return nn.functional.pad(x, [0, inp_len - output_len])
def shape_reconstructed(reconstructed, size):
if len(size) == 1:
return reconstructed.squeeze(0)
return reconstructed
def tensors_to_device(tensors, device):
"""Transfer tensor, dict or list of tensors to device.
Args:
tensors (:class:`torch.Tensor`): May be a single, a list or a
dictionary of tensors.
device (:class: `torch.device`): the device where to place the tensors.
Returns:
Union [:class:`torch.Tensor`, list, tuple, dict]:
Same as input but transferred to device.
Goes through lists and dicts and transfers the torch.Tensor to
device. Leaves the rest untouched.
"""
if isinstance(tensors, torch.Tensor):
return tensors.to(device)
elif isinstance(tensors, (list, tuple)):
return [tensors_to_device(tens, device) for tens in tensors]
elif isinstance(tensors, dict):
for key in tensors.keys():
tensors[key] = tensors_to_device(tensors[key], device)
return tensors
else:
return tensors
|