rahul7star commited on
Commit
f29d0cf
·
verified ·
1 Parent(s): 4235d82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -3,12 +3,12 @@ import uuid
3
  import os
4
  from huggingface_hub import snapshot_download
5
  from flux_train import build_job
6
-
7
  import sys
8
- sys.path.append("/app/ai-toolkit") # Tell Python to look here
9
 
10
- from toolkit.job import run_job
 
11
 
 
12
 
13
  app = FastAPI()
14
 
@@ -20,10 +20,11 @@ HF_TOKEN = os.environ.get("HF_TOKEN", "")
20
 
21
  status = {"running": False, "last_job": None, "error": None}
22
 
23
- def run_lora_training():
24
  try:
25
  status.update({"running": True, "error": None})
26
  local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
 
27
  snapshot_download(
28
  repo_id=REPO_ID,
29
  repo_type="dataset",
@@ -31,9 +32,11 @@ def run_lora_training():
31
  local_dir=local_dir,
32
  local_dir_use_symlinks=False
33
  )
 
34
  training_path = os.path.join(local_dir, FOLDER_IN_REPO)
35
- job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME)
36
  run_job(job)
 
37
  status.update({"running": False, "last_job": job})
38
  except Exception as e:
39
  status.update({"running": False, "error": str(e)})
@@ -46,9 +49,14 @@ def root():
46
  def get_status():
47
  return status
48
 
 
 
 
 
 
49
  @app.post("/train")
50
- def start_training(background_tasks: BackgroundTasks):
51
  if status["running"]:
52
  return {"message": "A training job is already running."}
53
- background_tasks.add_task(run_lora_training)
54
- return {"message": "Training started in background."}
 
3
  import os
4
  from huggingface_hub import snapshot_download
5
  from flux_train import build_job
 
6
  import sys
 
7
 
8
+ # Add ai-toolkit to sys.path for toolkit imports
9
+ sys.path.append("/app/ai-toolkit")
10
 
11
+ from toolkit.job import run_job
12
 
13
  app = FastAPI()
14
 
 
20
 
21
  status = {"running": False, "last_job": None, "error": None}
22
 
23
+ def run_lora_training(push_to_hub: bool = False):
24
  try:
25
  status.update({"running": True, "error": None})
26
  local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
27
+
28
  snapshot_download(
29
  repo_id=REPO_ID,
30
  repo_type="dataset",
 
32
  local_dir=local_dir,
33
  local_dir_use_symlinks=False
34
  )
35
+
36
  training_path = os.path.join(local_dir, FOLDER_IN_REPO)
37
+ job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME, push_to_hub=push_to_hub)
38
  run_job(job)
39
+
40
  status.update({"running": False, "last_job": job})
41
  except Exception as e:
42
  status.update({"running": False, "error": str(e)})
 
49
  def get_status():
50
  return status
51
 
52
+ from pydantic import BaseModel
53
+
54
+ class TrainRequest(BaseModel):
55
+ push_to_hub: bool = False
56
+
57
  @app.post("/train")
58
+ def start_training(background_tasks: BackgroundTasks, request: TrainRequest):
59
  if status["running"]:
60
  return {"message": "A training job is already running."}
61
+ background_tasks.add_task(run_lora_training, push_to_hub=request.push_to_hub)
62
+ return {"message": "Training started in background.", "push_to_hub": request.push_to_hub}