firstAI / training /train_gemma_unsloth.py
ndc8
update
91181f3
#!/usr/bin/env python3
"""
Unsloth fine-tuning runner for Gemma-3n-E4B-it.
- Trains a LoRA adapter on top of HF Transformers-format base model (not GGUF).
- Output: PEFT adapter that can later be merged/exported to GGUF separately if desired.
This is a minimal, production-friendly CLI so the API server can spawn it as a subprocess.
"""
import argparse
import os
import json
import time
from pathlib import Path
from typing import Any, Dict
import logging
logger = logging.getLogger(__name__)
# Lazy imports to keep API light
def _import_training_libs() -> Dict[str, Any]:
"""Try to import Unsloth fast path; if unavailable, fall back to Transformers+PEFT.
Returns a dict with keys:
mode: "unsloth" | "hf"
load_dataset, SFTTrainer, SFTConfig
If mode=="unsloth": FastLanguageModel, AutoTokenizer
If mode=="hf": AutoTokenizer, AutoModelForCausalLM, get_peft_model, LoraConfig, torch
"""
# Avoid heavy optional deps on macOS (no xformers/bitsandbytes)
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig
import torch
return {
"load_dataset": load_dataset,
"AutoTokenizer": AutoTokenizer,
"AutoModelForCausalLM": AutoModelForCausalLM,
"get_peft_model": get_peft_model,
"LoraConfig": LoraConfig,
"Trainer": Trainer,
"TrainingArguments": TrainingArguments,
"torch": torch,
}
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--job-id", required=True)
p.add_argument("--output-dir", required=True)
p.add_argument("--dataset", required=True, help="HF dataset path or local JSON/JSONL file")
p.add_argument("--text-field", dest="text_field", default=None)
p.add_argument("--prompt-field", dest="prompt_field", default=None)
p.add_argument("--response-field", dest="response_field", default=None)
p.add_argument("--model-id", dest="model_id", default="unsloth/gemma-3n-E4B-it")
p.add_argument("--epochs", type=int, default=1)
p.add_argument("--max-steps", dest="max_steps", type=int, default=None)
p.add_argument("--lr", type=float, default=2e-4)
p.add_argument("--batch-size", dest="batch_size", type=int, default=1)
p.add_argument("--gradient-accumulation", dest="gradient_accumulation", type=int, default=8)
p.add_argument("--lora-r", dest="lora_r", type=int, default=16)
p.add_argument("--lora-alpha", dest="lora_alpha", type=int, default=32)
p.add_argument("--cutoff-len", dest="cutoff_len", type=int, default=4096)
p.add_argument("--use-bf16", dest="use_bf16", action="store_true")
p.add_argument("--use-fp16", dest="use_fp16", action="store_true")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--dry-run", dest="dry_run", action="store_true", help="Write DONE and exit without training (for CI)")
p.add_argument("--grpo", dest="use_grpo", action="store_true", help="Enable GRPO (if supported by Unsloth)")
p.add_argument("--cpt", dest="use_cpt", action="store_true", help="Enable CPT (if supported by Unsloth)")
p.add_argument("--export-gguf", dest="export_gguf", action="store_true", help="Export model to GGUF Q4_K_XL after training")
p.add_argument("--gguf-out", dest="gguf_out", default=None, help="Path to save GGUF file (if exporting)")
return p.parse_args()
def _is_local_path(s: str) -> bool:
return os.path.exists(s)
def _load_dataset(load_dataset: Any, path: str) -> Any:
if _is_local_path(path):
# Infer extension
if path.endswith(".jsonl") or path.endswith(".jsonl.gz"):
return load_dataset("json", data_files=path, split="train")
elif path.endswith(".json"):
return load_dataset("json", data_files=path, split="train")
else:
raise ValueError("Unsupported local dataset format. Use JSON or JSONL.")
else:
return load_dataset(path, split="train")
def main():
args = parse_args()
start = time.time()
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
(out_dir / "meta.json").write_text(json.dumps({
"job_id": args.job_id,
"model_id": args.model_id,
"dataset": args.dataset,
"created_at": int(start),
}, indent=2))
if args.dry_run:
(out_dir / "DONE").write_text("dry_run")
print("[train] Dry run complete. DONE written.")
return
# Training imports (supports Unsloth fast path and HF fallback)
libs: Dict[str, Any] = _import_training_libs()
load_dataset = libs["load_dataset"]
AutoTokenizer = libs["AutoTokenizer"]
AutoModelForCausalLM = libs["AutoModelForCausalLM"]
get_peft_model = libs["get_peft_model"]
LoraConfig = libs["LoraConfig"]
Trainer = libs["Trainer"]
TrainingArguments = libs["TrainingArguments"]
torch = libs["torch"]
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
print(f"[train] Loading base model: {args.model_id}")
tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True, trust_remote_code=True)
use_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
if not use_mps:
if args.use_fp16:
dtype = torch.float16
elif args.use_bf16:
dtype = torch.bfloat16
else:
dtype = torch.float32
else:
dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
torch_dtype=dtype,
trust_remote_code=True,
)
if use_mps:
model.to("mps")
print("[train] Attaching LoRA adapter (PEFT)")
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
lora_dropout=0.0,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
# Load dataset
print(f"[train] Loading dataset: {args.dataset}")
ds = _load_dataset(load_dataset, args.dataset)
# Build formatting
text_field = args.text_field
prompt_field = args.prompt_field
response_field = args.response_field
if text_field:
# Simple SFT: single text field with validation
def format_row(ex: Dict[str, Any]) -> str:
if text_field not in ex:
raise KeyError(f"Missing required text field '{text_field}' in example: {ex}")
return ex[text_field]
elif prompt_field and response_field:
# Chat data: prompt + response with validation
def format_row(ex: Dict[str, Any]) -> str:
missing = [f for f in (prompt_field, response_field) if f not in ex]
if missing:
raise KeyError(f"Missing required field(s) {missing} in example: {ex}")
return (
f"<start_of_turn>user\n{ex[prompt_field]}<end_of_turn>\n"
f"<start_of_turn>model\n{ex[response_field]}<end_of_turn>\n"
)
else:
raise ValueError("Provide either --text-field or both --prompt-field and --response-field")
def map_fn(ex: Dict[str, Any]) -> Dict[str, str]:
return {"text": format_row(ex)}
ds = ds.map(map_fn, remove_columns=[c for c in ds.column_names if c != "text"])
# Tokenize dataset
def tokenize_fn(ex):
return tokenizer(
ex["text"],
truncation=True,
max_length=args.cutoff_len,
padding="max_length",
)
tokenized_ds = ds.map(tokenize_fn, batched=True)
# Trainer
training_args = TrainingArguments(
output_dir=str(out_dir / "hf"),
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.gradient_accumulation,
learning_rate=args.lr,
num_train_epochs=args.epochs,
max_steps=args.max_steps if args.max_steps else -1,
logging_steps=10,
save_steps=200,
save_total_limit=2,
bf16=args.use_bf16,
fp16=args.use_fp16,
seed=args.seed,
report_to=[],
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds,
tokenizer=tokenizer,
)
print("[train] Starting training...")
trainer.train()
print("[train] Saving adapter...")
adapter_path = out_dir / "adapter"
adapter_path.mkdir(parents=True, exist_ok=True)
try:
model.save_pretrained(str(adapter_path))
except Exception as e:
logger.error("Error during model saving: %s", e, exc_info=True)
tokenizer.save_pretrained(str(adapter_path))
# Optionally export to GGUF Q4_K_XL
if args.export_gguf:
print("[train] Export to GGUF is not supported in Hugging Face-only mode. Use llama.cpp's convert-hf-to-gguf.py after training.")
gguf_path = args.gguf_out or str(out_dir / "adapter-gguf-q4_k_xl")
print(f"python convert-hf-to-gguf.py --outtype q4_k_xl --outfile {gguf_path} {adapter_path}")
# Write done file
(out_dir / "DONE").write_text("ok")
elapsed = time.time() - start
print(f"[train] Finished in {elapsed:.1f}s. Artifacts at: {out_dir}")
if __name__ == "__main__":
main()