import torch import torch.nn.functional as F from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor from random import random from einops.layers.torch import Rearrange from einops import rearrange, repeat, reduce, pack, unpack def exists(val): return val is not None def identity(t): return t def default(val, d): return val if exists(val) else d def divisible_by(num, den): return (num % den) == 0 def is_odd(n): return not divisible_by(n, 2) def coin_flip(): return random() < 0.5 def pack_one(t, pattern): return pack([t], pattern) def unpack_one(t, ps, pattern): return unpack(t, ps, pattern)[0] # tensor helpers def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device = device, dtype = torch.bool) elif prob == 0: return torch.zeros(shape, device = device, dtype = torch.bool) else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob def reduce_masks_with_and(*masks): masks = [*filter(exists, masks)] if len(masks) == 0: return None mask, *rest_masks = masks for rest_mask in rest_masks: mask = mask & rest_mask return mask def interpolate_1d(t, length, mode = 'bilinear'): " pytorch does not offer interpolation 1d, so hack by converting to 2d " dtype = t.dtype t = t.float() implicit_one_channel = t.ndim == 2 if implicit_one_channel: t = rearrange(t, 'b n -> b 1 n') t = rearrange(t, 'b d n -> b d n 1') t = F.interpolate(t, (length, 1), mode = mode) t = rearrange(t, 'b d n 1 -> b d n') if implicit_one_channel: t = rearrange(t, 'b 1 n -> b n') t = t.to(dtype) return t def curtail_or_pad(t, target_length): length = t.shape[-2] if length > target_length: t = t[..., :target_length, :] elif length < target_length: t = F.pad(t, (0, 0, 0, target_length - length), value = 0.) return t # mask construction helpers def mask_from_start_end_indices( seq_len: int, start: Tensor, end: Tensor ): assert start.shape == end.shape device = start.device seq = torch.arange(seq_len, device = device, dtype = torch.long) seq = seq.reshape(*((-1,) * start.ndim), seq_len) seq = seq.expand(*start.shape, seq_len) mask = seq >= start[..., None].long() mask &= seq < end[..., None].long() return mask def mask_from_frac_lengths( seq_len: int, frac_lengths: Tensor ): device = frac_lengths lengths = (frac_lengths * seq_len).long() max_start = seq_len - lengths rand = torch.zeros_like(frac_lengths).float().uniform_(0, 1) start = (max_start * rand).clamp(min = 0) end = start + lengths return mask_from_start_end_indices(seq_len, start, end) # sinusoidal positions