alibayram commited on
Commit
6563ff2
Β·
1 Parent(s): 67856b9

v2 implemented

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -3,14 +3,14 @@ import os
3
  import gradio as gr
4
  import torch
5
 
6
- from v1.usta_model import UstaModel
7
- from v1.usta_tokenizer import UstaTokenizer
8
 
9
 
10
  # Load the model and tokenizer
11
  def load_model(custom_model_path=None):
12
  try:
13
- u_tokenizer = UstaTokenizer("v1/tokenizer.json")
14
  print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
15
 
16
  # Model parameters - adjust these to match your trained model
@@ -19,6 +19,7 @@ def load_model(custom_model_path=None):
19
  embedding_dim = 12
20
  num_heads = 4
21
  num_layers = 8
 
22
 
23
  # Load the model
24
  u_model = UstaModel(
@@ -26,7 +27,8 @@ def load_model(custom_model_path=None):
26
  embedding_dim=embedding_dim,
27
  num_heads=num_heads,
28
  context_length=context_length,
29
- num_layers=num_layers
 
30
  )
31
 
32
  # Determine which model file to use
@@ -34,7 +36,7 @@ def load_model(custom_model_path=None):
34
  model_path = custom_model_path
35
  print(f"🎯 Using uploaded model: {model_path}")
36
  else:
37
- model_path = "v1/u_model.pth"
38
 
39
  if not os.path.exists(model_path):
40
  print("❌ Model file not found at", model_path)
@@ -58,8 +60,8 @@ def load_model(custom_model_path=None):
58
 
59
  print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
60
 
61
- # Create v1 directory if it doesn't exist
62
- os.makedirs("v1", exist_ok=True)
63
 
64
  # Save the model weights to the local file system
65
  with open(model_path, "wb") as f:
@@ -195,7 +197,7 @@ def load_model_from_file(uploaded_file):
195
  model_status = error_msg
196
  return error_msg
197
 
198
- def chat_with_usta(message, history, max_tokens=20):
199
  """Simple chat function"""
200
  if model is None or tokenizer is None:
201
  return history + [["Error", "UstaModel is not available. Please try again later."]]
@@ -211,7 +213,13 @@ def chat_with_usta(message, history, max_tokens=20):
211
  # Generate response
212
  with torch.no_grad():
213
  actual_max_tokens = min(max_tokens, 32 - len(tokens))
214
- generated_tokens = model.generate(tokens, actual_max_tokens)
 
 
 
 
 
 
215
 
216
  # Decode the generated tokens
217
  response = tokenizer.decode(generated_tokens)
@@ -249,7 +257,14 @@ with gr.Blocks(title="πŸ€– Usta Model Chat") as demo:
249
  clear_btn = gr.Button("Clear")
250
 
251
  # Generation settings
252
- max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max tokens")
 
 
 
 
 
 
 
253
 
254
  # Model loading (simplified)
255
  gr.Markdown("## πŸ”§ Load Custom Model (Optional)")
@@ -268,20 +283,20 @@ with gr.Blocks(title="πŸ€– Usta Model Chat") as demo:
268
  status = gr.Textbox(label="Status", value=model_status, interactive=False)
269
 
270
  # Event handlers
271
- def send_message(message, history, max_tok):
272
  if not message.strip():
273
  return history, ""
274
- return chat_with_usta(message, history, max_tok), ""
275
 
276
  send_btn.click(
277
  send_message,
278
- inputs=[msg, chatbot, max_tokens],
279
  outputs=[chatbot, msg]
280
  )
281
 
282
  msg.submit(
283
  send_message,
284
- inputs=[msg, chatbot, max_tokens],
285
  outputs=[chatbot, msg]
286
  )
287
 
 
3
  import gradio as gr
4
  import torch
5
 
6
+ from v2.usta_model import UstaModel
7
+ from v2.usta_tokenizer import UstaTokenizer
8
 
9
 
10
  # Load the model and tokenizer
11
  def load_model(custom_model_path=None):
12
  try:
13
+ u_tokenizer = UstaTokenizer("v2/tokenizer.json")
14
  print("βœ… Tokenizer loaded successfully! vocab size:", len(u_tokenizer.vocab))
15
 
16
  # Model parameters - adjust these to match your trained model
 
19
  embedding_dim = 12
