Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
·
38fb491
1
Parent(s):
00263ef
support resume_from_checkpoint
Browse files- llama_lora/lib/finetune.py +12 -3
- llama_lora/ui/finetune_ui.py +23 -2
llama_lora/lib/finetune.py
CHANGED
|
@@ -33,7 +33,7 @@ def train(
|
|
| 33 |
num_train_epochs: int = 3,
|
| 34 |
learning_rate: float = 3e-4,
|
| 35 |
cutoff_len: int = 256,
|
| 36 |
-
val_set_size: int = 2000,
|
| 37 |
# lora hyperparams
|
| 38 |
lora_r: int = 8,
|
| 39 |
lora_alpha: int = 16,
|
|
@@ -46,7 +46,7 @@ def train(
|
|
| 46 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
| 47 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
| 48 |
# either training checkpoint or final adapter
|
| 49 |
-
resume_from_checkpoint
|
| 50 |
save_steps: int = 200,
|
| 51 |
save_total_limit: int = 3,
|
| 52 |
logging_steps: int = 10,
|
|
@@ -68,6 +68,7 @@ def train(
|
|
| 68 |
'num_train_epochs': num_train_epochs,
|
| 69 |
'learning_rate': learning_rate,
|
| 70 |
'cutoff_len': cutoff_len,
|
|
|
|
| 71 |
'lora_r': lora_r,
|
| 72 |
'lora_alpha': lora_alpha,
|
| 73 |
'lora_dropout': lora_dropout,
|
|
@@ -78,7 +79,12 @@ def train(
|
|
| 78 |
'save_total_limit': save_total_limit,
|
| 79 |
'logging_steps': logging_steps,
|
| 80 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
|
|
|
| 82 |
if wandb_api_key:
|
| 83 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
| 84 |
|
|
@@ -220,7 +226,7 @@ def train(
|
|
| 220 |
adapters_weights = torch.load(checkpoint_name)
|
| 221 |
model = set_peft_model_state_dict(model, adapters_weights)
|
| 222 |
else:
|
| 223 |
-
|
| 224 |
|
| 225 |
# Be more transparent about the % of trainable params.
|
| 226 |
model.print_trainable_parameters()
|
|
@@ -315,4 +321,7 @@ def train(
|
|
| 315 |
with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
|
| 316 |
json.dump(train_output, train_output_json_file, indent=2)
|
| 317 |
|
|
|
|
|
|
|
|
|
|
| 318 |
return train_output
|
|
|
|
| 33 |
num_train_epochs: int = 3,
|
| 34 |
learning_rate: float = 3e-4,
|
| 35 |
cutoff_len: int = 256,
|
| 36 |
+
val_set_size: int = 2000,
|
| 37 |
# lora hyperparams
|
| 38 |
lora_r: int = 8,
|
| 39 |
lora_alpha: int = 16,
|
|
|
|
| 46 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
| 47 |
group_by_length: bool = False, # faster, but produces an odd training loss curve
|
| 48 |
# either training checkpoint or final adapter
|
| 49 |
+
resume_from_checkpoint = None,
|
| 50 |
save_steps: int = 200,
|
| 51 |
save_total_limit: int = 3,
|
| 52 |
logging_steps: int = 10,
|
|
|
|
| 68 |
'num_train_epochs': num_train_epochs,
|
| 69 |
'learning_rate': learning_rate,
|
| 70 |
'cutoff_len': cutoff_len,
|
| 71 |
+
'val_set_size': val_set_size,
|
| 72 |
'lora_r': lora_r,
|
| 73 |
'lora_alpha': lora_alpha,
|
| 74 |
'lora_dropout': lora_dropout,
|
|
|
|
| 79 |
'save_total_limit': save_total_limit,
|
| 80 |
'logging_steps': logging_steps,
|
| 81 |
}
|
| 82 |
+
if val_set_size and val_set_size > 0:
|
| 83 |
+
finetune_args['val_set_size'] = val_set_size
|
| 84 |
+
if resume_from_checkpoint:
|
| 85 |
+
finetune_args['resume_from_checkpoint'] = resume_from_checkpoint
|
| 86 |
|
| 87 |
+
wandb = None
|
| 88 |
if wandb_api_key:
|
| 89 |
os.environ["WANDB_API_KEY"] = wandb_api_key
|
| 90 |
|
|
|
|
| 226 |
adapters_weights = torch.load(checkpoint_name)
|
| 227 |
model = set_peft_model_state_dict(model, adapters_weights)
|
| 228 |
else:
|
| 229 |
+
raise ValueError(f"Checkpoint {checkpoint_name} not found")
|
| 230 |
|
| 231 |
# Be more transparent about the % of trainable params.
|
| 232 |
model.print_trainable_parameters()
|
|
|
|
| 321 |
with open(os.path.join(output_dir, "train_output.json"), 'w') as train_output_json_file:
|
| 322 |
json.dump(train_output, train_output_json_file, indent=2)
|
| 323 |
|
| 324 |
+
if use_wandb and wandb:
|
| 325 |
+
wandb.finish()
|
| 326 |
+
|
| 327 |
return train_output
|
llama_lora/ui/finetune_ui.py
CHANGED
|
@@ -306,6 +306,17 @@ def do_train(
|
|
| 306 |
):
|
| 307 |
try:
|
| 308 |
base_model_name = Global.base_model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
| 310 |
if os.path.exists(output_dir):
|
| 311 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
|
@@ -376,6 +387,8 @@ Train options: {json.dumps({
|
|
| 376 |
'lora_dropout': lora_dropout,
|
| 377 |
'lora_target_modules': lora_target_modules,
|
| 378 |
'model_name': model_name,
|
|
|
|
|
|
|
| 379 |
}, indent=2)}
|
| 380 |
|
| 381 |
Train data (first 10):
|
|
@@ -386,7 +399,7 @@ Train data (first 10):
|
|
| 386 |
return message
|
| 387 |
|
| 388 |
if not should_training_progress_track_tqdm:
|
| 389 |
-
progress(0, desc="Preparing model for training...")
|
| 390 |
|
| 391 |
log_history = []
|
| 392 |
|
|
@@ -461,6 +474,10 @@ Train data (first 10):
|
|
| 461 |
# 'lora_dropout': lora_dropout,
|
| 462 |
# 'lora_target_modules': lora_target_modules,
|
| 463 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 464 |
json.dump(info, info_json_file, indent=2)
|
| 465 |
|
| 466 |
if not should_training_progress_track_tqdm:
|
|
@@ -490,7 +507,7 @@ Train data (first 10):
|
|
| 490 |
lora_target_modules, # lora_target_modules
|
| 491 |
train_on_inputs, # train_on_inputs
|
| 492 |
False, # group_by_length
|
| 493 |
-
|
| 494 |
save_steps, # save_steps
|
| 495 |
save_total_limit, # save_total_limit
|
| 496 |
logging_steps, # logging_steps
|
|
@@ -582,6 +599,8 @@ def handle_load_params_from_model(
|
|
| 582 |
cutoff_len = value
|
| 583 |
elif key == "evaluate_data_count":
|
| 584 |
evaluate_data_count = value
|
|
|
|
|
|
|
| 585 |
elif key == "micro_batch_size":
|
| 586 |
micro_batch_size = value
|
| 587 |
elif key == "gradient_accumulation_steps":
|
|
@@ -610,6 +629,8 @@ def handle_load_params_from_model(
|
|
| 610 |
logging_steps = value
|
| 611 |
elif key == "group_by_length":
|
| 612 |
pass
|
|
|
|
|
|
|
| 613 |
else:
|
| 614 |
unknown_keys.append(key)
|
| 615 |
except Exception as e:
|
|
|
|
| 306 |
):
|
| 307 |
try:
|
| 308 |
base_model_name = Global.base_model_name
|
| 309 |
+
|
| 310 |
+
resume_from_checkpoint = None
|
| 311 |
+
if continue_from_model == "-" or continue_from_model == "None":
|
| 312 |
+
continue_from_model = None
|
| 313 |
+
if continue_from_checkpoint == "-" or continue_from_checkpoint == "None":
|
| 314 |
+
continue_from_checkpoint = None
|
| 315 |
+
if continue_from_model:
|
| 316 |
+
resume_from_checkpoint = os.path.join(Global.data_dir, "lora_models", continue_from_model)
|
| 317 |
+
if continue_from_checkpoint:
|
| 318 |
+
resume_from_checkpoint = os.path.join(resume_from_checkpoint, continue_from_checkpoint)
|
| 319 |
+
|
| 320 |
output_dir = os.path.join(Global.data_dir, "lora_models", model_name)
|
| 321 |
if os.path.exists(output_dir):
|
| 322 |
if (not os.path.isdir(output_dir)) or os.path.exists(os.path.join(output_dir, 'adapter_config.json')):
|
|
|
|
| 387 |
'lora_dropout': lora_dropout,
|
| 388 |
'lora_target_modules': lora_target_modules,
|
| 389 |
'model_name': model_name,
|
| 390 |
+
'continue_from_model': continue_from_model,
|
| 391 |
+
'continue_from_checkpoint': continue_from_checkpoint,
|
| 392 |
}, indent=2)}
|
| 393 |
|
| 394 |
Train data (first 10):
|
|
|
|
| 399 |
return message
|
| 400 |
|
| 401 |
if not should_training_progress_track_tqdm:
|
| 402 |
+
progress(0, desc=f"Preparing model {base_model_name} for training...")
|
| 403 |
|
| 404 |
log_history = []
|
| 405 |
|
|
|
|
| 474 |
# 'lora_dropout': lora_dropout,
|
| 475 |
# 'lora_target_modules': lora_target_modules,
|
| 476 |
}
|
| 477 |
+
if continue_from_model:
|
| 478 |
+
info['continued_from_model'] = continue_from_model
|
| 479 |
+
if continue_from_checkpoint:
|
| 480 |
+
info['continued_from_checkpoint'] = continue_from_checkpoint
|
| 481 |
json.dump(info, info_json_file, indent=2)
|
| 482 |
|
| 483 |
if not should_training_progress_track_tqdm:
|
|
|
|
| 507 |
lora_target_modules, # lora_target_modules
|
| 508 |
train_on_inputs, # train_on_inputs
|
| 509 |
False, # group_by_length
|
| 510 |
+
resume_from_checkpoint, # resume_from_checkpoint
|
| 511 |
save_steps, # save_steps
|
| 512 |
save_total_limit, # save_total_limit
|
| 513 |
logging_steps, # logging_steps
|
|
|
|
| 599 |
cutoff_len = value
|
| 600 |
elif key == "evaluate_data_count":
|
| 601 |
evaluate_data_count = value
|
| 602 |
+
elif key == "val_set_size":
|
| 603 |
+
evaluate_data_count = value
|
| 604 |
elif key == "micro_batch_size":
|
| 605 |
micro_batch_size = value
|
| 606 |
elif key == "gradient_accumulation_steps":
|
|
|
|
| 629 |
logging_steps = value
|
| 630 |
elif key == "group_by_length":
|
| 631 |
pass
|
| 632 |
+
elif key == "resume_from_checkpoint":
|
| 633 |
+
pass
|
| 634 |
else:
|
| 635 |
unknown_keys.append(key)
|
| 636 |
except Exception as e:
|