|
|
|
""" |
|
Unsloth fine-tuning runner for Gemma-3n-E4B-it. |
|
- Trains a LoRA adapter on top of HF Transformers-format base model (not GGUF). |
|
- Output: PEFT adapter that can later be merged/exported to GGUF separately if desired. |
|
|
|
This is a minimal, production-friendly CLI so the API server can spawn it as a subprocess. |
|
""" |
|
import argparse |
|
import os |
|
import json |
|
import time |
|
from pathlib import Path |
|
from typing import Any, Dict |
|
|
|
|
|
|
|
def _import_training_libs() -> Dict[str, Any]: |
|
"""Try to import Unsloth fast path; if unavailable, fall back to Transformers+PEFT. |
|
|
|
Returns a dict with keys: |
|
mode: "unsloth" | "hf" |
|
load_dataset, SFTTrainer, SFTConfig |
|
If mode=="unsloth": FastLanguageModel, AutoTokenizer |
|
If mode=="hf": AutoTokenizer, AutoModelForCausalLM, get_peft_model, LoraConfig, torch |
|
""" |
|
|
|
os.environ.setdefault("UNSLOTH_DISABLE_XFORMERS", "1") |
|
os.environ.setdefault("UNSLOTH_DISABLE_BITSANDBYTES", "1") |
|
from datasets import load_dataset |
|
from trl import SFTTrainer, SFTConfig |
|
try: |
|
from unsloth import FastLanguageModel |
|
from transformers import AutoTokenizer |
|
return { |
|
"mode": "unsloth", |
|
"load_dataset": load_dataset, |
|
"SFTTrainer": SFTTrainer, |
|
"SFTConfig": SFTConfig, |
|
"FastLanguageModel": FastLanguageModel, |
|
"AutoTokenizer": AutoTokenizer, |
|
} |
|
except Exception: |
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from peft import get_peft_model, LoraConfig |
|
import torch |
|
return { |
|
"mode": "hf", |
|
"load_dataset": load_dataset, |
|
"SFTTrainer": SFTTrainer, |
|
"SFTConfig": SFTConfig, |
|
"AutoTokenizer": AutoTokenizer, |
|
"AutoModelForCausalLM": AutoModelForCausalLM, |
|
"get_peft_model": get_peft_model, |
|
"LoraConfig": LoraConfig, |
|
"torch": torch, |
|
} |
|
|
|
|
|
def parse_args(): |
|
p = argparse.ArgumentParser() |
|
p.add_argument("--job-id", required=True) |
|
p.add_argument("--output-dir", required=True) |
|
p.add_argument("--dataset", required=True, help="HF dataset path or local JSON/JSONL file") |
|
p.add_argument("--text-field", dest="text_field", default=None) |
|
p.add_argument("--prompt-field", dest="prompt_field", default=None) |
|
p.add_argument("--response-field", dest="response_field", default=None) |
|
p.add_argument("--model-id", dest="model_id", default="unsloth/gemma-3n-E4B-it") |
|
p.add_argument("--epochs", type=int, default=1) |
|
p.add_argument("--max-steps", dest="max_steps", type=int, default=None) |
|
p.add_argument("--lr", type=float, default=2e-4) |
|
p.add_argument("--batch-size", dest="batch_size", type=int, default=1) |
|
p.add_argument("--gradient-accumulation", dest="gradient_accumulation", type=int, default=8) |
|
p.add_argument("--lora-r", dest="lora_r", type=int, default=16) |
|
p.add_argument("--lora-alpha", dest="lora_alpha", type=int, default=32) |
|
p.add_argument("--cutoff-len", dest="cutoff_len", type=int, default=4096) |
|
p.add_argument("--use-bf16", dest="use_bf16", action="store_true") |
|
p.add_argument("--use-fp16", dest="use_fp16", action="store_true") |
|
p.add_argument("--seed", type=int, default=42) |
|
p.add_argument("--dry-run", dest="dry_run", action="store_true", help="Write DONE and exit without training (for CI)") |
|
return p.parse_args() |
|
|
|
|
|
def _is_local_path(s: str) -> bool: |
|
return os.path.exists(s) |
|
|
|
|
|
def _load_dataset(load_dataset: Any, path: str) -> Any: |
|
if _is_local_path(path): |
|
|
|
if path.endswith(".jsonl") or path.endswith(".jsonl.gz"): |
|
return load_dataset("json", data_files=path, split="train") |
|
elif path.endswith(".json"): |
|
return load_dataset("json", data_files=path, split="train") |
|
else: |
|
raise ValueError("Unsupported local dataset format. Use JSON or JSONL.") |
|
else: |
|
return load_dataset(path, split="train") |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
start = time.time() |
|
out_dir = Path(args.output_dir) |
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
(out_dir / "meta.json").write_text(json.dumps({ |
|
"job_id": args.job_id, |
|
"model_id": args.model_id, |
|
"dataset": args.dataset, |
|
"created_at": int(start), |
|
}, indent=2)) |
|
|
|
if args.dry_run: |
|
(out_dir / "DONE").write_text("dry_run") |
|
print("[train] Dry run complete. DONE written.") |
|
return |
|
|
|
|
|
libs: Dict[str, Any] = _import_training_libs() |
|
load_dataset = libs["load_dataset"] |
|
SFTTrainer = libs["SFTTrainer"] |
|
SFTConfig = libs["SFTConfig"] |
|
|
|
|
|
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
|
|
|
print(f"[train] Loading base model: {args.model_id}") |
|
if libs["mode"] == "unsloth": |
|
FastLanguageModel = libs["FastLanguageModel"] |
|
AutoTokenizer = libs["AutoTokenizer"] |
|
model, tokenizer = FastLanguageModel.from_pretrained( |
|
model_name=args.model_id, |
|
max_seq_length=args.cutoff_len, |
|
|
|
load_in_4bit=False, |
|
dtype=None, |
|
use_gradient_checkpointing="unsloth", |
|
) |
|
|
|
print("[train] Attaching LoRA adapter (Unsloth)") |
|
model = FastLanguageModel.get_peft_model( |
|
model, |
|
r=args.lora_r, |
|
lora_alpha=args.lora_alpha, |
|
lora_dropout=0, |
|
bias="none", |
|
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], |
|
use_rslora=True, |
|
loftq_config=None, |
|
) |
|
else: |
|
|
|
AutoTokenizer = libs["AutoTokenizer"] |
|
AutoModelForCausalLM = libs["AutoModelForCausalLM"] |
|
get_peft_model = libs["get_peft_model"] |
|
LoraConfig = libs["LoraConfig"] |
|
torch = libs["torch"] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=True, trust_remote_code=True) |
|
|
|
use_mps = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() |
|
torch_dtype = torch.float16 if (args.use_fp16 or args.use_bf16) and not use_mps else torch.float32 |
|
model = AutoModelForCausalLM.from_pretrained( |
|
args.model_id, |
|
torch_dtype=torch_dtype, |
|
trust_remote_code=True, |
|
) |
|
if use_mps: |
|
model.to("mps") |
|
print("[train] Attaching LoRA adapter (HF/PEFT)") |
|
lora_config = LoraConfig( |
|
r=args.lora_r, |
|
lora_alpha=args.lora_alpha, |
|
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], |
|
lora_dropout=0.0, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
print(f"[train] Loading dataset: {args.dataset}") |
|
ds = _load_dataset(load_dataset, args.dataset) |
|
|
|
|
|
text_field = args.text_field |
|
prompt_field = args.prompt_field |
|
response_field = args.response_field |
|
|
|
if text_field: |
|
|
|
def format_row(ex): |
|
return ex[text_field] |
|
elif prompt_field and response_field: |
|
|
|
def format_row(ex): |
|
return f"<start_of_turn>user\n{ex[prompt_field]}<end_of_turn>\n<start_of_turn>model\n{ex[response_field]}<end_of_turn>\n" |
|
else: |
|
raise ValueError("Provide either --text-field or both --prompt-field and --response-field") |
|
|
|
def map_fn(ex): |
|
return {"text": format_row(ex)} |
|
|
|
ds = ds.map(map_fn, remove_columns=[c for c in ds.column_names if c != "text"]) |
|
|
|
|
|
trainer = SFTTrainer( |
|
model=model, |
|
tokenizer=tokenizer, |
|
train_dataset=ds, |
|
max_seq_length=args.cutoff_len, |
|
dataset_text_field="text", |
|
packing=True, |
|
args=SFTConfig( |
|
output_dir=str(out_dir / "hf"), |
|
per_device_train_batch_size=args.batch_size, |
|
gradient_accumulation_steps=args.gradient_accumulation, |
|
learning_rate=args.lr, |
|
num_train_epochs=args.epochs, |
|
max_steps=args.max_steps if args.max_steps else -1, |
|
logging_steps=10, |
|
save_steps=200, |
|
save_total_limit=2, |
|
bf16=args.use_bf16, |
|
fp16=args.use_fp16, |
|
seed=args.seed, |
|
report_to=[], |
|
), |
|
) |
|
|
|
print("[train] Starting training...") |
|
trainer.train() |
|
print("[train] Saving adapter...") |
|
adapter_path = out_dir / "adapter" |
|
adapter_path.mkdir(parents=True, exist_ok=True) |
|
|
|
try: |
|
model.save_pretrained(str(adapter_path)) |
|
except Exception: |
|
|
|
try: |
|
model.base_model.save_pretrained(str(adapter_path)) |
|
except Exception: |
|
pass |
|
tokenizer.save_pretrained(str(adapter_path)) |
|
|
|
|
|
(out_dir / "DONE").write_text("ok") |
|
elapsed = time.time() - start |
|
print(f"[train] Finished in {elapsed:.1f}s. Artifacts at: {out_dir}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|