20
  num_heads = 4
21
  num_layers = 8
22
+ device = "cpu" # Use CPU for compatibility
23
 
24
  # Load the model
25
  u_model = UstaModel(
 
27
  embedding_dim=embedding_dim,
28
  num_heads=num_heads,
29
  context_length=context_length,
30
+ num_layers=num_layers,
31
+ device=device
32
  )
33
 
34
  # Determine which model file to use
 
36
  model_path = custom_model_path
37
  print(f"🎯 Using uploaded model: {model_path}")
38
  else:
39
+ model_path = "v2/u_model_4000.pth"
40
 
41
  if not os.path.exists(model_path):
42
  print("❌ Model file not found at", model_path)
 
60
 
61
  print(f"πŸ“¦ Downloaded {len(response.content)} bytes")
62
 
63
+ # Create v2 directory if it doesn't exist
64
+ os.makedirs("v2", exist_ok=True)
65
 
66
  # Save the model weights to the local file system
67
  with open(model_path, "wb") as f:
 
197
  model_status = error_msg
198
  return error_msg
199
 
200
+ def chat_with_usta(message, history, max_tokens=20, temperature=1.0, top_k=64, top_p=1.0):
201
  """Simple chat function"""
202
  if model is None or tokenizer is None:
203
  return history + [["Error", "UstaModel is not available. Please try again later."]]
 
213
  # Generate response
214
  with torch.no_grad():
215
  actual_max_tokens = min(max_tokens, 32 - len(tokens))
216
+ generated_tokens = model.generate(
217
+ tokens,
218
+ max_new_tokens=actual_max_tokens,
219
+ temperature=temperature,
220
+ top_k=top_k,
221
+ top_p=top_p
222
+ )
223
 
224
  # Decode the generated tokens
225
  response = tokenizer.decode(generated_tokens)
 
257
  clear_btn = gr.Button("Clear")
258
 
259
  # Generation settings
260
+ gr.Markdown("## βš™οΈ Generation Settings")
261
+ with gr.Row():
262
+ max_tokens = gr.Slider(minimum=1, maximum=30, value=20, step=1, label="Max tokens")
263
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
264
+
265
+ with gr.Row():
266
+ top_k = gr.Slider(minimum=1, maximum=64, value=40, step=1, label="Top-k")
267
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.05, label="Top-p (nucleus sampling)")
268
 
269
  # Model loading (simplified)
270
  gr.Markdown("## πŸ”§ Load Custom Model (Optional)")
 
283
  status = gr.Textbox(label="Status", value=model_status, interactive=False)
284
 
285
  # Event handlers
286
+ def send_message(message, history, max_tok, temp, k, p):
287
  if not message.strip():
288
  return history, ""
289
+ return chat_with_usta(message, history, max_tok, temp, k, p), ""
290
 
291
  send_btn.click(
292
  send_message,
293
+ inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p],
294
  outputs=[chatbot, msg]
295
  )
296
 
297
  msg.submit(
298
  send_message,
299
+ inputs=[msg, chatbot, max_tokens, temperature, top_k, top_p],
300
  outputs=[chatbot, msg]
301
  )
302
 
