ndc8 commited on
Commit
3960f0f
·
1 Parent(s): 78b611a
README.md CHANGED
@@ -431,4 +431,5 @@ To run with a real model locally:
431
  ```
432
 
433
  ## License
 
434
  Apache 2.0
 
431
  ```
432
 
433
  ## License
434
+
435
  Apache 2.0
gemma_gguf_backend.py CHANGED
@@ -14,8 +14,9 @@ import sys
14
  import subprocess
15
  import threading
16
  from pathlib import Path
 
17
 
18
- from fastapi import FastAPI, HTTPException
19
  from fastapi.responses import JSONResponse
20
  from fastapi.middleware.cors import CORSMiddleware
21
  from pydantic import BaseModel, Field, field_validator
@@ -28,6 +29,8 @@ except ImportError:
28
  llama_cpp_available = False
29
 
30
  import uvicorn
 
 
31
 
32
  # Configure logging
33
  logging.basicConfig(level=logging.INFO)
@@ -72,6 +75,7 @@ class HealthResponse(BaseModel):
72
  version: str
73
  backend: str
74
 
 
75
  # Global variables for model management
76
  current_model = os.environ.get("AI_MODEL", "unsloth/gemma-3n-E4B-it-GGUF")
77
  llm = None
@@ -277,19 +281,57 @@ async def create_chat_completion(
277
  # Training Job Management (Unsloth)
278
  # -----------------------------
279
 
280
- # Jobs are tracked in-memory; logs and artifacts are written to disk
281
  TRAIN_JOBS: Dict[str, Dict[str, Any]] = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  TRAIN_DIR = Path(os.environ.get("TRAIN_DIR", "./training_runs")).resolve()
283
  TRAIN_DIR.mkdir(parents=True, exist_ok=True)
 
 
284
 
285
  def _start_training_subprocess(job_id: str, args: Dict[str, Any]) -> subprocess.Popen[Any]:
286
  """Spawn a subprocess to run the Unsloth fine-tuning script."""
287
  logs_dir = TRAIN_DIR / job_id
288
  logs_dir.mkdir(parents=True, exist_ok=True)
289
  log_file = open(logs_dir / "train.log", "w", encoding="utf-8")
 
 
 
290
 
291
  # Build absolute script path to avoid module/package resolution issues
292
  script_path = (Path(__file__).parent / "training" / "train_gemma_unsloth.py").resolve()
 
 
 
 
293
  python_exec = sys.executable
294
 
295
  cmd = [
@@ -338,6 +380,15 @@ def _watch_process(job_id: str, proc: subprocess.Popen[Any]):
338
  TRAIN_JOBS[job_id]["status"] = status
339
  TRAIN_JOBS[job_id]["return_code"] = return_code
340
  TRAIN_JOBS[job_id]["ended_at"] = int(time.time())
 
 
 
 
 
 
 
 
 
341
  logger.info(f"🏁 Training job {job_id} finished with status={status}, code={return_code}")
342
 
343
  class StartTrainingRequest(BaseModel):
@@ -376,6 +427,13 @@ class TrainStatusResponse(BaseModel):
376
  @app.post("/train/start", response_model=StartTrainingResponse)
377
  def start_training(req: StartTrainingRequest):
378
  """Start a background Unsloth fine-tuning job. Returns a job_id to poll."""
 
 
 
 
 
 
 
379
  job_id = uuid.uuid4().hex[:12]
380
  now = int(time.time())
381
  output_dir = str((TRAIN_DIR / job_id).resolve())
@@ -386,18 +444,21 @@ def start_training(req: StartTrainingRequest):
386
  "args": req.model_dump(),
387
  "output_dir": output_dir,
388
  }
 
389
 
390
  try:
391
  proc = _start_training_subprocess(job_id, req.model_dump())
392
  TRAIN_JOBS[job_id]["status"] = "running"
393
  TRAIN_JOBS[job_id]["pid"] = proc.pid
 
394
  watcher = threading.Thread(target=_watch_process, args=(job_id, proc), daemon=True)
395
  watcher.start()
396
  return StartTrainingResponse(job_id=job_id, status="running", output_dir=output_dir)
397
  except Exception as e:
398
- logger.exception("Failed to start training job")
399
- TRAIN_JOBS[job_id]["status"] = "failed_to_start"
400
- raise HTTPException(status_code=500, detail=f"Failed to start training: {e}")
 
401
 
402
  @app.get("/train/status/{job_id}", response_model=TrainStatusResponse)
403
  def train_status(job_id: str):
@@ -415,7 +476,10 @@ def train_status(job_id: str):
415
  )
