Spaces:
Build error
Build error
import torch | |
import torch.nn.functional as F | |
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor | |
from random import random | |
from einops.layers.torch import Rearrange | |
from einops import rearrange, repeat, reduce, pack, unpack | |
def exists(val): | |
return val is not None | |
def identity(t): | |
return t | |
def default(val, d): | |
return val if exists(val) else d | |
def divisible_by(num, den): | |
return (num % den) == 0 | |
def is_odd(n): | |
return not divisible_by(n, 2) | |
def coin_flip(): | |
return random() < 0.5 | |
def pack_one(t, pattern): | |
return pack([t], pattern) | |
def unpack_one(t, ps, pattern): | |
return unpack(t, ps, pattern)[0] | |
# tensor helpers | |
def prob_mask_like(shape, prob, device): | |
if prob == 1: | |
return torch.ones(shape, device = device, dtype = torch.bool) | |
elif prob == 0: | |
return torch.zeros(shape, device = device, dtype = torch.bool) | |
else: | |
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob | |
def reduce_masks_with_and(*masks): | |
masks = [*filter(exists, masks)] | |
if len(masks) == 0: | |
return None | |
mask, *rest_masks = masks | |
for rest_mask in rest_masks: | |
mask = mask & rest_mask | |
return mask | |
def interpolate_1d(t, length, mode = 'bilinear'): | |
" pytorch does not offer interpolation 1d, so hack by converting to 2d " | |
dtype = t.dtype | |
t = t.float() | |
implicit_one_channel = t.ndim == 2 | |
if implicit_one_channel: | |
t = rearrange(t, 'b n -> b 1 n') | |
t = rearrange(t, 'b d n -> b d n 1') | |
t = F.interpolate(t, (length, 1), mode = mode) | |
t = rearrange(t, 'b d n 1 -> b d n') | |
if implicit_one_channel: | |
t = rearrange(t, 'b 1 n -> b n') | |
t = t.to(dtype) | |
return t | |
def curtail_or_pad(t, target_length): | |
length = t.shape[-2] | |
if length > target_length: | |
t = t[..., :target_length, :] | |
elif length < target_length: | |
t = F.pad(t, (0, 0, 0, target_length - length), value = 0.) | |
return t | |
# mask construction helpers | |
def mask_from_start_end_indices( | |
seq_len: int, | |
start: Tensor, | |
end: Tensor | |
): | |
assert start.shape == end.shape | |
device = start.device | |
seq = torch.arange(seq_len, device = device, dtype = torch.long) | |
seq = seq.reshape(*((-1,) * start.ndim), seq_len) | |
seq = seq.expand(*start.shape, seq_len) | |
mask = seq >= start[..., None].long() | |
mask &= seq < end[..., None].long() | |
return mask | |
def mask_from_frac_lengths( | |
seq_len: int, | |
frac_lengths: Tensor | |
): | |
device = frac_lengths | |
lengths = (frac_lengths * seq_len).long() | |
max_start = seq_len - lengths | |
rand = torch.zeros_like(frac_lengths).float().uniform_(0, 1) | |
start = (max_start * rand).clamp(min = 0) | |
end = start + lengths | |
return mask_from_start_end_indices(seq_len, start, end) | |
# sinusoidal positions |