Spaces:
Running
on
Zero
Running
on
Zero
File size: 870 Bytes
4f7b5ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import torch
def _broadcast_tensor(a, broadcast_shape):
while len(a.shape) < len(broadcast_shape):
a = a[..., None]
return a.expand(broadcast_shape)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
:param arr: the 1-D numpy array.
:param timesteps: a tensor of indices into the array to extract.
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
|