firstAI / training /train_gemma_unsloth.py
ndc8
Cleanup: Remove unnecessary files and update .gitignore
78b611a
raw
history blame
9.77 kB
#!/usr/bin/env python3
"""
Unsloth fine-tuning runner for Gemma-3n-E4B-it.
- Trains a LoRA adapter on top of HF Transformers-format base model (not GGUF).
- Output: PEFT adapter that can later be merged/exported to GGUF separately if desired.
This is a minimal, production-friendly CLI so the API server can spawn it as a subprocess.
"""
import argparse
import os
import json
import time
from pathlib import Path
from typing import Any, Dict
# Lazy imports to keep API light
def _import_training_libs() -> Dict[str, Any]:
"""Try to import Unsloth fast path; if unavailable, fall back to Transformers+PEFT.
Returns a dict with keys:
mode: "unsloth" | "hf"
load_dataset, SFTTrainer, SFTConfig
If mode=="unsloth": FastLanguageModel, AutoTokenizer
If mode=="hf": AutoTokenizer, AutoModelForCausalLM, get_peft_model, LoraConfig, torch
"""
# Avoid heavy optional deps on macOS (no xformers/bitsandbytes)
os.environ.setdefault("UNSLOTH_DISABLE_XFORMERS", "1")
os.environ.setdefault("UNSLOTH_DISABLE_BITSANDBYTES", "1")
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
try:
from unsloth import FastLanguageModel
from transformers import AutoTokenizer
return {
"mode": "unsloth",
"load_dataset": load_dataset,
"SFTTrainer": SFTTrainer,
"SFTConfig": SFTConfig,
"FastLanguageModel": FastLanguageModel,
"AutoTokenizer": AutoTokenizer,
}
except Exception:
# Fallback: pure HF + PEFT (CPU / MPS friendly)
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig
import torch
return {
"mode": "hf",
"load_dataset": load_dataset,
"SFTTrainer": SFTTrainer,
"SFTConfig": SFTConfig,
"AutoTokenizer": AutoTokenizer,
"AutoModelForCausalLM": AutoModelForCausalLM,
"get_peft_model": get_peft_model,
"LoraConfig": LoraConfig,
"torch": torch,
}
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--job-id", required=True)
p.add_argument("--output-dir", required=True)
p.add_argument("--dataset", required=True, help="HF dataset path or local JSON/JSONL file")
p.add_argument("--text-field", dest="text_field", default=None)
p.add_argument("--prompt-field", dest="prompt_field", default=None)
p.add_argument("--response-field", dest="response_field", default=None)
p.add_argument("--model-id", dest="model_id", default="unsloth/gemma-3n-E4B-it")
p.add_argument("--epochs", type=int, default=1)
p.add_argument("--max-steps", dest="max_steps", type=int, default=None)
p.add_argument("--lr", type=float, default=2e-4)
p.add_argument("--batch-size", dest="batch_size", type=int, default=1)
p.add_argument("--gradient-accumulation", dest="gradient_accumulation", type=int, default=8)
p.add_argument("--lora-r", dest="lora_r", type=int, default=16)
p.add_argument("--lora-alpha", dest="lora_alpha", type=int, default=32)
p.add_argument("--cutoff-len", dest="cutoff_len", type=int, default=4096)
p.add_argument("--use-bf16", dest="use_bf16", action="store_true")
p.add_argument("--use-fp16", dest="use_fp16", action="store_true")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--dry-run", dest="dry_run", action="store_true", help="Write DONE and exit without training (for CI)")
return p.parse_args()
def _is_local_path(s: str) -> bool:
return os.path.exists(s)
def _load_dataset(load_dataset: Any, path: str) -> Any:
if _is_local_path(path):
# Infer extension
if path.endswith(".jsonl") or path.endswith(".jsonl.gz"):
return load_dataset("json", data_files=path, split="train")
elif path.endswith(".json"):
return load_dataset("json", data_files=path, split="train")
else:
raise ValueError("Unsupported local dataset format. Use JSON or JSONL.")
else:
return load_dataset(path, split="train")
def main():
args = parse_args()
start = time.time()
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
(out_dir / "meta.json").write_text(json.dumps({
"job_id": args.job_id,
"model_id": args.model_id,
"dataset": args.dataset,
"created_at": int(start),
}, indent=2))
if args.dry_run:
(out_dir / "DONE").write_text("dry_run")
print("[train] Dry run complete. DONE written.")
return
# Training imports (supports Unsloth fast path and HF fallback)
libs: Dict[str, Any] = _import_training_libs()
load_dataset = libs["load_dataset"]
SFTTrainer = libs["SFTTrainer"]
SFTConfig = libs["SFTConfig"]
# Environment for stability on T4 etc per Unsloth guidance
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
print(f"[train] Loading base model: {args.model_id}")
if libs["mode"] == "unsloth":
FastLanguageModel = libs["FastLanguageModel"]
AutoTokenizer = libs["AutoTokenizer"]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_id,
max_seq_length=args.cutoff_len,
# Avoid bitsandbytes/xformers
load_in_4bit=False,
dtype=None,
use_gradient_checkpointing="unsloth",
)
# Prepare LoRA via Unsloth helper
print("[train] Attaching LoRA adapter (Unsloth)")
model = FastLanguageModel.get_peft_model(
model,
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=0,
bias="none",
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
use_rslora=True,
loftq_config=None,
)
else:
# HF + PEFT fallback (CPU / MPS)
AutoTokenizer = libs["AutoTokenizer"]
AutoModelForCausalLM = libs["AutoModelForCausalLM"]
get_peft_model = libs["get_peft_model"]
LoraConfig = libs["LoraConfig"]
torch = libs["torch"]
tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True, trust_remote_code=True)
# Prefer MPS on Apple Silicon if available
use_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
torch_dtype = torch.float16 if (args.use_fp16 or args.use_bf16) and not use_mps else torch.float32
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
torch_dtype=torch_dtype,
trust_remote_code=True,
)
if use_mps:
model.to("mps")
print("[train] Attaching LoRA adapter (HF/PEFT)")
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_dropout=0.0,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
# Load dataset
print(f"[train] Loading dataset: {args.dataset}")
ds = _load_dataset(load_dataset, args.dataset)
# Build formatting
text_field = args.text_field
prompt_field = args.prompt_field
response_field = args.response_field
if text_field:
# Simple SFT: single text field
def format_row(ex):
return ex[text_field]
elif prompt_field and response_field:
# Chat data: prompt + response
def format_row(ex):
return f"<start_of_turn>user\n{ex[prompt_field]}<end_of_turn>\n<start_of_turn>model\n{ex[response_field]}<end_of_turn>\n"
else:
raise ValueError("Provide either --text-field or both --prompt-field and --response-field")
def map_fn(ex):
return {"text": format_row(ex)}
ds = ds.map(map_fn, remove_columns=[c for c in ds.column_names if c != "text"])
# Trainer
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=ds,
max_seq_length=args.cutoff_len,
dataset_text_field="text",
packing=True,
args=SFTConfig(
output_dir=str(out_dir / "hf"),
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation,
learning_rate=args.lr,
num_train_epochs=args.epochs,
max_steps=args.max_steps if args.max_steps else -1,
logging_steps=10,
save_steps=200,
save_total_limit=2,
bf16=args.use_bf16,
fp16=args.use_fp16,
seed=args.seed,
report_to=[],
),
)
print("[train] Starting training...")
trainer.train()
print("[train] Saving adapter...")
adapter_path = out_dir / "adapter"
adapter_path.mkdir(parents=True, exist_ok=True)
# Save adapter-only weights if PEFT; Unsloth path is also PEFT-compatible
try:
model.save_pretrained(str(adapter_path))
except Exception:
# Fallback: save full model (large); unlikely on LoRA
try:
model.base_model.save_pretrained(str(adapter_path)) # type: ignore[attr-defined]
except Exception:
pass
tokenizer.save_pretrained(str(adapter_path))
# Write done file
(out_dir / "DONE").write_text("ok")
elapsed = time.time() - start
print(f"[train] Finished in {elapsed:.1f}s. Artifacts at: {out_dir}")
if __name__ == "__main__":
main()