phi4training / run_transformers_training.py
George-API's picture
Upload folder using huggingface_hub
a57357b verified
raw
history blame
28.7 kB
#!/usr/bin/env python
# coding=utf-8
import os
import sys
import json
import argparse
import logging
from datetime import datetime
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
set_seed,
BitsAndBytesConfig
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Check for BitsAndBytes
try:
from transformers import BitsAndBytesConfig
bitsandbytes_available = True
except ImportError:
bitsandbytes_available = False
logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.")
# Check for PEFT
try:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
peft_available = True
except ImportError:
peft_available = False
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
# Import Unsloth
try:
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
unsloth_available = True
except ImportError:
unsloth_available = False
logger.warning("Unsloth not available. Please install with: pip install unsloth")
def load_env_variables():
"""Load environment variables from system, .env file, or Hugging Face Space variables."""
# Check if we're running in a Hugging Face Space
if os.environ.get("SPACE_ID"):
logging.info("Running in Hugging Face Space")
# Log the presence of variables (without revealing values)
logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}")
# If username is not set, try to extract from SPACE_ID
if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""):
username = os.environ.get("SPACE_ID").split("/")[0]
os.environ["HF_USERNAME"] = username
logging.info(f"Set HF_USERNAME from SPACE_ID: {username}")
else:
# Try to load from .env file if not in a Space
try:
from dotenv import load_dotenv
# Updated path to .env file in the new directory structure
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env")
if os.path.exists(env_path):
load_dotenv(env_path)
logging.info(f"Loaded environment variables from {env_path}")
logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}")
logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}")
else:
logging.warning(f"No .env file found at {env_path}")
except ImportError:
logging.warning("python-dotenv not installed, not loading from .env file")
if not os.environ.get("HF_USERNAME"):
logger.warning("HF_USERNAME is not set. Using default username.")
if not os.environ.get("HF_SPACE_NAME"):
logger.warning("HF_SPACE_NAME is not set. Using default space name.")
# Set HF_TOKEN for huggingface_hub
if os.environ.get("HF_TOKEN"):
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
def load_configs(base_path):
"""Load all configuration files."""
configs = {}
# List of config files to load
config_files = [
"transformers_config.json",
"hardware_config.json",
"dataset_config.json"
]
for config_file in config_files:
file_path = os.path.join(base_path, config_file)
try:
with open(file_path, "r") as f:
config_name = config_file.replace("_config.json", "")
configs[config_name] = json.load(f)
logger.info(f"Loaded {config_name} configuration from {file_path}")
except Exception as e:
logger.error(f"Error loading {config_file}: {e}")
raise
return configs
def parse_args():
parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
parser.add_argument("--config_dir", type=str, default=".", help="Directory containing configuration files")
return parser.parse_args()
def load_model_and_tokenizer(config):
"""Load model and tokenizer with proper error handling and optimizations."""
try:
if config.get("use_unsloth", False) and unsloth_available:
logger.info("Using Unsloth optimizations")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=config.get("model_name"),
max_seq_length=config.get("max_seq_length", 2048),
dtype=None, # Let Unsloth choose optimal dtype
load_in_4bit=config.get("load_in_4bit", True),
device_map="auto",
)
# Apply Unsloth's training optimizations with config parameters
model = FastLanguageModel.get_peft_model(
model,
r=config.get("unsloth_r", 32),
target_modules=config.get("unsloth_target_modules",
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]),
lora_alpha=config.get("unsloth_alpha", 16),
lora_dropout=config.get("unsloth_dropout", 0.05),
bias="none",
use_gradient_checkpointing=config.get("gradient_checkpointing", True),
random_state=config.get("seed", 42),
)
logger.info("Unsloth optimizations applied successfully")
else:
if config.get("use_unsloth", False):
logger.warning("Unsloth requested but not available. Falling back to standard training.")
# Standard quantization setup
quantization_config = None
if config.get("load_in_4bit", False) and bitsandbytes_available:
logger.info("Using 4-bit quantization")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True
)
# Load model with standard settings
model = AutoModelForCausalLM.from_pretrained(
config.get("model_name"),
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=config.get("trust_remote_code", True),
use_cache=not config.get("gradient_checkpointing", True)
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
config.get("model_name"),
use_fast=config.get("use_fast_tokenizer", True),
trust_remote_code=config.get("trust_remote_code", True)
)
# Enable gradient checkpointing if requested
if config.get("gradient_checkpointing", True) and hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable(use_reentrant=False)
logger.info("Gradient checkpointing enabled")
# Set up tokenizer settings
if config.get("chat_template"):
if unsloth_available and config.get("use_unsloth", False):
chat_template = get_chat_template("phi")
tokenizer.chat_template = chat_template
else:
tokenizer.chat_template = config.get("chat_template")
logger.info(f"Set chat template to {config.get('chat_template')}")
# Ensure proper token settings
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}")
return model, tokenizer
except Exception as e:
logger.error(f"Error in model/tokenizer loading: {str(e)}")
raise
def load_dataset_with_mapping(dataset_config):
"""Load and prepare dataset with proper column mapping."""
try:
# Load dataset
dataset = load_dataset(
dataset_config["dataset"]["name"],
split=dataset_config["dataset"]["split"]
)
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
# Apply column mapping if specified
if "column_mapping" in dataset_config["dataset"]:
mapping = dataset_config["dataset"]["column_mapping"]
dataset = dataset.rename_columns({v: k for k, v in mapping.items()})
logger.info(f"Applied column mapping: {mapping}")
# Sort dataset if required
if dataset_config["dataset"]["processing"]["sort_by_id"]:
logger.info("Sorting dataset by ID to maintain paper chunk order")
dataset = dataset.sort("id")
# Log first few IDs to verify sorting
sample_ids = [example["id"] for example in dataset.select(range(min(5, len(dataset))))]
logger.info(f"First few IDs after sorting: {sample_ids}")
return dataset
except Exception as e:
logger.error(f"Error loading dataset: {str(e)}")
raise
def main():
# Set up logging
logger.info("Starting training process")
# Parse arguments
args = parse_args()
# Load environment variables
load_env_variables()
# Load all configurations
try:
configs = load_configs(args.config_dir)
logger.info("All configurations loaded successfully")
# Extract specific configs
model_config = configs["transformers"]
hardware_config = configs["hardware"]
dataset_config = configs["dataset"]
# Apply hardware-specific settings
per_device_batch_size = hardware_config["training_optimizations"]["per_device_batch_size"]
gradient_accumulation = hardware_config["training_optimizations"]["gradient_accumulation_steps"]
# Update model config with hardware settings
model_config["training"].update({
"per_device_train_batch_size": per_device_batch_size,
"gradient_accumulation_steps": gradient_accumulation,
"gradient_checkpointing": hardware_config["training_optimizations"]["memory_optimizations"]["use_gradient_checkpointing"]
})
except Exception as e:
logger.error(f"Error loading configurations: {e}")
return 1
# Set random seed for reproducibility
seed = model_config.get("seed", 42)
set_seed(seed)
logger.info(f"Set random seed to {seed}")
# Check if we're running in a Hugging Face Space
if os.environ.get("SPACE_ID") and not os.environ.get("HF_USERNAME"):
# Extract username from SPACE_ID
username = os.environ.get("SPACE_ID").split("/")[0]
logger.info(f"Extracted username from SPACE_ID: {username}")
# Set hub_model_id if not already set and push_to_hub is enabled
if model_config.get("push_to_hub", False) and not model_config.get("hub_model_id"):
model_name = model_config.get("model_name", "").split("/")[-1]
model_config["hub_model_id"] = f"{username}/finetuned-{model_name}"
logger.info(f"Set hub_model_id to {model_config['hub_model_id']}")
# Load model and tokenizer
logger.info(f"Loading model: {model_config.get('model_name')}")
try:
model, tokenizer = load_model_and_tokenizer(model_config)
logger.info("Model and tokenizer loaded successfully")
# Prepare model for k-bit training if using PEFT
if model_config.get("use_peft", False) and peft_available:
logger.info("Preparing model for parameter-efficient fine-tuning")
try:
model = prepare_model_for_kbit_training(model)
# Get target modules
target_modules = model_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
# Create LoRA config
lora_config = LoraConfig(
r=model_config.get("lora_r", 16),
lora_alpha=model_config.get("lora_alpha", 32),
lora_dropout=model_config.get("lora_dropout", 0.05),
bias="none",
task_type="CAUSAL_LM",
target_modules=target_modules
)
# Apply LoRA to model
model = get_peft_model(model, lora_config)
logger.info(f"Applied LoRA with r={model_config.get('lora_r', 16)}, alpha={model_config.get('lora_alpha', 32)}")
except Exception as e:
logger.error(f"Error setting up PEFT: {e}")
return 1
# Load dataset with proper mapping
try:
dataset = load_dataset_with_mapping(dataset_config)
logger.info("Dataset loaded and prepared successfully")
except Exception as e:
logger.error(f"Error loading dataset: {e}")
return 1
# Simple data collator that processes each entry independently
class SimpleDataCollator:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.prompt_counter = 0
self.paper_counters = {}
logger.info("SimpleDataCollator initialized - using phi-4 chat format")
def format_phi_chat(self, messages):
"""Format messages according to phi-4's chat template."""
formatted_chat = ""
for message in messages:
# Extract role and content
if isinstance(message, dict):
role = message.get("role", "").lower()
content = message.get("content", "")
else:
role = getattr(message, "role", "").lower()
content = getattr(message, "content", "")
# Format based on role
if role == "human" or role == "user":
formatted_chat += f"Human: {content}\n\n"
elif role == "assistant":
formatted_chat += f"Assistant: {content}\n\n"
elif role == "system":
# For system messages, we prepend them with a special format
formatted_chat = f"System: {content}\n\n" + formatted_chat
else:
logger.warning(f"Unknown role '{role}' - treating as system message")
formatted_chat += f"System: {content}\n\n"
return formatted_chat.strip()
def __call__(self, features):
batch = {"input_ids": [], "attention_mask": [], "labels": []}
for example in features:
try:
# Get ID and conversation fields
paper_id = example.get("id", "") if isinstance(example, dict) else getattr(example, "id", "")
conversation = example.get("conversations", []) if isinstance(example, dict) else getattr(example, "conversations", [])
if not conversation:
self.stats["skipped"] += 1
continue
# Increment counters
self.prompt_counter += 1
if paper_id not in self.paper_counters:
self.paper_counters[paper_id] = 0
self.paper_counters[paper_id] += 1
# Add metadata as system message
metadata = {
"role": "system",
"content": f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}"
}
# Format the conversation using phi-4's chat template
formatted_content = self.format_phi_chat([metadata] + conversation)
# Tokenize with the model's chat template
inputs = self.tokenizer(
formatted_content,
add_special_tokens=True,
truncation=True,
max_length=model_config.get("max_seq_length", 2048),
return_tensors=None, # Return list instead of tensors
)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
if len(input_ids) > 0:
# For causal language modeling, labels are the same as inputs
labels = input_ids.copy()
batch["input_ids"].append(input_ids)
batch["attention_mask"].append(attention_mask)
batch["labels"].append(labels)
self.stats["processed"] += 1
self.stats["total_tokens"] += len(input_ids)
# Debug logging for first few examples
if self.stats["processed"] <= 3:
logger.info(f"Example {self.stats['processed']} format:")
logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}")
logger.info(f"Token count: {len(input_ids)}")
logger.info(f"Content preview:\n{formatted_content[:500]}...")
else:
self.stats["skipped"] += 1
except Exception as e:
logger.warning(f"Error processing example: {str(e)[:100]}...")
self.stats["skipped"] += 1
continue
# Handle empty batches
if not batch["input_ids"]:
logger.warning("Empty batch, returning dummy tensors")
return {
"input_ids": torch.zeros((1, 1), dtype=torch.long),
"attention_mask": torch.zeros((1, 1), dtype=torch.long),
"labels": torch.zeros((1, 1), dtype=torch.long)
}
# Pad the batch
max_length = max(len(ids) for ids in batch["input_ids"])
for i in range(len(batch["input_ids"])):
padding_length = max_length - len(batch["input_ids"][i])
if padding_length > 0:
batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
batch["attention_mask"][i].extend([0] * padding_length)
batch["labels"][i].extend([-100] * padding_length) # Don't compute loss on padding
# Convert to tensors
batch = {k: torch.tensor(v) for k, v in batch.items()}
# Log stats periodically
if self.stats["processed"] % 100 == 0 and self.stats["processed"] > 0:
logger.info(f"Data collator stats: processed={self.stats['processed']}, "
f"skipped={self.stats['skipped']}, "
f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}, "
f"unique_papers={len(self.paper_counters)}")
return batch
# Create data collator
data_collator = SimpleDataCollator(tokenizer)
# Simple logging callback
class LoggingCallback(TrainerCallback):
def __init__(self):
self.last_log_time = datetime.now()
self.training_start_time = datetime.now()
def on_step_end(self, args, state, control, **kwargs):
# Log every 50 steps or every 5 minutes, whichever comes first
current_time = datetime.now()
time_diff = (current_time - self.last_log_time).total_seconds()
elapsed_time = (current_time - self.training_start_time).total_seconds() / 60 # in minutes
if state.global_step % 50 == 0 or time_diff > 300: # 300 seconds = 5 minutes
loss = state.log_history[-1]['loss'] if state.log_history else 'N/A'
lr = state.log_history[-1]['learning_rate'] if state.log_history else 'N/A'
if isinstance(loss, float):
loss_str = f"{loss:.4f}"
else:
loss_str = str(loss)
if isinstance(lr, float):
lr_str = f"{lr:.8f}"
else:
lr_str = str(lr)
logger.info(f"Step: {state.global_step} | Loss: {loss_str} | LR: {lr_str} | Elapsed: {elapsed_time:.2f} min")
self.last_log_time = current_time
# Set up training arguments
logger.info("Setting up training arguments")
training_args = TrainingArguments(
output_dir=model_config.get("output_dir", "./results"),
num_train_epochs=model_config.get("num_train_epochs", 3),
per_device_train_batch_size=model_config.get("per_device_train_batch_size", 4), # Use config value, can be > 1
gradient_accumulation_steps=model_config.get("gradient_accumulation_steps", 8),
learning_rate=model_config.get("learning_rate", 5e-5),
weight_decay=model_config.get("weight_decay", 0.01),
warmup_ratio=model_config.get("warmup_ratio", 0.1),
lr_scheduler_type=model_config.get("lr_scheduler_type", "cosine"),
logging_steps=model_config.get("logging_steps", 10),
save_strategy=model_config.get("save_strategy", "steps"), # Updated to use steps by default
save_steps=model_config.get("save_steps", 100), # Save every 100 steps by default
save_total_limit=model_config.get("save_total_limit", 3), # Keep last 3 checkpoints
fp16=model_config.get("fp16", True),
bf16=model_config.get("bf16", False),
max_grad_norm=model_config.get("max_grad_norm", 1.0),
push_to_hub=model_config.get("push_to_hub", False),
hub_model_id=model_config.get("hub_model_id", None),
hub_token=os.environ.get("HF_TOKEN", None),
report_to="tensorboard",
remove_unused_columns=False, # Keep the conversations column
gradient_checkpointing=model_config.get("gradient_checkpointing", True), # Enable gradient checkpointing
dataloader_pin_memory=False, # Reduce memory usage
optim=model_config.get("optim", "adamw_torch"),
ddp_find_unused_parameters=False, # Improve distributed training efficiency
dataloader_drop_last=False, # Process all examples
dataloader_num_workers=0, # Sequential data loading
)
# Create a sequential sampler to ensure dataset is processed in order
logger.info("Creating sequential sampler to maintain dataset order")
# Create trainer with callback
logger.info("Creating trainer")
# Check if we should resume from checkpoint
resume_from_checkpoint = False
output_dir = model_config.get("output_dir", "./results")
if os.path.exists(output_dir):
checkpoints = [folder for folder in os.listdir(output_dir) if folder.startswith("checkpoint-")]
if checkpoints:
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1]))
resume_from_checkpoint = os.path.join(output_dir, latest_checkpoint)
logger.info(f"Found checkpoint: {resume_from_checkpoint}. Training will resume from this point.")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
callbacks=[LoggingCallback()]
)
# Override the default data loader to disable shuffling
# This is necessary because TrainingArguments doesn't have a direct shuffle parameter
def get_train_dataloader_no_shuffle():
"""Create a train DataLoader with shuffling disabled."""
logger.info("Creating train dataloader with sequential sampler (no shuffling)")
# Create a sequential sampler to ensure dataset is processed in order
train_sampler = torch.utils.data.SequentialSampler(dataset)
return torch.utils.data.DataLoader(
dataset,
batch_size=training_args.per_device_train_batch_size,
sampler=train_sampler, # Use sequential sampler instead of shuffle parameter
collate_fn=data_collator,
drop_last=False,
num_workers=0,
pin_memory=False
)
# Replace the default data loader with our non-shuffling version
trainer.get_train_dataloader = get_train_dataloader_no_shuffle
# Start training
logger.info("Starting training")
logger.info(f"Processing with batch size = {training_args.per_device_train_batch_size}, each entry processed independently")
# Create a lock file to indicate training is in progress
lock_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TRAINING_IN_PROGRESS.lock")
with open(lock_file, "w") as f:
f.write(f"Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Expected completion: After {training_args.num_train_epochs} epochs\n")
f.write("DO NOT UPDATE OR RESTART THIS SPACE UNTIL TRAINING COMPLETES\n")
logger.info(f"Created lock file: {lock_file}")
try:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logger.info("Training completed successfully")
# Save model
if model_config.get("push_to_hub", False):
logger.info(f"Pushing model to hub: {model_config.get('hub_model_id')}")
trainer.push_to_hub()
logger.info("Model pushed to hub successfully")
else:
logger.info(f"Saving model to {model_config.get('output_dir', './results')}")
trainer.save_model()
logger.info("Model saved successfully")
except Exception as e:
logger.error(f"Training failed with error: {str(e)}")
raise
finally:
# Remove the lock file when training completes or fails
if os.path.exists(lock_file):
os.remove(lock_file)
logger.info(f"Removed lock file: {lock_file}")
return 0
except Exception as e:
logger.error(f"Error in main training loop: {str(e)}")
return 1
if __name__ == "__main__":
sys.exit(main())