phi4training / run_transformers_training.py
George-API's picture
Upload folder using huggingface_hub
5b6d8f0 verified
raw
history blame
67.4 kB
#!/usr/bin/env python
# coding=utf-8
import os
import sys
import json
import argparse
import logging
from datetime import datetime
import time
import warnings
import torch
from importlib.util import find_spec
# Global variables for hardware detection
CUDA_AVAILABLE = torch.cuda.is_available()
NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0
DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu"
# Import Unsloth first, before other ML imports
try:
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
unsloth_available = True
except ImportError:
unsloth_available = False
logger = logging.getLogger(__name__)
logger.warning("Unsloth not available. Please install with: pip install unsloth")
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
TrainerCallback,
set_seed,
BitsAndBytesConfig
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Set other loggers to WARNING to reduce noise and ensure our logs are visible
logging.getLogger("transformers").setLevel(logging.WARNING)
logging.getLogger("datasets").setLevel(logging.WARNING)
logging.getLogger("accelerate").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("bitsandbytes").setLevel(logging.WARNING)
# Check availability of libraries
peft_available = find_spec("peft") is not None
# Define a clean logging function for HF Space compatibility
def log_info(message):
"""Log information in a format compatible with Hugging Face Spaces"""
# Just use the logger, but ensure consistent formatting
logger.info(message)
# Also ensure output is flushed immediately for streaming
sys.stdout.flush()
# Check for BitsAndBytes
try:
from transformers import BitsAndBytesConfig
bitsandbytes_available = True
except ImportError:
bitsandbytes_available = False
logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.")
# Check for PEFT
try:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
peft_available = True
except ImportError:
peft_available = False
logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
def load_env_variables():
"""Load environment variables from system, .env file, or Hugging Face Space variables."""
# Check if we're running in a Hugging Face Space
if os.environ.get("SPACE_ID"):
logging.info("Running in Hugging Face Space")
# Log the presence of variables (without revealing values)
logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}")
# If username is not set, try to extract from SPACE_ID
if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""):
username = os.environ.get("SPACE_ID").split("/")[0]
os.environ["HF_USERNAME"] = username
logging.info(f"Set HF_USERNAME from SPACE_ID: {username}")
else:
# Try to load from .env file if not in a Space
try:
from dotenv import load_dotenv
# Updated path to .env file in the new directory structure
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env")
if os.path.exists(env_path):
load_dotenv(env_path)
logging.info(f"Loaded environment variables from {env_path}")
logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}")
logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}")
logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}")
else:
logging.warning(f"No .env file found at {env_path}")
except ImportError:
logging.warning("python-dotenv not installed, not loading from .env file")
if not os.environ.get("HF_USERNAME"):
logger.warning("HF_USERNAME is not set. Using default username.")
if not os.environ.get("HF_SPACE_NAME"):
logger.warning("HF_SPACE_NAME is not set. Using default space name.")
# Set HF_TOKEN for huggingface_hub
if os.environ.get("HF_TOKEN"):
os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
def load_configs(base_path):
"""Load configuration from transformers_config.json file."""
# Using a single consolidated config file
config_file = base_path
try:
with open(config_file, "r") as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_file}")
return config
except Exception as e:
logger.error(f"Error loading {config_file}: {e}")
raise
def parse_args():
parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file")
return parser.parse_args()
def load_model_and_tokenizer(config):
"""Load model and tokenizer with proper error handling and optimizations."""
try:
if not unsloth_available:
logger.error("Unsloth is required for training with pre-quantized model")
logger.error("Please ensure unsloth is in requirements.txt")
raise ImportError("Unsloth is required for this training setup")
# Get model name correctly from config
model_name = config.get("model_name") or config.get("model", {}).get("name")
logger.info(f"Loading model: {model_name}")
if not model_name:
raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.")
logger.info("Using Unsloth optimizations with pre-quantized model")
# Check for flash attention
use_flash_attention = config.get("use_flash_attention", True)
if use_flash_attention and not find_spec("flash_attn"):
logger.warning("flash-attn not found. Will continue without flash attention.")
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
use_flash_attention = False
# First detect if we have a GPU
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
logger.info(f"CUDA available, found {gpu_count} GPU(s)")
# Log GPU info
for i in range(gpu_count):
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
logger.info(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")
# Create an optimized device map for better balance
if gpu_count > 1:
logger.info(f"Creating balanced device map for {gpu_count} GPUs")
# Use auto mapping but with memory tracking
device_map = "auto"
# Set max memory for better balancing
max_memory = {i: f"{int(torch.cuda.get_device_properties(i).total_memory * 0.85 / 1024**3)}GiB" for i in range(gpu_count)}
logger.info(f"Max memory settings: {max_memory}")
else:
device_map = "auto"
max_memory = None
else:
logger.warning("No CUDA available, falling back to CPU")
device_map = {"": "cpu"} # Force CPU placement
max_memory = None
# Set default dtype for better numerics
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
# Use bfloat16 for Ampere or newer
dtype = torch.bfloat16
logger.info("Using bfloat16 precision (Ampere+ GPU)")
elif torch.cuda.is_available():
# Use float16 for older GPUs
dtype = torch.float16
logger.info("Using float16 precision (pre-Ampere GPU)")
else:
# CPU, use default dtype
dtype = None
logger.info("Using default precision (CPU)")
# Load model with proper error handling for out-of-memory
try:
# Improved memory settings for multi-GPU setup
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
dtype=dtype,
device_map=device_map,
max_memory=max_memory,
# Don't explicitly use flash attention config here, let Unsloth handle it
)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.")
raise
else:
# Try again with CPU placement to see if it's a memory issue
logger.warning(f"Error loading model on default device: {str(e)}")
logger.warning("Attempting to load with device_map='cpu' and no specific dtype")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
dtype=None,
device_map={"": "cpu"},
)
logger.warning("Model loaded on CPU. Training will be very slow.")
# Ensure model and optimizer init is on the same device
logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}")
# Apply Unsloth's training optimizations with config parameters
unsloth_config = config.get("unsloth", {})
model = FastLanguageModel.get_peft_model(
model,
r=unsloth_config.get("r", 32),
target_modules=unsloth_config.get("target_modules",
["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]),
lora_alpha=unsloth_config.get("alpha", 16),
lora_dropout=unsloth_config.get("dropout", 0.05),
bias="none",
use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True),
random_state=config.get("seed", 42),
)
logger.info("Unsloth optimizations applied successfully")
# Set up tokenizer settings
chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template")
if chat_template:
try:
template = get_chat_template("phi")
tokenizer.chat_template = template
logger.info("Set phi chat template")
except Exception as e:
logger.warning(f"Failed to set chat template: {str(e)}")
# Ensure proper token settings
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}")
return model, tokenizer
except Exception as e:
logger.error(f"Error in model/tokenizer loading: {str(e)}")
logger.error("If missing dependencies, check the requirements.txt file")
raise
def load_dataset_with_mapping(dataset_config):
"""Load dataset and apply appropriate column mappings."""
try:
# Load dataset
dataset_name = dataset_config.get("dataset", {}).get("name", "")
dataset_split = dataset_config.get("dataset", {}).get("split", "train")
if not dataset_name:
raise ValueError("Dataset name not provided in configuration")
logger.info(f"Loading dataset {dataset_name}, split {dataset_split}")
dataset = load_dataset(dataset_name, split=dataset_split)
# Map columns if specified - with checks to avoid conflicts
column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {})
if column_mapping:
logger.info(f"Checking column mapping: {column_mapping}")
# Only apply mappings for columns that need renaming and don't already exist
safe_mappings = {}
for target, source in column_mapping.items():
if source in dataset.column_names:
# Skip if target already exists and is not the same as source
if target in dataset.column_names and target != source:
logger.warning(f"Cannot rename '{source}' to '{target}' - target column already exists")
else:
safe_mappings[source] = target
# Apply safe renames
if safe_mappings:
logger.info(f"Applying safe column mapping: {safe_mappings}")
for source, target in safe_mappings.items():
if source != target: # Only rename if names are different
dataset = dataset.rename_column(source, target)
# Add prompt_number field that increments based on original order
def add_prompt_numbers(examples, indices):
# Defensive check to ensure indices is not None and is iterable
if indices is None:
logger.warning("Warning: indices is None in add_prompt_numbers, using empty list")
indices = []
elif isinstance(indices, int):
# Handle case where indices is a single integer
logger.warning(f"Warning: indices is an integer ({indices}) in add_prompt_numbers, converting to list")
indices = [indices]
# Ensure indices is always a list/iterable
try:
# Create a new field with the dataset index as the prompt number, starting at 1
examples["prompt_number"] = [idx + 1 for idx in indices] # Adding 1 to make it 1-indexed
except TypeError:
# Fallback for non-iterable types
logger.warning(f"Warning: non-iterable indices in add_prompt_numbers: {type(indices)}, using default")
examples["prompt_number"] = [1] * len(next(iter(examples.values())))
return examples
# Add prompt numbers to the dataset based on original order
logger.info("Adding prompt numbers based on original dataset order (starting at 1)")
try:
dataset = dataset.map(
add_prompt_numbers,
with_indices=True,
desc="Adding prompt numbers"
)
logger.info(f"Successfully added prompt_number field to dataset")
except Exception as e:
logger.error(f"Error adding prompt numbers: {e}")
# Create a fallback implementation that doesn't rely on with_indices
logger.info("Attempting fallback method for adding prompt numbers")
def add_prompt_numbers_fallback(example, idx):
example["prompt_number"] = idx + 1
return example
# Process each example one by one with explicit indices
updated_examples = []
for i, example in enumerate(dataset):
updated_examples.append(add_prompt_numbers_fallback(dict(example), i))
# Create a new dataset with the updated examples
from datasets import Dataset
dataset = Dataset.from_list(updated_examples)
logger.info(f"Successfully added prompt_number field using fallback method")
# Rename 'id' to 'article_id' if it exists
if 'id' in dataset.column_names and 'article_id' not in dataset.column_names:
logger.info("Renaming 'id' column to 'article_id'")
dataset = dataset.rename_column('id', 'article_id')
# Reorder columns to make prompt_number first if it exists
if 'prompt_number' in dataset.column_names:
logger.info("Reordering columns to place prompt_number first")
# Get current column names
current_columns = dataset.column_names
# Create new column order with prompt_number first
new_column_order = ['prompt_number'] + [col for col in current_columns if col != 'prompt_number']
# Reorder columns
dataset = dataset.select_columns(new_column_order)
# Verify all new column names for logging
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
logger.info(f"Dataset columns: {dataset.column_names}")
# Verify dataset is not empty
if len(dataset) == 0:
logger.error("Dataset is empty! This will cause errors during training.")
raise ValueError("Empty dataset loaded")
# Check for required columns
required_columns = ['conversations']
for col in required_columns:
if col not in dataset.column_names:
logger.error(f"Required column '{col}' not found in dataset!")
raise ValueError(f"Required column '{col}' missing from dataset")
# Verify expected columns exist
expected_columns = {"article_id", "conversations", "prompt_number"}
missing_columns = expected_columns - set(dataset.column_names)
if missing_columns:
logger.warning(f"Some expected columns are missing: {missing_columns}")
# If "conversations" is missing but "text" exists, attempt conversion
if "conversations" not in dataset.column_names and "text" in dataset.column_names:
logger.info("Converting 'text' field to 'conversations' format")
def convert_text_to_conversations(example):
# Check if text is already a list of conversation turns
if isinstance(example.get("text"), list):
example["conversations"] = example["text"]
# Otherwise, create a simple conversation with the text as user message
else:
example["conversations"] = [
{"role": "user", "content": str(example.get("text", ""))}
]
return example
dataset = dataset.map(convert_text_to_conversations)
logger.info("Successfully converted 'text' to 'conversations'")
# Verify data ordering requirements
processing_config = dataset_config.get("dataset", {}).get("processing", {})
data_loading_config = dataset_config.get("data_loading", {})
# Check if sorting is required
sort_by_article_id = processing_config.get("sort_by_article_id", False)
if sort_by_article_id and 'article_id' in dataset.column_names:
logger.info("Sorting dataset by article_id as specified in config")
dataset = dataset.sort("article_id")
sorted_ids = [example['article_id'] for example in dataset.select(range(min(5, len(dataset))))]
logger.info(f"First few article_ids after sorting: {sorted_ids}")
# Flag consolidation - we only need one flag to control sequence preservation
# Default to True to ensure safety
preserve_sequence = processing_config.get("preserve_entry_sequence", True)
shuffle_disabled = not data_loading_config.get("shuffle", False)
if not preserve_sequence:
logger.warning("CRITICAL: preserve_entry_sequence is set to False. This is NOT RECOMMENDED!")
logger.warning("Data sequence integrity is essential for proper model training.")
if not shuffle_disabled:
logger.error("CRITICAL: shuffle is enabled in the dataset config!")
logger.error("This will RANDOMIZE your dataset and break sequential order.")
logger.error("Please set shuffle: false in your data_loading configuration.")
# Actually enforce sequence preservation by raising an error
raise ValueError("Dataset shuffling is enabled but preserve_entry_sequence is required. " +
"Please disable shuffling in your configuration.")
# Verify the IDs are in sequential order if they're numeric
try:
if len(dataset) > 1:
# Check prompt numbers are sequential
sample_indices = range(min(10, len(dataset)))
sample_prompt_numbers = []
# Defensive collection of prompt numbers
for i in sample_indices:
try:
if i < len(dataset) and "prompt_number" in dataset[i]:
sample_prompt_numbers.append(dataset[i]["prompt_number"])
else:
# If prompt_number doesn't exist, use index+1 as fallback
sample_prompt_numbers.append(i + 1)
logger.warning(f"Sample at index {i} missing prompt_number, using {i+1} as fallback")
except Exception as e:
logger.warning(f"Error accessing sample at index {i}: {e}")
sample_prompt_numbers.append(i + 1) # Use fallback
logger.info(f"Verifying sequential integrity with prompt numbers: {sample_prompt_numbers}")
# Check if prompt numbers are sequential (1-indexed)
if sample_prompt_numbers:
is_sequential = all(sample_prompt_numbers[i] == i + 1 for i in range(len(sample_prompt_numbers)))
if not is_sequential:
logger.warning("WARNING: Prompt numbers are not in sequential order!")
logger.warning("This may indicate that data sequence is not preserved.")
else:
logger.info("Prompt numbers verify that samples are in sequential order.")
else:
logger.warning("Could not verify sequential integrity: no prompt numbers collected")
# Also check original IDs as a backup if numeric
try:
sample_examples = []
for i in sample_indices:
try:
if i < len(dataset):
sample_examples.append(dataset[i])
except Exception as e:
logger.warning(f"Error accessing dataset at index {i}: {e}")
if sample_examples:
id_field = 'article_id' if 'article_id' in dataset.column_names else 'id'
if all(isinstance(example.get(id_field, ''), (int, str)) for example in sample_examples):
sample_ids = [example.get(id_field, '') for example in sample_examples if id_field in example]
if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids):
numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids]
if len(numeric_ids) > 1:
is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1))
if not is_ordered:
logger.warning(f"WARNING: Sample {id_field}s are not in sequential order.")
else:
logger.info(f"Sample {id_field}s appear to be in sequential order.")
except Exception as e:
logger.warning(f"Error checking ID sequence: {e}")
except Exception as e:
logger.warning(f"Could not verify sequential integrity: {e}")
# Log examples without printing full content - with defensive coding
if "conversations" in dataset.column_names:
try:
# Safely get first few samples
first_few_indices = range(min(5, len(dataset)))
sample_prompt_numbers = []
sample_article_ids = []
for i in first_few_indices:
try:
example = dataset[i]
if 'prompt_number' in example:
sample_prompt_numbers.append(example['prompt_number'])
if 'article_id' in example:
sample_article_ids.append(example['article_id'])
except Exception as e:
logger.warning(f"Error accessing sample at index {i}: {e}")
logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, Article IDs: {sample_article_ids}")
# Log conversation structure without full content
if len(dataset) > 0:
try:
sample_conv_structure = []
first_example = dataset[0]
if 'conversations' in first_example and first_example['conversations'] is not None:
for msg in first_example['conversations']:
if isinstance(msg, dict):
content = msg.get('content', '')
preview = content[:50] + "..." if len(content) > 50 else content
sample_conv_structure.append({
"role": msg.get('role', ''),
"content_length": len(content),
"preview": preview
})
logger.info(f"Conversation structure: {sample_conv_structure}")
except Exception as e:
logger.warning(f"Error logging conversation structure: {e}")
except Exception as e:
logger.warning(f"Error logging sample examples: {e}")
logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
logger.info(f"Dataset columns: {dataset.column_names}")
# Verify dataset is not empty
if len(dataset) == 0:
logger.error("Dataset is empty! Cannot proceed with training.")
return dataset
# Check for required columns
required_cols = ['conversations', 'prompt_number']
for col in required_cols:
if col not in dataset.column_names:
logger.error(f"Required column '{col}' missing from dataset. Cannot proceed with training.")
return dataset
# Validate at least one sample can be processed
try:
if len(dataset) > 0:
sample = dataset[0]
if 'conversations' not in sample or not sample['conversations']:
logger.error("First sample has no conversations! Data format may be incorrect.")
return dataset
if not isinstance(sample['conversations'], list):
logger.error(f"Conversations field should be a list but got {type(sample['conversations'])}")
return dataset
except Exception as e:
logger.error(f"Error validating first sample: {e}")
return dataset
# Add metadata if specified
metadata_config = dataset_config.get("data_formatting", {}).get("metadata_handling", {})
if metadata_config:
include_article_id = metadata_config.get("include_article_id", False)
include_prompt_number = metadata_config.get("include_prompt_number", False)
metadata_format = metadata_config.get("metadata_format", "")
if (include_article_id or include_prompt_number) and metadata_format:
logger.info("Adding metadata to conversations")
def add_metadata(example):
if not example.get("conversations"):
return example
# Prepare metadata
metadata = metadata_format
if include_article_id and "article_id" in example:
metadata = metadata.replace("{article_id}", str(example.get("article_id", "")))
if include_prompt_number and "prompt_number" in example:
metadata = metadata.replace("{prompt_number}", str(example.get("prompt_number", "")))
# Add system message with metadata if not empty
if metadata.strip():
if example["conversations"] and isinstance(example["conversations"], list):
# Check if first message is already a system message
if (isinstance(example["conversations"][0], dict) and
example["conversations"][0].get("role") == "system"):
# Append to existing system message
example["conversations"][0]["content"] += f"\n\nMetadata: {metadata}"
else:
# Add new system message at the beginning
example["conversations"].insert(0, {
"role": "system",
"content": f"Metadata: {metadata}"
})
return example
dataset = dataset.map(add_metadata)
logger.info("Metadata added to conversations")
return dataset
except Exception as e:
logger.error(f"Error loading dataset: {str(e)}")
raise
def format_phi_chat(messages, dataset_config):
"""Format messages according to phi-4's chat template and dataset config."""
formatted_chat = ""
# Get role templates from config
roles = dataset_config.get("data_formatting", {}).get("roles", {
"system": "System: {content}\n\n",
"human": "Human: {content}\n\n",
"user": "Human: {content}\n\n",
"assistant": "Assistant: {content}\n\n"
})
# Handle research introduction metadata first
metadata = next((msg for msg in messages if isinstance(msg, dict) and
"[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
if metadata:
system_template = roles.get("system", "System: {content}\n\n")
formatted_chat = system_template.format(content=metadata['content'])
messages = [msg for msg in messages if msg != metadata]
# Process remaining messages
for message in messages:
if not isinstance(message, dict) or "content" not in message:
logger.warning(f"Skipping invalid message format: {message}")
continue
role = message.get("role", "").lower()
content = message.get("content", "")
# Format based on role
if role == "human" or role == "user":
template = roles.get("user", roles.get("human", "Human: {content}\n\n"))
formatted_chat += template.format(content=content)
elif role == "assistant" or role == "bot":
template = roles.get("assistant", "Assistant: {content}\n\n")
formatted_chat += template.format(content=content)
elif role == "system":
# For system messages, prepend them
template = roles.get("system", "System: {content}\n\n")
formatted_chat = template.format(content=content) + formatted_chat
else:
# Default to system for unknown roles
logger.warning(f"Unknown role '{role}' - treating as system message")
template = roles.get("system", "System: {content}\n\n")
formatted_chat += template.format(content=content)
return formatted_chat.strip()
class SimpleDataCollator:
def __init__(self, tokenizer, dataset_config):
self.tokenizer = tokenizer
self.dataset_config = dataset_config
self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}")
logger.info("Using exact dataset structure without reformatting")
# Check if we're on GPU
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"SimpleDataCollator using device: {self.device}")
def __call__(self, features):
"""Process examples preserving exact JSONL structure"""
batch = {"input_ids": [], "attention_mask": [], "labels": []}
for example in features:
try:
# Get ID
paper_id = example.get("id", "")
# Get conversations - these should already contain role and content
conversations = example.get("conversations", [])
if not conversations:
self.stats["skipped"] += 1
continue
# Directly use the conversations array as input to the model's chat template
# This preserves the exact structure with roles and content as they are
try:
# Let tokenizer handle the content with the model's chat template
inputs = self.tokenizer.apply_chat_template(
conversations,
return_tensors=None,
add_generation_prompt=False
)
except Exception as chat_error:
# Fallback if apply_chat_template fails
logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}")
# Create a basic representation of the conversation
conversation_text = ""
for msg in conversations:
if isinstance(msg, dict) and 'content' in msg:
conversation_text += msg.get('content', '') + "\n\n"
# Basic tokenization
inputs = self.tokenizer(
conversation_text,
add_special_tokens=True,
return_tensors=None
)
# Apply length cap if needed (shouldn't be necessary for pre-audited data)
if self.max_seq_length > 0 and len(inputs) > self.max_seq_length:
logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})")
inputs = inputs[:self.max_seq_length]
# Create attention mask (1 for all tokens)
attention_mask = [1] * len(inputs)
if len(inputs) > 0:
# For causal language modeling, labels are the same as inputs
labels = inputs.copy()
batch["input_ids"].append(inputs)
batch["attention_mask"].append(attention_mask)
batch["labels"].append(labels)
self.stats["processed"] += 1
self.stats["total_tokens"] += len(inputs)
# Debug logging for first few examples
log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
if self.stats["processed"] <= log_samples:
logger.info(f"Example {self.stats['processed']}:")
logger.info(f"Paper ID: {paper_id}")
logger.info(f"Token count: {len(inputs)}")
logger.info(f"Conversation entries: {len(conversations)}")
else:
self.stats["skipped"] += 1
except Exception as e:
logger.warning(f"Error processing example: {str(e)[:100]}...")
logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}")
self.stats["skipped"] += 1
continue
if not batch["input_ids"]:
logger.warning("Empty batch, returning dummy tensors")
return {
"input_ids": torch.zeros((1, 1), dtype=torch.long),
"attention_mask": torch.zeros((1, 1), dtype=torch.long),
"labels": torch.zeros((1, 1), dtype=torch.long)
}
# Pad the batch
max_length = max(len(ids) for ids in batch["input_ids"])
for i in range(len(batch["input_ids"])):
padding_length = max_length - len(batch["input_ids"][i])
if padding_length > 0:
batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
batch["attention_mask"][i].extend([0] * padding_length)
batch["labels"][i].extend([-100] * padding_length)
# Convert to tensors
batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
# Log stats periodically
log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0:
logger.info(f"Data collator stats: processed={self.stats['processed']}, "
f"skipped={self.stats['skipped']}, "
f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}")
return batch
class LoggingCallback(TrainerCallback):
def __init__(self):
super().__init__()
self.training_started = time.time()
self.last_log_time = time.time()
self.last_step = 0
self.verify_sequence = None
self.sequence_samples = None
self.sample_indices = None
def on_step_end(self, args, state, control, **kwargs):
# Log every 50 steps or every 5 minutes, whichever comes first
current_time = time.time()
# Perform actual sequence integrity verification if enabled
if self.verify_sequence is True and state.global_step % 100 == 0 and self.sequence_samples:
try:
# Get a batch of data without disturbing the training
train_dataloader = trainer.get_train_dataloader()
if train_dataloader is None:
log_info("Warning: Could not get train dataloader for verification")
else:
batch_iterator = iter(train_dataloader)
if batch_iterator is None:
log_info("Warning: Could not get batch iterator for verification")
else:
try:
batch = next(batch_iterator)
if batch is None:
log_info("Warning: Could not get batch for verification")
elif 'input_ids' in batch and 'labels' in batch:
log_info("Verifying data sequence integrity...")
# Check if we can access some of our reference samples
if not hasattr(trainer, 'train_dataset') or trainer.train_dataset is None:
log_info("Warning: Train dataset is not available")
else:
# Get current samples defensively
current_samples = []
current_indices = list(range(min(3, len(trainer.train_dataset))))
for idx in current_indices:
try:
if idx < len(trainer.train_dataset):
current_samples.append(trainer.train_dataset[idx])
except Exception as e:
log_info(f"Warning: Error accessing dataset at index {idx}: {e}")
# Only proceed if we have samples to compare
if current_samples and self.sequence_samples:
# Compare current samples with our reference samples from training start
is_sequence_maintained = True
for i, (orig_idx, orig_sample) in enumerate(zip(self.sample_indices, self.sequence_samples)):
# Check if sample index is valid
if i < len(current_samples):
current_sample = current_samples[i]
# Compare prompt numbers if available
if ('prompt_number' in orig_sample and
'prompt_number' in current_sample and
orig_sample['prompt_number'] is not None and
current_sample['prompt_number'] is not None):
if orig_sample['prompt_number'] != current_sample['prompt_number']:
log_info(f"WARNING: Sequence integrity compromised! Sample {i} prompt number changed from {orig_sample['prompt_number']} to {current_sample['prompt_number']}")
is_sequence_maintained = False
# Also compare IDs as a backup check
elif ('article_id' in orig_sample and
'article_id' in current_sample and
orig_sample['article_id'] is not None and
current_sample['article_id'] is not None):
if orig_sample['article_id'] != current_sample['article_id']:
log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}")
is_sequence_maintained = False
# Compare input fingerprints
if ('conversations' in orig_sample and
'conversations' in current_sample and
orig_sample['conversations'] is not None and
current_sample['conversations'] is not None):
orig_len = len(orig_sample['conversations'])
curr_len = len(current_sample['conversations'])
if orig_len != curr_len:
log_info(f"WARNING: Sequence integrity compromised! Sample {i} conversation length changed from {orig_len} to {curr_len}")
is_sequence_maintained = False
if is_sequence_maintained:
log_info("Data sequence integrity check: OK")
else:
log_info("CRITICAL WARNING: Data sequence integrity check FAILED!")
else:
log_info("Warning: Not enough samples available for sequence verification")
except StopIteration:
log_info("Warning: No batches available in the dataloader")
except Exception as e:
log_info(f"Warning: Error iterating through dataloader: {e}")
except Exception as e:
log_info(f"Warning: Couldn't verify sequence integrity: {e}")
time_interval = current_time - self.last_log_time
step_interval = state.global_step - self.last_step
if step_interval >= 50 or time_interval >= 300: # 5 minutes = 300 seconds
# Calculate throughput
examples_per_second = step_interval * args.per_device_train_batch_size * args.gradient_accumulation_steps / max(time_interval, 1e-6)
elapsed_total = time.strftime("%H:%M:%S", time.gmtime(current_time - self.training_started))
# Log progress
log_info(f"Step: {state.global_step}, Loss: {state.log_history[-1]['loss']:.4f}, "
f"Rate: {examples_per_second:.2f} examples/sec, Elapsed: {elapsed_total}")
# Report memory usage if CUDA is available
if CUDA_AVAILABLE:
log_info(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB allocated, "
f"{torch.cuda.max_memory_reserved() / 1024**3:.2f} GB reserved")
# Reset for next interval
self.last_log_time = current_time
self.last_step = state.global_step
def on_train_begin(self, args, state, control, **kwargs):
log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
# Set up sequence verification with actual sample capturing
try:
self.verify_sequence = dataset_config.get("validation", {}).get("verify_sequence_integrity", False)
if self.verify_sequence:
log_info("Sequence integrity verification enabled during training")
# Save actual samples for later verification
if trainer and hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None:
# Get some reference samples from the beginning of the dataset defensively
self.sample_indices = []
self.sequence_samples = []
max_samples = min(5, len(trainer.train_dataset))
for i in range(max_samples):
try:
if i < len(trainer.train_dataset):
self.sample_indices.append(i)
self.sequence_samples.append(trainer.train_dataset[i])
except Exception as e:
log_info(f"Warning: Error capturing reference sample at index {i}: {e}")
if self.sequence_samples:
log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification")
# Log sample prompt numbers for debugging
sample_prompt_numbers = []
for s in self.sequence_samples:
if isinstance(s, dict) and 'prompt_number' in s and s['prompt_number'] is not None:
sample_prompt_numbers.append(s.get('prompt_number'))
if sample_prompt_numbers:
log_info(f"Reference sample prompt numbers: {sample_prompt_numbers}")
else:
log_info("Warning: No reference samples were captured")
else:
log_info("Warning: Could not capture reference samples - verification will be limited")
except Exception as e:
log_info(f"Warning: Could not set up sequence integrity verification: {e}")
self.verify_sequence = False
log_info("=== Training is starting ===")
# Log important training parameters for visibility
total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
log_info(f"Learning rate: {args.learning_rate}")
log_info(f"Epochs: {args.num_train_epochs}")
# Log memory information in compact format
if CUDA_AVAILABLE:
memory_info = []
for i in range(NUM_GPUS):
allocated = torch.cuda.memory_allocated(i) / 1024**2
max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
log_info(f"Initial memory usage - {', '.join(memory_info)}")
def on_train_end(self, args, state, control, **kwargs):
training_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - self.training_started))
log_info(f"=== Training completed in {training_time} ===")
# Log final memory usage
if CUDA_AVAILABLE:
for i in range(NUM_GPUS):
max_mem = torch.cuda.max_memory_allocated(i) / 1024**3 # GB
log_info(f"GPU {i} max memory: {max_mem:.2f} GB")
# Clear GPU memory
torch.cuda.empty_cache()
log_info("GPU memory cleared")
log_info(f"Total steps: {state.global_step}")
log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}")
def check_dependencies():
"""Check if all required dependencies are installed."""
missing_packages = []
# Critical packages
if not unsloth_available:
missing_packages.append("unsloth>=2024.3")
if not peft_available:
missing_packages.append("peft>=0.9.0")
# Optional packages - don't add to missing list, just log
if find_spec("flash_attn"):
logger.info("flash-attn found. Flash attention will be used for faster training.")
else:
logger.warning("flash-attn not found. Training will work but may be slower.")
logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
# If critical packages are missing, exit with instructions
if missing_packages:
logger.error("Critical dependencies missing:")
for pkg in missing_packages:
logger.error(f" - {pkg}")
logger.error("Please ensure the space has these packages in requirements.txt")
return False
return True
def main():
# Set up logging
logger.info("Starting training process")
# Parse arguments
args = parse_args()
# Load environment variables
load_env_variables()
# Load configuration
try:
transformers_config = load_configs(args.config)
hardware_config = transformers_config.get("hardware", {})
dataset_config = transformers_config.get("dataset", {})
logger.info("Configuration loaded successfully")
except Exception as e:
logger.error(f"Error loading configuration: {e}")
return 1
# Check dependencies
if not check_dependencies():
logger.error("Aborting due to missing critical dependencies")
return 1
# Check if we're in distributed mode
is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
if is_distributed:
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}")
else:
log_info("Running in non-distributed mode (single process)")
# Set random seed for reproducibility
seed = transformers_config.get("seed", 42)
set_seed(seed)
logger.info(f"Set random seed to {seed}")
# Load model and tokenizer using the consolidated config
model, tokenizer = load_model_and_tokenizer(transformers_config)
# Empty CUDA cache to ensure clean state
if CUDA_AVAILABLE:
torch.cuda.empty_cache()
log_info("Cleared CUDA cache")
# Setup environment variable for CUDA memory allocation
if CUDA_AVAILABLE:
system_settings = hardware_config.get("system_settings", {})
cuda_memory_fraction = system_settings.get("cuda_memory_fraction", 0.85)
if cuda_memory_fraction < 1.0:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
try:
log_info("Loading dataset...")
dataset = load_dataset_with_mapping(dataset_config)
log_info(f"Dataset loaded with {len(dataset)} examples")
# Minimal validation before proceeding
if dataset is None or len(dataset) == 0:
logger.error("Dataset is empty or None! Cannot proceed with training.")
return 1
# Create data collator
data_collator = SimpleDataCollator(tokenizer, dataset_config)
# Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
# First check hardware config, then transformers config
use_bf16 = False
use_fp16 = False
# Check hardware config first
hardware_precision = hardware_config.get("training_optimizations", {}).get("mixed_precision", "")
if hardware_precision.lower() == "bf16":
use_bf16 = True
log_info("Using BF16 precision from hardware config")
elif hardware_precision.lower() == "fp16":
use_fp16 = True
log_info("Using FP16 precision from hardware config")
else:
# Fall back to transformers config
use_bf16 = transformers_config.get("bf16", False) or transformers_config.get("torch_dtype", "") == "bfloat16"
use_fp16 = transformers_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set
log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}")
# Get per device batch size - from transformers config, but possibly overridden by hardware config
per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16)
gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3)
# For multi-GPU setup, adjust for better balance
if CUDA_AVAILABLE and NUM_GPUS > 1:
log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs")
# Set up FSDP for multi-GPU training if specified and in distributed mode
fsdp_config = None
if multi_gpu_strategy == "fsdp" and is_distributed and NUM_GPUS > 1:
try:
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
enable_wrap,
wrap,
)
log_info("Using FSDP for distributed training")
# Configure FSDP
fsdp_config = {
"fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
"fsdp_offload_params": False,
"fsdp_backward_prefetch": "BACKWARD_PRE",
"fsdp_min_num_params": 1e6,
"fsdp_sharding_strategy": 1, # FULL_SHARD
}
if use_bf16 or use_fp16:
precision_type = "bf16" if use_bf16 else "fp16"
fsdp_config["fsdp_state_dict_type"] = "FULL_STATE_DICT"
log_info(f"FSDP using mixed precision: {precision_type}")
except ImportError:
log_info("FSDP imports failed, falling back to standard DDP")
fsdp_config = None
elif multi_gpu_strategy == "fsdp" and not is_distributed:
log_info("FSDP disabled: requires distributed environment (use torchrun or accelerate)")
log_info("Using DataParallel for multi-GPU training instead")
else:
log_info(f"Using {multi_gpu_strategy} for multi-GPU training")
# Get system settings from hardware config
dataloader_workers = hardware_config.get("system_settings", {}).get("dataloader_num_workers", 2)
pin_memory = hardware_config.get("system_settings", {}).get("dataloader_pin_memory", True)
# Set up training arguments
log_info("Setting up training arguments")
training_args = TrainingArguments(
output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"),
num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3),
per_device_train_batch_size=per_device_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=transformers_config.get("training", {}).get("learning_rate", 2e-5),
weight_decay=transformers_config.get("training", {}).get("weight_decay", 0.01),
warmup_ratio=transformers_config.get("training", {}).get("warmup_ratio", 0.05),
lr_scheduler_type=transformers_config.get("training", {}).get("lr_scheduler_type", "cosine"),
logging_steps=transformers_config.get("training", {}).get("logging_steps", 10),
save_strategy=transformers_config.get("checkpointing", {}).get("save_strategy", "steps"),
save_steps=transformers_config.get("checkpointing", {}).get("save_steps", 100),
save_total_limit=transformers_config.get("checkpointing", {}).get("save_total_limit", 3),
fp16=use_fp16,
bf16=use_bf16,
max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0),
push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False),
hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None),
hub_token=os.environ.get("HF_TOKEN", None),
report_to="tensorboard",
remove_unused_columns=False, # Keep all columns
gradient_checkpointing=transformers_config.get("training", {}).get("gradient_checkpointing", True),
dataloader_pin_memory=pin_memory,
optim=transformers_config.get("training", {}).get("optim", "adamw_torch"),
ddp_find_unused_parameters=False, # Improve distributed training efficiency
dataloader_drop_last=False, # Process all examples
dataloader_num_workers=dataloader_workers,
no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available
# Only add FSDP if we're in distributed mode with FSDP strategy
fsdp=fsdp_config if is_distributed and multi_gpu_strategy == "fsdp" else None,
)
# Create sequential sampler to maintain original dataset order
sequential_sampler = torch.utils.data.SequentialSampler(dataset)
# Initialize trainer first
log_info("Initializing Trainer")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset, # We'll override this with our custom dataloader
data_collator=data_collator,
callbacks=[LoggingCallback()],
)
# Then override the get_train_dataloader method
def custom_get_train_dataloader():
"""Custom dataloader that preserves original dataset order"""
log_info("Creating sequential dataloader to maintain original dataset order")
# Verification of sequence preservation flags - consolidated
data_loading_config = dataset_config.get("data_loading", {})
sequential_processing = data_loading_config.get("sequential_processing", True)
shuffle_disabled = not data_loading_config.get("shuffle", False)
if not sequential_processing:
log_info("CRITICAL WARNING: sequential_processing flag is disabled! This may affect data order.")
log_info("Data sequence integrity is essential - using sequential sampler regardless of flag.")
# Force sequential processing regardless of flag
if not shuffle_disabled:
log_info("CRITICAL ERROR: Shuffle is not disabled! This will randomize data entry order!")
# Actually handle the error rather than just logging it
raise ValueError("Dataset shuffling is enabled but sequential processing is required. " +
"Please disable shuffling in your configuration.")
# Calculate batch size based on device availability
if getattr(training_args, "no_cuda", False):
batch_size = training_args.per_device_train_batch_size
else:
batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1)
log_info(f"Using sequential sampler with batch size {batch_size}")
# Return DataLoader with sequential sampler
return torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
sampler=sequential_sampler,
collate_fn=data_collator,
drop_last=training_args.dataloader_drop_last,
num_workers=training_args.dataloader_num_workers,
pin_memory=training_args.dataloader_pin_memory,
)
# Override the get_train_dataloader method
trainer.get_train_dataloader = custom_get_train_dataloader
# Start training
log_info("=== Starting Training ===")
try:
# Empty cache again right before training
if CUDA_AVAILABLE:
torch.cuda.empty_cache()
log_info("Cleared CUDA cache before training")
# Display compact training info
total_steps = int(len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps) * training_args.num_train_epochs)
log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps")
trainer.train()
log_info("Training completed successfully!")
# Save the final model
log_info("Saving final model...")
trainer.save_model()
log_info(f"Model saved to {training_args.output_dir}")
# Push to hub if enabled
if transformers_config.get("huggingface_hub", {}).get("push_to_hub", False):
hub_id = transformers_config.get("huggingface_hub", {}).get("hub_model_id", "model")
log_info(f"Pushing model to Hugging Face Hub as {hub_id}...")
trainer.push_to_hub()
log_info("Model successfully pushed to Hub")
return 0
except Exception as e:
logger.error(f"Training failed with error: {str(e)}")
# Log CUDA memory info if available in compact format
if CUDA_AVAILABLE:
memory_info = []
for i in range(NUM_GPUS):
allocated = torch.cuda.memory_allocated(i) / 1024**2
reserved = torch.cuda.memory_reserved(i) / 1024**2
max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)")
logger.error(f"GPU memory at failure: {', '.join(memory_info)}")
raise
except Exception as e:
logger.error(f"Error in main training loop: {str(e)}")
return 1
if __name__ == "__main__":
sys.exit(main())