Spaces:
Build error
Build error
File size: 2,838 Bytes
8eb4303 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import torch
import torch.nn.functional as F
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
from random import random
from einops.layers.torch import Rearrange
from einops import rearrange, repeat, reduce, pack, unpack
def exists(val):
return val is not None
def identity(t):
return t
def default(val, d):
return val if exists(val) else d
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
def coin_flip():
return random() < 0.5
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
# tensor helpers
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
def reduce_masks_with_and(*masks):
masks = [*filter(exists, masks)]
if len(masks) == 0:
return None
mask, *rest_masks = masks
for rest_mask in rest_masks:
mask = mask & rest_mask
return mask
def interpolate_1d(t, length, mode = 'bilinear'):
" pytorch does not offer interpolation 1d, so hack by converting to 2d "
dtype = t.dtype
t = t.float()
implicit_one_channel = t.ndim == 2
if implicit_one_channel:
t = rearrange(t, 'b n -> b 1 n')
t = rearrange(t, 'b d n -> b d n 1')
t = F.interpolate(t, (length, 1), mode = mode)
t = rearrange(t, 'b d n 1 -> b d n')
if implicit_one_channel:
t = rearrange(t, 'b 1 n -> b n')
t = t.to(dtype)
return t
def curtail_or_pad(t, target_length):
length = t.shape[-2]
if length > target_length:
t = t[..., :target_length, :]
elif length < target_length:
t = F.pad(t, (0, 0, 0, target_length - length), value = 0.)
return t
# mask construction helpers
def mask_from_start_end_indices(
seq_len: int,
start: Tensor,
end: Tensor
):
assert start.shape == end.shape
device = start.device
seq = torch.arange(seq_len, device = device, dtype = torch.long)
seq = seq.reshape(*((-1,) * start.ndim), seq_len)
seq = seq.expand(*start.shape, seq_len)
mask = seq >= start[..., None].long()
mask &= seq < end[..., None].long()
return mask
def mask_from_frac_lengths(
seq_len: int,
frac_lengths: Tensor
):
device = frac_lengths
lengths = (frac_lengths * seq_len).long()
max_start = seq_len - lengths
rand = torch.zeros_like(frac_lengths).float().uniform_(0, 1)
start = (max_start * rand).clamp(min = 0)
end = start + lengths
return mask_from_start_end_indices(seq_len, start, end)
# sinusoidal positions |