alibayram commited on
Commit
8d4b0c7
·
1 Parent(s): 66af716

space update

Browse files
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
app.py CHANGED
@@ -1,11 +1,79 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  def respond(
11
  message,
@@ -15,30 +83,49 @@ def respond(
15
  temperature,
16
  top_p,
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  yield response
41
-
 
 
42
 
43
  """
44
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
@@ -46,19 +133,32 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
46
  demo = gr.ChatInterface(
47
  respond,
48
  additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
 
 
 
 
52
  gr.Slider(
53
  minimum=0.1,
54
  maximum=1.0,
55
  value=0.95,
56
  step=0.05,
57
  label="Top-p (nucleus sampling)",
 
58
  ),
59
  ],
 
 
 
 
 
 
 
 
 
60
  )
61
 
62
-
63
  if __name__ == "__main__":
64
  demo.launch()
 
1
+ import os
2
+
3
  import gradio as gr
4
+ import torch
5
+
6
+ from v1.usta_model import UstaModel
7
+ from v1.usta_tokenizer import UstaTokenizer
8
 
 
 
 
 
9
 
10
+ # Load the model and tokenizer
11
+ def load_model():
12
+ try:
13
+ u_tokenizer = UstaTokenizer("v1/tokenizer.json")
14
+
15
+ # Model parameters - adjust these to match your trained model
16
+ context_length = 32
17
+ vocab_size = len(u_tokenizer.vocab)
18
+ embedding_dim = 12
19
+ num_heads = 4
20
+ num_layers = 8
21
+
22
+ # Load the model
23
+ u_model = UstaModel(
24
+ vocab_size=vocab_size,
25
+ embedding_dim=embedding_dim,
26
+ num_heads=num_heads,
27
+ context_length=context_length,
28
+ num_layers=num_layers
29
+ )
30
+
31
+ # Load the trained weights if available
32
+ model_path = "v1/u_model.pth"
33
+
34
+ if not os.path.exists(model_path):
35
+ # Download the model file from GitHub
36
+ try:
37
+ print("📥 Downloading model weights from GitHub...")
38
+ import requests
39
+ url = "https://github.com/malibayram/llm-from-scratch/raw/main/u_model.pth"
40
+ response = requests.get(url)
41
+ response.raise_for_status() # Raise an exception for bad status codes
42
+
43
+ # Create v1 directory if it doesn't exist
44
+ os.makedirs("v1", exist_ok=True)
45
+
46
+ with open(model_path, "wb") as f:
47
+ f.write(response.content)
48
+ print("✅ Model weights downloaded successfully!")
49
+ except Exception as e:
50
+ print(f"❌ Failed to download model weights: {e}")
51
+ print("Using random initialization.")
52
+
53
+ if os.path.exists(model_path):
54
+ try:
55
+ u_model.load_state_dict(torch.load(model_path, map_location="cpu"))
56
+ u_model.eval()
57
+ print("✅ Model weights loaded successfully!")
58
+ except Exception as e:
59
+ print(f"⚠️ Warning: Could not load trained weights: {e}")
60
+ print("Using random initialization.")
61
+ else:
62
+ print(f"⚠️ Model file not found at {model_path}. Using random initialization.")
63
+
64
+ return u_model, u_tokenizer
65
+
66
+ except Exception as e:
67
+ print(f"❌ Error loading model: {e}")
68
+ raise e
69
+
70
+ # Initialize model and tokenizer globally
71
+ try:
72
+ model, tokenizer = load_model()
73
+ print("🚀 UstaModel and tokenizer initialized successfully!")
74
+ except Exception as e:
75
+ print(f"❌ Failed to initialize model: {e}")
76
+ model, tokenizer = None, None
77
 
78
  def respond(
79
  message,
 
83
  temperature,
84
  top_p,
85
  ):
86
+ """
87
+ Generate a response using the UstaModel
88
+ """
89
+ if model is None or tokenizer is None:
90
+ yield "Sorry, the UstaModel is not available. Please try again later."
91
+ return
92
+
93
+ try:
94
+ # For UstaModel, we'll use the message directly (ignoring system_message for now)
95
+ # since it's a simpler model focused on geographical knowledge
96
+
97
+ # Encode the input message
98
+ tokens = tokenizer.encode(message)
99
+
100
+ # Make sure we don't exceed context length
101
+ if len(tokens) > 25: # Leave some room for generation
102
+ tokens = tokens[-25:]
103
+
104
+ # Generate response
105
+ with torch.no_grad():
106
+ # Use max_tokens parameter, but cap it at reasonable limit for this model
107
+ actual_max_tokens = min(max_tokens, 32 - len(tokens))
108
+ generated_tokens = model.generate(tokens, actual_max_tokens)
109
+
110
+ # Decode the generated tokens
111
+ response = tokenizer.decode(generated_tokens)
112
+
113
+ # Clean up the response (remove the original input)
114
+ original_text = tokenizer.decode(tokens.tolist())
115
+ if response.startswith(original_text):
116
+ response = response[len(original_text):]
117
+
118
+ # Clean up any unwanted tokens
119
+ response = response.replace("<unk>", "").replace("<pad>", "").strip()
120
+
121
+ if not response:
122
+ response = "I'm not sure how to respond to that with my geographical knowledge."
123
+
124
+ # Yield the response (to maintain compatibility with streaming interface)
125
  yield response
126
+
127
+ except Exception as e:
128
+ yield f"Sorry, I encountered an error: {str(e)}"
129
 
130
  """
131
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
 
133
  demo = gr.ChatInterface(
134
  respond,
135
  additional_inputs=[
136
+ gr.Textbox(
137
+ value="You are Usta, a geographical knowledge assistant trained from scratch.",
138
+ label="System message",
139
+ info="Note: This model focuses on geographical knowledge (countries, capitals, cities)"
140
+ ),
141
+ gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max new tokens"),
142
+ gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
143
  gr.Slider(
144
  minimum=0.1,
145
  maximum=1.0,
146
  value=0.95,
147
  step=0.05,
148
  label="Top-p (nucleus sampling)",
149
+ info="Note: This parameter is not used by UstaModel but kept for interface compatibility"
150
  ),
151
  ],
152
+ title="🤖 Usta Model Chat",
153
+ description="Chat with a custom transformer language model built from scratch! This model specializes in geographical knowledge including countries, capitals, and cities.",
154
+ examples=[
155
+ "the capital of france",
156
+ "tell me about spain",
157
+ "what is the capital of united states",
158
+ "paris is in",
159
+ "germany and its capital"
160
+ ]
161
  )
162
 
 
163
  if __name__ == "__main__":
164
  demo.launch()
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- huggingface_hub==0.25.2
 
 
1
+ torch>=2.7.1
2
+ requests>=2.32.4
v1/__init__.py ADDED
File without changes
v1/tokenizer.json ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "the": 0,
3
+ "capital": 1,
4
+ "of": 2,
5
+ "united": 3,
6
+ "state": 4,
7
+ "is": 5,
8
+ "not": 6,
9
+ "london": 7,
10
+ "france": 8,
11
+ "paris": 9,
12
+ "and": 10,
13
+ "berlin": 11,
14
+ "germany": 12,
15
+ "rome": 13,
16
+ "in": 14,
17
+ "italy": 15,
18
+ "madrid": 16,
19
+ "spain": 17,
20
+ "lisbon": 18,
21
+ "portugal": 19,
22
+ "kingdom": 20,
23
+ "washington": 21,
24
+ "although": 22,
25
+ "these": 23,
26
+ "place": 24,
27
+ "are": 25,
28
+ "often": 26,
29
+ "mention": 27,
30
+ "together": 28,
31
+ "each": 29,
32
+ "country": 30,
33
+ "has": 31,
34
+ "its": 32,
35
+ "own": 33,
36
+ "identity": 34,
37
+ "any": 35,
38
+ "european": 36,
39
+ "city": 37,
40
+ "remain": 38,
41
+ "important": 39,
42
+ "with": 40,
43
+ "a": 41,
44
+ "rich": 42,
45
+ "history": 43,
46
+ "culture": 44,
47
+ "europe": 45,
48
+ "made": 46,
49
+ "many": 47,
50
+ "unique": 48,
51
+ "world": 49,
52
+ "while": 50,
53
+ "known": 51,
54
+ "for": 52,
55
+ "art": 53,
56
+ "fashion": 54,
57
+ "famous": 55,
58
+ "they": 56,
59
+ "ed": 57,
60
+ "s": 58,
61
+ ".": 59,
62
+ ",": 60,
63
+ " ": 61,
64
+ "<unk>": 62,
65
+ "<pad>": 63
66
+ }
v1/usta_causal_attention.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class UstaCausalAttention(nn.Module):
6
+ def __init__(self, embedding_dim, output_dim, context_length, dropout_rate = 0):
7
+ super().__init__()
8
+ self.embedding_dim = embedding_dim
9
+
10
+ self.q_weights = nn.Linear(embedding_dim, output_dim, bias=False)
11
+ self.k_weights = nn.Linear(embedding_dim, output_dim, bias=False)
12
+ self.v_weights = nn.Linear(embedding_dim, output_dim, bias=False)
13
+ self.dropout = nn.Dropout(dropout_rate)
14
+ self.register_buffer("mask", torch.tril(torch.ones(context_length, context_length)))
15
+ self.context_length = context_length
16
+
17
+ def forward(self, x):
18
+ number_of_tokens = x.shape[0]
19
+ # truncate the context length to the context length of the model
20
+ x = x[:self.context_length]
21
+ q = self.q_weights(x)
22
+ k = self.k_weights(x)
23
+ v = self.v_weights(x)
24
+
25
+ attention_scores = q @ k.T
26
+ attention_scores = attention_scores.masked_fill_(
27
+ self.mask.bool()[:number_of_tokens, :number_of_tokens] == 0, -torch.inf
28
+ )
29
+ attention_scores = torch.softmax(attention_scores / k.shape[-1] ** 0.5, dim=1)
30
+ attention_scores = self.dropout(attention_scores)
31
+
32
+ return attention_scores @ v
33
+
v1/usta_decoder_block.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .usta_layer_norm import UstaLayerNorm
4
+ from .usta_mlp import UstaMLP
5
+ from .usta_multi_head_attention import UstaMultiHeadAttention
6
+
7
+
8
+ class UstaDecoderBlock(nn.Module):
9
+ def __init__(self, embedding_dim, num_heads, context_length):
10
+ super().__init__()
11
+
12
+ self.self_attention = UstaMultiHeadAttention(embedding_dim, embedding_dim, context_length, num_heads, dropout_rate=0.5)
13
+ self.norm1 = UstaLayerNorm(embedding_dim)
14
+ self.mlp = UstaMLP(embedding_dim, embedding_dim)
15
+ self.norm2 = UstaLayerNorm(embedding_dim)
16
+
17
+ def forward(self, x):
18
+ res = self.norm1(x)
19
+
20
+ x = self.self_attention(x)
21
+ x = self.norm1(x)
22
+
23
+ x = x + res
24
+
25
+ res = self.norm2(x)
26
+ x = self.mlp(x)
27
+ x = self.norm2(x)
28
+
29
+ x = x + res
30
+
31
+ return x
v1/usta_embedding.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ def get_rotary_position_encoding(input: torch.Tensor, base=10000, device="cpu"):
6
+ context_length, dimension = input.shape
7
+
8
+ assert dimension % 2 == 0, "dimension must be even"
9
+
10
+ half_dimension = dimension // 2
11
+
12
+ freqs_indices = torch.arange(0, half_dimension, device=device, dtype=torch.float32)
13
+
14
+ freqs = 1.0 / (base ** (freqs_indices / dimension))
15
+
16
+ positions = torch.arange(0, context_length, device=device, dtype=torch.float32).unsqueeze(1)
17
+
18
+ angles = positions * freqs
19
+
20
+ sin_angles = torch.sin(angles)
21
+ cos_angles = torch.cos(angles)
22
+
23
+ input_even = input[:, :dimension // 2] # [0, 2, 4, ..]
24
+ input_odd = input[:, dimension // 2:] # [1, 3, 5, ..]
25
+
26
+ input_even_rotated = input_even * cos_angles - input_odd * sin_angles
27
+ input_odd_rotated = input_even * sin_angles + input_odd * cos_angles
28
+
29
+ input_rotated = torch.empty_like(input)
30
+
31
+ input_rotated[:, :dimension // 2] = input_even_rotated
32
+ input_rotated[:, dimension // 2:] = input_odd_rotated
33
+
34
+ return input_rotated
35
+
36
+ class UstaEmbedding(nn.Module):
37
+ def __init__(self, vocab_size, embedding_dim, context_length):
38
+ super().__init__()
39
+ # position embedding but not being used in the forward pass
40
+ # it is just for educational purposes
41
+ # self.pos_embedding = nn.Embedding(context_length, embedding_dim)
42
+ # self.get_pos = get_rotary_position_encoding
43
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
44
+ self.get_pos = get_rotary_position_encoding
45
+
46
+ def forward(self, x):
47
+ x = self.embedding(x)
48
+ x = self.get_pos(x)
49
+ return x
v1/usta_layer_norm.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class UstaLayerNorm(nn.Module):
6
+ def __init__(self, embedding_dim, eps=1e-5):
7
+ super().__init__()
8
+ self.eps = eps
9
+
10
+ self.weight = nn.Parameter(torch.ones(embedding_dim))
11
+
12
+
13
+ def forward(self, x):
14
+ mean = x.mean(dim=-1, keepdim=True)
15
+ variance = x.var(dim=-1, keepdim=True, unbiased=False)
16
+ normalized_x = (x - mean) / torch.sqrt(variance + self.eps)
17
+ return self.weight * normalized_x
18
+
v1/usta_mlp.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class GELU(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, x):
10
+ return 0.5 * x * (
11
+ 1 + torch.tanh(
12
+ torch.sqrt(torch.tensor(2 / torch.pi)) * (x + 0.044715 * torch.pow(x, 3))
13
+ )
14
+ )
15
+
16
+ class UstaMLP(nn.Module):
17
+ def __init__(self, embedding_dim, hidden_dim):
18
+ super().__init__()
19
+
20
+ self.gate_proj = nn.Linear(embedding_dim, hidden_dim)
21
+ self.up_proj = nn.Linear(embedding_dim, hidden_dim)
22
+ self.down_proj = nn.Linear(hidden_dim, embedding_dim)
23
+ self.gelu = GELU()
24
+
25
+ def forward(self, x):
26
+ """ gate = self.gate_proj(x)
27
+ gate = F.gelu(gate, approximate="tanh")
28
+ up = self.up_proj(x)
29
+ fuse = gate * up
30
+ outputs = self.down_proj(fuse) """
31
+ gate = self.gate_proj(x)
32
+ gate = self.gelu(gate)
33
+ up = self.up_proj(x)
34
+ fuse = gate * up
35
+ outputs = self.down_proj(fuse)
36
+ return outputs
v1/usta_model.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .usta_decoder_block import UstaDecoderBlock
5
+ from .usta_embedding import UstaEmbedding
6
+
7
+
8
+ class UstaModel(nn.Module):
9
+ def __init__(self, vocab_size, embedding_dim, num_heads, context_length, num_layers):
10
+ super().__init__()
11
+
12
+ self.embedding = UstaEmbedding(vocab_size, embedding_dim, context_length)
13
+ self.layers = nn.Sequential(
14
+ *[UstaDecoderBlock(embedding_dim, num_heads, context_length) for _ in range(num_layers)]
15
+ )
16
+
17
+ self.lm_head = nn.Linear(embedding_dim, vocab_size)
18
+
19
+ def forward(self, x: torch.Tensor):
20
+ x = self.embedding(x) # dictionary meaning of the tokens (words)
21
+
22
+ x = self.layers(x)
23
+ x = self.lm_head(x)
24
+
25
+ return x
26
+
27
+
28
+ """ out = u_model(torch.tensor(new_tokens))
29
+
30
+ probs = torch.softmax(out[-1], dim=-1)
31
+ max_prob, max_index = torch.max(probs, dim=-1)
32
+ max_prob, max_index, probs
33
+ """
34
+
35
+ def generate(self, x: torch.Tensor, max_new_tokens: int): # top_k, top_p, temperature
36
+ tokens = x.detach().cpu().numpy().tolist()
37
+
38
+ for _ in range(max_new_tokens):
39
+ out = self.forward(x)
40
+ probs = torch.softmax(out[-1], dim=-1)
41
+ _, max_index = torch.max(probs, dim=-1)
42
+ tokens.append(max_index.item())
43
+ if max_index == 59 or len(tokens) > 32: # <eos> and max context length
44
+ break
45
+
46
+ x = torch.tensor(tokens)
47
+
48
+ return tokens
49
+
50
+
51
+
52
+
v1/usta_multi_head_attention.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class UstaMultiHeadAttention(nn.Module):
6
+ def __init__(self, embedding_dim, output_dim, context_length, num_heads, dropout_rate = 0):
7
+ super().__init__()
8
+
9
+ self.context_length = context_length
10
+
11
+ self.multi_head_attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout=dropout_rate)
12
+ self.projection = nn.Linear(embedding_dim, output_dim)
13
+
14
+ self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())
15
+
16
+ def forward(self, x):
17
+ number_of_tokens = x.shape[0]
18
+ x = x[:self.context_length]
19
+ attention_mask = self.mask[:number_of_tokens, :number_of_tokens]
20
+ out, _ = self.multi_head_attention(x, x, x, attn_mask=attention_mask)
21
+ out = self.projection(out)
22
+ return out
v1/usta_multi_head_attention_old.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .usta_causal_attention import UstaCausalAttention
5
+
6
+
7
+ class UstaMultiHeadAttention(nn.Module):
8
+ def __init__(self, embedding_dim, output_dim, context_length, num_heads, dropout_rate = 0):
9
+ super().__init__()
10
+
11
+ self.heads = nn.ModuleList(
12
+ [UstaCausalAttention(embedding_dim, output_dim, context_length, dropout_rate) for _ in range(num_heads)]
13
+ )
14
+
15
+ self.projection = nn.Linear(embedding_dim, output_dim)
16
+
17
+ def forward(self, x):
18
+ attention_outs = []
19
+ for head in self.heads:
20
+ head_out = head(x)
21
+ attention_outs.append(head_out)
22
+
23
+ attention_out = torch.cat(attention_outs, dim=1)
24
+
25
+ return self.projection(attention_out)
26
+
v1/usta_self_attention.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class UstaSelfAttention(nn.Module):
6
+ def __init__(self, embedding_dim, output_dim):
7
+ super().__init__()
8
+ self.embedding_dim = embedding_dim
9
+
10
+ self.q_weights = nn.Linear(embedding_dim, output_dim, bias=False)
11
+ self.k_weights = nn.Linear(embedding_dim, output_dim, bias=False)
12
+ self.v_weights = nn.Linear(embedding_dim, output_dim, bias=False)
13
+
14
+ def forward(self, x):
15
+ q = self.q_weights(x)
16
+ k = self.k_weights(x)
17
+ v = self.v_weights(x)
18
+
19
+ attention_scores = q @ k.T
20
+ attention_weights = torch.softmax(attention_scores / k.shape[-1] ** 0.5, dim=1)
21
+ return attention_weights @ v
22
+
v1/usta_tokenizer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import torch
4
+
5
+
6
+ class UstaTokenizer:
7
+ def __init__(self, vocab_file):
8
+ with open(vocab_file, "r") as f:
9
+ self.vocab = json.load(f)
10
+ self.reverse_vocab = {v: k for k, v in self.vocab.items()}
11
+
12
+ def encode(self, text):
13
+ tokens = []
14
+
15
+ for word in text.split():
16
+ i = 0
17
+ # example: states
18
+ # state => 4
19
+ # s => 58
20
+ while i < len(word):
21
+ found_match = False
22
+ for j in range(len(word), i, -1):
23
+ sub_word = word[i:j]
24
+ if sub_word in self.vocab:
25
+ tokens.append(self.vocab[sub_word])
26
+ i = j
27
+ found_match = True
28
+ break
29
+ if not found_match:
30
+ tokens.append(self.vocab["<unk>"])
31
+ i += 1
32
+ tokens.append(self.vocab[" "])
33
+
34
+ tokens.pop()
35
+ return torch.tensor(tokens)
36
+
37
+ def tokenize(self, text):
38
+ token_ids = self.encode(text)
39
+ # token_ids from tensor to list
40
+ token_ids = token_ids.detach().numpy().tolist()
41
+
42
+ return [self.reverse_vocab[id] for id in token_ids]
43
+
44
+ def decode(self, ids):
45
+ text = ""
46
+ for id in ids:
47
+ text += self.reverse_vocab[id]
48
+ return text