asadsandhu commited on
Commit
e06e912
·
1 Parent(s): 60a4f1e

Coded updated.

Browse files
Files changed (4) hide show
  1. app.py +229 -0
  2. model.pth +3 -0
  3. requirements.txt +2 -0
  4. train.ipynb +656 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # gradio_app.py
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import math
6
+ import torch.nn as nn
7
+ import re
8
+ import sys
9
+ import asyncio
10
+
11
+ if sys.platform.startswith('win'):
12
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
13
+
14
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ MAX_LEN = 128
16
+ EMBED_DIM = 256
17
+ NHEAD = 4
18
+ NUM_ENCODER_LAYERS = 2
19
+ NUM_DECODER_LAYERS = 2
20
+ FF_DIM = 512
21
+
22
+ PAD_TOKEN = "<pad>"
23
+ SOS_TOKEN = "<sos>"
24
+ EOS_TOKEN = "<eos>"
25
+ UNK_TOKEN = "<unk>"
26
+
27
+ def tokenize_line(text: str):
28
+ return re.findall(r"[A-Za-z0-9]+|[^\sA-Za-z0-9]", text)
29
+
30
+ def numericalize(text: str, stoi: dict):
31
+ tokens = tokenize_line(text)
32
+ return [stoi.get(tok, stoi[UNK_TOKEN]) for tok in tokens]
33
+
34
+ def pad_sequence(seq, max_len, pad_id):
35
+ seq = seq[:max_len-1]
36
+ seq = seq + [tgt_stoi[EOS_TOKEN]]
37
+ if len(seq) < max_len:
38
+ seq += [pad_id] * (max_len - len(seq))
39
+ return seq
40
+
41
+ class PositionalEncoding(nn.Module):
42
+ def __init__(self, d_model, max_len=5000):
43
+ super().__init__()
44
+ pe = torch.zeros(max_len, d_model)
45
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
46
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
47
+ pe[:, 0::2] = torch.sin(position * div_term)
48
+ pe[:, 1::2] = torch.cos(position * div_term)
49
+ pe = pe.unsqueeze(0)
50
+ self.register_buffer("pe", pe)
51
+ def forward(self, x):
52
+ return x + self.pe[:, :x.size(1), :]
53
+
54
+ class MultiHeadAttention(nn.Module):
55
+ def __init__(self, d_model, n_heads):
56
+ super().__init__()
57
+ assert d_model % n_heads == 0
58
+ self.d_model = d_model
59
+ self.n_heads = n_heads
60
+ self.head_dim = d_model // n_heads
61
+ self.query_linear = nn.Linear(d_model, d_model)
62
+ self.key_linear = nn.Linear(d_model, d_model)
63
+ self.value_linear = nn.Linear(d_model, d_model)
64
+ self.out_linear = nn.Linear(d_model, d_model)
65
+ def forward(self, query, key, value, mask=None):
66
+ B, Q_len, _ = query.size()
67
+ B, K_len, _ = key.size()
68
+ Q = self.query_linear(query)
69
+ K = self.key_linear(key)
70
+ V = self.value_linear(value)
71
+ Q = Q.view(B, Q_len, self.n_heads, self.head_dim).transpose(1,2)
72
+ K = K.view(B, K_len, self.n_heads, self.head_dim).transpose(1,2)
73
+ V = V.view(B, K_len, self.n_heads, self.head_dim).transpose(1,2)
74
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
75
+ if mask is not None:
76
+ scores = scores.masked_fill(mask == 0, float('-inf'))
77
+ attn = torch.softmax(scores, dim=-1)
78
+ context = torch.matmul(attn, V)
79
+ context = context.transpose(1,2).contiguous().view(B, Q_len, self.d_model)
80
+ return self.out_linear(context)
81
+
82
+ class FeedForward(nn.Module):
83
+ def __init__(self, d_model, dim_feedforward):
84
+ super().__init__()
85
+ self.fc1 = nn.Linear(d_model, dim_feedforward)
86
+ self.fc2 = nn.Linear(dim_feedforward, d_model)
87
+ self.relu = nn.ReLU()
88
+ def forward(self, x):
89
+ return self.fc2(self.relu(self.fc1(x)))
90
+
91
+ class EncoderLayer(nn.Module):
92
+ def __init__(self, d_model, n_heads, dim_feedforward):
93
+ super().__init__()
94
+ self.self_attn = MultiHeadAttention(d_model, n_heads)
95
+ self.ff = FeedForward(d_model, dim_feedforward)
96
+ self.norm1 = nn.LayerNorm(d_model)
97
+ self.norm2 = nn.LayerNorm(d_model)
98
+ self.dropout = nn.Dropout(0.1)
99
+ def forward(self, src, src_mask=None):
100
+ attn_out = self.self_attn(src, src, src, mask=src_mask)
101
+ src = self.norm1(src + self.dropout(attn_out))
102
+ ff_out = self.ff(src)
103
+ return self.norm2(src + self.dropout(ff_out))
104
+
105
+ class DecoderLayer(nn.Module):
106
+ def __init__(self, d_model, n_heads, dim_feedforward):
107
+ super().__init__()
108
+ self.self_attn = MultiHeadAttention(d_model, n_heads)
109
+ self.cross_attn = MultiHeadAttention(d_model, n_heads)
110
+ self.ff = FeedForward(d_model, dim_feedforward)
111
+ self.norm1 = nn.LayerNorm(d_model)
112
+ self.norm2 = nn.LayerNorm(d_model)
113
+ self.norm3 = nn.LayerNorm(d_model)
114
+ self.dropout = nn.Dropout(0.1)
115
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
116
+ tgt = self.norm1(tgt + self.dropout(self.self_attn(tgt, tgt, tgt, mask=tgt_mask)))
117
+ tgt = self.norm2(tgt + self.dropout(self.cross_attn(tgt, memory, memory, mask=memory_mask)))
118
+ ff_out = self.ff(tgt)
119
+ return self.norm3(tgt + self.dropout(ff_out))
120
+
121
+ class Encoder(nn.Module):
122
+ def __init__(self, vocab_size, d_model, n_heads, num_layers, dim_feedforward):
123
+ super().__init__()
124
+ self.embedding = nn.Embedding(vocab_size, d_model)
125
+ self.pos_encoding = PositionalEncoding(d_model)
126
+ self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, dim_feedforward) for _ in range(num_layers)])
127
+ def forward(self, src, src_mask=None):
128
+ x = self.embedding(src)
129
+ x = self.pos_encoding(x)
130
+ for layer in self.layers:
131
+ x = layer(x, src_mask)
132
+ return x
133
+
134
+ class Decoder(nn.Module):
135
+ def __init__(self, vocab_size, d_model, n_heads, num_layers, dim_feedforward):
136
+ super().__init__()
137
+ self.embedding = nn.Embedding(vocab_size, d_model)
138
+ self.pos_encoding = PositionalEncoding(d_model)
139
+ self.layers = nn.ModuleList([DecoderLayer(d_model, n_heads, dim_feedforward) for _ in range(num_layers)])
140
+ self.fc_out = nn.Linear(d_model, vocab_size)
141
+ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
142
+ x = self.embedding(tgt)
143
+ x = self.pos_encoding(x)
144
+ for layer in self.layers:
145
+ x = layer(x, memory, tgt_mask, memory_mask)
146
+ return self.fc_out(x)
147
+
148
+ class TransformerSeq2Seq(nn.Module):
149
+ def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_heads,
150
+ num_encoder_layers, num_decoder_layers, dim_feedforward):
151
+ super().__init__()
152
+ self.encoder = Encoder(src_vocab_size, d_model, n_heads, num_encoder_layers, dim_feedforward)
153
+ self.decoder = Decoder(tgt_vocab_size, d_model, n_heads, num_decoder_layers, dim_feedforward)
154
+ def forward(self, src, tgt, src_mask=None, tgt_mask=None):
155
+ memory = self.encoder(src, src_mask)
156
+ return self.decoder(tgt, memory, tgt_mask)
157
+
158
+ def generate_subsequent_mask(size):
159
+ mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
160
+ return ~mask
161
+
162
+ def greedy_decode(model, src, src_stoi, tgt_stoi, tgt_itos, max_len=MAX_LEN):
163
+ model.eval()
164
+ src = torch.tensor(src, dtype=torch.long, device=DEVICE).unsqueeze(0)
165
+ memory = model.encoder(src)
166
+ ys = torch.tensor([tgt_stoi[SOS_TOKEN]], dtype=torch.long, device=DEVICE).unsqueeze(0)
167
+ for i in range(max_len-1):
168
+ tgt_mask = generate_subsequent_mask(ys.size(1)).to(DEVICE)
169
+ out = model.decoder(ys, memory, tgt_mask)
170
+ prob = out[:, -1, :]
171
+ next_token = torch.argmax(prob, dim=1).item()
172
+ ys = torch.cat([ys, torch.tensor([[next_token]], device=DEVICE)], dim=1)
173
+ if next_token == tgt_stoi[EOS_TOKEN]:
174
+ break
175
+ out_tokens = ys.squeeze(0).tolist()[1:]
176
+ if tgt_stoi[EOS_TOKEN] in out_tokens:
177
+ out_tokens = out_tokens[:out_tokens.index(tgt_stoi[EOS_TOKEN])]
178
+ return " ".join(tgt_itos[t] for t in out_tokens)
179
+
180
+ # Load model and vocabulary
181
+ if not os.path.exists("model.pth"):
182
+ raise FileNotFoundError("Model file 'model.pth' not found. Please train first.")
183
+
184
+ checkpoint = torch.load("model.pth", map_location=DEVICE)
185
+ src_stoi = checkpoint['src_stoi']
186
+ src_itos = checkpoint['src_itos']
187
+ tgt_stoi = checkpoint['tgt_stoi']
188
+ tgt_itos = checkpoint['tgt_itos']
189
+
190
+ model = TransformerSeq2Seq(
191
+ src_vocab_size=len(src_stoi),
192
+ tgt_vocab_size=len(tgt_stoi),
193
+ d_model=EMBED_DIM,
194
+ n_heads=NHEAD,
195
+ num_encoder_layers=NUM_ENCODER_LAYERS,
196
+ num_decoder_layers=NUM_DECODER_LAYERS,
197
+ dim_feedforward=FF_DIM
198
+ ).to(DEVICE)
199
+ model.load_state_dict(checkpoint['model_state_dict'])
200
+ model.eval()
201
+
202
+ def convert_pseudocode(text):
203
+ lines = text.strip().split('\n')
204
+ outputs = []
205
+ for i, line in enumerate(lines):
206
+ line = line.strip()
207
+ if not line:
208
+ outputs.append("")
209
+ elif line == "}":
210
+ outputs.append("}")
211
+ else:
212
+ try:
213
+ src_ids = numericalize(line, src_stoi)
214
+ src_ids = pad_sequence(src_ids, MAX_LEN, src_stoi[PAD_TOKEN])
215
+ output_line = greedy_decode(model, src_ids, src_stoi, tgt_stoi, tgt_itos)
216
+ outputs.append(output_line)
217
+ except Exception as e:
218
+ outputs.append(f"// [Error in line {i+1}]: {e}")
219
+ return "int main() {\n" + '\n'.join(outputs) + "\nreturn 0;\n}"
220
+
221
+ iface = gr.Interface(
222
+ fn=convert_pseudocode,
223
+ inputs=gr.Textbox(label="Enter pseudocode (line-by-line)", lines=10),
224
+ outputs=gr.Code(language="cpp", label="Generated C++ Code"),
225
+ title="PseudoCode to C++ Converter (Transformer from Scratch)"
226
+ )
227
+
228
+ if __name__ == "__main__":
229
+ iface.launch()
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f52e43c6473d1f1726347eccb1608d1de92a9eaabd5491d3bac4692f31e3662
3
+ size 41398742
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ streamlit==1.35.0
2
+ torch==2.2.2
train.ipynb ADDED
@@ -0,0 +1,656 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU"
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {
23
+ "colab": {
24
+ "base_uri": "https://localhost:8080/"
25
+ },
26
+ "collapsed": true,
27
+ "id": "12APLOKE15uD",
28
+ "outputId": "fb61078b-a249-476a-af53-e43ca978c8c1"
29
+ },
30
+ "outputs": [
31
+ {
32
+ "output_type": "stream",
33
+ "name": "stdout",
34
+ "text": [
35
+ "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.5.1+cu124)\n",
36
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
37
+ "Requirement already satisfied: streamlit in /usr/local/lib/python3.11/dist-packages (1.42.2)\n",
38
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.17.0)\n",
39
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.12.2)\n",
40
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n",
41
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.5)\n",
42
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2024.10.0)\n",
43
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
44
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
45
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
46
+ "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n",
47
+ "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n",
48
+ "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n",
49
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n",
50
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n",
51
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n",
52
+ "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
53
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
54
+ "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
55
+ "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.0)\n",
56
+ "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
57
+ "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
58
+ "Requirement already satisfied: altair<6,>=4.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (5.5.0)\n",
59
+ "Requirement already satisfied: blinker<2,>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (1.9.0)\n",
60
+ "Requirement already satisfied: cachetools<6,>=4.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (5.5.1)\n",
61
+ "Requirement already satisfied: click<9,>=7.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (8.1.8)\n",
62
+ "Requirement already satisfied: numpy<3,>=1.23 in /usr/local/lib/python3.11/dist-packages (from streamlit) (1.26.4)\n",
63
+ "Requirement already satisfied: packaging<25,>=20 in /usr/local/lib/python3.11/dist-packages (from streamlit) (24.2)\n",
64
+ "Requirement already satisfied: pandas<3,>=1.4.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (2.2.2)\n",
65
+ "Requirement already satisfied: pillow<12,>=7.1.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (11.1.0)\n",
66
+ "Requirement already satisfied: protobuf<6,>=3.20 in /usr/local/lib/python3.11/dist-packages (from streamlit) (4.25.6)\n",
67
+ "Requirement already satisfied: pyarrow>=7.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (17.0.0)\n",
68
+ "Requirement already satisfied: requests<3,>=2.27 in /usr/local/lib/python3.11/dist-packages (from streamlit) (2.32.3)\n",
69
+ "Requirement already satisfied: rich<14,>=10.14.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (13.9.4)\n",
70
+ "Requirement already satisfied: tenacity<10,>=8.1.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (9.0.0)\n",
71
+ "Requirement already satisfied: toml<2,>=0.10.1 in /usr/local/lib/python3.11/dist-packages (from streamlit) (0.10.2)\n",
72
+ "Requirement already satisfied: watchdog<7,>=2.1.5 in /usr/local/lib/python3.11/dist-packages (from streamlit) (6.0.0)\n",
73
+ "Requirement already satisfied: gitpython!=3.1.19,<4,>=3.0.7 in /usr/local/lib/python3.11/dist-packages (from streamlit) (3.1.44)\n",
74
+ "Requirement already satisfied: pydeck<1,>=0.8.0b4 in /usr/local/lib/python3.11/dist-packages (from streamlit) (0.9.1)\n",
75
+ "Requirement already satisfied: tornado<7,>=6.0.3 in /usr/local/lib/python3.11/dist-packages (from streamlit) (6.4.2)\n",
76
+ "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.11/dist-packages (from altair<6,>=4.0->streamlit) (4.23.0)\n",
77
+ "Requirement already satisfied: narwhals>=1.14.2 in /usr/local/lib/python3.11/dist-packages (from altair<6,>=4.0->streamlit) (1.27.1)\n",
78
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from gitpython!=3.1.19,<4,>=3.0.7->streamlit) (4.0.12)\n",
79
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3,>=1.4.0->streamlit) (2.8.2)\n",
80
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3,>=1.4.0->streamlit) (2025.1)\n",
81
+ "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3,>=1.4.0->streamlit) (2025.1)\n",
82
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
83
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.27->streamlit) (3.4.1)\n",
84
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.27->streamlit) (3.10)\n",
85
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.27->streamlit) (2.3.0)\n",
86
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.27->streamlit) (2025.1.31)\n",
87
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich<14,>=10.14.0->streamlit) (3.0.0)\n",
88
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich<14,>=10.14.0->streamlit) (2.18.0)\n",
89
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.11/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.19,<4,>=3.0.7->streamlit) (5.0.2)\n",
90
+ "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (25.1.0)\n",
91
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (2024.10.1)\n",
92
+ "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (0.36.2)\n",
93
+ "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (0.22.3)\n",
94
+ "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich<14,>=10.14.0->streamlit) (0.1.2)\n",
95
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3,>=1.4.0->streamlit) (1.17.0)\n"
96
+ ]
97
+ }
98
+ ],
99
+ "source": [
100
+ "!pip install torch tqdm streamlit"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "source": [
106
+ "######################################\n",
107
+ "# Pseudocode2Cpp.py\n",
108
+ "######################################\n",
109
+ "import os\n",
110
+ "import streamlit as st\n",
111
+ "import torch\n",
112
+ "import torch.nn as nn\n",
113
+ "import torch.optim as optim\n",
114
+ "import math\n",
115
+ "import re\n",
116
+ "from tqdm import tqdm\n",
117
+ "from typing import List, Tuple\n",
118
+ "import random\n",
119
+ "import requests\n",
120
+ "from torch.utils.data import DataLoader, TensorDataset"
121
+ ],
122
+ "metadata": {
123
+ "id": "tEYW8hGR19sm"
124
+ },
125
+ "execution_count": null,
126
+ "outputs": []
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "source": [
131
+ "# ----------------------------\n",
132
+ "# 1. Hyperparameters\n",
133
+ "# ----------------------------\n",
134
+ "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
135
+ "MAX_LEN = 128 # maximum sequence length\n",
136
+ "EMBED_DIM = 256 # embedding dimension\n",
137
+ "FF_DIM = 512 # feedforward dimension in Transformer\n",
138
+ "NHEAD = 4 # number of heads in multihead attention\n",
139
+ "NUM_ENCODER_LAYERS = 2\n",
140
+ "NUM_DECODER_LAYERS = 2\n",
141
+ "BATCH_SIZE = 64\n",
142
+ "EPOCHS = 10 # Increase for real training\n",
143
+ "LEARNING_RATE = 1e-4\n",
144
+ "\n",
145
+ "# Special tokens\n",
146
+ "PAD_TOKEN = \"<pad>\"\n",
147
+ "SOS_TOKEN = \"<sos>\"\n",
148
+ "EOS_TOKEN = \"<eos>\"\n",
149
+ "UNK_TOKEN = \"<unk>\""
150
+ ],
151
+ "metadata": {
152
+ "id": "HelkrJ-01-2B"
153
+ },
154
+ "execution_count": null,
155
+ "outputs": []
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "source": [
160
+ "# ----------------------------\n",
161
+ "# 2. Data Loading & Preprocessing\n",
162
+ "# ----------------------------\n",
163
+ "\n",
164
+ "def load_spoc_data(file_path: str):\n",
165
+ " \"\"\"\n",
166
+ " Loads (pseudo_code, cpp_code) pairs from a TSV file or raw GitHub link.\n",
167
+ " Each line is assumed to have: pseudocode <tab> c++ code.\n",
168
+ " \"\"\"\n",
169
+ " pairs = []\n",
170
+ "\n",
171
+ " # If file_path is a URL, fetch it with requests\n",
172
+ " if file_path.startswith(\"http\"):\n",
173
+ " response = requests.get(file_path)\n",
174
+ " response.raise_for_status()\n",
175
+ " lines = response.text.strip().split(\"\\n\")\n",
176
+ " else:\n",
177
+ " # Otherwise, assume it's a local file path\n",
178
+ " with open(file_path, 'r', encoding='utf-8') as f:\n",
179
+ " lines = f.readlines()\n",
180
+ "\n",
181
+ " for line in lines:\n",
182
+ " line = line.strip()\n",
183
+ " if not line:\n",
184
+ " continue\n",
185
+ " cols = line.split('\\t')\n",
186
+ " if len(cols) >= 2:\n",
187
+ " pseudo = cols[0].strip()\n",
188
+ " cpp = cols[1].strip()\n",
189
+ " pairs.append((pseudo, cpp))\n",
190
+ "\n",
191
+ " return pairs\n",
192
+ "\n",
193
+ "def create_dataloader(pairs, src_stoi, tgt_stoi, batch_size):\n",
194
+ " src_batches = []\n",
195
+ " tgt_batches = []\n",
196
+ " for pseudo, cpp in pairs:\n",
197
+ " src_ids = pad_sequence(numericalize(pseudo, src_stoi), MAX_LEN, src_stoi[PAD_TOKEN])\n",
198
+ " tgt_ids = pad_sequence(numericalize(cpp, tgt_stoi), MAX_LEN, tgt_stoi[PAD_TOKEN])\n",
199
+ " src_batches.append(src_ids)\n",
200
+ " tgt_batches.append(tgt_ids)\n",
201
+ "\n",
202
+ " src_tensor = torch.tensor(src_batches, dtype=torch.long)\n",
203
+ " tgt_tensor = torch.tensor(tgt_batches, dtype=torch.long)\n",
204
+ " dataset = TensorDataset(src_tensor, tgt_tensor)\n",
205
+ " return DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)\n",
206
+ "\n",
207
+ "def tokenize_line(text: str) -> List[str]:\n",
208
+ " \"\"\"Enhanced tokenizer for pseudocode/C++ patterns\"\"\"\n",
209
+ " # Separate operators and punctuation\n",
210
+ " text = re.sub(r'([=+\\-*/%<>!&|^~])', r' \\1 ', text) # Operators\n",
211
+ " text = re.sub(r'(?<!:):(?!:)', r' : ', text) # Single colon\n",
212
+ " return re.findall(r'\\b\\w+\\b|[-+*/%=<>!&|^~]+|[:;{},()\\[\\]\\.]', text)\n",
213
+ "\n",
214
+ "def build_vocab(pairs: List[Tuple[str, str]]) -> Tuple[dict, dict, dict, dict]:\n",
215
+ " \"\"\"\n",
216
+ " Build source (pseudo) and target (cpp) vocabularies from training data.\n",
217
+ " Returns:\n",
218
+ " src_stoi, src_itos, tgt_stoi, tgt_itos\n",
219
+ " \"\"\"\n",
220
+ " src_words = set()\n",
221
+ " tgt_words = set()\n",
222
+ "\n",
223
+ " for (pseudo, cpp) in pairs:\n",
224
+ " for tok in tokenize_line(pseudo):\n",
225
+ " src_words.add(tok)\n",
226
+ " for tok in tokenize_line(cpp):\n",
227
+ " tgt_words.add(tok)\n",
228
+ "\n",
229
+ " # Add special tokens\n",
230
+ " src_vocab = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + sorted(list(src_words))\n",
231
+ " tgt_vocab = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + sorted(list(tgt_words))\n",
232
+ "\n",
233
+ " src_stoi = {w: i for i, w in enumerate(src_vocab)}\n",
234
+ " src_itos = {i: w for i, w in enumerate(src_vocab)}\n",
235
+ " tgt_stoi = {w: i for i, w in enumerate(tgt_vocab)}\n",
236
+ " tgt_itos = {i: w for i, w in enumerate(tgt_vocab)}\n",
237
+ "\n",
238
+ " return src_stoi, src_itos, tgt_stoi, tgt_itos\n",
239
+ "\n",
240
+ "def numericalize(text: str, stoi: dict) -> List[int]:\n",
241
+ " \"\"\"\n",
242
+ " Convert text string to a list of token IDs.\n",
243
+ " \"\"\"\n",
244
+ " tokens = tokenize_line(text)\n",
245
+ " ids = []\n",
246
+ " for t in tokens:\n",
247
+ " if t in stoi:\n",
248
+ " ids.append(stoi[t])\n",
249
+ " else:\n",
250
+ " ids.append(stoi[UNK_TOKEN])\n",
251
+ " return ids\n",
252
+ "\n",
253
+ "def pad_sequence(seq: List[int], max_len: int, pad_id: int) -> List[int]:\n",
254
+ " \"\"\"Proper padding with SOS/EOS handling\"\"\"\n",
255
+ " seq = seq[:max_len-2] # Leave space for SOS/EOS\n",
256
+ " seq = [src_stoi[SOS_TOKEN]] + seq + [src_stoi[EOS_TOKEN]] # Add control tokens\n",
257
+ " padding = [pad_id] * (max_len - len(seq))\n",
258
+ " return seq + padding\n",
259
+ "\n",
260
+ "def create_batches(pairs, src_stoi, tgt_stoi, batch_size):\n",
261
+ " \"\"\"\n",
262
+ " Yield batches of data (source_ids, target_ids).\n",
263
+ " \"\"\"\n",
264
+ " random.shuffle(pairs)\n",
265
+ " for i in range(0, len(pairs), batch_size):\n",
266
+ " batch_pairs = pairs[i:i+batch_size]\n",
267
+ " src_batch = []\n",
268
+ " tgt_batch = []\n",
269
+ " for pseudo, cpp in batch_pairs:\n",
270
+ " src_ids = numericalize(pseudo, src_stoi)\n",
271
+ " tgt_ids = numericalize(cpp, tgt_stoi)\n",
272
+ "\n",
273
+ " # Pad/truncate\n",
274
+ " src_ids = pad_sequence(src_ids, MAX_LEN, src_stoi[PAD_TOKEN])\n",
275
+ " tgt_ids = pad_sequence(tgt_ids, MAX_LEN, tgt_stoi[PAD_TOKEN])\n",
276
+ "\n",
277
+ " src_batch.append(src_ids)\n",
278
+ " tgt_batch.append(tgt_ids)\n",
279
+ "\n",
280
+ " src_batch = torch.tensor(src_batch, dtype=torch.long, device=DEVICE)\n",
281
+ " tgt_batch = torch.tensor(tgt_batch, dtype=torch.long, device=DEVICE)\n",
282
+ " yield src_batch, tgt_batch"
283
+ ],
284
+ "metadata": {
285
+ "id": "2lFlkj-t2AGg"
286
+ },
287
+ "execution_count": null,
288
+ "outputs": []
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "source": [
293
+ "# ----------------------------\n",
294
+ "# 3. Transformer Model Implementation (from scratch)\n",
295
+ "# ----------------------------\n",
296
+ "\n",
297
+ "class PositionalEncoding(nn.Module):\n",
298
+ " def __init__(self, d_model, max_len=5000):\n",
299
+ " super(PositionalEncoding, self).__init__()\n",
300
+ " pe = torch.zeros(max_len, d_model)\n",
301
+ " position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
302
+ " div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
303
+ " pe[:, 0::2] = torch.sin(position * div_term)\n",
304
+ " pe[:, 1::2] = torch.cos(position * div_term)\n",
305
+ " pe = pe.unsqueeze(0) # shape (1, max_len, d_model)\n",
306
+ " self.register_buffer('pe', pe)\n",
307
+ "\n",
308
+ " def forward(self, x):\n",
309
+ " # x shape: (batch_size, seq_len, d_model)\n",
310
+ " seq_len = x.size(1)\n",
311
+ " x = x + self.pe[:, :seq_len, :]\n",
312
+ " return x\n",
313
+ "\n",
314
+ "class MultiHeadAttention(nn.Module):\n",
315
+ " def __init__(self, d_model, n_heads):\n",
316
+ " super(MultiHeadAttention, self).__init__()\n",
317
+ " assert d_model % n_heads == 0\n",
318
+ " self.d_model = d_model\n",
319
+ " self.n_heads = n_heads\n",
320
+ " self.head_dim = d_model // n_heads\n",
321
+ "\n",
322
+ " self.query_linear = nn.Linear(d_model, d_model)\n",
323
+ " self.key_linear = nn.Linear(d_model, d_model)\n",
324
+ " self.value_linear = nn.Linear(d_model, d_model)\n",
325
+ " self.out_linear = nn.Linear(d_model, d_model)\n",
326
+ "\n",
327
+ " def forward(self, query, key, value, mask=None):\n",
328
+ " # query/key/value shape: (batch_size, seq_len, d_model)\n",
329
+ " B, Q_len, _ = query.size()\n",
330
+ " B, K_len, _ = key.size()\n",
331
+ " B, V_len, _ = value.size()\n",
332
+ "\n",
333
+ " # Linear projections\n",
334
+ " Q = self.query_linear(query) # (B, Q_len, d_model)\n",
335
+ " K = self.key_linear(key) # (B, K_len, d_model)\n",
336
+ " V = self.value_linear(value) # (B, V_len, d_model)\n",
337
+ "\n",
338
+ " # Reshape for multi-head\n",
339
+ " Q = Q.view(B, Q_len, self.n_heads, self.head_dim).transpose(1,2) # (B, n_heads, Q_len, head_dim)\n",
340
+ " K = K.view(B, K_len, self.n_heads, self.head_dim).transpose(1,2) # (B, n_heads, K_len, head_dim)\n",
341
+ " V = V.view(B, V_len, self.n_heads, self.head_dim).transpose(1,2) # (B, n_heads, V_len, head_dim)\n",
342
+ "\n",
343
+ " # Scaled dot-product attention\n",
344
+ " scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, Q_len, K_len)\n",
345
+ " if mask is not None:\n",
346
+ " scores = scores.masked_fill(mask == 0, float('-inf'))\n",
347
+ " attn = torch.softmax(scores, dim=-1) # (B, n_heads, Q_len, K_len)\n",
348
+ "\n",
349
+ " context = torch.matmul(attn, V) # (B, n_heads, Q_len, head_dim)\n",
350
+ " context = context.transpose(1,2).contiguous().view(B, Q_len, self.d_model)\n",
351
+ " out = self.out_linear(context)\n",
352
+ " return out\n",
353
+ "\n",
354
+ "class FeedForward(nn.Module):\n",
355
+ " def __init__(self, d_model, dim_feedforward):\n",
356
+ " super(FeedForward, self).__init__()\n",
357
+ " self.fc1 = nn.Linear(d_model, dim_feedforward)\n",
358
+ " self.fc2 = nn.Linear(dim_feedforward, d_model)\n",
359
+ " self.relu = nn.ReLU()\n",
360
+ "\n",
361
+ " def forward(self, x):\n",
362
+ " return self.fc2(self.relu(self.fc1(x)))\n",
363
+ "\n",
364
+ "class EncoderLayer(nn.Module):\n",
365
+ " def __init__(self, d_model, n_heads, dim_feedforward):\n",
366
+ " super(EncoderLayer, self).__init__()\n",
367
+ " self.self_attn = MultiHeadAttention(d_model, n_heads)\n",
368
+ " self.ff = FeedForward(d_model, dim_feedforward)\n",
369
+ " self.norm1 = nn.LayerNorm(d_model)\n",
370
+ " self.norm2 = nn.LayerNorm(d_model)\n",
371
+ " self.dropout = nn.Dropout(0.1)\n",
372
+ "\n",
373
+ " def forward(self, src, src_mask=None):\n",
374
+ " # Self-attention\n",
375
+ " attn_out = self.self_attn(src, src, src, mask=src_mask)\n",
376
+ " src = self.norm1(src + self.dropout(attn_out))\n",
377
+ " # Feed Forward\n",
378
+ " ff_out = self.ff(src)\n",
379
+ " src = self.norm2(src + self.dropout(ff_out))\n",
380
+ " return src\n",
381
+ "\n",
382
+ "class DecoderLayer(nn.Module):\n",
383
+ " def __init__(self, d_model, n_heads, dim_feedforward):\n",
384
+ " super(DecoderLayer, self).__init__()\n",
385
+ " self.self_attn = MultiHeadAttention(d_model, n_heads)\n",
386
+ " self.cross_attn = MultiHeadAttention(d_model, n_heads)\n",
387
+ " self.ff = FeedForward(d_model, dim_feedforward)\n",
388
+ "\n",
389
+ " self.norm1 = nn.LayerNorm(d_model)\n",
390
+ " self.norm2 = nn.LayerNorm(d_model)\n",
391
+ " self.norm3 = nn.LayerNorm(d_model)\n",
392
+ " self.dropout = nn.Dropout(0.1)\n",
393
+ "\n",
394
+ " def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):\n",
395
+ " # Self-attention (mask future tokens)\n",
396
+ " _tgt = tgt\n",
397
+ " tgt = self.norm1(tgt + self.dropout(self.self_attn(tgt, tgt, tgt, mask=tgt_mask)))\n",
398
+ " # Cross-attention\n",
399
+ " _tgt2 = tgt\n",
400
+ " tgt = self.norm2(tgt + self.dropout(self.cross_attn(tgt, memory, memory, mask=memory_mask)))\n",
401
+ " # Feed Forward\n",
402
+ " ff_out = self.ff(tgt)\n",
403
+ " tgt = self.norm3(tgt + self.dropout(ff_out))\n",
404
+ " return tgt\n",
405
+ "\n",
406
+ "class Encoder(nn.Module):\n",
407
+ " def __init__(self, vocab_size, d_model, n_heads, num_layers, dim_feedforward):\n",
408
+ " super(Encoder, self).__init__()\n",
409
+ " self.embedding = nn.Embedding(vocab_size, d_model)\n",
410
+ " self.pos_encoding = PositionalEncoding(d_model)\n",
411
+ " self.layers = nn.ModuleList([\n",
412
+ " EncoderLayer(d_model, n_heads, dim_feedforward)\n",
413
+ " for _ in range(num_layers)\n",
414
+ " ])\n",
415
+ "\n",
416
+ " def forward(self, src, src_mask=None):\n",
417
+ " # src shape: (batch_size, seq_len)\n",
418
+ " x = self.embedding(src) # (batch_size, seq_len, d_model)\n",
419
+ " x = self.pos_encoding(x)\n",
420
+ " for layer in self.layers:\n",
421
+ " x = layer(x, src_mask)\n",
422
+ " return x\n",
423
+ "\n",
424
+ "class Decoder(nn.Module):\n",
425
+ " def __init__(self, vocab_size, d_model, n_heads, num_layers, dim_feedforward):\n",
426
+ " super(Decoder, self).__init__()\n",
427
+ " self.embedding = nn.Embedding(vocab_size, d_model)\n",
428
+ " self.pos_encoding = PositionalEncoding(d_model)\n",
429
+ " self.layers = nn.ModuleList([\n",
430
+ " DecoderLayer(d_model, n_heads, dim_feedforward)\n",
431
+ " for _ in range(num_layers)\n",
432
+ " ])\n",
433
+ " self.fc_out = nn.Linear(d_model, vocab_size)\n",
434
+ "\n",
435
+ " def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):\n",
436
+ " x = self.embedding(tgt)\n",
437
+ " x = self.pos_encoding(x)\n",
438
+ " for layer in self.layers:\n",
439
+ " x = layer(x, memory, tgt_mask, memory_mask)\n",
440
+ " logits = self.fc_out(x) # (batch_size, seq_len, vocab_size)\n",
441
+ " return logits\n",
442
+ "\n",
443
+ "class TransformerSeq2Seq(nn.Module):\n",
444
+ " def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_heads, num_encoder_layers,\n",
445
+ " num_decoder_layers, dim_feedforward):\n",
446
+ " super(TransformerSeq2Seq, self).__init__()\n",
447
+ " self.encoder = Encoder(src_vocab_size, d_model, n_heads, num_encoder_layers, dim_feedforward)\n",
448
+ " self.decoder = Decoder(tgt_vocab_size, d_model, n_heads, num_decoder_layers, dim_feedforward)\n",
449
+ "\n",
450
+ " def forward(self, src, tgt, src_mask=None, tgt_mask=None):\n",
451
+ " # src: (batch_size, src_seq_len)\n",
452
+ " # tgt: (batch_size, tgt_seq_len)\n",
453
+ " memory = self.encoder(src, src_mask) # (batch_size, src_seq_len, d_model)\n",
454
+ " outputs = self.decoder(tgt, memory, tgt_mask) # (batch_size, tgt_seq_len, vocab_size)\n",
455
+ " return outputs"
456
+ ],
457
+ "metadata": {
458
+ "id": "f8HioKcS2ZRy"
459
+ },
460
+ "execution_count": null,
461
+ "outputs": []
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "source": [
466
+ "# ----------------------------\n",
467
+ "# 4. Training Setup\n",
468
+ "# ----------------------------\n",
469
+ "import torch\n",
470
+ "import torch.nn as nn\n",
471
+ "from torch.utils.data import DataLoader, TensorDataset\n",
472
+ "from typing import List, Tuple\n",
473
+ "import random\n",
474
+ "def generate_subsequent_mask(size):\n",
475
+ " # Mask out subsequent positions (for decoding)\n",
476
+ " mask = torch.triu(torch.ones(size, size), diagonal=1).bool()\n",
477
+ " return ~mask # True where we can attend, False where we cannot\n",
478
+ "\n",
479
+ "def train_one_epoch(model, optimizer, criterion, train_data, src_stoi, tgt_stoi):\n",
480
+ " model.train()\n",
481
+ " total_loss = 0\n",
482
+ " steps = 0\n",
483
+ "\n",
484
+ " data_loader = create_dataloader(train_pairs, src_stoi, tgt_stoi, BATCH_SIZE)\n",
485
+ " for src_batch, tgt_batch in data_loader:\n",
486
+ " src_batch = src_batch.to(DEVICE)\n",
487
+ " tgt_batch = tgt_batch.to(DEVICE)\n",
488
+ "\n",
489
+ " # Prepare the target inputs and outputs (shifted by one token)\n",
490
+ " tgt_inp = tgt_batch[:, :-1]\n",
491
+ " tgt_out = tgt_batch[:, 1:]\n",
492
+ "\n",
493
+ " # Create subsequent mask for the target sequence\n",
494
+ " tgt_seq_len = tgt_inp.size(1)\n",
495
+ " tgt_mask = generate_subsequent_mask(tgt_seq_len).to(DEVICE)\n",
496
+ "\n",
497
+ " optimizer.zero_grad()\n",
498
+ " logits = model(src_batch, tgt_inp, None, tgt_mask) # (B, seq_len, vocab_size)\n",
499
+ "\n",
500
+ " # Use .reshape() instead of .view() to avoid runtime errors\n",
501
+ " loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))\n",
502
+ " loss.backward()\n",
503
+ " optimizer.step()\n",
504
+ "\n",
505
+ " total_loss += loss.item()\n",
506
+ " steps += 1\n",
507
+ "\n",
508
+ " return total_loss / steps\n",
509
+ "\n",
510
+ "def evaluate(model, criterion, eval_data, src_stoi, tgt_stoi):\n",
511
+ " model.eval()\n",
512
+ " total_loss = 0\n",
513
+ " steps = 0\n",
514
+ " with torch.no_grad():\n",
515
+ " for src_batch, tgt_batch in create_batches(eval_data, src_stoi, tgt_stoi, BATCH_SIZE):\n",
516
+ " tgt_inp = tgt_batch[:, :-1]\n",
517
+ " tgt_out = tgt_batch[:, 1:]\n",
518
+ " tgt_seq_len = tgt_inp.size(1)\n",
519
+ " tgt_mask = generate_subsequent_mask(tgt_seq_len).to(DEVICE)\n",
520
+ "\n",
521
+ " logits = model(src_batch, tgt_inp, None, tgt_mask)\n",
522
+ " # Use .reshape() instead of .view()\n",
523
+ " loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))\n",
524
+ "\n",
525
+ " total_loss += loss.item()\n",
526
+ " steps += 1\n",
527
+ " return total_loss / steps\n",
528
+ "\n",
529
+ "def greedy_decode(model, src, src_stoi, tgt_stoi, tgt_itos, max_len=MAX_LEN):\n",
530
+ " \"\"\"\n",
531
+ " Given a single source sequence (1D list of token IDs),\n",
532
+ " generate a decoded target sequence using greedy search.\n",
533
+ " \"\"\"\n",
534
+ " model.eval()\n",
535
+ " src = torch.tensor(src, dtype=torch.long, device=DEVICE).unsqueeze(0) # (1, seq_len)\n",
536
+ " memory = model.encoder(src) # (1, seq_len, d_model)\n",
537
+ "\n",
538
+ " ys = torch.tensor([tgt_stoi[SOS_TOKEN]], dtype=torch.long, device=DEVICE).unsqueeze(0) # (1, 1)\n",
539
+ " for i in range(max_len-1):\n",
540
+ " tgt_mask = generate_subsequent_mask(ys.size(1)).to(DEVICE)\n",
541
+ " out = model.decoder(ys, memory, tgt_mask) # (1, seq_len, vocab_size)\n",
542
+ " prob = out[:, -1, :] # last timestep\n",
543
+ " next_token = torch.argmax(prob, dim=1).item()\n",
544
+ " ys = torch.cat([ys, torch.tensor([[next_token]], device=DEVICE)], dim=1)\n",
545
+ " if next_token == tgt_stoi[EOS_TOKEN]:\n",
546
+ " break\n",
547
+ "\n",
548
+ " # Convert back to tokens\n",
549
+ " out_tokens = ys.squeeze(0).tolist() # e.g. [SOS, ..., EOS]\n",
550
+ " # Remove the initial SOS\n",
551
+ " out_tokens = out_tokens[1:]\n",
552
+ " # Stop at EOS if present\n",
553
+ " if tgt_stoi[EOS_TOKEN] in out_tokens:\n",
554
+ " eos_idx = out_tokens.index(tgt_stoi[EOS_TOKEN])\n",
555
+ " out_tokens = out_tokens[:eos_idx]\n",
556
+ "\n",
557
+ " return \" \".join(tgt_itos[t] for t in out_tokens)"
558
+ ],
559
+ "metadata": {
560
+ "id": "ffYgGSXy2a4B"
561
+ },
562
+ "execution_count": null,
563
+ "outputs": []
564
+ },
565
+ {
566
+ "cell_type": "code",
567
+ "source": [
568
+ "# ----------------------------\n",
569
+ "# 5. Main: Train the Model\n",
570
+ "# ----------------------------\n",
571
+ "if __name__ == \"__main__\":\n",
572
+ " # Hardcode the file paths from your GitHub repo (raw URLs):\n",
573
+ " train_path = \"https://raw.githubusercontent.com/asadsandhu/Pseudocode2Cpp/main/spoc/train/spoc-train.tsv\"\n",
574
+ " eval_path = \"https://raw.githubusercontent.com/asadsandhu/Pseudocode2Cpp/main/spoc/train/split/spoc-train-eval.tsv\"\n",
575
+ "\n",
576
+ " print(f\"Loading training data from {train_path} ...\")\n",
577
+ " train_pairs = load_spoc_data(train_path)\n",
578
+ " print(f\"Loaded {len(train_pairs)} training pairs.\")\n",
579
+ "\n",
580
+ " print(f\"Loading eval data from {eval_path} ...\")\n",
581
+ " eval_pairs = load_spoc_data(eval_path)\n",
582
+ " print(f\"Loaded {len(eval_pairs)} eval pairs.\")\n",
583
+ "\n",
584
+ " print(\"Building vocab...\")\n",
585
+ " src_stoi, src_itos, tgt_stoi, tgt_itos = build_vocab(train_pairs)\n",
586
+ " global stoi_eos\n",
587
+ " stoi_eos = tgt_stoi[EOS_TOKEN] # for pad_sequence usage\n",
588
+ "\n",
589
+ " print(\"Creating model...\")\n",
590
+ " model = TransformerSeq2Seq(\n",
591
+ " src_vocab_size=len(src_stoi),\n",
592
+ " tgt_vocab_size=len(tgt_stoi),\n",
593
+ " d_model=EMBED_DIM,\n",
594
+ " n_heads=NHEAD,\n",
595
+ " num_encoder_layers=NUM_ENCODER_LAYERS,\n",
596
+ " num_decoder_layers=NUM_DECODER_LAYERS,\n",
597
+ " dim_feedforward=FF_DIM\n",
598
+ " ).to(DEVICE)\n",
599
+ "\n",
600
+ " criterion = nn.CrossEntropyLoss(ignore_index=tgt_stoi[PAD_TOKEN])\n",
601
+ " optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
602
+ "\n",
603
+ " print(\"Starting training...\")\n",
604
+ " for epoch in range(1, EPOCHS+1):\n",
605
+ " train_loss = train_one_epoch(model, optimizer, criterion, train_pairs, src_stoi, tgt_stoi)\n",
606
+ " eval_loss = evaluate(model, criterion, eval_pairs, src_stoi, tgt_stoi)\n",
607
+ " print(f\"Epoch [{epoch}/{EPOCHS}] - Train Loss: {train_loss:.4f}, Eval Loss: {eval_loss:.4f}\")\n",
608
+ "\n",
609
+ " # Save model & vocab\n",
610
+ " torch.save({\n",
611
+ " 'model_state_dict': model.state_dict(),\n",
612
+ " 'src_stoi': src_stoi,\n",
613
+ " 'src_itos': src_itos,\n",
614
+ " 'tgt_stoi': tgt_stoi,\n",
615
+ " 'tgt_itos': tgt_itos\n",
616
+ " }, \"model.pth\")\n",
617
+ "\n",
618
+ " print(\"Model and vocab saved to model.pth\")"
619
+ ],
620
+ "metadata": {
621
+ "colab": {
622
+ "base_uri": "https://localhost:8080/"
623
+ },
624
+ "id": "iffrMhkc2cVt",
625
+ "outputId": "38839989-38e5-4b10-fbea-90767dca60e3"
626
+ },
627
+ "execution_count": null,
628
+ "outputs": [
629
+ {
630
+ "output_type": "stream",
631
+ "name": "stdout",
632
+ "text": [
633
+ "Loading training data from https://raw.githubusercontent.com/asadsandhu/Pseudocode2Cpp/main/spoc/train/spoc-train.tsv ...\n",
634
+ "Loaded 293855 training pairs.\n",
635
+ "Loading eval data from https://raw.githubusercontent.com/asadsandhu/Pseudocode2Cpp/main/spoc/train/split/spoc-train-eval.tsv ...\n",
636
+ "Loaded 27289 eval pairs.\n",
637
+ "Building vocab...\n",
638
+ "Creating model...\n",
639
+ "Starting training...\n",
640
+ "Epoch [1/10] - Train Loss: 0.9915, Eval Loss: 0.4901\n",
641
+ "Epoch [2/10] - Train Loss: 0.4401, Eval Loss: 0.3597\n",
642
+ "Epoch [3/10] - Train Loss: 0.3326, Eval Loss: 0.2897\n",
643
+ "Epoch [4/10] - Train Loss: 0.2752, Eval Loss: 0.2735\n",
644
+ "Epoch [5/10] - Train Loss: 0.2401, Eval Loss: 0.2281\n",
645
+ "Epoch [6/10] - Train Loss: 0.2166, Eval Loss: 0.2111\n",
646
+ "Epoch [7/10] - Train Loss: 0.2002, Eval Loss: 0.2015\n",
647
+ "Epoch [8/10] - Train Loss: 0.1883, Eval Loss: 0.1919\n",
648
+ "Epoch [9/10] - Train Loss: 0.1793, Eval Loss: 0.1848\n",
649
+ "Epoch [10/10] - Train Loss: 0.1724, Eval Loss: 0.1819\n",
650
+ "Model and vocab saved to transformer_spoc.pth\n"
651
+ ]
652
+ }
653
+ ]
654
+ }
655
+ ]
656
+ }