|
|
|
""" |
|
Unsloth fine-tuning runner for Gemma-3n-E4B-it. |
|
- Trains a LoRA adapter on top of HF Transformers-format base model (not GGUF). |
|
- Output: PEFT adapter that can later be merged/exported to GGUF separately if desired. |
|
|
|
This is a minimal, production-friendly CLI so the API server can spawn it as a subprocess. |
|
""" |
|
import argparse |
|
import os |
|
import json |
|
import time |
|
from pathlib import Path |
|
from typing import Any, Dict |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
def _import_training_libs() -> Dict[str, Any]: |
|
"""Try to import Unsloth fast path; if unavailable, fall back to Transformers+PEFT. |
|
|
|
Returns a dict with keys: |
|
mode: "unsloth" | "hf" |
|
load_dataset, SFTTrainer, SFTConfig |
|
If mode=="unsloth": FastLanguageModel, AutoTokenizer |
|
If mode=="hf": AutoTokenizer, AutoModelForCausalLM, get_peft_model, LoraConfig, torch |
|
""" |
|
|
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments |
|
from peft import get_peft_model, LoraConfig |
|
import torch |
|
return { |
|
"load_dataset": load_dataset, |
|
"AutoTokenizer": AutoTokenizer, |
|
"AutoModelForCausalLM": AutoModelForCausalLM, |
|
"get_peft_model": get_peft_model, |
|
"LoraConfig": LoraConfig, |
|
"Trainer": Trainer, |
|
"TrainingArguments": TrainingArguments, |
|
"torch": torch, |
|
} |
|
|
|
|
|
def parse_args(): |
|
p = argparse.ArgumentParser() |
|
p.add_argument("--job-id", required=True) |
|
p.add_argument("--output-dir", required=True) |
|
p.add_argument("--dataset", required=True, help="HF dataset path or local JSON/JSONL file") |
|
p.add_argument("--text-field", dest="text_field", default=None) |
|
p.add_argument("--prompt-field", dest="prompt_field", default=None) |
|
p.add_argument("--response-field", dest="response_field", default=None) |
|
p.add_argument("--model-id", dest="model_id", default="unsloth/gemma-3n-E4B-it") |
|
p.add_argument("--epochs", type=int, default=1) |
|
p.add_argument("--max-steps", dest="max_steps", type=int, default=None) |
|
p.add_argument("--lr", type=float, default=2e-4) |
|
p.add_argument("--batch-size", dest="batch_size", type=int, default=1) |
|
p.add_argument("--gradient-accumulation", dest="gradient_accumulation", type=int, default=8) |
|
p.add_argument("--lora-r", dest="lora_r", type=int, default=16) |
|
p.add_argument("--lora-alpha", dest="lora_alpha", type=int, default=32) |
|
p.add_argument("--cutoff-len", dest="cutoff_len", type=int, default=4096) |
|
p.add_argument("--use-bf16", dest="use_bf16", action="store_true") |
|
p.add_argument("--use-fp16", dest="use_fp16", action="store_true") |
|
p.add_argument("--seed", type=int, default=42) |
|
p.add_argument("--dry-run", dest="dry_run", action="store_true", help="Write DONE and exit without training (for CI)") |
|
p.add_argument("--grpo", dest="use_grpo", action="store_true", help="Enable GRPO (if supported by Unsloth)") |
|
p.add_argument("--cpt", dest="use_cpt", action="store_true", help="Enable CPT (if supported by Unsloth)") |
|
p.add_argument("--export-gguf", dest="export_gguf", action="store_true", help="Export model to GGUF Q4_K_XL after training") |
|
p.add_argument("--gguf-out", dest="gguf_out", default=None, help="Path to save GGUF file (if exporting)") |
|
return p.parse_args() |
|
|
|
|
|
def _is_local_path(s: str) -> bool: |
|
return os.path.exists(s) |
|
|
|
|
|
def _load_dataset(load_dataset: Any, path: str) -> Any: |
|
if _is_local_path(path): |
|
|
|
if path.endswith(".jsonl") or path.endswith(".jsonl.gz"): |
|
return load_dataset("json", data_files=path, split="train") |
|
elif path.endswith(".json"): |
|
return load_dataset("json", data_files=path, split="train") |
|
else: |
|
raise ValueError("Unsupported local dataset format. Use JSON or JSONL.") |
|
else: |
|
return load_dataset(path, split="train") |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
start = time.time() |
|
out_dir = Path(args.output_dir) |
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
(out_dir / "meta.json").write_text(json.dumps({ |
|
"job_id": args.job_id, |
|
"model_id": args.model_id, |
|
"dataset": args.dataset, |
|
"created_at": int(start), |
|
}, indent=2)) |
|
|
|
if args.dry_run: |
|
(out_dir / "DONE").write_text("dry_run") |
|
print("[train] Dry run complete. DONE written.") |
|
return |
|
|
|
|
|
libs: Dict[str, Any] = _import_training_libs() |
|
load_dataset = libs["load_dataset"] |
|
AutoTokenizer = libs["AutoTokenizer"] |
|
AutoModelForCausalLM = libs["AutoModelForCausalLM"] |
|
get_peft_model = libs["get_peft_model"] |
|
LoraConfig = libs["LoraConfig"] |
|
Trainer = libs["Trainer"] |
|
TrainingArguments = libs["TrainingArguments"] |
|
torch = libs["torch"] |
|
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
print(f"[train] Loading base model: {args.model_id}") |
|
tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True, trust_remote_code=True) |
|
use_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() |
|
if not use_mps: |
|
if args.use_fp16: |
|
dtype = torch.float16 |
|
elif args.use_bf16: |
|
dtype = torch.bfloat16 |
|
else: |
|
dtype = torch.float32 |
|
else: |
|
dtype = torch.float32 |
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.model_id, |
|
torch_dtype=dtype, |
|
trust_remote_code=True, |
|
) |
|
if use_mps: |
|
model.to("mps") |
|
print("[train] Attaching LoRA adapter (PEFT)") |
|
lora_config = LoraConfig( |
|
r=args.lora_r, |
|
lora_alpha=args.lora_alpha, |
|
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], |
|
lora_dropout=0.0, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
print(f"[train] Loading dataset: {args.dataset}") |
|
ds = _load_dataset(load_dataset, args.dataset) |
|
|
|
|
|
text_field = args.text_field |
|
prompt_field = args.prompt_field |
|
response_field = args.response_field |
|
|
|
if text_field: |
|
|
|
def format_row(ex: Dict[str, Any]) -> str: |
|
if text_field not in ex: |
|
raise KeyError(f"Missing required text field '{text_field}' in example: {ex}") |
|
return ex[text_field] |
|
elif prompt_field and response_field: |
|
|
|
def format_row(ex: Dict[str, Any]) -> str: |
|
missing = [f for f in (prompt_field, response_field) if f not in ex] |
|
if missing: |
|
raise KeyError(f"Missing required field(s) {missing} in example: {ex}") |
|
return ( |
|
f"<start_of_turn>user\n{ex[prompt_field]}<end_of_turn>\n" |
|
f"<start_of_turn>model\n{ex[response_field]}<end_of_turn>\n" |
|
) |
|
else: |
|
raise ValueError("Provide either --text-field or both --prompt-field and --response-field") |
|
|
|
def map_fn(ex: Dict[str, Any]) -> Dict[str, str]: |
|
return {"text": format_row(ex)} |
|
|
|
ds = ds.map(map_fn, remove_columns=[c for c in ds.column_names if c != "text"]) |
|
|
|
|
|
def tokenize_fn(ex): |
|
return tokenizer( |
|
ex["text"], |
|
truncation=True, |
|
max_length=args.cutoff_len, |
|
padding="max_length", |
|
) |
|
tokenized_ds = ds.map(tokenize_fn, batched=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=str(out_dir / "hf"), |
|
per_device_train_batch_size=args.batch_size, |
|
gradient_accumulation_steps=args.gradient_accumulation, |
|
learning_rate=args.lr, |
|
num_train_epochs=args.epochs, |
|
max_steps=args.max_steps if args.max_steps else -1, |
|
logging_steps=10, |
|
save_steps=200, |
|
save_total_limit=2, |
|
bf16=args.use_bf16, |
|
fp16=args.use_fp16, |
|
seed=args.seed, |
|
report_to=[], |
|
) |
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_ds, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
print("[train] Starting training...") |
|
trainer.train() |
|
print("[train] Saving adapter...") |
|
adapter_path = out_dir / "adapter" |
|
adapter_path.mkdir(parents=True, exist_ok=True) |
|
try: |
|
model.save_pretrained(str(adapter_path)) |
|
except Exception as e: |
|
logger.error("Error during model saving: %s", e, exc_info=True) |
|
tokenizer.save_pretrained(str(adapter_path)) |
|
|
|
|
|
if args.export_gguf: |
|
print("[train] Export to GGUF is not supported in Hugging Face-only mode. Use llama.cpp's convert-hf-to-gguf.py after training.") |
|
gguf_path = args.gguf_out or str(out_dir / "adapter-gguf-q4_k_xl") |
|
print(f"python convert-hf-to-gguf.py --outtype q4_k_xl --outfile {gguf_path} {adapter_path}") |
|
|
|
|
|
(out_dir / "DONE").write_text("ok") |
|
elapsed = time.time() - start |
|
print(f"[train] Finished in {elapsed:.1f}s. Artifacts at: {out_dir}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|