phi4training / run_transformers_training.py
George-API's picture
Upload folder using huggingface_hub
2b5da3a verified
raw
history blame
51.9 kB
#!/usr/bin/env python
# coding=utf-8
# Basic Python imports
import os
import sys
import json
import argparse
import logging
from datetime import datetime
import time
import warnings
from importlib.util import find_spec
import multiprocessing
# Check hardware capabilities first
CUDA_AVAILABLE = "CUDA_VISIBLE_DEVICES" in os.environ or os.environ.get("NVIDIA_VISIBLE_DEVICES") != ""
NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0
DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu"
# Set the multiprocessing start method to 'spawn' for CUDA compatibility
if CUDA_AVAILABLE:
try:
multiprocessing.set_start_method('spawn', force=True)
print("Set multiprocessing start method to 'spawn' for CUDA compatibility")
except RuntimeError:
# Method already set, which is fine
print("Multiprocessing start method already set")
# Now import the rest of the modules
import torch
# Configure logging early
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Set other loggers to WARNING to reduce noise and ensure our logs are visible
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("accelerate").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("bitsandbytes").setLevel(logging.WARNING)
# Import Unsloth first, before other ML imports
try:
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
unsloth_available = True
logger.info("Unsloth successfully imported")
except ImportError:
unsloth_available = False
logger.warning("Unsloth not available. Please install with: pip install unsloth")
# Now import other ML libraries
try:
import transformers
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
set_seed,
BitsAndBytesConfig
)
logger.info(f"Transformers version: {transformers.__version__}")
except ImportError:
logger.error("Transformers not available. This is a critical dependency.")
# Check availability of libraries
peft_available = find_spec("peft") is not None
if peft_available:
import peft
logger.info(f"PEFT version: {peft.__version__}")
else:
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
# Import datasets library after the main ML libraries
try:
from datasets import load_dataset
logger.info("Datasets library successfully imported")
except ImportError:
logger.error("Datasets library not available. This is required for loading training data.")
# Define a clean logging function for HF Space compatibility
def log_info(message):
"""Log information in a format compatible with Hugging Face Spaces"""
# Just use the logger, but ensure consistent formatting
logger.info(message)
# Also ensure output is flushed immediately for streaming
sys.stdout.flush()
# Check for BitsAndBytes
try:
from transformers import BitsAndBytesConfig
bitsandbytes_available = True
except ImportError:
bitsandbytes_available = False
logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.")
# Check for PEFT
try:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
peft_available = True
except ImportError:
peft_available = False
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
def load_env_variables():
"""Load environment variables from system, .env file, or Hugging Face Space variables."""
# Check if we're running in a Hugging Face Space
if os.environ.get("SPACE_ID"):
logging.info("Running in Hugging Face Space")
# Log the presence of variables (without revealing values)
logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}")
# If username is not set, try to extract from SPACE_ID
if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""):
username = os.environ.get("SPACE_ID").split("/")[0]
os.environ["HF_USERNAME"] = username
logging.info(f"Set HF_USERNAME from SPACE_ID: {username}")
else:
# Try to load from .env file if not in a Space
try:
from dotenv import load_dotenv
# First check the current directory
env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env")
if os.path.exists(env_path):
load_dotenv(env_path)
logging.info(f"Loaded environment variables from {env_path}")
logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}")
logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}")
else:
# Try the shared directory as fallback
shared_env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env")
if os.path.exists(shared_env_path):
load_dotenv(shared_env_path)
logging.info(f"Loaded environment variables from {shared_env_path}")
logging.info(f"HF_TOKEN loaded from shared .env file: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME loaded from shared .env file: {bool(os.environ.get('HF_USERNAME'))}")
logging.info(f"HF_SPACE_NAME loaded from shared .env file: {bool(os.environ.get('HF_SPACE_NAME'))}")
else:
logging.warning(f"No .env file found in current or shared directory")
except ImportError:
logging.warning("python-dotenv not installed, not loading from .env file")
if not os.environ.get("HF_TOKEN"):
logger.warning("HF_TOKEN is not set. Pushing to Hugging Face Hub will not work.")
if not os.environ.get("HF_USERNAME"):
logger.warning("HF_USERNAME is not set. Using default username.")
if not os.environ.get("HF_SPACE_NAME"):
logger.warning("HF_SPACE_NAME is not set. Using default space name.")
# Set HF_TOKEN for huggingface_hub
if os.environ.get("HF_TOKEN"):
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
def load_configs(base_path):
"""Load configuration from transformers_config.json file."""
# Using a single consolidated config file
config_file = base_path
try:
with open(config_file, "r") as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_file}")
return config
except Exception as e:
logger.error(f"Error loading {config_file}: {e}")
raise
def parse_args():
parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file")
return parser.parse_args()
def load_model_and_tokenizer(config):
"""Load model and tokenizer with proper error handling and optimizations."""
try:
if not unsloth_available:
logger.error("Unsloth is required for training with pre-quantized model")
logger.error("Please ensure unsloth is in requirements.txt")
raise ImportError("Unsloth is required for this training setup")
# Get model name correctly from config
model_name = config.get("model_name") or config.get("model", {}).get("name")
logger.info(f"Loading model: {model_name}")
if not model_name:
raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.")
logger.info("Using Unsloth optimizations with pre-quantized model")
# First detect if we have a GPU
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
logger.info(f"Found {gpu_count} CUDA devices")
else:
logger.warning("No CUDA devices detected. Training will be slow on CPU!")
gpu_count = 0
# Set default dtype for better numerics
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
# Use bfloat16 for Ampere or newer
dtype = torch.bfloat16
logger.info("Using bfloat16 precision (Ampere+ GPU)")
elif torch.cuda.is_available():
# Use float16 for older GPUs
dtype = torch.float16
logger.info("Using float16 precision (pre-Ampere GPU)")
else:
# CPU, use default dtype
dtype = None
logger.info("Using default precision (CPU)")
# Check for flash attention as the last dependency check
use_flash_attention = config.get("use_flash_attention", True)
if use_flash_attention and not find_spec("flash_attn"):
logger.warning("flash-attn not found. Will continue without flash attention.")
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
use_flash_attention = False
# Set device map based on config or default to "auto"
device_map = config.get("hardware", {}).get("hardware_setup", {}).get("device_map", "auto")
# Calculate max memory settings if multiple GPUs are available
max_memory = None
if gpu_count > 1:
memory_per_gpu = config.get("hardware", {}).get("specs", {}).get("vram_per_gpu", 24)
max_memory = {i: f"{int(memory_per_gpu * 0.85)}GiB" for i in range(gpu_count)}
max_memory["cpu"] = "64GiB" # Allow CPU offloading if needed
# Load model with proper error handling for out-of-memory
try:
# Improved memory settings for multi-GPU setup
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
dtype=dtype,
device_map=device_map,
max_memory=max_memory,
# Don't explicitly use flash attention config here, let Unsloth handle it
)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.")
raise
else:
# Try again with CPU placement to see if it's a memory issue
logger.warning(f"Error loading model on default device: {str(e)}")
logger.warning("Attempting to load with device_map='cpu' and no specific dtype")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
dtype=None,
device_map={"": "cpu"},
)
logger.warning("Model loaded on CPU. Training will be very slow.")
# Ensure model and optimizer init is on the same device
logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}")
# Apply Unsloth's training optimizations with config parameters
unsloth_config = config.get("unsloth", {})
# Get dropout value; if not explicitly zero, warn about performance implications
lora_dropout = unsloth_config.get("dropout", 0.05)
if lora_dropout > 0:
logger.warning(f"Unsloth works best with dropout=0, but config has dropout={lora_dropout}")
logger.warning("This will impact performance but training will still work")
logger.warning("Consider setting dropout=0 in your config for better performance")
# Apply optimizations
model = FastLanguageModel.get_peft_model(
model,
r=unsloth_config.get("r", 32),
target_modules=unsloth_config.get("target_modules",
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]),
lora_alpha=unsloth_config.get("alpha", 16),
lora_dropout=lora_dropout, # Using the value from config or default
bias="none",
use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True),
random_state=config.get("seed", 42),
)
logger.info("Unsloth optimizations applied successfully")
# Set up tokenizer settings
chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template")
if chat_template:
try:
# Get the correct chat template for phi models
template = get_chat_template("phi")
# Correctly apply the template to the tokenizer (it's a string)
if isinstance(template, str):
tokenizer.chat_template = template
logger.info("Set phi chat template (string)")
else:
# If it's not a string, it's likely already a template object
tokenizer.chat_template = template
logger.info("Set phi chat template (object)")
except Exception as e:
logger.warning(f"Failed to set chat template: {str(e)}")
logger.warning("Chat formatting may not work correctly, but training can continue")
# Ensure proper token settings
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}")
return model, tokenizer
except Exception as e:
logger.error(f"Error in model/tokenizer loading: {str(e)}")
logger.error("If missing dependencies, check the requirements.txt file")
raise
def load_dataset_with_mapping(dataset_config):
"""Load dataset and apply appropriate column mappings."""
try:
# Load dataset
dataset_name = dataset_config.get("dataset", {}).get("name", "")
dataset_split = dataset_config.get("dataset", {}).get("split", "train")
if not dataset_name:
raise ValueError("Dataset name not provided in configuration")
logger.info(f"Loading pre-processed dataset {dataset_name}, split {dataset_split}")
try:
dataset = load_dataset(dataset_name, split=dataset_split)
# Verify the dataset was actually loaded and is not None
if dataset is None:
raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) loaded as None - check dataset exists and is accessible")
# Check if the dataset is empty
if len(dataset) == 0:
raise ValueError(f"Dataset {dataset_name} (split {dataset_split}) is empty (contains 0 examples)")
# Verify conversations field specifically
if "conversations" not in dataset.column_names:
raise ValueError(f"Dataset {dataset_name} missing required 'conversations' column")
# Validate conversation structure
if len(dataset) > 0:
sample = dataset[0]
conversations = sample.get("conversations", [])
if conversations:
first_conv = conversations[0]
if isinstance(first_conv, dict):
# Check actual fields
fields = list(first_conv.keys())
logger.info(f"Conversation fields: {fields}")
# Verify only 'content' field exists
if fields == ["content"]:
logger.info("Confirmed conversations have correct format with only 'content' field")
else:
logger.warning(f"Unexpected conversation fields: {fields}")
logger.warning("Expected only 'content' field")
# Check a sample of conversation entries to validate structure
logger.info("Validating conversation structure...")
for i in range(min(5, len(dataset))):
conv = dataset[i].get("conversations")
if conv is None:
logger.warning(f"Example {i} has None as 'conversations' value")
elif not isinstance(conv, list):
logger.warning(f"Example {i} has non-list 'conversations': {type(conv)}")
elif len(conv) == 0:
logger.warning(f"Example {i} has empty conversations list")
else:
# Look at the first conversation entry
first_entry = conv[0]
if isinstance(first_entry, dict) and "content" in first_entry:
logger.info(f"Content field example: {str(first_entry['content'])[:50]}...")
else:
logger.warning(f"Example {i} missing 'content' key in conversation")
except Exception as dataset_error:
logger.error(f"Failed to load dataset {dataset_name}: {str(dataset_error)}")
logger.error("Make sure the dataset exists and you have proper access permissions")
logger.error("This could be due to authentication issues with your HF_TOKEN")
raise
return dataset
except Exception as e:
logger.error(f"Error loading dataset: {str(e)}")
return 1
def format_phi_chat(messages, dataset_config):
"""Format messages according to phi-4's chat template and dataset config.
Only formats the conversation structure, preserves the actual content."""
formatted_chat = ""
# Get role templates from config
roles = dataset_config.get("data_formatting", {}).get("roles", {
"system": "System: {content}\n\n",
"human": "Human: {content}\n\n",
"assistant": "Assistant: {content}\n\n"
})
# Handle each message in the conversation
for message in messages:
if not isinstance(message, dict) or "content" not in message:
logger.warning(f"Skipping invalid message format: {message}")
continue
content = message.get("content", "") # Don't strip() - preserve exact content
# Skip empty content
if not content:
continue
# Only add role prefixes based on position/content
if "[RESEARCH INTRODUCTION]" in content:
# System message
template = roles.get("system", "System: {content}\n\n")
formatted_chat = template.format(content=content) + formatted_chat
else:
# Alternate between human and assistant for regular conversation turns
# In phi-4 format, human messages come first, followed by assistant responses
if len(formatted_chat.split("Human:")) == len(formatted_chat.split("Assistant:")):
# If equal numbers of Human and Assistant messages, next is Human
template = roles.get("human", "Human: {content}\n\n")
else:
# Otherwise, next is Assistant
template = roles.get("assistant", "Assistant: {content}\n\n")
formatted_chat += template.format(content=content)
return formatted_chat
class SimpleDataCollator:
def __init__(self, tokenizer, dataset_config):
self.tokenizer = tokenizer
self.max_seq_length = min(dataset_config.get("max_seq_length", 2048), tokenizer.model_max_length)
self.stats = {
"processed": 0,
"skipped": 0,
"total_tokens": 0
}
logger.info(f"Initialized SimpleDataCollator with max_seq_length={self.max_seq_length}")
def __call__(self, features):
# Initialize tensors on CPU to save GPU memory
batch = {
"input_ids": [],
"attention_mask": [],
"labels": []
}
for feature in features:
paper_id = feature.get("article_id", "unknown")
prompt_num = feature.get("prompt_number", 0)
conversations = feature.get("conversations", [])
if not conversations:
logger.warning(f"No conversations for paper_id {paper_id}, prompt {prompt_num}")
self.stats["skipped"] += 1
continue
# Get the content directly
content = conversations[0].get("content", "")
if not content:
logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}")
self.stats["skipped"] += 1
continue
# Process the content string by tokenizing it
if isinstance(content, str):
# Tokenize the content string
input_ids = self.tokenizer.encode(content, add_special_tokens=True)
else:
# If somehow the content is already tokenized (not a string), use it directly
input_ids = content
# Truncate if needed
if len(input_ids) > self.max_seq_length:
input_ids = input_ids[:self.max_seq_length]
logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}")
# Create attention mask (1s for all tokens)
attention_mask = [1] * len(input_ids)
# Add to batch
batch["input_ids"].append(input_ids)
batch["attention_mask"].append(attention_mask)
batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids
self.stats["processed"] += 1
self.stats["total_tokens"] += len(input_ids)
# Log statistics periodically
if self.stats["processed"] % 100 == 0:
avg_tokens = self.stats["total_tokens"] / max(1, self.stats["processed"])
logger.info(f"Data collation stats: processed={self.stats['processed']}, "
f"skipped={self.stats['skipped']}, avg_tokens={avg_tokens:.1f}")
# Convert to tensors or pad sequences (PyTorch will handle)
if batch["input_ids"]:
# Pad sequences to max length in batch using the tokenizer
batch = self.tokenizer.pad(
batch,
padding="max_length",
max_length=self.max_seq_length,
return_tensors="pt"
)
return batch
else:
# Return empty batch if no valid examples
return {k: [] for k in batch}
class LoggingCallback(TrainerCallback):
def __init__(self, model=None, dataset=None):
super().__init__()
self.training_started = time.time()
self.last_log_time = time.time()
self.last_step = 0
self.model = model
self.dataset = dataset
def on_train_begin(self, args, state, control, **kwargs):
log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
# Log model info if available
if self.model is not None:
log_info(f"Model parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M")
# Log dataset info if available
if self.dataset is not None:
log_info(f"Dataset size: {len(self.dataset)} examples")
# Log important training parameters for visibility
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
total_steps = int(len(self.dataset or []) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
log_info(f"Training plan: {len(self.dataset or [])} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
# Log memory information in compact format
if CUDA_AVAILABLE:
memory_info = []
for i in range(NUM_GPUS):
allocated = torch.cuda.memory_allocated(i) / 1024**2
max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
log_info(f"Initial memory usage - {', '.join(memory_info)}")
def check_dependencies():
"""Check if all required dependencies are installed and in the correct order."""
missing_packages = []
order_issues = []
# Define required packages with versions
required_packages = {
"unsloth": ">=2024.3",
"transformers": ">=4.38.0",
"peft": ">=0.9.0",
"accelerate": ">=0.27.0"
}
# Check for required packages
for package, version in required_packages.items():
try:
if package == "unsloth" and not unsloth_available:
missing_packages.append(f"{package}{version}")
elif package == "peft" and not peft_available:
missing_packages.append(f"{package}{version}")
else:
module = __import__(package)
logger.info(f"Using {package} version {getattr(module, '__version__', 'unknown')}")
except ImportError:
missing_packages.append(f"{package}{version}")
# Check import order
try:
import sys
modules = list(sys.modules.keys())
if 'transformers' in modules and 'unsloth' in modules:
try:
transformers_idx = modules.index('transformers')
unsloth_idx = modules.index('unsloth')
if transformers_idx < unsloth_idx:
order_issues.append("For optimal performance, unsloth should be imported before transformers")
except ValueError:
pass
except Exception as e:
logger.warning(f"Could not check module import order: {str(e)}")
# Check optional dependencies
optional_packages = {
"flash_attn": "Flash attention support",
"bitsandbytes": "4-bit quantization support"
}
for package, feature in optional_packages.items():
if find_spec(package):
logger.info(f"Found {package} - {feature} enabled")
else:
logger.warning(f"{package} not found - {feature} will not be available")
# Report missing required packages
if missing_packages:
logger.error("Critical dependencies missing:")
for pkg in missing_packages:
logger.error(f" - {pkg}")
logger.error("Please install the missing dependencies with:")
logger.error(f" pip install {' '.join(missing_packages)}")
return False
# Report order issues as warnings
for issue in order_issues:
logger.warning(issue)
return True
def update_huggingface_space():
"""Update the Hugging Face Space with the current code."""
log_info("Updating Hugging Face Space...")
update_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "update_space.py")
if not os.path.exists(update_script):
logger.warning(f"Update space script not found at {update_script}")
return False
try:
import subprocess
# Explicitly set space_name to ensure we're targeting the right Space
result = subprocess.run(
[sys.executable, update_script, "--force", "--space_name", "phi4training"],
capture_output=True, text=True, check=False
)
if result.returncode == 0:
log_info("Hugging Face Space updated successfully!")
log_info(f"Space URL: https://huggingface.co/spaces/George-API/phi4training")
return True
else:
logger.error(f"Failed to update Hugging Face Space: {result.stderr}")
return False
except Exception as e:
logger.error(f"Error updating Hugging Face Space: {str(e)}")
return False
def validate_huggingface_credentials():
"""Validate Hugging Face credentials to ensure they work correctly."""
if not os.environ.get("HF_TOKEN"):
logger.warning("HF_TOKEN not found. Skipping Hugging Face credentials validation.")
return False
try:
# Import here to avoid requiring huggingface_hub if not needed
from huggingface_hub import HfApi, login
# Try to login with the token
login(token=os.environ.get("HF_TOKEN"))
# Check if we can access the API
api = HfApi()
username = os.environ.get("HF_USERNAME", "George-API")
space_name = os.environ.get("HF_SPACE_NAME", "phi4training")
# Try to get whoami info
user_info = api.whoami()
logger.info(f"Successfully authenticated with Hugging Face as {user_info['name']}")
# Check if we're using the expected Space
expected_space_id = "George-API/phi4training"
actual_space_id = f"{username}/{space_name}"
if actual_space_id != expected_space_id:
logger.warning(f"Using Space '{actual_space_id}' instead of the expected '{expected_space_id}'")
logger.warning(f"Make sure this is intentional. To use the correct Space, update your .env file.")
else:
logger.info(f"Confirmed using Space: {expected_space_id}")
# Check if the space exists
try:
space_id = f"{username}/{space_name}"
space_info = api.space_info(repo_id=space_id)
logger.info(f"Space {space_id} is accessible at: https://huggingface.co/spaces/{space_id}")
return True
except Exception as e:
logger.warning(f"Could not access Space {username}/{space_name}: {str(e)}")
logger.warning("Space updating may not work correctly")
return False
except ImportError:
logger.warning("huggingface_hub not installed. Cannot validate Hugging Face credentials.")
return False
except Exception as e:
logger.warning(f"Error validating Hugging Face credentials: {str(e)}")
return False
def main():
# Set up logging
logger.info("Starting training process")
try:
# Check dependencies first, before any other operations
if not check_dependencies():
logger.error("Aborting due to missing critical dependencies")
return 1
# Parse arguments
args = parse_args()
# Load environment variables
load_env_variables()
# Validate Hugging Face credentials if we're going to use them
validate_huggingface_credentials()
# Load configuration
try:
transformers_config = load_configs(args.config)
hardware_config = transformers_config.get("hardware", {})
dataset_config = transformers_config.get("dataset", {})
logger.info("Configuration loaded successfully")
except Exception as e:
logger.error(f"Error loading configuration: {e}")
return 1
# Check if we're in distributed mode
is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
if is_distributed:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}")
else:
log_info("Running in non-distributed mode (single process)")
# Set random seed for reproducibility
seed = transformers_config.get("seed", 42)
set_seed(seed)
logger.info(f"Set random seed to {seed}")
# Load model and tokenizer using the consolidated config
model, tokenizer = load_model_and_tokenizer(transformers_config)
# Empty CUDA cache to ensure clean state
if CUDA_AVAILABLE:
torch.cuda.empty_cache()
log_info("Cleared CUDA cache")
# Setup environment variable for CUDA memory allocation
if CUDA_AVAILABLE:
system_settings = hardware_config.get("system_settings", {})
cuda_memory_fraction = system_settings.get("cuda_memory_fraction", 0.85)
if cuda_memory_fraction < 1.0:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
try:
log_info("Loading dataset...")
dataset = load_dataset_with_mapping(dataset_config)
# Extra validation to catch None/empty dataset issues
if dataset is None:
logger.error("Dataset is None! Cannot proceed with training.")
return 1
if not hasattr(dataset, '__len__') or len(dataset) == 0:
logger.error("Dataset is empty! Cannot proceed with training.")
return 1
log_info(f"Dataset loaded with {len(dataset)} examples")
# Minimal validation before proceeding
if dataset is None or len(dataset) == 0:
logger.error("Dataset is empty or None! Cannot proceed with training.")
return 1
# Create data collator
data_collator = SimpleDataCollator(tokenizer, dataset_config)
# Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
# First check hardware config, then transformers config
use_bf16 = False
use_fp16 = False
# Check hardware config first
hardware_precision = hardware_config.get("training_optimizations", {}).get("mixed_precision", "")
if hardware_precision.lower() == "bf16":
use_bf16 = True
log_info("Using BF16 precision from hardware config")
elif hardware_precision.lower() == "fp16":
use_fp16 = True
log_info("Using FP16 precision from hardware config")
else:
# Fall back to transformers config
use_bf16 = transformers_config.get("bf16", False) or transformers_config.get("torch_dtype", "") == "bfloat16"
use_fp16 = transformers_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set
log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}")
# Get per device batch size - from transformers config, but possibly overridden by hardware config
per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16)
gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3)
# Get multi-GPU strategy from hardware config (default to data_parallel)
multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel")
logger.info(f"Multi-GPU strategy: {multi_gpu_strategy}")
# For multi-GPU setup, adjust for better balance
if CUDA_AVAILABLE and NUM_GPUS > 1:
log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs")
# Set up FSDP for multi-GPU training if specified and in distributed mode
fsdp_config = None
if multi_gpu_strategy == "fsdp" and is_distributed and NUM_GPUS > 1:
try:
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
log_info("Using FSDP for distributed training")
# Configure FSDP
fsdp_config = {
"fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
"fsdp_offload_params": False,
"fsdp_backward_prefetch": "BACKWARD_PRE",
"fsdp_min_num_params": 1e6,
"fsdp_sharding_strategy": 1, # FULL_SHARD
}
if use_bf16 or use_fp16:
precision_type = "bf16" if use_bf16 else "fp16"
fsdp_config["fsdp_state_dict_type"] = "FULL_STATE_DICT"
log_info(f"FSDP using mixed precision: {precision_type}")
except ImportError:
log_info("FSDP imports failed, falling back to standard DDP")
fsdp_config = None
elif multi_gpu_strategy == "fsdp" and not is_distributed:
log_info("FSDP disabled: requires distributed environment (use torchrun or accelerate)")
log_info("Using DataParallel for multi-GPU training instead")
else:
log_info(f"Using {multi_gpu_strategy} for multi-GPU training")
# Get system settings from hardware config
dataloader_workers = hardware_config.get("system_settings", {}).get("dataloader_num_workers", 2)
pin_memory = hardware_config.get("system_settings", {}).get("dataloader_pin_memory", True)
# Set up training arguments
log_info("Setting up training arguments")
# Handle FSDP configuration
fsdp_config = transformers_config.get("distributed_training", {}).get("fsdp_config", {})
fsdp_enabled = fsdp_config.get("enabled", False)
# Only set FSDP args if explicitly enabled
fsdp_args = None
if fsdp_enabled and is_distributed and NUM_GPUS > 1:
fsdp_args = {
"fsdp": ["full_shard", "auto_wrap"],
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_offload_params": fsdp_config.get("offload_params", False),
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_sharding_strategy": 1, # FULL_SHARD
}
log_info("FSDP configuration enabled")
else:
log_info("FSDP disabled, using standard data parallel")
# Check if we're running in a Space
is_space = bool(os.environ.get("SPACE_ID"))
# Create training arguments
training_args = TrainingArguments(
output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"),
num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3),
per_device_train_batch_size=per_device_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=transformers_config.get("training", {}).get("learning_rate", 2e-5),
weight_decay=transformers_config.get("training", {}).get("weight_decay", 0.01),
warmup_ratio=transformers_config.get("training", {}).get("warmup_ratio", 0.05),
lr_scheduler_type=transformers_config.get("training", {}).get("lr_scheduler_type", "cosine"),
logging_steps=transformers_config.get("training", {}).get("logging_steps", 10),
save_strategy=transformers_config.get("checkpointing", {}).get("save_strategy", "steps"),
save_steps=transformers_config.get("checkpointing", {}).get("save_steps", 100),
save_total_limit=transformers_config.get("checkpointing", {}).get("save_total_limit", 3),
fp16=use_fp16,
bf16=use_bf16,
max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0),
push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False),
hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None),
hub_token=None if is_space else os.environ.get("HF_TOKEN", None),
report_to="tensorboard",
remove_unused_columns=False, # Keep all columns
gradient_checkpointing=transformers_config.get("training", {}).get("gradient_checkpointing", True),
dataloader_pin_memory=pin_memory,
optim=transformers_config.get("training", {}).get("optim", "adamw_torch"),
ddp_find_unused_parameters=False, # Improve distributed training efficiency
dataloader_drop_last=False, # Process all examples
dataloader_num_workers=dataloader_workers,
no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available
**({} if fsdp_args is None else fsdp_args) # Only include FSDP args if configured
)
log_info("Training arguments created successfully")
# Validate dataset before creating sampler
if dataset is None:
raise ValueError("Dataset is None - cannot create sampler")
# Create sequential sampler to maintain original dataset order
sequential_sampler = torch.utils.data.SequentialSampler(dataset)
log_info("Sequential sampler created")
# Initialize trainer first
log_info("Initializing Trainer")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
callbacks=[LoggingCallback(model=model, dataset=dataset)],
)
# Then override the get_train_dataloader method
def custom_get_train_dataloader():
"""Custom dataloader that preserves original dataset order"""
log_info("Creating sequential dataloader to maintain original dataset order")
# Safety check - make sure dataset exists and is not None
if dataset is None:
raise ValueError("Dataset is None - cannot create dataloader")
# Make sure dataset is not empty
if len(dataset) == 0:
raise ValueError("Dataset is empty - cannot create dataloader")
# Create a simple sequential sampler
sequential_sampler = torch.utils.data.SequentialSampler(dataset)
# Verification of sequence preservation flags - simplified
data_loading_config = dataset_config.get("data_loading", {})
shuffle_enabled = data_loading_config.get("shuffle", False)
if shuffle_enabled:
log_info("WARNING: Shuffle is enabled in configuration! This will be overridden to preserve order.")
# We enforce sequential processing regardless of config
# Log our approach clearly
log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
# Verify column order and check for 'conversations' field
expected_order = ["prompt_number", "article_id", "conversations"]
if hasattr(dataset, 'column_names'):
actual_order = dataset.column_names
# Verify all required fields exist
missing_fields = [field for field in ["conversations"] if field not in actual_order]
if missing_fields:
raise ValueError(f"Dataset missing critical fields: {missing_fields}")
if actual_order == expected_order:
log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
else:
log_info(f"Note: Dataset columns ({', '.join(actual_order)}) are not in expected order ({', '.join(expected_order)})")
log_info("This is handled correctly by field-based access, but noting for clarity")
log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
# Validate a few samples before proceeding
for i in range(min(3, len(dataset))):
sample = dataset[i]
if "conversations" not in sample:
log_info(f"WARNING: Sample {i} missing 'conversations' field")
elif sample["conversations"] is None:
log_info(f"WARNING: Sample {i} has None 'conversations' field")
elif not isinstance(sample["conversations"], list):
log_info(f"WARNING: Sample {i} has non-list 'conversations' field: {type(sample['conversations'])}")
# Calculate batch size based on device availability
if getattr(training_args, "no_cuda", False):
batch_size = training_args.per_device_train_batch_size
else:
batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1)
log_info(f"Using sequential sampler with batch size {batch_size}")
# Return DataLoader with sequential sampler and extra error handling
try:
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
sampler=sequential_sampler, # Always use sequential sampler
collate_fn=data_collator,
drop_last=training_args.dataloader_drop_last,
num_workers=training_args.dataloader_num_workers,
pin_memory=training_args.dataloader_pin_memory,
)
except Exception as e:
log_info(f"Error creating DataLoader: {str(e)}")
# Try again with minimal settings
log_info("Attempting to create DataLoader with minimal settings")
return torch.utils.data.DataLoader(
dataset,
batch_size=1, # Minimal batch size
sampler=sequential_sampler,
collate_fn=data_collator,
num_workers=0, # No parallel workers
pin_memory=False,
)
# Override the get_train_dataloader method
trainer.get_train_dataloader = custom_get_train_dataloader
# Start training
log_info("=== Starting Training ===")
try:
# Empty cache again right before training
if CUDA_AVAILABLE:
torch.cuda.empty_cache()
log_info("Cleared CUDA cache before training")
# Display compact training info
total_steps = int((len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps)) * training_args.num_train_epochs)
log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps")
trainer.train()
log_info("Training completed successfully!")
# Save the final model
log_info("Saving final model...")
trainer.save_model()
log_info(f"Model saved to {training_args.output_dir}")
# Push to hub if enabled
if transformers_config.get("huggingface_hub", {}).get("push_to_hub", False):
hub_id = transformers_config.get("huggingface_hub", {}).get("hub_model_id", "model")
log_info(f"Pushing model to Hugging Face Hub as {hub_id}...")
trainer.push_to_hub()
log_info("Model successfully pushed to Hub")
# Update the Hugging Face Space with current code
if os.environ.get("HF_TOKEN") and os.environ.get("HF_USERNAME") and os.environ.get("HF_SPACE_NAME"):
update_huggingface_space()
return 0
except Exception as e:
logger.error(f"Training failed with error: {str(e)}")
# Log CUDA memory info if available in compact format
if CUDA_AVAILABLE:
memory_info = []
for i in range(NUM_GPUS):
allocated = torch.cuda.memory_allocated(i) / 1024**2
reserved = torch.cuda.memory_reserved(i) / 1024**2
max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)")
logger.error(f"GPU memory at failure: {', '.join(memory_info)}")
raise
except Exception as e:
logger.error(f"Error in main training loop: {str(e)}")
return 1
except Exception as e:
logger.error(f"Error in main function: {str(e)}")
return 1
if __name__ == "__main__":
sys.exit(main())