Cylanoid's picture
Update train_llama4.py
406313e verified
raw
history blame
2.79 kB
# train_llama4.py
# Script to fine-tune Llama 4 Maverick for healthcare fraud detection
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
import datasets
import torch
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate import Accelerator
import huggingface_hub
import os
# Debug: Confirm file version
print("Running train_llama4.py with CPU offloading (version: 2025-04-21 v2)")
# Authenticate with Hugging Face
LLama = os.getenv("LLama")
if not LLama:
raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")
huggingface_hub.login(token=LLama)
# Model setup
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
# Custom device map for CPU offloading
device_map = {
"model.embed_tokens": 0,
"model.layers.0-15": 0,
"model.layers.16-31": "cpu",
"model.norm": 0,
"lm_head": 0
}
# Debug: Confirm offloading settings
print("Loading model with CPU offloading: llm_int8_enable_fp32_cpu_offload=True, device_map=", device_map)
# Load model with 8-bit quantization and CPU offloading
model = Llama4ForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map=device_map,
quantization_config={"load_in_8bit": True},
llm_int8_enable_fp32_cpu_offload=True,
attn_implementation="flex_attention"
)
# Resize token embeddings
model.resize_token_embeddings(len(tokenizer))
# Initialize Accelerator
accelerator = Accelerator()
model = accelerator.prepare(model)
# Load dataset
dataset = datasets.load_dataset('json', data_files="Bingaman_training_data.json")['train']
# LoRA configuration
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Prepare model for fine-tuning
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
# Training arguments
training_args = {
"output_dir": "./results",
"num_train_epochs": 1,
"per_device_train_batch_size": 2,
"gradient_accumulation_steps": 8,
"optim": "adamw_torch",
"save_steps": 500,
"logging_steps": 100,
"learning_rate": 2e-4,
"fp16": True,
"max_grad_norm": 0.3,
"warmup_ratio": 0.03,
"lr_scheduler_type": "cosine"
}
# Initialize trainer
trainer = accelerator.prepare(
datasets.Trainer(
model=model,
args=datasets.TrainingArguments(**training_args),
train_dataset=dataset,
)
)
# Train
trainer.train()
model.save_pretrained("./fine_tuned_model")
print("Training completed!")