Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,318 Bytes
d13869d d7f22c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
import math
import torch
from torch import nn
class TokenEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
vocab_size: int,
dropout: float = 0.0,
):
super().__init__()
self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.dropout = torch.nn.Dropout(p=dropout)
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
@property
def weight(self) -> torch.Tensor:
return self.word_embeddings.weight
def embedding(self, index: int) -> torch.Tensor:
return self.word_embeddings.weight[index : index + 1]
def forward(self, x: torch.Tensor):
x = self.word_embeddings(x)
x = self.dropout(x)
return x
class SinePositionalEmbeddingNested(nn.Module):
def __init__(
self,
embedding_dim: int,
dropout: float = 0.0,
scale: bool = False,
alpha: bool = False,
max_batch_size: int = 20,
max_seq_len: int = 2500,
):
super().__init__()
self.embedding_dim = embedding_dim
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
self.dropout = torch.nn.Dropout(p=dropout)
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.reverse = False
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
self.pe: torch.Tensor
self.compute_pe()
def compute_pe(self):
"""Reset the positional encodings."""
if self.reverse:
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
)
pe = self.pe
pe[:, :, 0::2] = torch.sin(position * div_term)
pe[:, :, 1::2] = torch.cos(position * div_term)
def forward(self, input_pos: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Args:
input_pos (Tensor): [batch_size, ]
x (Tensor): [batch_size, 1, embed_dim]
Returns:
embedded_x (Tensor): [batch_size, 1, embed_dim]
"""
batch_size = x.shape[0]
pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
def prefill(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]
Returns:
embedded_x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]
"""
input_pos: torch.Tensor = torch.tensor([i.shape[0] for i in x.unbind()])
pe_values = torch.nested.nested_tensor([self.pe[i, : input_pos[i], :] for i in range(input_pos.size(0))])
return x * self.x_scale + self.alpha.item() * pe_values
|