416
 
417
  @app.get("/train/logs/{job_id}")
418
- def train_logs(job_id: str, tail: int = 200):
 
 
 
419
  job = TRAIN_JOBS.get(job_id)
420
  if not job:
421
  raise HTTPException(status_code=404, detail="Job not found")
@@ -438,11 +502,20 @@ def train_stop(job_id: str):
438
  if not pid:
439
  raise HTTPException(status_code=400, detail="Job does not have an active PID")
440
  try:
441
- os.kill(pid, 15) # SIGTERM
442
- job["status"] = "stopping"
443
- return {"job_id": job_id, "status": "stopping"}
 
 
 
 
 
444
  except Exception as e:
445
  raise HTTPException(status_code=500, detail=f"Failed to stop job: {e}")
 
 
 
 
446
 
447
  # Main entry point
448
  if __name__ == "__main__":
 
14
  import subprocess
15
  import threading
16
  from pathlib import Path
17
+ import signal # Use signal.SIGTERM for process termination
18
 
19
+ from fastapi import FastAPI, HTTPException, Query
20
  from fastapi.responses import JSONResponse
21
  from fastapi.middleware.cors import CORSMiddleware
22
  from pydantic import BaseModel, Field, field_validator
 
29
  llama_cpp_available = False
30
 
31
  import uvicorn
32
+ import sqlite3
33
+ import json # For persisting job metadata
34
 
35
  # Configure logging
36
  logging.basicConfig(level=logging.INFO)
 
75
  version: str
76
  backend: str
77
 
78
+ from pathlib import Path
79
  # Global variables for model management
80
  current_model = os.environ.get("AI_MODEL", "unsloth/gemma-3n-E4B-it-GGUF")
81
  llm = None
 
281
  # Training Job Management (Unsloth)
282
  # -----------------------------
283
 
284
+ # Persistent job store: in-memory dict backed by SQLite
285
  TRAIN_JOBS: Dict[str, Dict[str, Any]] = {}
286
+ # Initialize SQLite DB for job persistence
287
+ DB_PATH = Path(os.environ.get("JOB_DB_PATH", "./jobs.db"))
288
+ conn = sqlite3.connect(str(DB_PATH), check_same_thread=False)
289
+ cursor = conn.cursor()
290
+ cursor.execute(
291
+ """
292
+ CREATE TABLE IF NOT EXISTS jobs (
293
+ job_id TEXT PRIMARY KEY,
294
+ data TEXT NOT NULL
295
+ )
296
+ """
297
+ )
298
+ conn.commit()
299
+
300
+ def load_jobs() -> None:
301
+ cursor.execute("SELECT job_id, data FROM jobs")
302
+ for job_id, data in cursor.fetchall():
303
+ TRAIN_JOBS[job_id] = json.loads(data)
304
+
305
+ def save_job(job_id: str) -> None:
306
+ cursor.execute(
307
+ "INSERT OR REPLACE INTO jobs (job_id, data) VALUES (?, ?)",
308
+ (job_id, json.dumps(TRAIN_JOBS[job_id]))
309
+ )
310
+ conn.commit()
311
+
312
+ # Load existing jobs on startup
313
+ load_jobs()
314
+
315
  TRAIN_DIR = Path(os.environ.get("TRAIN_DIR", "./training_runs")).resolve()
316
  TRAIN_DIR.mkdir(parents=True, exist_ok=True)