module_3_3.ipynb ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Using device: mps\n",
13
+ "tensor([ 0, 61, 1, 61, 2, 61, 0, 61, 3], device='mps:0')\n"
14
+ ]
15
+ },
16
+ {
17
+ "data": {
18
+ "text/plain": [
19
+ "torch.Size([4, 32])"
20
+ ]
21
+ },
22
+ "execution_count": 1,
23
+ "metadata": {},
24
+ "output_type": "execute_result"
25
+ }
26
+ ],
27
+ "source": [
28
+ "import torch\n",
29
+ "\n",
30
+ "from usta_model import UstaModel\n",
31
+ "from usta_tokenizer import UstaTokenizer\n",
32
+ "\n",
33
+ "device = \"cpu\"\n",
34
+ "\n",
35
+ "if torch.cuda.is_available():\n",
36
+ " device = \"cuda\"\n",
37
+ "elif torch.backends.mps.is_available():\n",
38
+ " device = \"mps\"\n",
39
+ " \n",
40
+ "\n",
41
+ "print(f\"Using device: {device}\")\n",
42
+ "\n",
43
+ "u_tokenizer = UstaTokenizer(\"tokenizer.json\")\n",
44
+ "\n",
45
+ "prompts = [\n",
46
+ " \"the capital of the united\",\n",
47
+ " \"madrid is in\",\n",
48
+ " \"the capital of france is\",\n",
49
+ " \"the capital of germany is\"\n",
50
+ "]\n",
51
+ "\n",
52
+ "tokens = u_tokenizer.encode(prompts[0])\n",
53
+ "tokens = tokens.to(device)\n",
54
+ "print(tokens)\n",
55
+ "batch_tokens = u_tokenizer.encode_batch(prompts, 32)\n",
56
+ "batch_tokens = batch_tokens.to(device)\n",
57
+ "batch_tokens.shape"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 2,
63
+ "metadata": {},
64
+ "outputs": [
65
+ {
66
+ "data": {
67
+ "text/plain": [
68
+ "<All keys matched successfully>"
69
+ ]
70
+ },
71
+ "execution_count": 2,
72
+ "metadata": {},
73
+ "output_type": "execute_result"
74
+ }
75
+ ],
76
+ "source": [
77
+ "torch.manual_seed(1)\n",
78
+ "context_length = 32\n",
79
+ "\n",
80
+ "u_model = UstaModel(\n",
81
+ " vocab_size=len(u_tokenizer.vocab),\n",
82
+ " embedding_dim=12,\n",
83
+ " num_heads=4,\n",
84
+ " context_length=context_length,\n",
85
+ " num_layers=8,\n",
86
+ " device=device\n",
87
+ ")\n",
88
+ "\n",
89
+ "# load model\n",
90
+ "u_model.load_state_dict(torch.load(\"../u_model_4000.pth\"))"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 3,
96
+ "metadata": {},
97
+ "outputs": [
98
+ {
99
+ "data": {
100
+ "text/plain": [
101
+ "torch.Size([4, 32, 64])"
102
+ ]
103
+ },
104
+ "execution_count": 3,
105
+ "metadata": {},
106
+ "output_type": "execute_result"
107
+ }
108
+ ],
109
+ "source": [
110
+ "out = u_model(batch_tokens)\n",
111
+ "out.shape"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": 4,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "# temperature\n",
121
+ "# top_k \n",
122
+ "# top_p\n"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": 5,
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "top_k = 10"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 6,
137
+ "metadata": {},
138
+ "outputs": [
139
+ {
140
+ "data": {
141
+ "text/plain": [
142
+ "(tensor([17.6884, 14.0799, 9.0104, 8.4548, 7.3207, 7.2960, 6.8096, 6.6073,\n",
143
+ " 6.6009, 6.3761]),\n",
144
+ " [61, 60, 35, 58, 9, 38, 59, 4, 18, 49])"
145
+ ]
146
+ },
147
+ "execution_count": 6,
148
+ "metadata": {},
149
+ "output_type": "execute_result"
150
+ }
151
+ ],
152
+ "source": [
153
+ "sorted_outs = sorted(out[-1][-1].tolist(), reverse=True)\n",
154
+ "sorted_indexes = []\n",
155
+ "for so in sorted_outs[:top_k]:\n",
156
+ " so_index = out[-1][-1].tolist().index(so)\n",
157
+ " sorted_indexes.append(so_index)\n",
158
+ "sorted_outs = torch.tensor(sorted_outs[:top_k])\n",
159
+ "sorted_outs, sorted_indexes\n"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 7,
165
+ "metadata": {},
166
+ "outputs": [
167
+ {
168
+ "data": {
169
+ "text/plain": [
170
+ "(tensor([17.6884, 14.0799, 9.0104, 8.4548, 7.3207, 7.2960, 6.8096, 6.6073,\n",
171
+ " 6.6009, 6.3761], device='mps:0', grad_fn=<TopkBackward0>),\n",
172
+ " tensor([61, 60, 35, 58, 9, 38, 59, 4, 18, 49], device='mps:0'))"
173
+ ]
174
+ },
175
+ "execution_count": 7,
176
+ "metadata": {},
177
+ "output_type": "execute_result"
178
+ }
179
+ ],
180
+ "source": [
181
+ "values, indexes = torch.topk(out[-1][-1], k=10)\n",
182
+ "values, indexes"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": []
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 8,
195
+ "metadata": {},
196
+ "outputs": [
197
+ {
198
+ "name": "stderr",
199
+ "output_type": "stream",
200
+ "text": [
201
+ "/var/folders/z7/wrd0w0hn7pvb9g97kmdn17640000gn/T/ipykernel_91075/2885985782.py:2: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.detach().clone() or sourceTensor.detach().clone().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
202
+ " adjusted_outs = torch.tensor(sorted_outs) / temperature\n"
203
+ ]
204
+ },
205
+ {
206
+ "data": {
207
+ "text/plain": [
208
+ "tensor([1.6830, 1.3397, 0.8573, 0.8045, 0.6965, 0.6942, 0.6479, 0.6287, 0.6281,\n",
209
+ " 0.6067])"
210
+ ]
211
+ },
212
+ "execution_count": 8,
213
+ "metadata": {},
214
+ "output_type": "execute_result"
215
+ }
216
+ ],
217
+ "source": [
218
+ "temperature = 10.51\n",
219
+ "adjusted_outs = torch.tensor(sorted_outs) / temperature\n",
220
+ "adjusted_outs"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": 9,
226
+ "metadata": {},
227
+ "outputs": [
228
+ {
229
+ "data": {
230
+ "text/plain": [
231
+ "tensor([0.2128, 0.1509, 0.0932, 0.0884, 0.0793, 0.0791, 0.0756, 0.0741, 0.0741,\n",
232
+ " 0.0725])"
233
+ ]
234
+ },
235
+ "execution_count": 9,
236
+ "metadata": {},
237
+ "output_type": "execute_result"
238
+ }
239
+ ],
240
+ "source": [
241
+ "probs = torch.softmax(adjusted_outs, dim=-1)\n",
242
+ "probs"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": 10,
248
+ "metadata": {},
249
+ "outputs": [],
250
+ "source": [
251
+ "top_p = 0.7"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": 11,
257
+ "metadata": {},
258
+ "outputs": [
259
+ {
260
+ "data": {
261
+ "text/plain": [
262
+ "tensor(0.5453)"
263
+ ]
264
+ },
265
+ "execution_count": 11,
266
+ "metadata": {},
267
+ "output_type": "execute_result"
268
+ }
269
+ ],
270
+ "source": [
271
+ "[0.2128, 0.36, 0.37, 0.38, 0.70, 0.71]\n",
272
+ "torch.sum(torch.tensor([0.2128, 0.1509, 0.0932, 0.0884]))"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": 12,
278
+ "metadata": {},
279
+ "outputs": [
280
+ {
281
+ "data": {
282
+ "text/plain": [
283
+ "{0: 212, 4: 82, 5: 87, 9: 83, 2: 74, 6: 73, 1: 154, 3: 91, 8: 80, 7: 64}"
284
+ ]
285
+ },
286
+ "execution_count": 12,
287
+ "metadata": {},
288
+ "output_type": "execute_result"
289
+ }
290
+ ],
291
+ "source": [
292
+ "sample_count = {}\n",
293
+ "for _ in range(1000):\n",
294
+ " sample = torch.multinomial(probs, 1)\n",
295
+ " sample_count[sample.item()] = sample_count.get(sample.item(), 0) + 1\n",
296
+ "sample_count"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": 14,
302
+ "metadata": {},
303
+ "outputs": [
304
+ {
305
+ "data": {
306
+ "text/plain": [
307
+ "{'the capital of the united.': 3,\n",
308
+ " 'the capital of the united the ': 22,\n",
309
+ " 'the capital of the united identity,': 1,\n",
310
+ " 'the capital of the united capitals': 5,\n",
311
+ " 'the capital of the united country ': 8,\n",
312
+ " 'the capital of the united europe ': 26,\n",
313
+ " 'the capital of the united is ': 7,\n",
314
+ " 'the capital of the united place ': 4,\n",
315
+ " 'the capital of the united europe,': 3,\n",
316
+ " 'the capital of the united united ': 6,\n",
317
+ " 'the capital of the united for ': 1,\n",
318
+ " 'the capital of the united spain,': 2,\n",
319
+ " 'the capital of the united europe.': 1,\n",
320
+ " 'the capital of the united italy,': 4,\n",
321
+ " 'the capital of the united art ': 1,\n",
322
+ " 'the capital of the united of ': 1,\n",
323
+ " 'the capital of the united united': 1,\n",
324
+ " 'the capital of the united capitaled': 1,\n",
325
+ " 'the capital of the united, country': 1,\n",
326
+ " 'the capital of the united place.': 1,\n",
327
+ " 'the capital of the united, europe': 1}"
328
+ ]
329
+ },
330
+ "execution_count": 14,
331
+ "metadata": {},
332
+ "output_type": "execute_result"
333
+ }
334
+ ],
335
+ "source": [
336
+ "outs = {}\n",
337
+ "for _ in range(100):\n",
338
+ " out = u_model.generate(tokens, max_new_tokens = 3, temperature = 1.7, top_k = 10, top_p = 0.7)\n",
339
+ " decoded = u_tokenizer.decode(out)\n",
340
+ " outs[decoded] = outs.get(decoded, 0) + 1\n",
341
+ "outs"
342
+ ]
343
+ }
344
+ ],
345
+ "metadata": {
346
+ "kernelspec": {
347
+ "display_name": "Python 3",
348
+ "language": "python",
349
+ "name": "python3"
350
+ },
351
+ "language_info": {
352
+ "codemirror_mode": {
353
+ "name": "ipython",
354
+ "version": 3
355
+ },
356
+ "file_extension": ".py",
357
+ "mimetype": "text/x-python",
358
+ "name": "python",
359
+ "nbconvert_exporter": "python",
360
+ "pygments_lexer": "ipython3",
361
+ "version": "3.13.3"
362
+ }
363
+ },
364
+ "nbformat": 4,
365
+ "nbformat_minor": 2
366
+ }
v1/u_model.pth DELETED
Binary file (97.2 kB)
 
{v1 β†’ v2}/__init__.py RENAMED
File without changes
{v1 β†’ v2}/tokenizer.json RENAMED
File without changes
v2/u_model_4000.pth ADDED
Binary file (96.1 kB). View file
 
{v1 β†’ v2}/usta_causal_attention.py RENAMED
File without changes
{v1 β†’ v2}/usta_decoder_block.py RENAMED
@@ -6,17 +6,23 @@ from .usta_multi_head_attention import UstaMultiHeadAttention
6
 
7
 
8
  class UstaDecoderBlock(nn.Module):
9
- def __init__(self, embedding_dim, num_heads, context_length):
10
  super().__init__()
11
 
12
- self.self_attention = UstaMultiHeadAttention(embedding_dim, embedding_dim, context_length, num_heads, dropout_rate=0.5)
13
- self.norm1 = UstaLayerNorm(embedding_dim)
14
- self.mlp = UstaMLP(embedding_dim, embedding_dim)
15
- self.norm2 = UstaLayerNorm(embedding_dim)
 
 
 
 
 
 
 
16
 
17
  def forward(self, x):
18
  res = self.norm1(x)
19
-
20
  x = self.self_attention(x)
21
  x = self.norm1(x)
22
 
 
6
 
7
 
8
  class UstaDecoderBlock(nn.Module):
9
+ def __init__(self, embedding_dim, num_heads, context_length, device):
10
  super().__init__()
11
 
12
+ self.self_attention = UstaMultiHeadAttention(
13
+ embedding_dim,
14
+ embedding_dim,
15
+ context_length,
16
+ num_heads,
17
+ dropout_rate=0.5,
18
+ device=device
19
+ )
20
+ self.norm1 = UstaLayerNorm(embedding_dim, device=device)
21
+ self.mlp = UstaMLP(embedding_dim, embedding_dim, device=device)
22
+ self.norm2 = UstaLayerNorm(embedding_dim, device=device)
23
 
24
  def forward(self, x):
25
  res = self.norm1(x)
 
26
  x = self.self_attention(x)
27
  x = self.norm1(x)
28
 
{v1 β†’ v2}/usta_embedding.py RENAMED
@@ -3,7 +3,7 @@ import torch.nn as nn
3
 
4
 
5
  def get_rotary_position_encoding(input: torch.Tensor, base=10000, device="cpu"):
6
- context_length, dimension = input.shape
7
 
8
  assert dimension % 2 == 0, "dimension must be even"
9
 
@@ -20,30 +20,31 @@ def get_rotary_position_encoding(input: torch.Tensor, base=10000, device="cpu"):
20
  sin_angles = torch.sin(angles)
21
  cos_angles = torch.cos(angles)
22
 
23
- input_even = input[:, :dimension // 2] # [0, 2, 4, ..]
24
- input_odd = input[:, dimension // 2:] # [1, 3, 5, ..]
25
 
26
  input_even_rotated = input_even * cos_angles - input_odd * sin_angles
27
  input_odd_rotated = input_even * sin_angles + input_odd * cos_angles
28
 
29
- input_rotated = torch.empty_like(input)
30
 
31
- input_rotated[:, :dimension // 2] = input_even_rotated
32
- input_rotated[:, dimension // 2:] = input_odd_rotated
33
 
34
  return input_rotated
35
 
36
  class UstaEmbedding(nn.Module):
37
- def __init__(self, vocab_size, embedding_dim, context_length):
38
  super().__init__()
39
  # position embedding but not being used in the forward pass
40
  # it is just for educational purposes
41
  # self.pos_embedding = nn.Embedding(context_length, embedding_dim)
42
  # self.get_pos = get_rotary_position_encoding
43
- self.embedding = nn.Embedding(vocab_size, embedding_dim)
44
  self.get_pos = get_rotary_position_encoding
 
45
 
46
  def forward(self, x):
47
  x = self.embedding(x)
48
- x = self.get_pos(x)
49
  return x
 
3
 
4
 
5
  def get_rotary_position_encoding(input: torch.Tensor, base=10000, device="cpu"):
6
+ batch_size, context_length, dimension = input.shape
7
 
8
  assert dimension % 2 == 0, "dimension must be even"
9
 
 
20
  sin_angles = torch.sin(angles)
21
  cos_angles = torch.cos(angles)
22
 
23
+ input_even = input[:, :, :dimension // 2] # [0, 2, 4, ..]
24
+ input_odd = input[:, :, dimension // 2:] # [1, 3, 5, ..]
25
 
26
  input_even_rotated = input_even * cos_angles - input_odd * sin_angles
27
  input_odd_rotated = input_even * sin_angles + input_odd * cos_angles
28
 
29
+ input_rotated = torch.empty_like(input, device=device)
30
 
31
+ input_rotated[:, :, :dimension // 2] = input_even_rotated
32
+ input_rotated[:, :, dimension // 2:] = input_odd_rotated
33
 
34
  return input_rotated
35
 
36
  class UstaEmbedding(nn.Module):
37
+ def __init__(self, vocab_size, embedding_dim, context_length, device):
38
  super().__init__()
39
  # position embedding but not being used in the forward pass
40
  # it is just for educational purposes
41
  # self.pos_embedding = nn.Embedding(context_length, embedding_dim)
42
  # self.get_pos = get_rotary_position_encoding
43
+ self.embedding = nn.Embedding(vocab_size, embedding_dim, device=device)
44
  self.get_pos = get_rotary_position_encoding
45
+ self.device = device
46
 
47
  def forward(self, x):
48
  x = self.embedding(x)
49
+ x = self.get_pos(x, device=self.device)
50
  return x
{v1 β†’ v2}/usta_layer_norm.py RENAMED
@@ -3,13 +3,12 @@ import torch.nn as nn
3
 
4
 
5
  class UstaLayerNorm(nn.Module):
6
- def __init__(self, embedding_dim, eps=1e-5):
7
  super().__init__()
8
  self.eps = eps
 
 
9
 
10
- self.weight = nn.Parameter(torch.ones(embedding_dim))
11
-
12
-
13
  def forward(self, x):
14
  mean = x.mean(dim=-1, keepdim=True)
15
  variance = x.var(dim=-1, keepdim=True, unbiased=False)
 
3
 
4
 
5
  class UstaLayerNorm(nn.Module):
6
+ def __init__(self, embedding_dim, eps=1e-5, device="cpu"):
7
  super().__init__()
8
  self.eps = eps
9
+ self.weight = nn.Parameter(torch.ones(embedding_dim, device=device))
10
+ self.device = device
11
 
 
 
 
12
  def forward(self, x):
13
  mean = x.mean(dim=-1, keepdim=True)
14
  variance = x.var(dim=-1, keepdim=True, unbiased=False)
{v1 β†’ v2}/usta_mlp.py RENAMED
@@ -14,13 +14,13 @@ class GELU(nn.Module):
14
  )
15
 
16
  class UstaMLP(nn.Module):
17
- def __init__(self, embedding_dim, hidden_dim):
18
  super().__init__()
19
 
20
- self.gate_proj = nn.Linear(embedding_dim, hidden_dim)
21
- self.up_proj = nn.Linear(embedding_dim, hidden_dim)
22
- self.down_proj = nn.Linear(hidden_dim, embedding_dim)
23
- self.gelu = GELU()
24
 
25
  def forward(self, x):
26
  """ gate = self.gate_proj(x)
 
14
  )
15
 
16
  class UstaMLP(nn.Module):
17
+ def __init__(self, embedding_dim, hidden_dim, device="cpu"):
18
  super().__init__()
19
 
20
+ self.gate_proj = nn.Linear(embedding_dim, hidden_dim, device=device)
21
+ self.up_proj = nn.Linear(embedding_dim, hidden_dim, device=device)
22
+ self.down_proj = nn.Linear(hidden_dim, embedding_dim, device=device)
23
+ self.gelu = GELU().to(device)
24
 
25
  def forward(self, x):
26
  """ gate = self.gate_proj(x)
{v1 β†’ v2}/usta_model.py RENAMED
@@ -6,15 +6,16 @@ from .usta_embedding import UstaEmbedding
6
 
7
 
8
  class UstaModel(nn.Module):
9
- def __init__(self, vocab_size, embedding_dim, num_heads, context_length, num_layers):
10
  super().__init__()
11
 
12
- self.embedding = UstaEmbedding(vocab_size, embedding_dim, context_length)
13
  self.layers = nn.Sequential(
14
- *[UstaDecoderBlock(embedding_dim, num_heads, context_length) for _ in range(num_layers)]
15
  )
16
 
17
- self.lm_head = nn.Linear(embedding_dim, vocab_size)
 
18
 
19
  def forward(self, x: torch.Tensor):
20
  x = self.embedding(x) # dictionary meaning of the tokens (words)
@@ -32,13 +33,49 @@ class UstaModel(nn.Module):
32
  max_prob, max_index, probs
33
  """
34
 
35
- def generate(self, x: torch.Tensor, max_new_tokens: int): # top_k, top_p, temperature
36
- tokens = x.detach().cpu().numpy().tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  for _ in range(max_new_tokens):
 
39
  out = self.forward(x)
40
- probs = torch.softmax(out[-1], dim=-1)
41
- _, max_index = torch.max(probs, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  tokens.append(max_index.item())
43
  if max_index == 59 or len(tokens) > 32: # <eos> and max context length
44
  break
 
6
 
7
 
8
  class UstaModel(nn.Module):
9
+ def __init__(self, vocab_size, embedding_dim, num_heads, context_length, num_layers, device):
10
  super().__init__()
11
 
12
+ self.embedding = UstaEmbedding(vocab_size, embedding_dim, context_length, device)
13
  self.layers = nn.Sequential(
14
+ *[UstaDecoderBlock(embedding_dim, num_heads, context_length, device) for _ in range(num_layers)]
15
  )
16
 
17
+ self.lm_head = nn.Linear(embedding_dim, vocab_size, device=device)
18
+ self.device = device
19
 
20
  def forward(self, x: torch.Tensor):
21
  x = self.embedding(x) # dictionary meaning of the tokens (words)
 
33
  max_prob, max_index, probs
34
  """
35
 
36
+ def top_p_filtering(self, logits, top_p):
37
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
38
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
39
+ sorted_indices_to_remove = cumulative_probs > top_p
40
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
41
+ sorted_indices_to_remove[..., 0] = False
42
+
43
+ sorted_logits[sorted_indices_to_remove] = -float('inf')
44
+ filtered_logits = sorted_logits.clone()
45
+ filtered_logits.scatter_(0, sorted_indices, sorted_logits)
46
+ return filtered_logits
47
+
48
+
49
+
50
+ def generate(self,
51
+ x: torch.Tensor,
52
+ max_new_tokens: int = 3,
53
+ temperature: float = 1.0,
54
+ top_k: int = 64,
55
+ top_p: float = 1.0
56
+ ): # top_k, top_p, temperature
57
+ tokens = x.tolist()
58
 
59
  for _ in range(max_new_tokens):
60
+ x = x.unsqueeze(0).to(self.device)
61
  out = self.forward(x)
62
+ out = out.squeeze(0)
63
+ logits = out[-1]
64
+ if top_k > 0:
65
+ values, indexes = torch.topk(logits, k=top_k)
66
+ logits = torch.full_like(logits, -float('inf'))
67
+ logits.scatter_(0, indexes, values)
68
+
69
+ if top_p > 0 and top_p < 1:
70
+ logits = self.top_p_filtering(logits, top_p)
71
+
72
+ if temperature != 1.0 and temperature > 0:
73
+ logits = logits / temperature
74
+
75
+ probs = torch.softmax(values, dim=-1)
76
+ # _, max_index = torch.max(probs, dim=-1)
77
+ sample = torch.multinomial(probs, 1)
78
+ max_index = indexes[sample]
79
  tokens.append(max_index.item())
80
  if max_index == 59 or len(tokens) > 32: # <eos> and max context length
81
  break
{v1 β†’ v2}/usta_multi_head_attention.py RENAMED
@@ -3,15 +3,15 @@ import torch.nn as nn
3
 
4
 
5
  class UstaMultiHeadAttention(nn.Module):
6
- def __init__(self, embedding_dim, output_dim, context_length, num_heads, dropout_rate = 0):
7
  super().__init__()
8
 
9
  self.context_length = context_length
10
 
11
- self.multi_head_attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout=dropout_rate)
12
- self.projection = nn.Linear(embedding_dim, output_dim)
13
 
14
- self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())
15
 
16
  def forward(self, x):
17
  number_of_tokens = x.shape[0]
 
3
 
4
 
5
  class UstaMultiHeadAttention(nn.Module):
6
+ def __init__(self, embedding_dim, output_dim, context_length, num_heads, dropout_rate = 0, device="cpu"):
7
  super().__init__()
8
 
9
  self.context_length = context_length
10
 
11
+ self.multi_head_attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout=dropout_rate, device=device)
12
+ self.projection = nn.Linear(embedding_dim, output_dim, device=device)
13
 
14
+ self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool().to(device))
15
 
16
  def forward(self, x):
17
  number_of_tokens = x.shape[0]
{v1 β†’ v2}/usta_multi_head_attention_old.py RENAMED
@@ -22,5 +22,4 @@ class UstaMultiHeadAttention(nn.Module):
22
 
23
  attention_out = torch.cat(attention_outs, dim=1)
24
 
25
- return self.projection(attention_out)
26
-
 
22
 
23
  attention_out = torch.cat(attention_outs, dim=1)
24
 
25
+ return self.projection(attention_out)
 
{v1 β†’ v2}/usta_self_attention.py RENAMED
File without changes
{v1 β†’ v2}/usta_tokenizer.py RENAMED
@@ -9,6 +9,19 @@ class UstaTokenizer:
9
  self.vocab = json.load(f)
10
  self.reverse_vocab = {v: k for k, v in self.vocab.items()}
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def encode(self, text):
13
  tokens = []
14
 
@@ -31,7 +44,9 @@ class UstaTokenizer:
31
  i += 1
32
  tokens.append(self.vocab[" "])
33
 
34
- tokens.pop()
 
 
35
  return torch.tensor(tokens)
36
 
37
  def tokenize(self, text):
 
9
  self.vocab = json.load(f)
10
  self.reverse_vocab = {v: k for k, v in self.vocab.items()}
11
 
12
+ def encode_batch(self, texts, context_length):
13
+ sentences_tokens = []
14
+ for text in texts:
15
+ tokens = self.encode(text).tolist()
16
+ if len(tokens) > context_length:
17
+ tokens = tokens[:context_length]
18
+ else:
19
+ tokens = tokens + [self.vocab["<pad>"]] * (context_length - len(tokens))
20
+
21
+ sentences_tokens.append(tokens)
22
+
23
+ return torch.tensor(sentences_tokens)
24
+
25
  def encode(self, text):
26
  tokens = []
27
 
 
44
  i += 1
45
  tokens.append(self.vocab[" "])
46
 
47
+ # check if text is not ends with a space
48
+ if not text.endswith(" "):
49
+ tokens.pop()
50
  return torch.tensor(tokens)
51
 
52
  def tokenize(self, text):