Spaces:
Runtime error
Runtime error
zetavg
commited on
let finetune ui support showing training progress
Browse files- llama_lora/globals.py +3 -0
- llama_lora/ui/finetune_ui.py +36 -1
llama_lora/globals.py
CHANGED
|
@@ -17,6 +17,9 @@ class Global:
|
|
| 17 |
# Functions
|
| 18 |
train_fn: Any = None
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
# UI related
|
| 21 |
ui_title: str = "LLaMA-LoRA"
|
| 22 |
ui_emoji: str = "π¦ποΈ"
|
|
|
|
| 17 |
# Functions
|
| 18 |
train_fn: Any = None
|
| 19 |
|
| 20 |
+
# Training Control
|
| 21 |
+
should_stop_training = False
|
| 22 |
+
|
| 23 |
# UI related
|
| 24 |
ui_title: str = "LLaMA-LoRA"
|
| 25 |
ui_emoji: str = "π¦ποΈ"
|
llama_lora/ui/finetune_ui.py
CHANGED
|
@@ -5,6 +5,8 @@ from datetime import datetime
|
|
| 5 |
import gradio as gr
|
| 6 |
from random_word import RandomWords
|
| 7 |
|
|
|
|
|
|
|
| 8 |
from ..globals import Global
|
| 9 |
from ..models import get_base_model, get_tokenizer
|
| 10 |
from ..utils.data import (
|
|
@@ -331,6 +333,31 @@ Train data (first 10):
|
|
| 331 |
time.sleep(2)
|
| 332 |
return message
|
| 333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
return Global.train_fn(
|
| 335 |
get_base_model(), # base_model
|
| 336 |
get_tokenizer(), # tokenizer
|
|
@@ -351,11 +378,16 @@ Train data (first 10):
|
|
| 351 |
True, # train_on_inputs
|
| 352 |
False, # group_by_length
|
| 353 |
None, # resume_from_checkpoint
|
|
|
|
| 354 |
)
|
| 355 |
except Exception as e:
|
| 356 |
raise gr.Error(e)
|
| 357 |
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
def finetune_ui():
|
| 360 |
with gr.Blocks() as finetune_ui_blocks:
|
| 361 |
with gr.Column(elem_id="finetune_ui_content"):
|
|
@@ -580,7 +612,10 @@ def finetune_ui():
|
|
| 580 |
|
| 581 |
# controlled by JS, shows the confirm_abort_button
|
| 582 |
abort_button.click(None, None, None, None)
|
| 583 |
-
confirm_abort_button.click(
|
|
|
|
|
|
|
|
|
|
| 584 |
|
| 585 |
finetune_ui_blocks.load(_js="""
|
| 586 |
function finetune_ui_blocks_js() {
|
|
|
|
| 5 |
import gradio as gr
|
| 6 |
from random_word import RandomWords
|
| 7 |
|
| 8 |
+
from transformers import TrainerCallback
|
| 9 |
+
|
| 10 |
from ..globals import Global
|
| 11 |
from ..models import get_base_model, get_tokenizer
|
| 12 |
from ..utils.data import (
|
|
|
|
| 333 |
time.sleep(2)
|
| 334 |
return message
|
| 335 |
|
| 336 |
+
class UiTrainerCallback(TrainerCallback):
|
| 337 |
+
def on_epoch_begin(self, args, state, control, **kwargs):
|
| 338 |
+
if Global.should_stop_training:
|
| 339 |
+
control.should_training_stop = True
|
| 340 |
+
total_steps = (
|
| 341 |
+
state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
|
| 342 |
+
progress(
|
| 343 |
+
(state.global_step, total_steps),
|
| 344 |
+
desc=f"Training... (Epoch {state.epoch}/{epochs}, Step {state.global_step}/{total_steps})"
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def on_step_end(self, args, state, control, **kwargs):
|
| 348 |
+
if Global.should_stop_training:
|
| 349 |
+
control.should_training_stop = True
|
| 350 |
+
total_steps = (
|
| 351 |
+
state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
|
| 352 |
+
progress(
|
| 353 |
+
(state.global_step, total_steps),
|
| 354 |
+
desc=f"Training... (Epoch {state.epoch}/{epochs}, Step {state.global_step}/{total_steps})"
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
training_callbacks = [UiTrainerCallback]
|
| 358 |
+
|
| 359 |
+
Global.should_stop_training = False
|
| 360 |
+
|
| 361 |
return Global.train_fn(
|
| 362 |
get_base_model(), # base_model
|
| 363 |
get_tokenizer(), # tokenizer
|
|
|
|
| 378 |
True, # train_on_inputs
|
| 379 |
False, # group_by_length
|
| 380 |
None, # resume_from_checkpoint
|
| 381 |
+
training_callbacks # callbacks
|
| 382 |
)
|
| 383 |
except Exception as e:
|
| 384 |
raise gr.Error(e)
|
| 385 |
|
| 386 |
|
| 387 |
+
def do_abort_training():
|
| 388 |
+
Global.should_stop_training = True
|
| 389 |
+
|
| 390 |
+
|
| 391 |
def finetune_ui():
|
| 392 |
with gr.Blocks() as finetune_ui_blocks:
|
| 393 |
with gr.Column(elem_id="finetune_ui_content"):
|
|
|
|
| 612 |
|
| 613 |
# controlled by JS, shows the confirm_abort_button
|
| 614 |
abort_button.click(None, None, None, None)
|
| 615 |
+
confirm_abort_button.click(
|
| 616 |
+
fn=do_abort_training,
|
| 617 |
+
inputs=None, outputs=None,
|
| 618 |
+
cancels=[train_progress])
|
| 619 |
|
| 620 |
finetune_ui_blocks.load(_js="""
|
| 621 |
function finetune_ui_blocks_js() {
|