Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -9,6 +9,16 @@ import sys
|
|
9 |
sys.path.append("/app/ai-toolkit")
|
10 |
|
11 |
from toolkit.job import run_job
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
@@ -23,8 +33,10 @@ status = {"running": False, "last_job": None, "error": None}
|
|
23 |
def run_lora_training(push_to_hub: bool = False):
|
24 |
try:
|
25 |
status.update({"running": True, "error": None})
|
|
|
|
|
26 |
local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
|
27 |
-
|
28 |
snapshot_download(
|
29 |
repo_id=REPO_ID,
|
30 |
repo_type="dataset",
|
@@ -32,15 +44,27 @@ def run_lora_training(push_to_hub: bool = False):
|
|
32 |
local_dir=local_dir,
|
33 |
local_dir_use_symlinks=False
|
34 |
)
|
35 |
-
|
36 |
training_path = os.path.join(local_dir, FOLDER_IN_REPO)
|
|
|
37 |
job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME, push_to_hub=push_to_hub)
|
|
|
38 |
run_job(job)
|
39 |
-
|
40 |
status.update({"running": False, "last_job": job})
|
41 |
except Exception as e:
|
|
|
42 |
status.update({"running": False, "error": str(e)})
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
@app.get("/")
|
45 |
def root():
|
46 |
return {"message": "LoRA training FastAPI is live."}
|
|
|
9 |
sys.path.append("/app/ai-toolkit")
|
10 |
|
11 |
from toolkit.job import run_job
|
12 |
+
import logging
|
13 |
+
import io
|
14 |
+
|
15 |
+
log_stream = io.StringIO()
|
16 |
+
logger = logging.getLogger("lora_training")
|
17 |
+
logger.setLevel(logging.DEBUG)
|
18 |
+
handler = logging.StreamHandler(log_stream)
|
19 |
+
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
|
20 |
+
logger.addHandler(handler)
|
21 |
+
|
22 |
|
23 |
app = FastAPI()
|
24 |
|
|
|
33 |
def run_lora_training(push_to_hub: bool = False):
|
34 |
try:
|
35 |
status.update({"running": True, "error": None})
|
36 |
+
logger.info("Starting training...")
|
37 |
+
|
38 |
local_dir = f"/tmp/{LORA_NAME}-{uuid.uuid4()}"
|
39 |
+
logger.info(f"Downloading dataset to {local_dir} ...")
|
40 |
snapshot_download(
|
41 |
repo_id=REPO_ID,
|
42 |
repo_type="dataset",
|
|
|
44 |
local_dir=local_dir,
|
45 |
local_dir_use_symlinks=False
|
46 |
)
|
|
|
47 |
training_path = os.path.join(local_dir, FOLDER_IN_REPO)
|
48 |
+
logger.info(f"Building job with training path: {training_path}")
|
49 |
job = build_job(CONCEPT_SENTENCE, training_path, LORA_NAME, push_to_hub=push_to_hub)
|
50 |
+
logger.info("Running job...")
|
51 |
run_job(job)
|
52 |
+
logger.info("Training completed successfully.")
|
53 |
status.update({"running": False, "last_job": job})
|
54 |
except Exception as e:
|
55 |
+
logger.error(f"Training failed: {e}")
|
56 |
status.update({"running": False, "error": str(e)})
|
57 |
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
@app.get("/logs")
|
63 |
+
def get_logs():
|
64 |
+
return {"logs": log_stream.getvalue()}
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
@app.get("/")
|
69 |
def root():
|
70 |
return {"message": "LoRA training FastAPI is live."}
|