317
+ # Maximum concurrent training jobs
318
+ MAX_CONCURRENT_JOBS = int(os.environ.get("MAX_CONCURRENT_JOBS", "5"))
319
 
320
  def _start_training_subprocess(job_id: str, args: Dict[str, Any]) -> subprocess.Popen[Any]:
321
  """Spawn a subprocess to run the Unsloth fine-tuning script."""
322
  logs_dir = TRAIN_DIR / job_id
323
  logs_dir.mkdir(parents=True, exist_ok=True)
324
  log_file = open(logs_dir / "train.log", "w", encoding="utf-8")
325
+ # Store log file handle to close later
326
+ TRAIN_JOBS.setdefault(job_id, {})["log_file"] = log_file
327
+ save_job(job_id)
328
 
329
  # Build absolute script path to avoid module/package resolution issues
330
  script_path = (Path(__file__).parent / "training" / "train_gemma_unsloth.py").resolve()
331
+ # Verify training script exists
332
+ if not script_path.exists():
333
+ logger.error(f"Training script not found at {script_path}")
334
+ raise HTTPException(status_code=500, detail=f"Training script not found at {script_path}")
335
  python_exec = sys.executable
336
 
337
  cmd = [
 
380
  TRAIN_JOBS[job_id]["status"] = status
381
  TRAIN_JOBS[job_id]["return_code"] = return_code
382
  TRAIN_JOBS[job_id]["ended_at"] = int(time.time())
383
+ # Persist updated job status
384
+ save_job(job_id)
385
+ # Close the log file handle to prevent resource leaks
386
+ log_file = TRAIN_JOBS[job_id].get("log_file")
387
+ if log_file:
388
+ try:
389
+ log_file.close()
390
+ except Exception as close_err:
391
+ logger.warning(f"Failed to close log file for job {job_id}: {close_err}")
392
  logger.info(f"🏁 Training job {job_id} finished with status={status}, code={return_code}")
393
 
394
  class StartTrainingRequest(BaseModel):
 
427
  @app.post("/train/start", response_model=StartTrainingResponse)
428
  def start_training(req: StartTrainingRequest):
429
  """Start a background Unsloth fine-tuning job. Returns a job_id to poll."""
430
+ # Enforce maximum concurrent training jobs
431
+ running_jobs = sum(1 for job in TRAIN_JOBS.values() if job.get("status") == "running")
432
+ if running_jobs >= MAX_CONCURRENT_JOBS:
433
+ raise HTTPException(
434
+ status_code=429,
435
+ detail=f"Maximum concurrent training jobs reached ({MAX_CONCURRENT_JOBS}). Try again later."
436
+ )
437
  job_id = uuid.uuid4().hex[:12]
438
  now = int(time.time())
439
  output_dir = str((TRAIN_DIR / job_id).resolve())
 
444
  "args": req.model_dump(),
445
  "output_dir": output_dir,
446
  }
447
+ save_job(job_id)
448
 
449
  try:
450
  proc = _start_training_subprocess(job_id, req.model_dump())
451
  TRAIN_JOBS[job_id]["status"] = "running"
452
  TRAIN_JOBS[job_id]["pid"] = proc.pid
453
+ save_job(job_id)
454
  watcher = threading.Thread(target=_watch_process, args=(job_id, proc), daemon=True)
455
  watcher.start()
456
  return StartTrainingResponse(job_id=job_id, status="running", output_dir=output_dir)
457
  except Exception as e:
458
+ logger.exception("Failed to start training job")
459
+ TRAIN_JOBS[job_id]["status"] = "failed_to_start"
460
+ save_job(job_id)
461
+ raise HTTPException(status_code=500, detail=f"Failed to start training: {e}")
462
 
463
  @app.get("/train/status/{job_id}", response_model=TrainStatusResponse)
464
  def train_status(job_id: str):
 
476
  )
477
 
478
  @app.get("/train/logs/{job_id}")
479
+ def train_logs(
480
+ job_id: str,
481
+ tail: int = Query(200, ge=0, le=1000, description="Number of lines to tail, between 0 and 1000"),
482
+ ):
483
  job = TRAIN_JOBS.get(job_id)
