asadsandhu commited on
Commit
a0ce0ae
·
1 Parent(s): cfc0e75
Files changed (2) hide show
  1. README.md +192 -2
  2. train.ipynb +0 -656
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Pseudocode2Cpp
3
  emoji: 👀
4
  colorFrom: yellow
5
  colorTo: gray
@@ -11,4 +11,194 @@ license: mit
11
  short_description: Convert pseudocode to C++ using a Transformer model.
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Pseudo2Code
3
  emoji: 👀
4
  colorFrom: yellow
5
  colorTo: gray
 
11
  short_description: Convert pseudocode to C++ using a Transformer model.
12
  ---
13
 
14
+ # 🚀 Pseudo2Code Transformer-based Pseudocode to C++ Converter
15
+
16
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
17
+ [![Python](https://img.shields.io/badge/Python-3.10+-blue.svg)](https://www.python.org/)
18
+ [![Hugging Face](https://img.shields.io/badge/HuggingFace-Spaces-orange)](https://huggingface.co/spaces/asadsandhu/Pseudo2Code)
19
+ [![GitHub Repo](https://img.shields.io/badge/GitHub-asadsandhu/Pseudo2Code-black?logo=github)](https://github.com/asadsandhu/Pseudo2Code)
20
+
21
+ > A fully custom Transformer-based Sequence-to-Sequence model built from scratch in PyTorch to convert human-written pseudocode into executable C++ code. Trained on the [SPoC dataset](https://arxiv.org/abs/2005.04326) from Stanford.
22
+
23
+ ---
24
+
25
+ ## 🖼️ Demo
26
+
27
+ Try it live on **Hugging Face Spaces**:
28
+ 👉 https://huggingface.co/spaces/asadsandhu/Pseudo2Code
29
+
30
+ ![App Demo](assets/demo.png)
31
+
32
+ ---
33
+
34
+ ## 🧠 Model Architecture
35
+
36
+ - Developed using the **Transformer** architecture from scratch in PyTorch
37
+ - No pre-trained models (pure from-scratch implementation)
38
+ - Token-level sequence generation using greedy decoding
39
+ - Custom vocabulary construction for both pseudocode and C++ output
40
+
41
+ ```
42
+
43
+ Input: Pseudocode lines (line-by-line)
44
+ Model: Transformer (Encoder-Decoder)
45
+ Output: C++ code line for each pseudocode line
46
+
47
+ ```
48
+
49
+ ---
50
+
51
+ ## 📊 Dataset
52
+
53
+ We used the **SPoC dataset** from Stanford:
54
+
55
+ - ✅ Clean pseudocode–C++ line pairs
56
+ - ✅ Token-level annotations for syntax handling
57
+ - ✅ Multiple test splits (generalization to problems/workers)
58
+ - ✅ Custom preprocessing and vocabulary building implemented
59
+
60
+ > 📎 Licensed under [CC BY 4.0](https://creativecommons.org/licenses/by/4.0/)
61
+
62
+ ---
63
+
64
+ ## 📁 Directory Structure
65
+
66
+ ```
67
+
68
+ .
69
+ ├── app.py # Gradio web app for inference
70
+ ├── train.py # Transformer training code
71
+ ├── model.pth # Trained model weights
72
+ ├── spoc/ # Dataset directory
73
+ │ └── train/
74
+ │ ├── spoc-train.tsv
75
+ │ └── split/spoc-train-eval.tsv
76
+ ├── assets/
77
+ │ └── demo.png # App screenshot
78
+ └── README.md # You're here
79
+
80
+ ````
81
+
82
+ ---
83
+
84
+ ## 🛠️ How to Run Locally
85
+
86
+ ### ⚙️ 1. Clone Repo & Install Requirements
87
+
88
+ ```bash
89
+ git clone https://github.com/asadsandhu/Pseudo2Code.git
90
+ cd Pseudo2Code
91
+ pip install -r requirements.txt
92
+ ````
93
+
94
+ Or manually install:
95
+
96
+ ```bash
97
+ pip install torch gradio tqdm
98
+ ```
99
+
100
+ ### 🚀 2. Launch the App
101
+
102
+ Make sure `model.pth` is present (or train using `train.py`):
103
+
104
+ ```bash
105
+ python app.py
106
+ ```
107
+
108
+ The app will open in your browser.
109
+
110
+ ---
111
+
112
+ ## 🧪 Training the Model
113
+
114
+ You can retrain the model using the `train.py` script:
115
+
116
+ ```bash
117
+ python train.py
118
+ ```
119
+
120
+ By default, it downloads data from the public repo and trains for 10 epochs.
121
+ Outputs a `model.pth` file with learned weights and vocab.
122
+
123
+ ---
124
+
125
+ ## 🔧 Key Hyperparameters
126
+
127
+ | Parameter | Value |
128
+ | -------------- | ----------- |
129
+ | Model Type | Transformer |
130
+ | Max Length | 128 |
131
+ | Embedding Dim | 256 |
132
+ | FFN Dim | 512 |
133
+ | Heads | 4 |
134
+ | Encoder Layers | 2 |
135
+ | Decoder Layers | 2 |
136
+ | Batch Size | 64 |
137
+ | Epochs | 10 |
138
+ | Optimizer | Adam |
139
+ | Learning Rate | 1e-4 |
140
+
141
+ ---
142
+
143
+ ## 🧩 Example Input
144
+
145
+ ```text
146
+ n , nn, ans = integers with ans =0
147
+ Read n
148
+ for i=2 to n-1 execute
149
+ set nn to n
150
+ while nn is not equal to 0, set ans to ans + nn%i, and also set nn= nn/i
151
+ }
152
+ set o to gcd(ans, n-2)
153
+ print out ans/o "/" (n-2)/o
154
+ ```
155
+
156
+ ### ⏩ Output C++
157
+
158
+ ```cpp
159
+ int main() {
160
+ int n , nn , ans = 0 ;
161
+ cin > > n ;
162
+ for ( int i = 2 ; i < = n - 1 ; i + + ) {
163
+ nn = n ;
164
+ while ( nn = = 0 ) ans + = nn % i , nn / = i ;
165
+ }
166
+ o = gcd ( ans , n - 2 ) ;
167
+ cout < < ans / 2 / o ( n - 2 ) / o < < endl ;
168
+ return 0;
169
+ }
170
+ ```
171
+
172
+ ---
173
+
174
+ ## 📦 Deployment
175
+
176
+ This app is deployed live on:
177
+
178
+ * **Hugging Face Spaces**: [Pseudo2Code](https://huggingface.co/spaces/asadsandhu/Pseudo2Code)
179
+ * **GitHub**: [github.com/asadsandhu/Pseudo2Code](https://github.com/asadsandhu/Pseudo2Code)
180
+
181
+ ---
182
+
183
+ ## 🙌 Acknowledgements
184
+
185
+ * 📘 **SPoC Dataset** by Stanford University
186
+ Kulal, S., Pasupat, P., & Liang, P. (2020). [SPoC: Search-based Pseudocode to Code](https://arxiv.org/abs/2005.04326)
187
+
188
+ * 🧠 Transformer Paper: ["Attention is All You Need"](https://arxiv.org/abs/1706.03762)
189
+
190
+ ---
191
+
192
+ ## 🧑‍💻 Author
193
+
194
+ **Asad Ali**
195
+ [GitHub: asadsandhu](https://github.com/asadsandhu)
196
+ [Hugging Face: asadsandhu](https://huggingface.co/asadsandhu)
197
+ [LinkedIn: asadxali](https://www.linkedin.com/in/asadxali)
198
+
199
+ ---
200
+
201
+ ## 📄 License
202
+
203
+ This project is licensed under the MIT License.
204
+ Feel free to use, modify, and share with credit.
train.ipynb DELETED
@@ -1,656 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "gpuType": "T4"
8
- },
9
- "kernelspec": {
10
- "name": "python3",
11
- "display_name": "Python 3"
12
- },
13
- "language_info": {
14
- "name": "python"
15
- },
16
- "accelerator": "GPU"
17
- },
18
- "cells": [
19
- {
20
- "cell_type": "code",
21
- "execution_count": null,
22
- "metadata": {
23
- "colab": {
24
- "base_uri": "https://localhost:8080/"
25
- },
26
- "collapsed": true,
27
- "id": "12APLOKE15uD",
28
- "outputId": "fb61078b-a249-476a-af53-e43ca978c8c1"
29
- },
30
- "outputs": [
31
- {
32
- "output_type": "stream",
33
- "name": "stdout",
34
- "text": [
35
- "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.5.1+cu124)\n",
36
- "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
37
- "Requirement already satisfied: streamlit in /usr/local/lib/python3.11/dist-packages (1.42.2)\n",
38
- "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.17.0)\n",
39
- "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.12.2)\n",
40
- "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.4.2)\n",
41
- "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.5)\n",
42
- "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2024.10.0)\n",
43
- "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
44
- "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
45
- "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
46
- "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n",
47
- "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n",
48
- "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n",
49
- "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n",
50
- "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n",
51
- "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n",
52
- "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
53
- "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
54
- "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
55
- "Requirement already satisfied: triton==3.1.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.0)\n",
56
- "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
57
- "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
58
- "Requirement already satisfied: altair<6,>=4.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (5.5.0)\n",
59
- "Requirement already satisfied: blinker<2,>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (1.9.0)\n",
60
- "Requirement already satisfied: cachetools<6,>=4.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (5.5.1)\n",
61
- "Requirement already satisfied: click<9,>=7.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (8.1.8)\n",
62
- "Requirement already satisfied: numpy<3,>=1.23 in /usr/local/lib/python3.11/dist-packages (from streamlit) (1.26.4)\n",
63
- "Requirement already satisfied: packaging<25,>=20 in /usr/local/lib/python3.11/dist-packages (from streamlit) (24.2)\n",
64
- "Requirement already satisfied: pandas<3,>=1.4.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (2.2.2)\n",
65
- "Requirement already satisfied: pillow<12,>=7.1.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (11.1.0)\n",
66
- "Requirement already satisfied: protobuf<6,>=3.20 in /usr/local/lib/python3.11/dist-packages (from streamlit) (4.25.6)\n",
67
- "Requirement already satisfied: pyarrow>=7.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (17.0.0)\n",
68
- "Requirement already satisfied: requests<3,>=2.27 in /usr/local/lib/python3.11/dist-packages (from streamlit) (2.32.3)\n",
69
- "Requirement already satisfied: rich<14,>=10.14.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (13.9.4)\n",
70
- "Requirement already satisfied: tenacity<10,>=8.1.0 in /usr/local/lib/python3.11/dist-packages (from streamlit) (9.0.0)\n",
71
- "Requirement already satisfied: toml<2,>=0.10.1 in /usr/local/lib/python3.11/dist-packages (from streamlit) (0.10.2)\n",
72
- "Requirement already satisfied: watchdog<7,>=2.1.5 in /usr/local/lib/python3.11/dist-packages (from streamlit) (6.0.0)\n",
73
- "Requirement already satisfied: gitpython!=3.1.19,<4,>=3.0.7 in /usr/local/lib/python3.11/dist-packages (from streamlit) (3.1.44)\n",
74
- "Requirement already satisfied: pydeck<1,>=0.8.0b4 in /usr/local/lib/python3.11/dist-packages (from streamlit) (0.9.1)\n",
75
- "Requirement already satisfied: tornado<7,>=6.0.3 in /usr/local/lib/python3.11/dist-packages (from streamlit) (6.4.2)\n",
76
- "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.11/dist-packages (from altair<6,>=4.0->streamlit) (4.23.0)\n",
77
- "Requirement already satisfied: narwhals>=1.14.2 in /usr/local/lib/python3.11/dist-packages (from altair<6,>=4.0->streamlit) (1.27.1)\n",
78
- "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from gitpython!=3.1.19,<4,>=3.0.7->streamlit) (4.0.12)\n",
79
- "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3,>=1.4.0->streamlit) (2.8.2)\n",
80
- "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3,>=1.4.0->streamlit) (2025.1)\n",
81
- "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3,>=1.4.0->streamlit) (2025.1)\n",
82
- "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
83
- "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.27->streamlit) (3.4.1)\n",
84
- "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.27->streamlit) (3.10)\n",
85
- "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.27->streamlit) (2.3.0)\n",
86
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.27->streamlit) (2025.1.31)\n",
87
- "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich<14,>=10.14.0->streamlit) (3.0.0)\n",
88
- "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich<14,>=10.14.0->streamlit) (2.18.0)\n",
89
- "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.11/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.19,<4,>=3.0.7->streamlit) (5.0.2)\n",
90
- "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (25.1.0)\n",
91
- "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (2024.10.1)\n",
92
- "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (0.36.2)\n",
93
- "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (0.22.3)\n",
94
- "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich<14,>=10.14.0->streamlit) (0.1.2)\n",
95
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3,>=1.4.0->streamlit) (1.17.0)\n"
96
- ]
97
- }
98
- ],
99
- "source": [
100
- "!pip install torch tqdm streamlit"
101
- ]
102
- },
103
- {
104
- "cell_type": "code",
105
- "source": [
106
- "######################################\n",
107
- "# Pseudocode2Cpp.py\n",
108
- "######################################\n",
109
- "import os\n",
110
- "import streamlit as st\n",
111
- "import torch\n",
112
- "import torch.nn as nn\n",
113
- "import torch.optim as optim\n",
114
- "import math\n",
115
- "import re\n",
116
- "from tqdm import tqdm\n",
117
- "from typing import List, Tuple\n",
118
- "import random\n",
119
- "import requests\n",
120
- "from torch.utils.data import DataLoader, TensorDataset"
121
- ],
122
- "metadata": {
123
- "id": "tEYW8hGR19sm"
124
- },
125
- "execution_count": null,
126
- "outputs": []
127
- },
128
- {
129
- "cell_type": "code",
130
- "source": [
131
- "# ----------------------------\n",
132
- "# 1. Hyperparameters\n",
133
- "# ----------------------------\n",
134
- "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
135
- "MAX_LEN = 128 # maximum sequence length\n",
136
- "EMBED_DIM = 256 # embedding dimension\n",
137
- "FF_DIM = 512 # feedforward dimension in Transformer\n",
138
- "NHEAD = 4 # number of heads in multihead attention\n",
139
- "NUM_ENCODER_LAYERS = 2\n",
140
- "NUM_DECODER_LAYERS = 2\n",
141
- "BATCH_SIZE = 64\n",
142
- "EPOCHS = 10 # Increase for real training\n",
143
- "LEARNING_RATE = 1e-4\n",
144
- "\n",
145
- "# Special tokens\n",
146
- "PAD_TOKEN = \"<pad>\"\n",
147
- "SOS_TOKEN = \"<sos>\"\n",
148
- "EOS_TOKEN = \"<eos>\"\n",
149
- "UNK_TOKEN = \"<unk>\""
150
- ],
151
- "metadata": {
152
- "id": "HelkrJ-01-2B"
153
- },
154
- "execution_count": null,
155
- "outputs": []
156
- },
157
- {
158
- "cell_type": "code",
159
- "source": [
160
- "# ----------------------------\n",
161
- "# 2. Data Loading & Preprocessing\n",
162
- "# ----------------------------\n",
163
- "\n",
164
- "def load_spoc_data(file_path: str):\n",
165
- " \"\"\"\n",
166
- " Loads (pseudo_code, cpp_code) pairs from a TSV file or raw GitHub link.\n",
167
- " Each line is assumed to have: pseudocode <tab> c++ code.\n",
168
- " \"\"\"\n",
169
- " pairs = []\n",
170
- "\n",
171
- " # If file_path is a URL, fetch it with requests\n",
172
- " if file_path.startswith(\"http\"):\n",
173
- " response = requests.get(file_path)\n",
174
- " response.raise_for_status()\n",
175
- " lines = response.text.strip().split(\"\\n\")\n",
176
- " else:\n",
177
- " # Otherwise, assume it's a local file path\n",
178
- " with open(file_path, 'r', encoding='utf-8') as f:\n",
179
- " lines = f.readlines()\n",
180
- "\n",
181
- " for line in lines:\n",
182
- " line = line.strip()\n",
183
- " if not line:\n",
184
- " continue\n",
185
- " cols = line.split('\\t')\n",
186
- " if len(cols) >= 2:\n",
187
- " pseudo = cols[0].strip()\n",
188
- " cpp = cols[1].strip()\n",
189
- " pairs.append((pseudo, cpp))\n",
190
- "\n",
191
- " return pairs\n",
192
- "\n",
193
- "def create_dataloader(pairs, src_stoi, tgt_stoi, batch_size):\n",
194
- " src_batches = []\n",
195
- " tgt_batches = []\n",
196
- " for pseudo, cpp in pairs:\n",
197
- " src_ids = pad_sequence(numericalize(pseudo, src_stoi), MAX_LEN, src_stoi[PAD_TOKEN])\n",
198
- " tgt_ids = pad_sequence(numericalize(cpp, tgt_stoi), MAX_LEN, tgt_stoi[PAD_TOKEN])\n",
199
- " src_batches.append(src_ids)\n",
200
- " tgt_batches.append(tgt_ids)\n",
201
- "\n",
202
- " src_tensor = torch.tensor(src_batches, dtype=torch.long)\n",
203
- " tgt_tensor = torch.tensor(tgt_batches, dtype=torch.long)\n",
204
- " dataset = TensorDataset(src_tensor, tgt_tensor)\n",
205
- " return DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)\n",
206
- "\n",
207
- "def tokenize_line(text: str) -> List[str]:\n",
208
- " \"\"\"Enhanced tokenizer for pseudocode/C++ patterns\"\"\"\n",
209
- " # Separate operators and punctuation\n",
210
- " text = re.sub(r'([=+\\-*/%<>!&|^~])', r' \\1 ', text) # Operators\n",
211
- " text = re.sub(r'(?<!:):(?!:)', r' : ', text) # Single colon\n",
212
- " return re.findall(r'\\b\\w+\\b|[-+*/%=<>!&|^~]+|[:;{},()\\[\\]\\.]', text)\n",
213
- "\n",
214
- "def build_vocab(pairs: List[Tuple[str, str]]) -> Tuple[dict, dict, dict, dict]:\n",
215
- " \"\"\"\n",
216
- " Build source (pseudo) and target (cpp) vocabularies from training data.\n",
217
- " Returns:\n",
218
- " src_stoi, src_itos, tgt_stoi, tgt_itos\n",
219
- " \"\"\"\n",
220
- " src_words = set()\n",
221
- " tgt_words = set()\n",
222
- "\n",
223
- " for (pseudo, cpp) in pairs:\n",
224
- " for tok in tokenize_line(pseudo):\n",
225
- " src_words.add(tok)\n",
226
- " for tok in tokenize_line(cpp):\n",
227
- " tgt_words.add(tok)\n",
228
- "\n",
229
- " # Add special tokens\n",
230
- " src_vocab = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + sorted(list(src_words))\n",
231
- " tgt_vocab = [PAD_TOKEN, SOS_TOKEN, EOS_TOKEN, UNK_TOKEN] + sorted(list(tgt_words))\n",
232
- "\n",
233
- " src_stoi = {w: i for i, w in enumerate(src_vocab)}\n",
234
- " src_itos = {i: w for i, w in enumerate(src_vocab)}\n",
235
- " tgt_stoi = {w: i for i, w in enumerate(tgt_vocab)}\n",
236
- " tgt_itos = {i: w for i, w in enumerate(tgt_vocab)}\n",
237
- "\n",
238
- " return src_stoi, src_itos, tgt_stoi, tgt_itos\n",
239
- "\n",
240
- "def numericalize(text: str, stoi: dict) -> List[int]:\n",
241
- " \"\"\"\n",
242
- " Convert text string to a list of token IDs.\n",
243
- " \"\"\"\n",
244
- " tokens = tokenize_line(text)\n",
245
- " ids = []\n",
246
- " for t in tokens:\n",
247
- " if t in stoi:\n",
248
- " ids.append(stoi[t])\n",
249
- " else:\n",
250
- " ids.append(stoi[UNK_TOKEN])\n",
251
- " return ids\n",
252
- "\n",
253
- "def pad_sequence(seq: List[int], max_len: int, pad_id: int) -> List[int]:\n",
254
- " \"\"\"Proper padding with SOS/EOS handling\"\"\"\n",
255
- " seq = seq[:max_len-2] # Leave space for SOS/EOS\n",
256
- " seq = [src_stoi[SOS_TOKEN]] + seq + [src_stoi[EOS_TOKEN]] # Add control tokens\n",
257
- " padding = [pad_id] * (max_len - len(seq))\n",
258
- " return seq + padding\n",
259
- "\n",
260
- "def create_batches(pairs, src_stoi, tgt_stoi, batch_size):\n",
261
- " \"\"\"\n",
262
- " Yield batches of data (source_ids, target_ids).\n",
263
- " \"\"\"\n",
264
- " random.shuffle(pairs)\n",
265
- " for i in range(0, len(pairs), batch_size):\n",
266
- " batch_pairs = pairs[i:i+batch_size]\n",
267
- " src_batch = []\n",
268
- " tgt_batch = []\n",
269
- " for pseudo, cpp in batch_pairs:\n",
270
- " src_ids = numericalize(pseudo, src_stoi)\n",
271
- " tgt_ids = numericalize(cpp, tgt_stoi)\n",
272
- "\n",
273
- " # Pad/truncate\n",
274
- " src_ids = pad_sequence(src_ids, MAX_LEN, src_stoi[PAD_TOKEN])\n",
275
- " tgt_ids = pad_sequence(tgt_ids, MAX_LEN, tgt_stoi[PAD_TOKEN])\n",
276
- "\n",
277
- " src_batch.append(src_ids)\n",
278
- " tgt_batch.append(tgt_ids)\n",
279
- "\n",
280
- " src_batch = torch.tensor(src_batch, dtype=torch.long, device=DEVICE)\n",
281
- " tgt_batch = torch.tensor(tgt_batch, dtype=torch.long, device=DEVICE)\n",
282
- " yield src_batch, tgt_batch"
283
- ],
284
- "metadata": {
285
- "id": "2lFlkj-t2AGg"
286
- },
287
- "execution_count": null,
288
- "outputs": []
289
- },
290
- {
291
- "cell_type": "code",
292
- "source": [
293
- "# ----------------------------\n",
294
- "# 3. Transformer Model Implementation (from scratch)\n",
295
- "# ----------------------------\n",
296
- "\n",
297
- "class PositionalEncoding(nn.Module):\n",
298
- " def __init__(self, d_model, max_len=5000):\n",
299
- " super(PositionalEncoding, self).__init__()\n",
300
- " pe = torch.zeros(max_len, d_model)\n",
301
- " position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
302
- " div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
303
- " pe[:, 0::2] = torch.sin(position * div_term)\n",
304
- " pe[:, 1::2] = torch.cos(position * div_term)\n",
305
- " pe = pe.unsqueeze(0) # shape (1, max_len, d_model)\n",
306
- " self.register_buffer('pe', pe)\n",
307
- "\n",
308
- " def forward(self, x):\n",
309
- " # x shape: (batch_size, seq_len, d_model)\n",
310
- " seq_len = x.size(1)\n",
311
- " x = x + self.pe[:, :seq_len, :]\n",
312
- " return x\n",
313
- "\n",
314
- "class MultiHeadAttention(nn.Module):\n",
315
- " def __init__(self, d_model, n_heads):\n",
316
- " super(MultiHeadAttention, self).__init__()\n",
317
- " assert d_model % n_heads == 0\n",
318
- " self.d_model = d_model\n",
319
- " self.n_heads = n_heads\n",
320
- " self.head_dim = d_model // n_heads\n",
321
- "\n",
322
- " self.query_linear = nn.Linear(d_model, d_model)\n",
323
- " self.key_linear = nn.Linear(d_model, d_model)\n",
324
- " self.value_linear = nn.Linear(d_model, d_model)\n",
325
- " self.out_linear = nn.Linear(d_model, d_model)\n",
326
- "\n",
327
- " def forward(self, query, key, value, mask=None):\n",
328
- " # query/key/value shape: (batch_size, seq_len, d_model)\n",
329
- " B, Q_len, _ = query.size()\n",
330
- " B, K_len, _ = key.size()\n",
331
- " B, V_len, _ = value.size()\n",
332
- "\n",
333
- " # Linear projections\n",
334
- " Q = self.query_linear(query) # (B, Q_len, d_model)\n",
335
- " K = self.key_linear(key) # (B, K_len, d_model)\n",
336
- " V = self.value_linear(value) # (B, V_len, d_model)\n",
337
- "\n",
338
- " # Reshape for multi-head\n",
339
- " Q = Q.view(B, Q_len, self.n_heads, self.head_dim).transpose(1,2) # (B, n_heads, Q_len, head_dim)\n",
340
- " K = K.view(B, K_len, self.n_heads, self.head_dim).transpose(1,2) # (B, n_heads, K_len, head_dim)\n",
341
- " V = V.view(B, V_len, self.n_heads, self.head_dim).transpose(1,2) # (B, n_heads, V_len, head_dim)\n",
342
- "\n",
343
- " # Scaled dot-product attention\n",
344
- " scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (B, n_heads, Q_len, K_len)\n",
345
- " if mask is not None:\n",
346
- " scores = scores.masked_fill(mask == 0, float('-inf'))\n",
347
- " attn = torch.softmax(scores, dim=-1) # (B, n_heads, Q_len, K_len)\n",
348
- "\n",
349
- " context = torch.matmul(attn, V) # (B, n_heads, Q_len, head_dim)\n",
350
- " context = context.transpose(1,2).contiguous().view(B, Q_len, self.d_model)\n",
351
- " out = self.out_linear(context)\n",
352
- " return out\n",
353
- "\n",
354
- "class FeedForward(nn.Module):\n",
355
- " def __init__(self, d_model, dim_feedforward):\n",
356
- " super(FeedForward, self).__init__()\n",
357
- " self.fc1 = nn.Linear(d_model, dim_feedforward)\n",
358
- " self.fc2 = nn.Linear(dim_feedforward, d_model)\n",
359
- " self.relu = nn.ReLU()\n",
360
- "\n",
361
- " def forward(self, x):\n",
362
- " return self.fc2(self.relu(self.fc1(x)))\n",
363
- "\n",
364
- "class EncoderLayer(nn.Module):\n",
365
- " def __init__(self, d_model, n_heads, dim_feedforward):\n",
366
- " super(EncoderLayer, self).__init__()\n",
367
- " self.self_attn = MultiHeadAttention(d_model, n_heads)\n",
368
- " self.ff = FeedForward(d_model, dim_feedforward)\n",
369
- " self.norm1 = nn.LayerNorm(d_model)\n",
370
- " self.norm2 = nn.LayerNorm(d_model)\n",
371
- " self.dropout = nn.Dropout(0.1)\n",
372
- "\n",
373
- " def forward(self, src, src_mask=None):\n",
374
- " # Self-attention\n",
375
- " attn_out = self.self_attn(src, src, src, mask=src_mask)\n",
376
- " src = self.norm1(src + self.dropout(attn_out))\n",
377
- " # Feed Forward\n",
378
- " ff_out = self.ff(src)\n",
379
- " src = self.norm2(src + self.dropout(ff_out))\n",
380
- " return src\n",
381
- "\n",
382
- "class DecoderLayer(nn.Module):\n",
383
- " def __init__(self, d_model, n_heads, dim_feedforward):\n",
384
- " super(DecoderLayer, self).__init__()\n",
385
- " self.self_attn = MultiHeadAttention(d_model, n_heads)\n",
386
- " self.cross_attn = MultiHeadAttention(d_model, n_heads)\n",
387
- " self.ff = FeedForward(d_model, dim_feedforward)\n",
388
- "\n",
389
- " self.norm1 = nn.LayerNorm(d_model)\n",
390
- " self.norm2 = nn.LayerNorm(d_model)\n",
391
- " self.norm3 = nn.LayerNorm(d_model)\n",
392
- " self.dropout = nn.Dropout(0.1)\n",
393
- "\n",
394
- " def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):\n",
395
- " # Self-attention (mask future tokens)\n",
396
- " _tgt = tgt\n",
397
- " tgt = self.norm1(tgt + self.dropout(self.self_attn(tgt, tgt, tgt, mask=tgt_mask)))\n",
398
- " # Cross-attention\n",
399
- " _tgt2 = tgt\n",
400
- " tgt = self.norm2(tgt + self.dropout(self.cross_attn(tgt, memory, memory, mask=memory_mask)))\n",
401
- " # Feed Forward\n",
402
- " ff_out = self.ff(tgt)\n",
403
- " tgt = self.norm3(tgt + self.dropout(ff_out))\n",
404
- " return tgt\n",
405
- "\n",
406
- "class Encoder(nn.Module):\n",
407
- " def __init__(self, vocab_size, d_model, n_heads, num_layers, dim_feedforward):\n",
408
- " super(Encoder, self).__init__()\n",
409
- " self.embedding = nn.Embedding(vocab_size, d_model)\n",
410
- " self.pos_encoding = PositionalEncoding(d_model)\n",
411
- " self.layers = nn.ModuleList([\n",
412
- " EncoderLayer(d_model, n_heads, dim_feedforward)\n",
413
- " for _ in range(num_layers)\n",
414
- " ])\n",
415
- "\n",
416
- " def forward(self, src, src_mask=None):\n",
417
- " # src shape: (batch_size, seq_len)\n",
418
- " x = self.embedding(src) # (batch_size, seq_len, d_model)\n",
419
- " x = self.pos_encoding(x)\n",
420
- " for layer in self.layers:\n",
421
- " x = layer(x, src_mask)\n",
422
- " return x\n",
423
- "\n",
424
- "class Decoder(nn.Module):\n",
425
- " def __init__(self, vocab_size, d_model, n_heads, num_layers, dim_feedforward):\n",
426
- " super(Decoder, self).__init__()\n",
427
- " self.embedding = nn.Embedding(vocab_size, d_model)\n",
428
- " self.pos_encoding = PositionalEncoding(d_model)\n",
429
- " self.layers = nn.ModuleList([\n",
430
- " DecoderLayer(d_model, n_heads, dim_feedforward)\n",
431
- " for _ in range(num_layers)\n",
432
- " ])\n",
433
- " self.fc_out = nn.Linear(d_model, vocab_size)\n",
434
- "\n",
435
- " def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):\n",
436
- " x = self.embedding(tgt)\n",
437
- " x = self.pos_encoding(x)\n",
438
- " for layer in self.layers:\n",
439
- " x = layer(x, memory, tgt_mask, memory_mask)\n",
440
- " logits = self.fc_out(x) # (batch_size, seq_len, vocab_size)\n",
441
- " return logits\n",
442
- "\n",
443
- "class TransformerSeq2Seq(nn.Module):\n",
444
- " def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_heads, num_encoder_layers,\n",
445
- " num_decoder_layers, dim_feedforward):\n",
446
- " super(TransformerSeq2Seq, self).__init__()\n",
447
- " self.encoder = Encoder(src_vocab_size, d_model, n_heads, num_encoder_layers, dim_feedforward)\n",
448
- " self.decoder = Decoder(tgt_vocab_size, d_model, n_heads, num_decoder_layers, dim_feedforward)\n",
449
- "\n",
450
- " def forward(self, src, tgt, src_mask=None, tgt_mask=None):\n",
451
- " # src: (batch_size, src_seq_len)\n",
452
- " # tgt: (batch_size, tgt_seq_len)\n",
453
- " memory = self.encoder(src, src_mask) # (batch_size, src_seq_len, d_model)\n",
454
- " outputs = self.decoder(tgt, memory, tgt_mask) # (batch_size, tgt_seq_len, vocab_size)\n",
455
- " return outputs"
456
- ],
457
- "metadata": {
458
- "id": "f8HioKcS2ZRy"
459
- },
460
- "execution_count": null,
461
- "outputs": []
462
- },
463
- {
464
- "cell_type": "code",
465
- "source": [
466
- "# ----------------------------\n",
467
- "# 4. Training Setup\n",
468
- "# ----------------------------\n",
469
- "import torch\n",
470
- "import torch.nn as nn\n",
471
- "from torch.utils.data import DataLoader, TensorDataset\n",
472
- "from typing import List, Tuple\n",
473
- "import random\n",
474
- "def generate_subsequent_mask(size):\n",
475
- " # Mask out subsequent positions (for decoding)\n",
476
- " mask = torch.triu(torch.ones(size, size), diagonal=1).bool()\n",
477
- " return ~mask # True where we can attend, False where we cannot\n",
478
- "\n",
479
- "def train_one_epoch(model, optimizer, criterion, train_data, src_stoi, tgt_stoi):\n",
480
- " model.train()\n",
481
- " total_loss = 0\n",
482
- " steps = 0\n",
483
- "\n",
484
- " data_loader = create_dataloader(train_pairs, src_stoi, tgt_stoi, BATCH_SIZE)\n",
485
- " for src_batch, tgt_batch in data_loader:\n",
486
- " src_batch = src_batch.to(DEVICE)\n",
487
- " tgt_batch = tgt_batch.to(DEVICE)\n",
488
- "\n",
489
- " # Prepare the target inputs and outputs (shifted by one token)\n",
490
- " tgt_inp = tgt_batch[:, :-1]\n",
491
- " tgt_out = tgt_batch[:, 1:]\n",
492
- "\n",
493
- " # Create subsequent mask for the target sequence\n",
494
- " tgt_seq_len = tgt_inp.size(1)\n",
495
- " tgt_mask = generate_subsequent_mask(tgt_seq_len).to(DEVICE)\n",
496
- "\n",
497
- " optimizer.zero_grad()\n",
498
- " logits = model(src_batch, tgt_inp, None, tgt_mask) # (B, seq_len, vocab_size)\n",
499
- "\n",
500
- " # Use .reshape() instead of .view() to avoid runtime errors\n",
501
- " loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))\n",
502
- " loss.backward()\n",
503
- " optimizer.step()\n",
504
- "\n",
505
- " total_loss += loss.item()\n",
506
- " steps += 1\n",
507
- "\n",
508
- " return total_loss / steps\n",
509
- "\n",
510
- "def evaluate(model, criterion, eval_data, src_stoi, tgt_stoi):\n",
511
- " model.eval()\n",
512
- " total_loss = 0\n",
513
- " steps = 0\n",
514
- " with torch.no_grad():\n",
515
- " for src_batch, tgt_batch in create_batches(eval_data, src_stoi, tgt_stoi, BATCH_SIZE):\n",
516
- " tgt_inp = tgt_batch[:, :-1]\n",
517
- " tgt_out = tgt_batch[:, 1:]\n",
518
- " tgt_seq_len = tgt_inp.size(1)\n",
519
- " tgt_mask = generate_subsequent_mask(tgt_seq_len).to(DEVICE)\n",
520
- "\n",
521
- " logits = model(src_batch, tgt_inp, None, tgt_mask)\n",
522
- " # Use .reshape() instead of .view()\n",
523
- " loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))\n",
524
- "\n",
525
- " total_loss += loss.item()\n",
526
- " steps += 1\n",
527
- " return total_loss / steps\n",
528
- "\n",
529
- "def greedy_decode(model, src, src_stoi, tgt_stoi, tgt_itos, max_len=MAX_LEN):\n",
530
- " \"\"\"\n",
531
- " Given a single source sequence (1D list of token IDs),\n",
532
- " generate a decoded target sequence using greedy search.\n",
533
- " \"\"\"\n",
534
- " model.eval()\n",
535
- " src = torch.tensor(src, dtype=torch.long, device=DEVICE).unsqueeze(0) # (1, seq_len)\n",
536
- " memory = model.encoder(src) # (1, seq_len, d_model)\n",
537
- "\n",
538
- " ys = torch.tensor([tgt_stoi[SOS_TOKEN]], dtype=torch.long, device=DEVICE).unsqueeze(0) # (1, 1)\n",
539
- " for i in range(max_len-1):\n",
540
- " tgt_mask = generate_subsequent_mask(ys.size(1)).to(DEVICE)\n",
541
- " out = model.decoder(ys, memory, tgt_mask) # (1, seq_len, vocab_size)\n",
542
- " prob = out[:, -1, :] # last timestep\n",
543
- " next_token = torch.argmax(prob, dim=1).item()\n",
544
- " ys = torch.cat([ys, torch.tensor([[next_token]], device=DEVICE)], dim=1)\n",
545
- " if next_token == tgt_stoi[EOS_TOKEN]:\n",
546
- " break\n",
547
- "\n",
548
- " # Convert back to tokens\n",
549
- " out_tokens = ys.squeeze(0).tolist() # e.g. [SOS, ..., EOS]\n",
550
- " # Remove the initial SOS\n",
551
- " out_tokens = out_tokens[1:]\n",
552
- " # Stop at EOS if present\n",
553
- " if tgt_stoi[EOS_TOKEN] in out_tokens:\n",
554
- " eos_idx = out_tokens.index(tgt_stoi[EOS_TOKEN])\n",
555
- " out_tokens = out_tokens[:eos_idx]\n",
556
- "\n",
557
- " return \" \".join(tgt_itos[t] for t in out_tokens)"
558
- ],
559
- "metadata": {
560
- "id": "ffYgGSXy2a4B"
561
- },
562
- "execution_count": null,
563
- "outputs": []
564
- },
565
- {
566
- "cell_type": "code",
567
- "source": [
568
- "# ----------------------------\n",
569
- "# 5. Main: Train the Model\n",
570
- "# ----------------------------\n",
571
- "if __name__ == \"__main__\":\n",
572
- " # Hardcode the file paths from your GitHub repo (raw URLs):\n",
573
- " train_path = \"https://raw.githubusercontent.com/asadsandhu/Pseudocode2Cpp/main/spoc/train/spoc-train.tsv\"\n",
574
- " eval_path = \"https://raw.githubusercontent.com/asadsandhu/Pseudocode2Cpp/main/spoc/train/split/spoc-train-eval.tsv\"\n",
575
- "\n",
576
- " print(f\"Loading training data from {train_path} ...\")\n",
577
- " train_pairs = load_spoc_data(train_path)\n",
578
- " print(f\"Loaded {len(train_pairs)} training pairs.\")\n",
579
- "\n",
580
- " print(f\"Loading eval data from {eval_path} ...\")\n",
581
- " eval_pairs = load_spoc_data(eval_path)\n",
582
- " print(f\"Loaded {len(eval_pairs)} eval pairs.\")\n",
583
- "\n",
584
- " print(\"Building vocab...\")\n",
585
- " src_stoi, src_itos, tgt_stoi, tgt_itos = build_vocab(train_pairs)\n",
586
- " global stoi_eos\n",
587
- " stoi_eos = tgt_stoi[EOS_TOKEN] # for pad_sequence usage\n",
588
- "\n",
589
- " print(\"Creating model...\")\n",
590
- " model = TransformerSeq2Seq(\n",
591
- " src_vocab_size=len(src_stoi),\n",
592
- " tgt_vocab_size=len(tgt_stoi),\n",
593
- " d_model=EMBED_DIM,\n",
594
- " n_heads=NHEAD,\n",
595
- " num_encoder_layers=NUM_ENCODER_LAYERS,\n",
596
- " num_decoder_layers=NUM_DECODER_LAYERS,\n",
597
- " dim_feedforward=FF_DIM\n",
598
- " ).to(DEVICE)\n",
599
- "\n",
600
- " criterion = nn.CrossEntropyLoss(ignore_index=tgt_stoi[PAD_TOKEN])\n",
601
- " optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
602
- "\n",
603
- " print(\"Starting training...\")\n",
604
- " for epoch in range(1, EPOCHS+1):\n",
605
- " train_loss = train_one_epoch(model, optimizer, criterion, train_pairs, src_stoi, tgt_stoi)\n",
606
- " eval_loss = evaluate(model, criterion, eval_pairs, src_stoi, tgt_stoi)\n",
607
- " print(f\"Epoch [{epoch}/{EPOCHS}] - Train Loss: {train_loss:.4f}, Eval Loss: {eval_loss:.4f}\")\n",
608
- "\n",
609
- " # Save model & vocab\n",
610
- " torch.save({\n",
611
- " 'model_state_dict': model.state_dict(),\n",
612
- " 'src_stoi': src_stoi,\n",
613
- " 'src_itos': src_itos,\n",
614
- " 'tgt_stoi': tgt_stoi,\n",
615
- " 'tgt_itos': tgt_itos\n",
616
- " }, \"model.pth\")\n",
617
- "\n",
618
- " print(\"Model and vocab saved to model.pth\")"
619
- ],
620
- "metadata": {
621
- "colab": {
622
- "base_uri": "https://localhost:8080/"
623
- },
624
- "id": "iffrMhkc2cVt",
625
- "outputId": "38839989-38e5-4b10-fbea-90767dca60e3"
626
- },
627
- "execution_count": null,
628
- "outputs": [
629
- {
630
- "output_type": "stream",
631
- "name": "stdout",
632
- "text": [
633
- "Loading training data from https://raw.githubusercontent.com/asadsandhu/Pseudocode2Cpp/main/spoc/train/spoc-train.tsv ...\n",
634
- "Loaded 293855 training pairs.\n",
635
- "Loading eval data from https://raw.githubusercontent.com/asadsandhu/Pseudocode2Cpp/main/spoc/train/split/spoc-train-eval.tsv ...\n",
636
- "Loaded 27289 eval pairs.\n",
637
- "Building vocab...\n",
638
- "Creating model...\n",
639
- "Starting training...\n",
640
- "Epoch [1/10] - Train Loss: 0.9915, Eval Loss: 0.4901\n",
641
- "Epoch [2/10] - Train Loss: 0.4401, Eval Loss: 0.3597\n",
642
- "Epoch [3/10] - Train Loss: 0.3326, Eval Loss: 0.2897\n",
643
- "Epoch [4/10] - Train Loss: 0.2752, Eval Loss: 0.2735\n",
644
- "Epoch [5/10] - Train Loss: 0.2401, Eval Loss: 0.2281\n",
645
- "Epoch [6/10] - Train Loss: 0.2166, Eval Loss: 0.2111\n",
646
- "Epoch [7/10] - Train Loss: 0.2002, Eval Loss: 0.2015\n",
647
- "Epoch [8/10] - Train Loss: 0.1883, Eval Loss: 0.1919\n",
648
- "Epoch [9/10] - Train Loss: 0.1793, Eval Loss: 0.1848\n",
649
- "Epoch [10/10] - Train Loss: 0.1724, Eval Loss: 0.1819\n",
650
- "Model and vocab saved to transformer_spoc.pth\n"
651
- ]
652
- }
653
- ]
654
- }
655
- ]
656
- }