ndc8 commited on
Commit
78b611a
·
1 Parent(s): 4ecf54e

Cleanup: Remove unnecessary files and update .gitignore

Browse files
.gitignore CHANGED
@@ -84,3 +84,11 @@ logs/
84
 
85
  # Hugging Face cache directory
86
  .hf_cache/
 
 
 
 
 
 
 
 
 
84
 
85
  # Hugging Face cache directory
86
  .hf_cache/
87
+
88
+ # Ignore Python cache and virtual environment directories
89
+ __pycache__/
90
+ .venv/
91
+ *.pyc
92
+ *~
93
+ .env
94
+ .DS_Store
README.md CHANGED
@@ -396,3 +396,39 @@ Ready for production with:
396
  Successfully transformed from broken Gradio app to production-ready AI backend service.
397
 
398
  For detailed conversion documentation, see [`CONVERSION_COMPLETE.md`](CONVERSION_COMPLETE.md).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  Successfully transformed from broken Gradio app to production-ready AI backend service.
397
 
398
  For detailed conversion documentation, see [`CONVERSION_COMPLETE.md`](CONVERSION_COMPLETE.md).
399
+
400
+ # Gemma 3n GGUF FastAPI Backend (Hugging Face Space)
401
+
402
+ This Space provides an OpenAI-compatible chat API for Gemma 3n GGUF models, powered by FastAPI.
403
+
404
+ **Note:** On Hugging Face Spaces, the backend runs in `DEMO_MODE` (no model loaded) for demonstration and endpoint testing. For real inference, run locally with a GGUF model and llama-cpp-python.
405
+
406
+ ## Endpoints
407
+
408
+ - `/health` — Health check
409
+ - `/v1/chat/completions` — OpenAI-style chat completions (returns demo response)
410
+ - `/train/start` — Start a (demo) training job
411
+ - `/train/status/{job_id}` — Check training job status
412
+ - `/train/logs/{job_id}` — Get training logs
413
+
414
+ ## Usage
415
+
416
+ 1. **Clone this repo** or create a Hugging Face Space (type: FastAPI).
417
+ 2. All dependencies are in `requirements.txt`.
418
+ 3. The Space will start in demo mode (no model download required).
419
+
420
+ ## Local Inference (with GGUF)
421
+
422
+ To run with a real model locally:
423
+
424
+ 1. Download a Gemma 3n GGUF model (e.g. from https://huggingface.co/unsloth/gemma-3n-E4B-it-GGUF).
425
+ 2. Set `AI_MODEL` to the local path or repo.
426
+ 3. Unset `DEMO_MODE`.
427
+ 4. Run:
428
+ ```bash
429
+ pip install -r requirements.txt
430
+ uvicorn gemma_gguf_backend:app --host 0.0.0.0 --port 8000
431
+ ```
432
+
433
+ ## License
434
+ Apache 2.0
gemma_gguf_backend.py CHANGED
@@ -9,6 +9,11 @@ import logging
9
  import time
10
  from contextlib import asynccontextmanager
11
  from typing import List, Dict, Any, Optional
 
 
 
 
 
12
 
13
  from fastapi import FastAPI, HTTPException
14
  from fastapi.responses import JSONResponse
@@ -97,6 +102,12 @@ async def lifespan(app: FastAPI):
97
  """Application lifespan manager for startup and shutdown events"""
98
  global llm
99
  logger.info("🚀 Starting Gemma 3n GGUF Backend Service...")
 
 
 
 
 
 
100
 
101
  if not llama_cpp_available:
102
  logger.error("❌ llama-cpp-python is not available. Please install with: pip install llama-cpp-python")
@@ -262,6 +273,177 @@ async def create_chat_completion(
262
  logger.error(f"Error in chat completion: {e}")
263
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  # Main entry point
266
  if __name__ == "__main__":
267
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
9
  import time
10
  from contextlib import asynccontextmanager
11
  from typing import List, Dict, Any, Optional
12
+ import uuid
13
+ import sys
14
+ import subprocess
15
+ import threading
16
+ from pathlib import Path
17
 
18
  from fastapi import FastAPI, HTTPException
19
  from fastapi.responses import JSONResponse
 
102
  """Application lifespan manager for startup and shutdown events"""
103
  global llm
104
  logger.info("🚀 Starting Gemma 3n GGUF Backend Service...")
105
+ if os.environ.get("DEMO_MODE", "").strip() not in ("", "0", "false", "False"):
106
+ logger.info("🧪 DEMO_MODE enabled: skipping model load")
107
+ llm = None
108
+ yield
109
+ logger.info("🔄 Shutting down Gemma 3n Backend Service (demo mode)...")
110
+ return
111
 
112
  if not llama_cpp_available:
113
  logger.error("❌ llama-cpp-python is not available. Please install with: pip install llama-cpp-python")
 
273
  logger.error(f"Error in chat completion: {e}")
274
  raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
275
 
276
+ # -----------------------------
277
+ # Training Job Management (Unsloth)
278
+ # -----------------------------
279
+
280
+ # Jobs are tracked in-memory; logs and artifacts are written to disk
281
+ TRAIN_JOBS: Dict[str, Dict[str, Any]] = {}
282
+ TRAIN_DIR = Path(os.environ.get("TRAIN_DIR", "./training_runs")).resolve()
283
+ TRAIN_DIR.mkdir(parents=True, exist_ok=True)
284
+
285
+ def _start_training_subprocess(job_id: str, args: Dict[str, Any]) -> subprocess.Popen[Any]:
286
+ """Spawn a subprocess to run the Unsloth fine-tuning script."""
287
+ logs_dir = TRAIN_DIR / job_id
288
+ logs_dir.mkdir(parents=True, exist_ok=True)
289
+ log_file = open(logs_dir / "train.log", "w", encoding="utf-8")
290
+
291
+ # Build absolute script path to avoid module/package resolution issues
292
+ script_path = (Path(__file__).parent / "training" / "train_gemma_unsloth.py").resolve()
293
+ python_exec = sys.executable
294
+
295
+ cmd = [
296
+ python_exec,
297
+ str(script_path),
298
+ "--job-id", job_id,
299
+ "--output-dir", str(logs_dir),
300
+ ]
301
+
302
+ # Optional user-specified args
303
+ def _extend(k: str, v: Any):
304
+ if v is None:
305
+ return
306
+ if isinstance(v, bool):
307
+ cmd.extend([f"--{k}"] if v else [])
308
+ else:
309
+ cmd.extend([f"--{k}", str(v)])
310
+
311
+ _extend("dataset", args.get("dataset"))
312
+ _extend("text-field", args.get("text_field"))
313
+ _extend("prompt-field", args.get("prompt_field"))
314
+ _extend("response-field", args.get("response_field"))
315
+ _extend("max-steps", args.get("max_steps"))
316
+ _extend("epochs", args.get("epochs"))
317
+ _extend("lr", args.get("lr"))
318
+ _extend("batch-size", args.get("batch_size"))
319
+ _extend("gradient-accumulation", args.get("gradient_accumulation"))
320
+ _extend("lora-r", args.get("lora_r"))
321
+ _extend("lora-alpha", args.get("lora_alpha"))
322
+ _extend("cutoff-len", args.get("cutoff_len"))
323
+ _extend("model-id", args.get("model_id"))
324
+ _extend("use-bf16", args.get("use_bf16"))
325
+ _extend("use-fp16", args.get("use_fp16"))
326
+ _extend("seed", args.get("seed"))
327
+ _extend("dry-run", args.get("dry_run"))
328
+
329
+ logger.info(f"🧵 Starting training subprocess for job {job_id}: {' '.join(cmd)}")
330
+ logger.info(f"🐍 Using interpreter: {python_exec}")
331
+ proc = subprocess.Popen(cmd, stdout=log_file, stderr=subprocess.STDOUT, cwd=str(Path(__file__).parent))
332
+ return proc
333
+
334
+ def _watch_process(job_id: str, proc: subprocess.Popen[Any]):
335
+ """Monitor a training process and update job state on exit."""
336
+ return_code = proc.wait()
337
+ status = "completed" if return_code == 0 else "failed"
338
+ TRAIN_JOBS[job_id]["status"] = status
339
+ TRAIN_JOBS[job_id]["return_code"] = return_code
340
+ TRAIN_JOBS[job_id]["ended_at"] = int(time.time())
341
+ logger.info(f"🏁 Training job {job_id} finished with status={status}, code={return_code}")
342
+
343
+ class StartTrainingRequest(BaseModel):
344
+ dataset: str = Field(..., description="HF dataset name or path to local JSONL/JSON file")
345
+ model_id: Optional[str] = Field(default="unsloth/gemma-3n-E4B-it", description="Base model for training (HF Transformers format)")
346
+ text_field: Optional[str] = Field(default=None, description="Single text field name (SFT)")
347
+ prompt_field: Optional[str] = Field(default=None, description="Prompt/instruction field (chat data)")
348
+ response_field: Optional[str] = Field(default=None, description="Response/output field (chat data)")
349
+ max_steps: Optional[int] = Field(default=None)
350
+ epochs: Optional[int] = Field(default=1)
351
+ lr: Optional[float] = Field(default=2e-4)
352
+ batch_size: Optional[int] = Field(default=1)
353
+ gradient_accumulation: Optional[int] = Field(default=8)
354
+ lora_r: Optional[int] = Field(default=16)
355
+ lora_alpha: Optional[int] = Field(default=32)
356
+ cutoff_len: Optional[int] = Field(default=4096)
357
+ use_bf16: Optional[bool] = Field(default=True)
358
+ use_fp16: Optional[bool] = Field(default=False)
359
+ seed: Optional[int] = Field(default=42)
360
+ dry_run: Optional[bool] = Field(default=False, description="Write DONE and exit without running (for CI/macOS)")
361
+
362
+ class StartTrainingResponse(BaseModel):
363
+ job_id: str
364
+ status: str
365
+ output_dir: str
366
+
367
+ class TrainStatusResponse(BaseModel):
368
+ job_id: str
369
+ status: str
370
+ created_at: int
371
+ started_at: Optional[int] = None
372
+ ended_at: Optional[int] = None
373
+ output_dir: Optional[str] = None
374
+ return_code: Optional[int] = None
375
+
376
+ @app.post("/train/start", response_model=StartTrainingResponse)
377
+ def start_training(req: StartTrainingRequest):
378
+ """Start a background Unsloth fine-tuning job. Returns a job_id to poll."""
379
+ job_id = uuid.uuid4().hex[:12]
380
+ now = int(time.time())
381
+ output_dir = str((TRAIN_DIR / job_id).resolve())
382
+ TRAIN_JOBS[job_id] = {
383
+ "status": "starting",
384
+ "created_at": now,
385
+ "started_at": now,
386
+ "args": req.model_dump(),
387
+ "output_dir": output_dir,
388
+ }
389
+
390
+ try:
391
+ proc = _start_training_subprocess(job_id, req.model_dump())
392
+ TRAIN_JOBS[job_id]["status"] = "running"
393
+ TRAIN_JOBS[job_id]["pid"] = proc.pid
394
+ watcher = threading.Thread(target=_watch_process, args=(job_id, proc), daemon=True)
395
+ watcher.start()
396
+ return StartTrainingResponse(job_id=job_id, status="running", output_dir=output_dir)
397
+ except Exception as e:
398
+ logger.exception("Failed to start training job")
399
+ TRAIN_JOBS[job_id]["status"] = "failed_to_start"
400
+ raise HTTPException(status_code=500, detail=f"Failed to start training: {e}")
401
+
402
+ @app.get("/train/status/{job_id}", response_model=TrainStatusResponse)
403
+ def train_status(job_id: str):
404
+ job = TRAIN_JOBS.get(job_id)
405
+ if not job:
406
+ raise HTTPException(status_code=404, detail="Job not found")
407
+ return TrainStatusResponse(
408
+ job_id=job_id,
409
+ status=job.get("status", "unknown"),
410
+ created_at=job.get("created_at", 0),
411
+ started_at=job.get("started_at"),
412
+ ended_at=job.get("ended_at"),
413
+ output_dir=job.get("output_dir"),
414
+ return_code=job.get("return_code"),
415
+ )
416
+
417
+ @app.get("/train/logs/{job_id}")
418
+ def train_logs(job_id: str, tail: int = 200):
419
+ job = TRAIN_JOBS.get(job_id)
420
+ if not job:
421
+ raise HTTPException(status_code=404, detail="Job not found")
422
+ log_path = Path(job["output_dir"]) / "train.log"
423
+ if not log_path.exists():
424
+ return {"job_id": job_id, "logs": "(no logs yet)"}
425
+ try:
426
+ with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
427
+ lines = f.readlines()[-tail:]
428
+ return {"job_id": job_id, "logs": "".join(lines)}
429
+ except Exception as e:
430
+ raise HTTPException(status_code=500, detail=f"Failed to read logs: {e}")
431
+
432
+ @app.post("/train/stop/{job_id}")
433
+ def train_stop(job_id: str):
434
+ job = TRAIN_JOBS.get(job_id)
435
+ if not job:
436
+ raise HTTPException(status_code=404, detail="Job not found")
437
+ pid = job.get("pid")
438
+ if not pid:
439
+ raise HTTPException(status_code=400, detail="Job does not have an active PID")
440
+ try:
441
+ os.kill(pid, 15) # SIGTERM
442
+ job["status"] = "stopping"
443
+ return {"job_id": job_id, "status": "stopping"}
444
+ except Exception as e:
445
+ raise HTTPException(status_code=500, detail=f"Failed to stop job: {e}")
446
+
447
  # Main entry point
448
  if __name__ == "__main__":
449
  uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -1,14 +1,12 @@
1
- gradio>=5.41.0
2
- huggingface_hub>=0.34.0
 
 
 
 
 
 
 
3
  transformers>=4.36.0
4
  torch>=2.0.0
5
- Pillow>=10.0.0
6
  accelerate>=0.24.0
7
- requests>=2.31.0
8
- protobuf>=3.20.0
9
- # llama-cpp-python for GGUF model support (Gemma 3n)
10
- llama-cpp-python>=0.3.14
11
- # NOTE: GGUF models like 'unsloth/gemma-3n-E4B-it-GGUF' can be loaded directly from HuggingFace
12
- fastapi>=0.100.0
13
- uvicorn[standard]>=0.23.0
14
- pydantic>=2.0.0
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ pydantic
4
+ llama-cpp-python
5
+ # Training dependencies for CCUF/Unsloth
6
+ unsloth>=2024.7.0
7
+ datasets>=2.20.0
8
+ trl>=0.9.6
9
+ peft>=0.11.1
10
  transformers>=4.36.0
11
  torch>=2.0.0
 
12
  accelerate>=0.24.0
 
 
 
 
 
 
 
 
sample_data/train.jsonl ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ {"prompt": "Hello! Introduce yourself.", "response": "I'm a helpful assistant built on Gemma 3n."}
2
+ {"prompt": "Give me a fun fact.", "response": "Honey never spoils; archaeologists found edible honey in ancient Egyptian tombs."}
space.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ sdk: fastapi
2
+ python_version: 3.10
3
+ app_file: gemma_gguf_backend.py
4
+ env:
5
+ - DEMO_MODE=1
test_training_api.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Minimal integration test for training endpoints.
4
+ """
5
+ import time
6
+ import json
7
+ import requests
8
+
9
+ BASE = "http://localhost:8001"
10
+
11
+ print("1) Start a training job")
12
+ resp = requests.post(f"{BASE}/train/start", json={
13
+ "dataset": "./sample_data/train.jsonl",
14
+ "model_id": "unsloth/gemma-3n-E4B-it",
15
+ "prompt_field": "prompt",
16
+ "response_field": "response",
17
+ "epochs": 1,
18
+ "batch_size": 1,
19
+ "gradient_accumulation": 8,
20
+ "use_bf16": True,
21
+ "dry_run": True
22
+ })
23
+ print(resp.status_code, resp.text)
24
+ resp.raise_for_status()
25
+ job = resp.json()
26
+ job_id = job["job_id"]
27
+ print("job_id=", job_id)
28
+
29
+ print("2) Poll status (10s)")
30
+ for _ in range(10):
31
+ s = requests.get(f"{BASE}/train/status/{job_id}")
32
+ print(s.status_code, json.dumps(s.json(), indent=2))
33
+ time.sleep(1)
34
+
35
+ print("3) Done")
training/train_gemma_unsloth.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Unsloth fine-tuning runner for Gemma-3n-E4B-it.
4
+ - Trains a LoRA adapter on top of HF Transformers-format base model (not GGUF).
5
+ - Output: PEFT adapter that can later be merged/exported to GGUF separately if desired.
6
+
7
+ This is a minimal, production-friendly CLI so the API server can spawn it as a subprocess.
8
+ """
9
+ import argparse
10
+ import os
11
+ import json
12
+ import time
13
+ from pathlib import Path
14
+ from typing import Any, Dict
15
+
16
+ # Lazy imports to keep API light
17
+
18
+ def _import_training_libs() -> Dict[str, Any]:
19
+ """Try to import Unsloth fast path; if unavailable, fall back to Transformers+PEFT.
20
+
21
+ Returns a dict with keys:
22
+ mode: "unsloth" | "hf"
23
+ load_dataset, SFTTrainer, SFTConfig
24
+ If mode=="unsloth": FastLanguageModel, AutoTokenizer
25
+ If mode=="hf": AutoTokenizer, AutoModelForCausalLM, get_peft_model, LoraConfig, torch
26
+ """
27
+ # Avoid heavy optional deps on macOS (no xformers/bitsandbytes)
28
+ os.environ.setdefault("UNSLOTH_DISABLE_XFORMERS", "1")
29
+ os.environ.setdefault("UNSLOTH_DISABLE_BITSANDBYTES", "1")
30
+ from datasets import load_dataset
31
+ from trl import SFTTrainer, SFTConfig
32
+ try:
33
+ from unsloth import FastLanguageModel
34
+ from transformers import AutoTokenizer
35
+ return {
36
+ "mode": "unsloth",
37
+ "load_dataset": load_dataset,
38
+ "SFTTrainer": SFTTrainer,
39
+ "SFTConfig": SFTConfig,
40
+ "FastLanguageModel": FastLanguageModel,
41
+ "AutoTokenizer": AutoTokenizer,
42
+ }
43
+ except Exception:
44
+ # Fallback: pure HF + PEFT (CPU / MPS friendly)
45
+ from transformers import AutoTokenizer, AutoModelForCausalLM
46
+ from peft import get_peft_model, LoraConfig
47
+ import torch
48
+ return {
49
+ "mode": "hf",
50
+ "load_dataset": load_dataset,
51
+ "SFTTrainer": SFTTrainer,
52
+ "SFTConfig": SFTConfig,
53
+ "AutoTokenizer": AutoTokenizer,
54
+ "AutoModelForCausalLM": AutoModelForCausalLM,
55
+ "get_peft_model": get_peft_model,
56
+ "LoraConfig": LoraConfig,
57
+ "torch": torch,
58
+ }
59
+
60
+
61
+ def parse_args():
62
+ p = argparse.ArgumentParser()
63
+ p.add_argument("--job-id", required=True)
64
+ p.add_argument("--output-dir", required=True)
65
+ p.add_argument("--dataset", required=True, help="HF dataset path or local JSON/JSONL file")
66
+ p.add_argument("--text-field", dest="text_field", default=None)
67
+ p.add_argument("--prompt-field", dest="prompt_field", default=None)
68
+ p.add_argument("--response-field", dest="response_field", default=None)
69
+ p.add_argument("--model-id", dest="model_id", default="unsloth/gemma-3n-E4B-it")
70
+ p.add_argument("--epochs", type=int, default=1)
71
+ p.add_argument("--max-steps", dest="max_steps", type=int, default=None)
72
+ p.add_argument("--lr", type=float, default=2e-4)
73
+ p.add_argument("--batch-size", dest="batch_size", type=int, default=1)
74
+ p.add_argument("--gradient-accumulation", dest="gradient_accumulation", type=int, default=8)
75
+ p.add_argument("--lora-r", dest="lora_r", type=int, default=16)
76
+ p.add_argument("--lora-alpha", dest="lora_alpha", type=int, default=32)
77
+ p.add_argument("--cutoff-len", dest="cutoff_len", type=int, default=4096)
78
+ p.add_argument("--use-bf16", dest="use_bf16", action="store_true")
79
+ p.add_argument("--use-fp16", dest="use_fp16", action="store_true")
80
+ p.add_argument("--seed", type=int, default=42)
81
+ p.add_argument("--dry-run", dest="dry_run", action="store_true", help="Write DONE and exit without training (for CI)")
82
+ return p.parse_args()
83
+
84
+
85
+ def _is_local_path(s: str) -> bool:
86
+ return os.path.exists(s)
87
+
88
+
89
+ def _load_dataset(load_dataset: Any, path: str) -> Any:
90
+ if _is_local_path(path):
91
+ # Infer extension
92
+ if path.endswith(".jsonl") or path.endswith(".jsonl.gz"):
93
+ return load_dataset("json", data_files=path, split="train")
94
+ elif path.endswith(".json"):
95
+ return load_dataset("json", data_files=path, split="train")
96
+ else:
97
+ raise ValueError("Unsupported local dataset format. Use JSON or JSONL.")
98
+ else:
99
+ return load_dataset(path, split="train")
100
+
101
+
102
+ def main():
103
+ args = parse_args()
104
+ start = time.time()
105
+ out_dir = Path(args.output_dir)
106
+ out_dir.mkdir(parents=True, exist_ok=True)
107
+ (out_dir / "meta.json").write_text(json.dumps({
108
+ "job_id": args.job_id,
109
+ "model_id": args.model_id,
110
+ "dataset": args.dataset,
111
+ "created_at": int(start),
112
+ }, indent=2))
113
+
114
+ if args.dry_run:
115
+ (out_dir / "DONE").write_text("dry_run")
116
+ print("[train] Dry run complete. DONE written.")
117
+ return
118
+
119
+ # Training imports (supports Unsloth fast path and HF fallback)
120
+ libs: Dict[str, Any] = _import_training_libs()
121
+ load_dataset = libs["load_dataset"]
122
+ SFTTrainer = libs["SFTTrainer"]
123
+ SFTConfig = libs["SFTConfig"]
124
+
125
+ # Environment for stability on T4 etc per Unsloth guidance
126
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
127
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
128
+
129
+ print(f"[train] Loading base model: {args.model_id}")
130
+ if libs["mode"] == "unsloth":
131
+ FastLanguageModel = libs["FastLanguageModel"]
132
+ AutoTokenizer = libs["AutoTokenizer"]
133
+ model, tokenizer = FastLanguageModel.from_pretrained(
134
+ model_name=args.model_id,
135
+ max_seq_length=args.cutoff_len,
136
+ # Avoid bitsandbytes/xformers
137
+ load_in_4bit=False,
138
+ dtype=None,
139
+ use_gradient_checkpointing="unsloth",
140
+ )
141
+ # Prepare LoRA via Unsloth helper
142
+ print("[train] Attaching LoRA adapter (Unsloth)")
143
+ model = FastLanguageModel.get_peft_model(
144
+ model,
145
+ r=args.lora_r,
146
+ lora_alpha=args.lora_alpha,
147
+ lora_dropout=0,
148
+ bias="none",
149
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
150
+ use_rslora=True,
151
+ loftq_config=None,
152
+ )
153
+ else:
154
+ # HF + PEFT fallback (CPU / MPS)
155
+ AutoTokenizer = libs["AutoTokenizer"]
156
+ AutoModelForCausalLM = libs["AutoModelForCausalLM"]
157
+ get_peft_model = libs["get_peft_model"]
158
+ LoraConfig = libs["LoraConfig"]
159
+ torch = libs["torch"]
160
+
161
+ tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True, trust_remote_code=True)
162
+ # Prefer MPS on Apple Silicon if available
163
+ use_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
164
+ torch_dtype = torch.float16 if (args.use_fp16 or args.use_bf16) and not use_mps else torch.float32
165
+ model = AutoModelForCausalLM.from_pretrained(
166
+ args.model_id,
167
+ torch_dtype=torch_dtype,
168
+ trust_remote_code=True,
169
+ )
170
+ if use_mps:
171
+ model.to("mps")
172
+ print("[train] Attaching LoRA adapter (HF/PEFT)")
173
+ lora_config = LoraConfig(
174
+ r=args.lora_r,
175
+ lora_alpha=args.lora_alpha,
176
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
177
+ lora_dropout=0.0,
178
+ bias="none",
179
+ task_type="CAUSAL_LM",
180
+ )
181
+ model = get_peft_model(model, lora_config)
182
+
183
+ # Load dataset
184
+ print(f"[train] Loading dataset: {args.dataset}")
185
+ ds = _load_dataset(load_dataset, args.dataset)
186
+
187
+ # Build formatting
188
+ text_field = args.text_field
189
+ prompt_field = args.prompt_field
190
+ response_field = args.response_field
191
+
192
+ if text_field:
193
+ # Simple SFT: single text field
194
+ def format_row(ex):
195
+ return ex[text_field]
196
+ elif prompt_field and response_field:
197
+ # Chat data: prompt + response
198
+ def format_row(ex):
199
+ return f"<start_of_turn>user\n{ex[prompt_field]}<end_of_turn>\n<start_of_turn>model\n{ex[response_field]}<end_of_turn>\n"
200
+ else:
201
+ raise ValueError("Provide either --text-field or both --prompt-field and --response-field")
202
+
203
+ def map_fn(ex):
204
+ return {"text": format_row(ex)}
205
+
206
+ ds = ds.map(map_fn, remove_columns=[c for c in ds.column_names if c != "text"])
207
+
208
+ # Trainer
209
+ trainer = SFTTrainer(
210
+ model=model,
211
+ tokenizer=tokenizer,
212
+ train_dataset=ds,
213
+ max_seq_length=args.cutoff_len,
214
+ dataset_text_field="text",
215
+ packing=True,
216
+ args=SFTConfig(
217
+ output_dir=str(out_dir / "hf"),
218
+ per_device_train_batch_size=args.batch_size,
219
+ gradient_accumulation_steps=args.gradient_accumulation,
220
+ learning_rate=args.lr,
221
+ num_train_epochs=args.epochs,
222
+ max_steps=args.max_steps if args.max_steps else -1,
223
+ logging_steps=10,
224
+ save_steps=200,
225
+ save_total_limit=2,
226
+ bf16=args.use_bf16,
227
+ fp16=args.use_fp16,
228
+ seed=args.seed,
229
+ report_to=[],
230
+ ),
231
+ )
232
+
233
+ print("[train] Starting training...")
234
+ trainer.train()
235
+ print("[train] Saving adapter...")
236
+ adapter_path = out_dir / "adapter"
237
+ adapter_path.mkdir(parents=True, exist_ok=True)
238
+ # Save adapter-only weights if PEFT; Unsloth path is also PEFT-compatible
239
+ try:
240
+ model.save_pretrained(str(adapter_path))
241
+ except Exception:
242
+ # Fallback: save full model (large); unlikely on LoRA
243
+ try:
244
+ model.base_model.save_pretrained(str(adapter_path)) # type: ignore[attr-defined]
245
+ except Exception:
246
+ pass
247
+ tokenizer.save_pretrained(str(adapter_path))
248
+
249
+ # Write done file
250
+ (out_dir / "DONE").write_text("ok")
251
+ elapsed = time.time() - start
252
+ print(f"[train] Finished in {elapsed:.1f}s. Artifacts at: {out_dir}")
253
+
254
+
255
+ if __name__ == "__main__":
256
+ main()
training_runs/c6fdb7b0a765/meta.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "job_id": "c6fdb7b0a765",
3
+ "model_id": "unsloth/gemma-3n-E4B-it",
4
+ "dataset": "./sample_data/train.jsonl",
5
+ "created_at": 1754620412
6
+ }
training_runs/devlocal/DONE ADDED
@@ -0,0 +1 @@
 
 
1
+ dry_run
training_runs/devlocal/meta.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "job_id": "devlocal",
3
+ "model_id": "unsloth/gemma-3n-E4B-it",
4
+ "dataset": "/Users/congnguyen/DevRepo/firstAI/sample_data/train.jsonl",
5
+ "created_at": 1754620844
6
+ }