Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
·
68255ee
1
Parent(s):
0537112
wandb fix
Browse files- llama_lora/lib/finetune.py +41 -22
- llama_lora/ui/finetune_ui.py +9 -1
llama_lora/lib/finetune.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
| 3 |
from typing import Any, List
|
| 4 |
|
| 5 |
import json
|
|
@@ -54,16 +55,38 @@ def train(
|
|
| 54 |
# wandb params
|
| 55 |
wandb_api_key = None,
|
| 56 |
wandb_project: str = "",
|
|
|
|
| 57 |
wandb_run_name: str = "",
|
|
|
|
| 58 |
wandb_watch: str = "false", # options: false | gradients | all
|
| 59 |
wandb_log_model: str = "true", # options: false | true
|
| 60 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
if wandb_api_key:
|
| 62 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
if
|
| 66 |
-
|
|
|
|
|
|
|
| 67 |
if wandb_watch:
|
| 68 |
os.environ["WANDB_WATCH"] = wandb_watch
|
| 69 |
if wandb_log_model:
|
|
@@ -73,6 +96,18 @@ def train(
|
|
| 73 |
)
|
| 74 |
if use_wandb:
|
| 75 |
os.environ['WANDB_MODE'] = "online"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
else:
|
| 77 |
os.environ['WANDB_MODE'] = "disabled"
|
| 78 |
|
|
@@ -243,24 +278,8 @@ def train(
|
|
| 243 |
os.makedirs(output_dir)
|
| 244 |
with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
|
| 245 |
json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
|
| 246 |
-
with open(os.path.join(output_dir, "
|
| 247 |
-
|
| 248 |
-
'micro_batch_size': micro_batch_size,
|
| 249 |
-
'gradient_accumulation_steps': gradient_accumulation_steps,
|
| 250 |
-
'num_train_epochs': num_train_epochs,
|
| 251 |
-
'learning_rate': learning_rate,
|
| 252 |
-
'cutoff_len': cutoff_len,
|
| 253 |
-
'lora_r': lora_r,
|
| 254 |
-
'lora_alpha': lora_alpha,
|
| 255 |
-
'lora_dropout': lora_dropout,
|
| 256 |
-
'lora_target_modules': lora_target_modules,
|
| 257 |
-
'train_on_inputs': train_on_inputs,
|
| 258 |
-
'group_by_length': group_by_length,
|
| 259 |
-
'save_steps': save_steps,
|
| 260 |
-
'save_total_limit': save_total_limit,
|
| 261 |
-
'logging_steps': logging_steps,
|
| 262 |
-
}
|
| 263 |
-
json.dump(finetune_params, finetune_params_json_file, indent=2)
|
| 264 |
|
| 265 |
# Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
|
| 266 |
# if train_data:
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
+
import importlib
|
| 4 |
from typing import Any, List
|
| 5 |
|
| 6 |
import json
|
|
|
|
| 55 |
# wandb params
|
| 56 |
wandb_api_key = None,
|
| 57 |
wandb_project: str = "",
|
| 58 |
+
wandb_group = None,
|
| 59 |
wandb_run_name: str = "",
|
| 60 |
+
wandb_tags: List[str] = [],
|
| 61 |
wandb_watch: str = "false", # options: false | gradients | all
|
| 62 |
wandb_log_model: str = "true", # options: false | true
|
| 63 |
):
|
| 64 |
+
# for logging
|
| 65 |
+
finetune_args = {
|
| 66 |
+
'micro_batch_size': micro_batch_size,
|
| 67 |
+
'gradient_accumulation_steps': gradient_accumulation_steps,
|
| 68 |
+
'num_train_epochs': num_train_epochs,
|
| 69 |
+
'learning_rate': learning_rate,
|
| 70 |
+
'cutoff_len': cutoff_len,
|
| 71 |
+
'lora_r': lora_r,
|
| 72 |
+
'lora_alpha': lora_alpha,
|
| 73 |
+
'lora_dropout': lora_dropout,
|
| 74 |
+
'lora_target_modules': lora_target_modules,
|
| 75 |
+
'train_on_inputs': train_on_inputs,
|
| 76 |
+
'group_by_length': group_by_length,
|
| 77 |
+
'save_steps': save_steps,
|
| 78 |
+
'save_total_limit': save_total_limit,
|
| 79 |
+
'logging_steps': logging_steps,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
if wandb_api_key:
|
| 83 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
| 84 |
+
|
| 85 |
+
# wandb: WARNING Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to https://wandb.me/wandb-init.
|
| 86 |
+
# if wandb_project:
|
| 87 |
+
# os.environ["WANDB_PROJECT"] = wandb_project
|
| 88 |
+
# if wandb_run_name:
|
| 89 |
+
# os.environ["WANDB_RUN_NAME"] = wandb_run_name
|
| 90 |
if wandb_watch:
|
| 91 |
os.environ["WANDB_WATCH"] = wandb_watch
|
| 92 |
if wandb_log_model:
|
|
|
|
| 96 |
)
|
| 97 |
if use_wandb:
|
| 98 |
os.environ['WANDB_MODE'] = "online"
|
| 99 |
+
wandb = importlib.import_module("wandb")
|
| 100 |
+
wandb.init(
|
| 101 |
+
project=wandb_project,
|
| 102 |
+
resume="auto",
|
| 103 |
+
group=wandb_group,
|
| 104 |
+
name=wandb_run_name,
|
| 105 |
+
tags=wandb_tags,
|
| 106 |
+
reinit=True,
|
| 107 |
+
magic=True,
|
| 108 |
+
config={'finetune_args': finetune_args},
|
| 109 |
+
# id=None # used for resuming
|
| 110 |
+
)
|
| 111 |
else:
|
| 112 |
os.environ['WANDB_MODE'] = "disabled"
|
| 113 |
|
|
|
|
| 278 |
os.makedirs(output_dir)
|
| 279 |
with open(os.path.join(output_dir, "trainer_args.json"), 'w') as trainer_args_json_file:
|
| 280 |
json.dump(trainer.args.to_dict(), trainer_args_json_file, indent=2)
|
| 281 |
+
with open(os.path.join(output_dir, "finetune_args.json"), 'w') as finetune_args_json_file:
|
| 282 |
+
json.dump(finetune_args, finetune_args_json_file, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
# Not working, will only give us ["prompt", "completion", "input_ids", "attention_mask", "labels"]
|
| 285 |
# if train_data:
|
llama_lora/ui/finetune_ui.py
CHANGED
|
@@ -415,6 +415,12 @@ Train data (first 10):
|
|
| 415 |
if not should_training_progress_track_tqdm:
|
| 416 |
progress(0, desc="Train starting...")
|
| 417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
train_output = Global.train_fn(
|
| 419 |
base_model, # base_model
|
| 420 |
tokenizer, # tokenizer
|
|
@@ -440,7 +446,9 @@ Train data (first 10):
|
|
| 440 |
training_callbacks, # callbacks
|
| 441 |
Global.wandb_api_key, # wandb_api_key
|
| 442 |
Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
|
| 443 |
-
|
|
|
|
|
|
|
| 444 |
)
|
| 445 |
|
| 446 |
logs_str = "\n".join([json.dumps(log)
|
|
|
|
| 415 |
if not should_training_progress_track_tqdm:
|
| 416 |
progress(0, desc="Train starting...")
|
| 417 |
|
| 418 |
+
wandb_group = template
|
| 419 |
+
wandb_tags = [f"template:{template}"]
|
| 420 |
+
if load_dataset_from == "Data Dir" and dataset_from_data_dir:
|
| 421 |
+
wandb_group += f"/{dataset_from_data_dir}"
|
| 422 |
+
wandb_tags.append(f"dataset:{dataset_from_data_dir}")
|
| 423 |
+
|
| 424 |
train_output = Global.train_fn(
|
| 425 |
base_model, # base_model
|
| 426 |
tokenizer, # tokenizer
|
|
|
|
| 446 |
training_callbacks, # callbacks
|
| 447 |
Global.wandb_api_key, # wandb_api_key
|
| 448 |
Global.default_wandb_project if Global.enable_wandb else None, # wandb_project
|
| 449 |
+
wandb_group, # wandb_group
|
| 450 |
+
model_name, # wandb_run_name
|
| 451 |
+
wandb_tags # wandb_tags
|
| 452 |
)
|
| 453 |
|
| 454 |
logs_str = "\n".join([json.dumps(log)
|