484
  if not job:
485
  raise HTTPException(status_code=404, detail="Job not found")
 
502
  if not pid:
503
  raise HTTPException(status_code=400, detail="Job does not have an active PID")
504
  try:
505
+ os.kill(pid, signal.SIGTERM)
506
+ except ProcessLookupError:
507
+ logger.warning(
508
+ f"Process {pid} for job {job_id} not found; may have exited already"
509
+ )
510
+ job["status"] = "stopping_failed"
511
+ save_job(job_id)
512
+ return {"job_id": job_id, "status": job["status"]}
513
  except Exception as e:
514
  raise HTTPException(status_code=500, detail=f"Failed to stop job: {e}")
515
+ else:
516
+ job["status"] = "stopping"
517
+ save_job(job_id)
518
+ return {"job_id": job_id, "status": "stopping"}
519
 
520
  # Main entry point
521
  if __name__ == "__main__":
space.yaml CHANGED
@@ -2,4 +2,4 @@ sdk: fastapi
2
  python_version: 3.10
3
  app_file: gemma_gguf_backend.py
4
  env:
5
- - DEMO_MODE=1
 
2
  python_version: 3.10
3
  app_file: gemma_gguf_backend.py
4
  env:
5
+ - DEMO_MODE=0 # Ensure model loads properly in production
training/train_gemma_unsloth.py CHANGED
@@ -12,6 +12,9 @@ import json
12
  import time
13
  from pathlib import Path
14
  from typing import Any, Dict
 
 
 
15
 
16
  # Lazy imports to keep API light
17
 
@@ -40,7 +43,12 @@ def _import_training_libs() -> Dict[str, Any]:
40
  "FastLanguageModel": FastLanguageModel,
41
  "AutoTokenizer": AutoTokenizer,
42
  }
43
- except Exception:
 
 
 
 
 
44
  # Fallback: pure HF + PEFT (CPU / MPS friendly)
45
  from transformers import AutoTokenizer, AutoModelForCausalLM
46
  from peft import get_peft_model, LoraConfig
@@ -161,10 +169,18 @@ def main():
161
  tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True, trust_remote_code=True)
162
  # Prefer MPS on Apple Silicon if available
163
  use_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
164
- torch_dtype = torch.float16 if (args.use_fp16 or args.use_bf16) and not use_mps else torch.float32
 
 
 
 
 
 
 
 
165
  model = AutoModelForCausalLM.from_pretrained(
166
  args.model_id,
167
- torch_dtype=torch_dtype,
168
  trust_remote_code=True,
169
  )
170
  if use_mps:
@@ -190,17 +206,25 @@ def main():
190
  response_field = args.response_field
191
 
192
  if text_field:
193
- # Simple SFT: single text field
194
- def format_row(ex):
 
 
195
  return ex[text_field]
196
  elif prompt_field and response_field:
197
- # Chat data: prompt + response
198
- def format_row(ex):
199
- return f"<start_of_turn>user\n{ex[prompt_field]}<end_of_turn>\n<start_of_turn>model\n{ex[response_field]}<end_of_turn>\n"
 
 
 
 
 
 
200
  else:
201
  raise ValueError("Provide either --text-field or both --prompt-field and --response-field")
202
 
203
- def map_fn(ex):
204
  return {"text": format_row(ex)}
205
 
206
  ds = ds.map(map_fn, remove_columns=[c for c in ds.column_names if c != "text"])
@@ -237,13 +261,16 @@ def main():
237
  adapter_path.mkdir(parents=True, exist_ok=True)
238
  # Save adapter-only weights if PEFT; Unsloth path is also PEFT-compatible
239
  try:
 
240
  model.save_pretrained(str(adapter_path))
241
- except Exception:
242
- # Fallback: save full model (large); unlikely on LoRA
243
  try:
244
- model.base_model.save_pretrained(str(adapter_path)) # type: ignore[attr-defined]
245
- except Exception:
246
- pass
 
 
247
  tokenizer.save_pretrained(str(adapter_path))
