Spaces:
Running
Running
File size: 2,688 Bytes
6563ff2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
import torch.nn as nn
from .usta_decoder_block import UstaDecoderBlock
from .usta_embedding import UstaEmbedding
class UstaModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, num_heads, context_length, num_layers, device):
super().__init__()
self.embedding = UstaEmbedding(vocab_size, embedding_dim, context_length, device)
self.layers = nn.Sequential(
*[UstaDecoderBlock(embedding_dim, num_heads, context_length, device) for _ in range(num_layers)]
)
self.lm_head = nn.Linear(embedding_dim, vocab_size, device=device)
self.device = device
def forward(self, x: torch.Tensor):
x = self.embedding(x) # dictionary meaning of the tokens (words)
x = self.layers(x)
x = self.lm_head(x)
return x
""" out = u_model(torch.tensor(new_tokens))
probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs
"""
def top_p_filtering(self, logits, top_p):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
sorted_logits[sorted_indices_to_remove] = -float('inf')
filtered_logits = sorted_logits.clone()
filtered_logits.scatter_(0, sorted_indices, sorted_logits)
return filtered_logits
def generate(self,
x: torch.Tensor,
max_new_tokens: int = 3,
temperature: float = 1.0,
top_k: int = 64,
top_p: float = 1.0
): # top_k, top_p, temperature
tokens = x.tolist()
for _ in range(max_new_tokens):
x = x.unsqueeze(0).to(self.device)
out = self.forward(x)
out = out.squeeze(0)
logits = out[-1]
if top_k > 0:
values, indexes = torch.topk(logits, k=top_k)
logits = torch.full_like(logits, -float('inf'))
logits.scatter_(0, indexes, values)
if top_p > 0 and top_p < 1:
logits = self.top_p_filtering(logits, top_p)
if temperature != 1.0 and temperature > 0:
logits = logits / temperature
probs = torch.softmax(values, dim=-1)
# _, max_index = torch.max(probs, dim=-1)
sample = torch.multinomial(probs, 1)
max_index = indexes[sample]
tokens.append(max_index.item())
if max_index == 59 or len(tokens) > 32: # <eos> and max context length
break
x = torch.tensor(tokens)
return tokens
|