File size: 2,688 Bytes
6563ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn as nn

from .usta_decoder_block import UstaDecoderBlock
from .usta_embedding import UstaEmbedding


class UstaModel(nn.Module):
  def __init__(self, vocab_size, embedding_dim, num_heads, context_length, num_layers, device):
    super().__init__()

    self.embedding = UstaEmbedding(vocab_size, embedding_dim, context_length, device)
    self.layers = nn.Sequential(
      *[UstaDecoderBlock(embedding_dim, num_heads, context_length, device) for _ in range(num_layers)]
    )

    self.lm_head = nn.Linear(embedding_dim, vocab_size, device=device)
    self.device = device

  def forward(self, x: torch.Tensor):
    x = self.embedding(x) # dictionary meaning of the tokens (words)
    
    x = self.layers(x)
    x = self.lm_head(x)

    return x


  """ out = u_model(torch.tensor(new_tokens))

  probs = torch.softmax(out[-1], dim=-1)
  max_prob, max_index = torch.max(probs, dim=-1)
  max_prob, max_index, probs
  """

  def top_p_filtering(self, logits, top_p):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False
    
    sorted_logits[sorted_indices_to_remove] = -float('inf')
    filtered_logits = sorted_logits.clone()
    filtered_logits.scatter_(0, sorted_indices, sorted_logits)
    return filtered_logits



  def generate(self, 
               x: torch.Tensor,
               max_new_tokens: int = 3,
               temperature: float = 1.0,
               top_k: int = 64,
               top_p: float = 1.0
              ): # top_k, top_p, temperature
    tokens = x.tolist()
    
    for _ in range(max_new_tokens):
      x = x.unsqueeze(0).to(self.device)
      out = self.forward(x)
      out = out.squeeze(0)
      logits = out[-1]
      if top_k > 0:
        values, indexes = torch.topk(logits, k=top_k)
        logits = torch.full_like(logits, -float('inf'))
        logits.scatter_(0, indexes, values)

      if top_p > 0 and top_p < 1:
        logits = self.top_p_filtering(logits, top_p)

      if temperature != 1.0 and temperature > 0:
        logits = logits / temperature
      
      probs = torch.softmax(values, dim=-1)
      # _, max_index = torch.max(probs, dim=-1)
      sample = torch.multinomial(probs, 1)
      max_index = indexes[sample]
      tokens.append(max_index.item())
      if max_index == 59 or len(tokens) > 32: # <eos> and max context length
        break
      
      x = torch.tensor(tokens)

    return tokens