248
 
249
  # Write done file
 
12
  import time
13
  from pathlib import Path
14
  from typing import Any, Dict
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
 
19
  # Lazy imports to keep API light
20
 
 
43
  "FastLanguageModel": FastLanguageModel,
44
  "AutoTokenizer": AutoTokenizer,
45
  }
46
+ except ImportError as e:
47
+ logger.warning(
48
+ "Primary Unsloth import failed, falling back to HF+PEFT: %s",
49
+ e,
50
+ exc_info=True,
51
+ )
52
  # Fallback: pure HF + PEFT (CPU / MPS friendly)
53
  from transformers import AutoTokenizer, AutoModelForCausalLM
54
  from peft import get_peft_model, LoraConfig
 
169
  tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True, trust_remote_code=True)
170
  # Prefer MPS on Apple Silicon if available
171
  use_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
172
+ if not use_mps:
173
+ if args.use_fp16:
174
+ dtype = torch.float16
175
+ elif args.use_bf16:
176
+ dtype = torch.bfloat16
177
+ else:
178
+ dtype = torch.float32
179
+ else:
180
+ dtype = torch.float32
181
  model = AutoModelForCausalLM.from_pretrained(
182
  args.model_id,
183
+ torch_dtype=dtype,
184
  trust_remote_code=True,
185
  )
186
  if use_mps:
 
206
  response_field = args.response_field
207
 
208
  if text_field:
209
+ # Simple SFT: single text field with validation
210
+ def format_row(ex: Dict[str, Any]) -> str:
211
+ if text_field not in ex:
212
+ raise KeyError(f"Missing required text field '{text_field}' in example: {ex}")
213
  return ex[text_field]
214
  elif prompt_field and response_field:
215
+ # Chat data: prompt + response with validation
216
+ def format_row(ex: Dict[str, Any]) -> str:
217
+ missing = [f for f in (prompt_field, response_field) if f not in ex]
218
+ if missing:
219
+ raise KeyError(f"Missing required field(s) {missing} in example: {ex}")
220
+ return (
221
+ f"<start_of_turn>user\n{ex[prompt_field]}<end_of_turn>\n"
222
+ f"<start_of_turn>model\n{ex[response_field]}<end_of_turn>\n"
223
+ )
224
  else:
225
  raise ValueError("Provide either --text-field or both --prompt-field and --response-field")
226
 
227
+ def map_fn(ex: Dict[str, Any]) -> Dict[str, str]:
228
  return {"text": format_row(ex)}
229
 
230
  ds = ds.map(map_fn, remove_columns=[c for c in ds.column_names if c != "text"])
 
261
  adapter_path.mkdir(parents=True, exist_ok=True)
262
  # Save adapter-only weights if PEFT; Unsloth path is also PEFT-compatible
263
  try:
264
+ # Primary model saving logic
265
  model.save_pretrained(str(adapter_path))
266
+ except Exception as e:
267
+ logger.error("Error during primary model saving: %s", e, exc_info=True) # type: ignore
268
  try:
269
+ # Fallback model saving logic
270
+ model.base_model.save_pretrained(str(adapter_path)) # type: ignore[attr-defined]
271
+ except Exception as fallback_e:
272
+ logger.error("Fallback model saving failed: %s", fallback_e, exc_info=True) # type: ignore
273
+ pass # Optionally re-raise or handle accordingly
274
  tokenizer.save_pretrained(str(adapter_path))
275
 
276
  # Write done file
training_runs/devlocal/meta.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "job_id": "devlocal",
3
  "model_id": "unsloth/gemma-3n-E4B-it",
4
- "dataset": "/Users/congnguyen/DevRepo/firstAI/sample_data/train.jsonl",
5
  "created_at": 1754620844
6
- }
 
1
  {
2
  "job_id": "devlocal",
3
  "model_id": "unsloth/gemma-3n-E4B-it",
4
+ "dataset": "sample_data/train.jsonl",
5
  "created_at": 1754620844
6
+ }