Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -3,12 +3,12 @@ import uuid
|
|
3 |
import os
|
4 |
from huggingface_hub import snapshot_download
|
5 |
from flux_train import build_job
|
6 |
-
|
7 |
import sys
|
8 |
-
sys.path.append("/app/ai-toolkit") # Tell Python to look here
|
9 |
|
10 |
-
|
|
|
11 |
|
|
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
@@ -20,10 +20,11 @@ HF_TOKEN = os.environ.get("HF_TOKEN", "")
|
|
20 |
|
21 |
status = {"running": False, "last_job": None, "error": None}
|
22 |
|
23 |
-
def run_lora_training():
|
24 |
try:
|
25 |
status.update({"running": True, "error": None})
|
26 |
local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
|
|
|
27 |
snapshot_download(
|
28 |
repo_id=REPO_ID,
|
29 |
repo_type="dataset",
|
@@ -31,9 +32,11 @@ def run_lora_training():
|
|
31 |
local_dir=local_dir,
|
32 |
local_dir_use_symlinks=False
|
33 |
)
|
|
|
34 |
training_path = os.path.join(local_dir, FOLDER_IN_REPO)
|
35 |
-
job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME)
|
36 |
run_job(job)
|
|
|
37 |
status.update({"running": False, "last_job": job})
|
38 |
except Exception as e:
|
39 |
status.update({"running": False, "error": str(e)})
|
@@ -46,9 +49,14 @@ def root():
|
|
46 |
def get_status():
|
47 |
return status
|
48 |
|
|
|
|
|
|
|
|
|
|
|
49 |
@app.post("/train")
|
50 |
-
def start_training(background_tasks: BackgroundTasks):
|
51 |
if status["running"]:
|
52 |
return {"message": "A training job is already running."}
|
53 |
-
background_tasks.add_task(run_lora_training)
|
54 |
-
return {"message": "Training started in background."}
|
|
|
3 |
import os
|
4 |
from huggingface_hub import snapshot_download
|
5 |
from flux_train import build_job
|
|
|
6 |
import sys
|
|
|
7 |
|
8 |
+
# Add ai-toolkit to sys.path for toolkit imports
|
9 |
+
sys.path.append("/app/ai-toolkit")
|
10 |
|
11 |
+
from toolkit.job import run_job
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
|
|
20 |
|
21 |
status = {"running": False, "last_job": None, "error": None}
|
22 |
|
23 |
+
def run_lora_training(push_to_hub: bool = False):
|
24 |
try:
|
25 |
status.update({"running": True, "error": None})
|
26 |
local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
|
27 |
+
|
28 |
snapshot_download(
|
29 |
repo_id=REPO_ID,
|
30 |
repo_type="dataset",
|
|
|
32 |
local_dir=local_dir,
|
33 |
local_dir_use_symlinks=False
|
34 |
)
|
35 |
+
|
36 |
training_path = os.path.join(local_dir, FOLDER_IN_REPO)
|
37 |
+
job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME, push_to_hub=push_to_hub)
|
38 |
run_job(job)
|
39 |
+
|
40 |
status.update({"running": False, "last_job": job})
|
41 |
except Exception as e:
|
42 |
status.update({"running": False, "error": str(e)})
|
|
|
49 |
def get_status():
|
50 |
return status
|
51 |
|
52 |
+
from pydantic import BaseModel
|
53 |
+
|
54 |
+
class TrainRequest(BaseModel):
|
55 |
+
push_to_hub: bool = False
|
56 |
+
|
57 |
@app.post("/train")
|
58 |
+
def start_training(background_tasks: BackgroundTasks, request: TrainRequest):
|
59 |
if status["running"]:
|
60 |
return {"message": "A training job is already running."}
|
61 |
+
background_tasks.add_task(run_lora_training, push_to_hub=request.push_to_hub)
|
62 |
+
return {"message": "Training started in background.", "push_to_hub": request.push_to_hub}
|