File size: 3,318 Bytes
d13869d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7f22c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py
import math

import torch
from torch import nn


class TokenEmbedding(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        vocab_size: int,
        dropout: float = 0.0,
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim

        self.dropout = torch.nn.Dropout(p=dropout)
        self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)

    @property
    def weight(self) -> torch.Tensor:
        return self.word_embeddings.weight

    def embedding(self, index: int) -> torch.Tensor:
        return self.word_embeddings.weight[index : index + 1]

    def forward(self, x: torch.Tensor):
        x = self.word_embeddings(x)
        x = self.dropout(x)
        return x


class SinePositionalEmbeddingNested(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        dropout: float = 0.0,
        scale: bool = False,
        alpha: bool = False,
        max_batch_size: int = 20,
        max_seq_len: int = 2500,
    ):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
        self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

        self.reverse = False
        self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
        self.pe: torch.Tensor
        self.compute_pe()

    def compute_pe(self):
        """Reset the positional encodings."""
        if self.reverse:
            position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
        else:
            position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
        )
        pe = self.pe
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)

    def forward(self, input_pos: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input_pos (Tensor): [batch_size, ]
            x (Tensor): [batch_size, 1, embed_dim]

        Returns:
            embedded_x (Tensor): [batch_size, 1, embed_dim]
        """

        batch_size = x.shape[0]
        pe_values = self.pe[torch.arange(batch_size), input_pos - 1]  # (batch_size, embed_dim)

        return x * self.x_scale + self.alpha * pe_values.unsqueeze(1)  # (batch_size, 1, embed_dim)

    def prefill(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]

        Returns:
            embedded_x (Tensor): Nested Seqlen [batch_size, seq_len, embed_dim]
        """

        input_pos: torch.Tensor = torch.tensor([i.shape[0] for i in x.unbind()])
        pe_values = torch.nested.nested_tensor([self.pe[i, : input_pos[i], :] for i in range(input_pos.size(0))])
        return x * self.x_scale + self.alpha.item() * pe_values