Spaces:
Runtime error
Runtime error
Add model details and set training parameters
Browse files- train_llm.ipynb +302 -0
- train_llm.py +96 -31
train_llm.ipynb
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 32,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": [
|
| 9 |
+
"import os\n",
|
| 10 |
+
"from uuid import uuid4\n",
|
| 11 |
+
"import pandas as pd\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"from datasets import load_dataset\n",
|
| 14 |
+
"import subprocess\n",
|
| 15 |
+
"from transformers import AutoTokenizer"
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"cell_type": "code",
|
| 20 |
+
"execution_count": 33,
|
| 21 |
+
"metadata": {},
|
| 22 |
+
"outputs": [],
|
| 23 |
+
"source": [
|
| 24 |
+
"# from dotenv import load_dotenv,find_dotenv\n",
|
| 25 |
+
"# load_dotenv(find_dotenv(),override=True)\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"def max_token_len(dataset):\n",
|
| 28 |
+
" max_seq_length = 0\n",
|
| 29 |
+
" for row in dataset:\n",
|
| 30 |
+
" tokens = len(tokenizer(row['text'])['input_ids'])\n",
|
| 31 |
+
" if tokens > max_seq_length:\n",
|
| 32 |
+
" max_seq_length = tokens\n",
|
| 33 |
+
" return max_seq_length"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": 34,
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [
|
| 41 |
+
{
|
| 42 |
+
"name": "stdout",
|
| 43 |
+
"output_type": "stream",
|
| 44 |
+
"text": [
|
| 45 |
+
"Model Max Length: 1000000000000000019884624838656\n"
|
| 46 |
+
]
|
| 47 |
+
}
|
| 48 |
+
],
|
| 49 |
+
"source": [
|
| 50 |
+
"# model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1'\n",
|
| 51 |
+
"model_name = 'mistralai/Mistral-7B-v0.1'\n",
|
| 52 |
+
"# model_name = 'distilbert-base-uncased'\n",
|
| 53 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
| 54 |
+
"model_max_length = tokenizer.model_max_length\n",
|
| 55 |
+
"print(\"Model Max Length:\", model_max_length)"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": 37,
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [
|
| 63 |
+
{
|
| 64 |
+
"name": "stdout",
|
| 65 |
+
"output_type": "stream",
|
| 66 |
+
"text": [
|
| 67 |
+
"Max token length train: 1121\n",
|
| 68 |
+
"Max token length validation: 38\n",
|
| 69 |
+
"Block size: 2242\n"
|
| 70 |
+
]
|
| 71 |
+
}
|
| 72 |
+
],
|
| 73 |
+
"source": [
|
| 74 |
+
"# Load dataset\n",
|
| 75 |
+
"dataset_name = 'ai-aerospace/ams_data_train_generic_v0.1_100'\n",
|
| 76 |
+
"dataset=load_dataset(dataset_name)\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"# Write dataset files into data directory\n",
|
| 79 |
+
"data_directory = './fine_tune_data/'\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"# Create the data directory if it doesn't exist\n",
|
| 82 |
+
"os.makedirs(data_directory, exist_ok=True)\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"# Write the train data to a CSV file\n",
|
| 85 |
+
"train_data='train_data'\n",
|
| 86 |
+
"train_filename = os.path.join(data_directory, train_data)\n",
|
| 87 |
+
"dataset['train'].to_pandas().to_csv(train_filename+'.csv', columns=['text'], index=False)\n",
|
| 88 |
+
"max_token_length_train=max_token_len(dataset['train'])\n",
|
| 89 |
+
"print('Max token length train: '+str(max_token_length_train))\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# Write the validation data to a CSV file\n",
|
| 92 |
+
"validation_data='validation_data'\n",
|
| 93 |
+
"validation_filename = os.path.join(data_directory, validation_data)\n",
|
| 94 |
+
"dataset['validation'].to_pandas().to_csv(validation_filename+'.csv', columns=['text'], index=False)\n",
|
| 95 |
+
"max_token_length_validation=max_token_len(dataset['validation'])\n",
|
| 96 |
+
"print('Max token length validation: '+str(max_token_length_validation))\n",
|
| 97 |
+
" \n",
|
| 98 |
+
"max_token_length=max(max_token_length_train,max_token_length_validation)\n",
|
| 99 |
+
"if max_token_length > model_max_length:\n",
|
| 100 |
+
" raise ValueError(\"Maximum token length exceeds model limits.\")\n",
|
| 101 |
+
"block_size=2*max_token_length\n",
|
| 102 |
+
"print('Block size: '+str(block_size))\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"# Define project parameters\n",
|
| 105 |
+
"username='ai-aerospace'\n",
|
| 106 |
+
"project_name='./llms/'+'ams_data_train-100_'+str(uuid4())\n",
|
| 107 |
+
"repo_name='ams-data-train-100-'+str(uuid4())"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "code",
|
| 112 |
+
"execution_count": 46,
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [
|
| 115 |
+
{
|
| 116 |
+
"name": "stdout",
|
| 117 |
+
"output_type": "stream",
|
| 118 |
+
"text": [
|
| 119 |
+
"{'project_name': './llms/ams_data_train-100_6abb23dc-cb9d-428e-9079-e47deee0edd9', 'model_name': 'mistralai/Mistral-7B-v0.1', 'repo_id': 'ai-aerospace/ams-data-train-100-4601c8c8-0903-4f18-a6e8-1d2a40a697ce', 'train_data': 'train_data', 'validation_data': 'validation_data', 'data_directory': './fine_tune_data/', 'block_size': 2242, 'model_max_length': 1121, 'logging_steps': -1, 'evaluation_strategy': 'epoch', 'save_total_limit': 1, 'save_strategy': 'epoch', 'mixed_precision': 'fp16', 'lr': 3e-05, 'epochs': 3, 'batch_size': 2, 'warmup_ratio': 0.1, 'gradient_accumulation': 1, 'optimizer': 'adamw_torch', 'scheduler': 'linear', 'weight_decay': 0, 'max_grad_norm': 1, 'seed': 42, 'quantization': 'int4', 'lora_r': 16, 'lora_alpha': 32, 'lora_dropout': 0.05}\n"
|
| 120 |
+
]
|
| 121 |
+
}
|
| 122 |
+
],
|
| 123 |
+
"source": [
|
| 124 |
+
"\"\"\"\n",
|
| 125 |
+
"This set of parameters runs on a low memory gpu on hugging face spaces:\n",
|
| 126 |
+
"{\n",
|
| 127 |
+
" \"block_size\": 1024,\n",
|
| 128 |
+
" \"model_max_length\": 2048,\n",
|
| 129 |
+
" x\"use_flash_attention_2\": false,\n",
|
| 130 |
+
" x\"disable_gradient_checkpointing\": false,\n",
|
| 131 |
+
" \"logging_steps\": -1,\n",
|
| 132 |
+
" \"evaluation_strategy\": \"epoch\",\n",
|
| 133 |
+
" \"save_total_limit\": 1,\n",
|
| 134 |
+
" \"save_strategy\": \"epoch\",\n",
|
| 135 |
+
" x\"auto_find_batch_size\": false,\n",
|
| 136 |
+
" \"mixed_precision\": \"fp16\",\n",
|
| 137 |
+
" \"lr\": 0.00003,\n",
|
| 138 |
+
" \"epochs\": 3,\n",
|
| 139 |
+
" \"batch_size\": 2,\n",
|
| 140 |
+
" \"warmup_ratio\": 0.1,\n",
|
| 141 |
+
" \"gradient_accumulation\": 1,\n",
|
| 142 |
+
" \"optimizer\": \"adamw_torch\",\n",
|
| 143 |
+
" \"scheduler\": \"linear\",\n",
|
| 144 |
+
" \"weight_decay\": 0,\n",
|
| 145 |
+
" \"max_grad_norm\": 1,\n",
|
| 146 |
+
" \"seed\": 42,\n",
|
| 147 |
+
" \"apply_chat_template\": false,\n",
|
| 148 |
+
" \"quantization\": \"int4\",\n",
|
| 149 |
+
" \"target_modules\": \"\",\n",
|
| 150 |
+
" x\"merge_adapter\": false,\n",
|
| 151 |
+
" \"peft\": true,\n",
|
| 152 |
+
" \"lora_r\": 16,\n",
|
| 153 |
+
" \"lora_alpha\": 32,\n",
|
| 154 |
+
" \"lora_dropout\": 0.05\n",
|
| 155 |
+
"}\n",
|
| 156 |
+
"\"\"\"\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"model_params={\n",
|
| 159 |
+
" \"project_name\": project_name,\n",
|
| 160 |
+
" \"model_name\": model_name,\n",
|
| 161 |
+
" \"repo_id\": username+'/'+repo_name,\n",
|
| 162 |
+
" \"train_data\": train_data,\n",
|
| 163 |
+
" \"validation_data\": validation_data,\n",
|
| 164 |
+
" \"data_directory\": data_directory,\n",
|
| 165 |
+
" \"block_size\": block_size,\n",
|
| 166 |
+
" \"model_max_length\": max_token_length,\n",
|
| 167 |
+
" \"logging_steps\": -1,\n",
|
| 168 |
+
" \"evaluation_strategy\": \"epoch\",\n",
|
| 169 |
+
" \"save_total_limit\": 1,\n",
|
| 170 |
+
" \"save_strategy\": \"epoch\",\n",
|
| 171 |
+
" \"mixed_precision\": \"fp16\",\n",
|
| 172 |
+
" \"lr\": 0.00003,\n",
|
| 173 |
+
" \"epochs\": 3,\n",
|
| 174 |
+
" \"batch_size\": 2,\n",
|
| 175 |
+
" \"warmup_ratio\": 0.1,\n",
|
| 176 |
+
" \"gradient_accumulation\": 1,\n",
|
| 177 |
+
" \"optimizer\": \"adamw_torch\",\n",
|
| 178 |
+
" \"scheduler\": \"linear\",\n",
|
| 179 |
+
" \"weight_decay\": 0,\n",
|
| 180 |
+
" \"max_grad_norm\": 1,\n",
|
| 181 |
+
" \"seed\": 42,\n",
|
| 182 |
+
" \"quantization\": \"int4\",\n",
|
| 183 |
+
" \"lora_r\": 16,\n",
|
| 184 |
+
" \"lora_alpha\": 32,\n",
|
| 185 |
+
" \"lora_dropout\": 0.05\n",
|
| 186 |
+
"}\n",
|
| 187 |
+
"for key, value in model_params.items():\n",
|
| 188 |
+
" os.environ[key] = str(value)\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"print(model_params)\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"\n",
|
| 193 |
+
"# Save parameters to environment variables\n",
|
| 194 |
+
"# os.environ[\"project_name\"] = project_name\n",
|
| 195 |
+
"# os.environ[\"model_name\"] = model_name\n",
|
| 196 |
+
"# os.environ[\"repo_id\"] = username+'/'+repo_name\n",
|
| 197 |
+
"# os.environ[\"train_data\"] = train_data \n",
|
| 198 |
+
"# os.environ[\"validation_data\"] = validation_data\n",
|
| 199 |
+
"# os.environ[\"data_directory\"] = data_directory"
|
| 200 |
+
]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": 49,
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [
|
| 207 |
+
{
|
| 208 |
+
"name": "stderr",
|
| 209 |
+
"output_type": "stream",
|
| 210 |
+
"text": [
|
| 211 |
+
"⚠️ WARNING | 2023-12-22 10:39:42 | autotrain.cli.run_dreambooth:<module>:14 - ❌ Some DreamBooth components are missing! Please run `autotrain setup` to install it. Ignore this warning if you are not using DreamBooth or running `autotrain setup` already.\n",
|
| 212 |
+
"usage: autotrain <command> [<args>]\n",
|
| 213 |
+
"AutoTrain advanced CLI: error: unrecognized arguments: --batch_size 2\n"
|
| 214 |
+
]
|
| 215 |
+
},
|
| 216 |
+
{
|
| 217 |
+
"ename": "CalledProcessError",
|
| 218 |
+
"evalue": "Command '\nautotrain llm --train --trainer sft --project_name ./llms/ams_data_train-100_6abb23dc-cb9d-428e-9079-e47deee0edd9 --model mistralai/Mistral-7B-v0.1 --data_path ./fine_tune_data/ --train_split train_data --valid_split validation_data --repo_id ai-aerospace/ams-data-train-100-4601c8c8-0903-4f18-a6e8-1d2a40a697ce --push_to_hub --token HUGGINGFACE_TOKEN --block_size 2242 --model_max_length 1121 --logging_steps -1 --evaluation_strategy epoch --save_total_limit 1 --save_strategy epoch --fp16 --lr 3e-05 --num_train_epochs 3 --batch_size 2 --warmup_ratio 0.1 --gradient_accumulation 1 --optimizer adamw_torch --scheduler linear --weight_decay 0 --max_grad_norm 1 --seed 42 --use_int4 --use-peft --lora_r 16 --lora_alpha 32 --lora_dropout 0.05\n' returned non-zero exit status 2.",
|
| 219 |
+
"output_type": "error",
|
| 220 |
+
"traceback": [
|
| 221 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
| 222 |
+
"\u001b[0;31mCalledProcessError\u001b[0m Traceback (most recent call last)",
|
| 223 |
+
"Cell \u001b[0;32mIn[49], line 40\u001b[0m\n\u001b[1;32m 4\u001b[0m command\u001b[38;5;241m=\u001b[39m\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\"\"\u001b[39m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124mautotrain llm --train \u001b[39m\u001b[38;5;130;01m\\\u001b[39;00m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m --trainer sft \u001b[39m\u001b[38;5;130;01m\\\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;124m --lora_dropout \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel_params[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlora_dropout\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;124m\"\"\"\u001b[39m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;66;03m# Use subprocess.run() to execute the command\u001b[39;00m\n\u001b[0;32m---> 40\u001b[0m \u001b[43msubprocess\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcommand\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshell\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcheck\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
|
| 224 |
+
"File \u001b[0;32m/usr/lib/python3.11/subprocess.py:571\u001b[0m, in \u001b[0;36mrun\u001b[0;34m(input, capture_output, timeout, check, *popenargs, **kwargs)\u001b[0m\n\u001b[1;32m 569\u001b[0m retcode \u001b[38;5;241m=\u001b[39m process\u001b[38;5;241m.\u001b[39mpoll()\n\u001b[1;32m 570\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check \u001b[38;5;129;01mand\u001b[39;00m retcode:\n\u001b[0;32m--> 571\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m CalledProcessError(retcode, process\u001b[38;5;241m.\u001b[39margs,\n\u001b[1;32m 572\u001b[0m output\u001b[38;5;241m=\u001b[39mstdout, stderr\u001b[38;5;241m=\u001b[39mstderr)\n\u001b[1;32m 573\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m CompletedProcess(process\u001b[38;5;241m.\u001b[39margs, retcode, stdout, stderr)\n",
|
| 225 |
+
"\u001b[0;31mCalledProcessError\u001b[0m: Command '\nautotrain llm --train --trainer sft --project_name ./llms/ams_data_train-100_6abb23dc-cb9d-428e-9079-e47deee0edd9 --model mistralai/Mistral-7B-v0.1 --data_path ./fine_tune_data/ --train_split train_data --valid_split validation_data --repo_id ai-aerospace/ams-data-train-100-4601c8c8-0903-4f18-a6e8-1d2a40a697ce --push_to_hub --token HUGGINGFACE_TOKEN --block_size 2242 --model_max_length 1121 --logging_steps -1 --evaluation_strategy epoch --save_total_limit 1 --save_strategy epoch --fp16 --lr 3e-05 --num_train_epochs 3 --batch_size 2 --warmup_ratio 0.1 --gradient_accumulation 1 --optimizer adamw_torch --scheduler linear --weight_decay 0 --max_grad_norm 1 --seed 42 --use_int4 --use-peft --lora_r 16 --lora_alpha 32 --lora_dropout 0.05\n' returned non-zero exit status 2."
|
| 226 |
+
]
|
| 227 |
+
}
|
| 228 |
+
],
|
| 229 |
+
"source": [
|
| 230 |
+
"\n",
|
| 231 |
+
"# Set .venv and execute the autotrain script\n",
|
| 232 |
+
"# To see all parameters: autotrain llm --help\n",
|
| 233 |
+
"# !autotrain llm --train --project_name my-llm --model TinyLlama/TinyLlama-1.1B-Chat-v0.1 --data_path . --use-peft --use_int4 --learning_rate 2e-4 --train_batch_size 6 --num_train_epochs 3 --trainer sft\n",
|
| 234 |
+
"command=f\"\"\"\n",
|
| 235 |
+
"autotrain llm --train \\\n",
|
| 236 |
+
" --trainer sft \\\n",
|
| 237 |
+
" --project_name {model_params['project_name']} \\\n",
|
| 238 |
+
" --model {model_params['model_name']} \\\n",
|
| 239 |
+
" --data_path {model_params['data_directory']} \\\n",
|
| 240 |
+
" --train_split {model_params['train_data']} \\\n",
|
| 241 |
+
" --valid_split {model_params['validation_data']} \\\n",
|
| 242 |
+
" --repo_id {model_params['repo_id']} \\\n",
|
| 243 |
+
" --push_to_hub \\\n",
|
| 244 |
+
" --token HUGGINGFACE_TOKEN \\\n",
|
| 245 |
+
" --block_size {model_params['block_size']} \\\n",
|
| 246 |
+
" --model_max_length {model_params['model_max_length']} \\\n",
|
| 247 |
+
" --logging_steps {model_params['logging_steps']} \\\n",
|
| 248 |
+
" --evaluation_strategy {model_params['evaluation_strategy']} \\\n",
|
| 249 |
+
" --save_total_limit {model_params['save_total_limit']} \\\n",
|
| 250 |
+
" --save_strategy {model_params['save_strategy']} \\\n",
|
| 251 |
+
" --fp16 \\\n",
|
| 252 |
+
" --lr {model_params['lr']} \\\n",
|
| 253 |
+
" --num_train_epochs {model_params['epochs']} \\\n",
|
| 254 |
+
" --train_batch_size {model_params['batch_size']} \\\n",
|
| 255 |
+
" --warmup_ratio {model_params['warmup_ratio']} \\\n",
|
| 256 |
+
" --gradient_accumulation {model_params['gradient_accumulation']} \\\n",
|
| 257 |
+
" --optimizer {model_params['optimizer']} \\\n",
|
| 258 |
+
" --scheduler linear \\\n",
|
| 259 |
+
" --weight_decay {model_params['weight_decay']} \\\n",
|
| 260 |
+
" --max_grad_norm {model_params['max_grad_norm']} \\\n",
|
| 261 |
+
" --seed {model_params['seed']} \\\n",
|
| 262 |
+
" --use_int4 \\\n",
|
| 263 |
+
" --use-peft \\\n",
|
| 264 |
+
" --lora_r {model_params['lora_r']} \\\n",
|
| 265 |
+
" --lora_alpha {model_params['lora_alpha']} \\\n",
|
| 266 |
+
" --lora_dropout {model_params['lora_dropout']}\n",
|
| 267 |
+
"\"\"\"\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"# Use subprocess.run() to execute the command\n",
|
| 270 |
+
"subprocess.run(command, shell=True, check=True)"
|
| 271 |
+
]
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"cell_type": "code",
|
| 275 |
+
"execution_count": null,
|
| 276 |
+
"metadata": {},
|
| 277 |
+
"outputs": [],
|
| 278 |
+
"source": []
|
| 279 |
+
}
|
| 280 |
+
],
|
| 281 |
+
"metadata": {
|
| 282 |
+
"kernelspec": {
|
| 283 |
+
"display_name": ".venv",
|
| 284 |
+
"language": "python",
|
| 285 |
+
"name": "python3"
|
| 286 |
+
},
|
| 287 |
+
"language_info": {
|
| 288 |
+
"codemirror_mode": {
|
| 289 |
+
"name": "ipython",
|
| 290 |
+
"version": 3
|
| 291 |
+
},
|
| 292 |
+
"file_extension": ".py",
|
| 293 |
+
"mimetype": "text/x-python",
|
| 294 |
+
"name": "python",
|
| 295 |
+
"nbconvert_exporter": "python",
|
| 296 |
+
"pygments_lexer": "ipython3",
|
| 297 |
+
"version": "3.11.7"
|
| 298 |
+
}
|
| 299 |
+
},
|
| 300 |
+
"nbformat": 4,
|
| 301 |
+
"nbformat_minor": 2
|
| 302 |
+
}
|
train_llm.py
CHANGED
|
@@ -4,10 +4,30 @@ import pandas as pd
|
|
| 4 |
|
| 5 |
from datasets import load_dataset
|
| 6 |
import subprocess
|
|
|
|
| 7 |
|
|
|
|
| 8 |
# from dotenv import load_dotenv,find_dotenv
|
| 9 |
# load_dotenv(find_dotenv(),override=True)
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
# Load dataset
|
| 12 |
dataset_name = 'ai-aerospace/ams_data_train_generic_v0.1_100'
|
| 13 |
dataset=load_dataset(dataset_name)
|
|
@@ -21,54 +41,99 @@ os.makedirs(data_directory, exist_ok=True)
|
|
| 21 |
# Write the train data to a CSV file
|
| 22 |
train_data='train_data'
|
| 23 |
train_filename = os.path.join(data_directory, train_data)
|
| 24 |
-
dataset['train'].to_pandas().to_csv(train_filename, columns=['text'], index=False)
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Write the validation data to a CSV file
|
| 27 |
validation_data='validation_data'
|
| 28 |
validation_filename = os.path.join(data_directory, validation_data)
|
| 29 |
-
dataset['validation'].to_pandas().to_csv(validation_filename, columns=['text'], index=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# Define project parameters
|
| 32 |
username='ai-aerospace'
|
| 33 |
project_name='./llms/'+'ams_data_train-100_'+str(uuid4())
|
| 34 |
repo_name='ams-data-train-100-'+str(uuid4())
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
|
|
|
| 54 |
# Set .venv and execute the autotrain script
|
| 55 |
# To see all parameters: autotrain llm --help
|
| 56 |
# !autotrain llm --train --project_name my-llm --model TinyLlama/TinyLlama-1.1B-Chat-v0.1 --data_path . --use-peft --use_int4 --learning_rate 2e-4 --train_batch_size 6 --num_train_epochs 3 --trainer sft
|
| 57 |
-
command="""
|
| 58 |
autotrain llm --train \
|
| 59 |
-
--project_name ${project_name} \
|
| 60 |
-
--model ${model_name} \
|
| 61 |
-
--data_path ${data_directory} \
|
| 62 |
-
--train_split ${train_data} \
|
| 63 |
-
--valid_split ${validation_data} \
|
| 64 |
-
--use-peft \
|
| 65 |
-
--learning_rate 2e-4 \
|
| 66 |
-
--train_batch_size 6 \
|
| 67 |
-
--num_train_epochs 3 \
|
| 68 |
--trainer sft \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
--push_to_hub \
|
| 70 |
-
--
|
| 71 |
-
--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
"""
|
| 73 |
|
| 74 |
# Use subprocess.run() to execute the command
|
|
|
|
| 4 |
|
| 5 |
from datasets import load_dataset
|
| 6 |
import subprocess
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
|
| 9 |
+
### Read environment variables
|
| 10 |
# from dotenv import load_dotenv,find_dotenv
|
| 11 |
# load_dotenv(find_dotenv(),override=True)
|
| 12 |
|
| 13 |
+
### Functions
|
| 14 |
+
def max_token_len(dataset):
|
| 15 |
+
max_seq_length = 0
|
| 16 |
+
for row in dataset:
|
| 17 |
+
tokens = len(tokenizer(row['text'])['input_ids'])
|
| 18 |
+
if tokens > max_seq_length:
|
| 19 |
+
max_seq_length = tokens
|
| 20 |
+
return max_seq_length
|
| 21 |
+
|
| 22 |
+
### Model details
|
| 23 |
+
# model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1'
|
| 24 |
+
model_name = 'mistralai/Mistral-7B-v0.1'
|
| 25 |
+
# model_name = 'distilbert-base-uncased'
|
| 26 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 27 |
+
model_max_length = tokenizer.model_max_length
|
| 28 |
+
print("Model Max Length:", model_max_length)
|
| 29 |
+
|
| 30 |
+
### Repo name, dataset initialization, and data directory
|
| 31 |
# Load dataset
|
| 32 |
dataset_name = 'ai-aerospace/ams_data_train_generic_v0.1_100'
|
| 33 |
dataset=load_dataset(dataset_name)
|
|
|
|
| 41 |
# Write the train data to a CSV file
|
| 42 |
train_data='train_data'
|
| 43 |
train_filename = os.path.join(data_directory, train_data)
|
| 44 |
+
dataset['train'].to_pandas().to_csv(train_filename+'.csv', columns=['text'], index=False)
|
| 45 |
+
max_token_length_train=max_token_len(dataset['train'])
|
| 46 |
+
print('Max token length train: '+str(max_token_length_train))
|
| 47 |
|
| 48 |
# Write the validation data to a CSV file
|
| 49 |
validation_data='validation_data'
|
| 50 |
validation_filename = os.path.join(data_directory, validation_data)
|
| 51 |
+
dataset['validation'].to_pandas().to_csv(validation_filename+'.csv', columns=['text'], index=False)
|
| 52 |
+
max_token_length_validation=max_token_len(dataset['validation'])
|
| 53 |
+
print('Max token length validation: '+str(max_token_length_validation))
|
| 54 |
+
|
| 55 |
+
max_token_length=max(max_token_length_train,max_token_length_validation)
|
| 56 |
+
if max_token_length > model_max_length:
|
| 57 |
+
raise ValueError("Maximum token length exceeds model limits.")
|
| 58 |
+
block_size=2*max_token_length
|
| 59 |
|
| 60 |
# Define project parameters
|
| 61 |
username='ai-aerospace'
|
| 62 |
project_name='./llms/'+'ams_data_train-100_'+str(uuid4())
|
| 63 |
repo_name='ams-data-train-100-'+str(uuid4())
|
| 64 |
|
| 65 |
+
### Set training params
|
| 66 |
+
model_params={
|
| 67 |
+
"project_name": project_name,
|
| 68 |
+
"model_name": model_name,
|
| 69 |
+
"repo_id": username+'/'+repo_name,
|
| 70 |
+
"train_data": train_data,
|
| 71 |
+
"validation_data": validation_data,
|
| 72 |
+
"data_directory": data_directory,
|
| 73 |
+
"block_size": block_size,
|
| 74 |
+
"model_max_length": max_token_length,
|
| 75 |
+
"logging_steps": -1,
|
| 76 |
+
"evaluation_strategy": "epoch",
|
| 77 |
+
"save_total_limit": 1,
|
| 78 |
+
"save_strategy": "epoch",
|
| 79 |
+
"mixed_precision": "fp16",
|
| 80 |
+
"lr": 0.00003,
|
| 81 |
+
"epochs": 3,
|
| 82 |
+
"batch_size": 2,
|
| 83 |
+
"warmup_ratio": 0.1,
|
| 84 |
+
"gradient_accumulation": 1,
|
| 85 |
+
"optimizer": "adamw_torch",
|
| 86 |
+
"scheduler": "linear",
|
| 87 |
+
"weight_decay": 0,
|
| 88 |
+
"max_grad_norm": 1,
|
| 89 |
+
"seed": 42,
|
| 90 |
+
"quantization": "int4",
|
| 91 |
+
"target_modules": "",
|
| 92 |
+
"lora_r": 16,
|
| 93 |
+
"lora_alpha": 32,
|
| 94 |
+
"lora_dropout": 0.05
|
| 95 |
+
}
|
| 96 |
+
for key, value in model_params.items():
|
| 97 |
+
os.environ[key] = str(value)
|
| 98 |
|
| 99 |
+
### Feed into and run autotrain command
|
| 100 |
# Set .venv and execute the autotrain script
|
| 101 |
# To see all parameters: autotrain llm --help
|
| 102 |
# !autotrain llm --train --project_name my-llm --model TinyLlama/TinyLlama-1.1B-Chat-v0.1 --data_path . --use-peft --use_int4 --learning_rate 2e-4 --train_batch_size 6 --num_train_epochs 3 --trainer sft
|
| 103 |
+
command=f"""
|
| 104 |
autotrain llm --train \
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
--trainer sft \
|
| 106 |
+
--project_name {model_params['project_name']} \
|
| 107 |
+
--model {model_params['model_name']} \
|
| 108 |
+
--data_path {model_params['data_directory']} \
|
| 109 |
+
--train_split {model_params['train_data']} \
|
| 110 |
+
--valid_split {model_params['validation_data']} \
|
| 111 |
+
--repo_id {model_params['repo_id']} \
|
| 112 |
--push_to_hub \
|
| 113 |
+
--token HUGGINGFACE_TOKEN
|
| 114 |
+
--block_size {model_params['block_size']} \
|
| 115 |
+
--model_max_length {model_params['model_max_length']} \
|
| 116 |
+
--logging_steps {model_params['logging_steps']} \
|
| 117 |
+
--evaluation_strategy {model_params['evaluation_strategy']} \
|
| 118 |
+
--save_total_limit {model_params['save_total_limit']} \
|
| 119 |
+
--save_strategy {model_params['save_strategy']} \
|
| 120 |
+
--fp16 \
|
| 121 |
+
--lr {model_params['lr']} \
|
| 122 |
+
--num_train_epochs {model_params['lr']} \
|
| 123 |
+
--batch_size {model_params['batch_size']} \
|
| 124 |
+
--warmup_ratio {model_params['warmup_ratio']} \
|
| 125 |
+
--gradient_accumulation {model_params['gradient_accumulation']} \
|
| 126 |
+
--optimizer {model_params['gradient_accumulation']} \
|
| 127 |
+
--scheduler linear \
|
| 128 |
+
--weight_decay {model_params['weight_decay']} \
|
| 129 |
+
--max_grad_norm {model_params['max_grad_norm']} \
|
| 130 |
+
--seed {model_params['seed']} \
|
| 131 |
+
--use_int4 \
|
| 132 |
+
--target_modules {model_params['target_modules']} \
|
| 133 |
+
--use-peft \
|
| 134 |
+
--lora_r {model_params['lora_r']} \
|
| 135 |
+
--lora_alpha {model_params['lora_alpha']} \
|
| 136 |
+
--lora_dropout {model_params['lora_dropout']}
|
| 137 |
"""
|
| 138 |
|
| 139 |
# Use subprocess.run() to execute the command
|