Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
def _broadcast_tensor(a, broadcast_shape): | |
while len(a.shape) < len(broadcast_shape): | |
a = a[..., None] | |
return a.expand(broadcast_shape) | |
def _extract_into_tensor(arr, timesteps, broadcast_shape): | |
""" | |
Extract values from a 1-D numpy array for a batch of indices. | |
:param arr: the 1-D numpy array. | |
:param timesteps: a tensor of indices into the array to extract. | |
:param broadcast_shape: a larger shape of K dimensions with the batch | |
dimension equal to the length of timesteps. | |
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. | |
""" | |
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() | |
while len(res.shape) < len(broadcast_shape): | |
res = res[..., None] | |
return res + torch.zeros(broadcast_shape, device=timesteps.device) | |