#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Basic Python imports import os import sys import json import argparse import logging from datetime import datetime import time import warnings import traceback from importlib.util import find_spec import multiprocessing import torch import random import numpy as np from tqdm import tqdm # Check hardware capabilities first CUDA_AVAILABLE = "CUDA_VISIBLE_DEVICES" in os.environ or os.environ.get("NVIDIA_VISIBLE_DEVICES") != "" NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0 DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu" # Set the multiprocessing start method to 'spawn' for CUDA compatibility if CUDA_AVAILABLE: try: multiprocessing.set_start_method('spawn', force=True) print("Set multiprocessing start method to 'spawn' for CUDA compatibility") except RuntimeError: # Method already set, which is fine print("Multiprocessing start method already set") # Import order is important: unsloth should be imported before transformers # Check for libraries without importing them unsloth_available = find_spec("unsloth") is not None if unsloth_available: import unsloth # Import torch first, then transformers if available import torch transformers_available = find_spec("transformers") is not None if transformers_available: import transformers from transformers import AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, set_seed from torch.utils.data import DataLoader peft_available = find_spec("peft") is not None if peft_available: import peft # Only import HF datasets if available datasets_available = find_spec("datasets") is not None if datasets_available: from datasets import load_dataset # Set up the logger logger = logging.getLogger(__name__) log_handler = logging.StreamHandler() log_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') log_handler.setFormatter(log_format) logger.addHandler(log_handler) logger.setLevel(logging.INFO) # Define a clean logging function for HF Space compatibility def log_info(message): """Log information in a format compatible with Hugging Face Spaces""" # Just use the logger, but ensure consistent formatting logger.info(message) # Also ensure output is flushed immediately for streaming sys.stdout.flush() # Check for BitsAndBytes try: from transformers import BitsAndBytesConfig bitsandbytes_available = True except ImportError: bitsandbytes_available = False logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") # Check for PEFT try: from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training peft_available = True except ImportError: peft_available = False logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") def load_env_variables(): """Load environment variables from system, .env file, or Hugging Face Space variables.""" # Check if we're running in a Hugging Face Space if os.environ.get("SPACE_ID"): logging.info("Running in Hugging Face Space") # Log the presence of variables (without revealing values) logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") # If username is not set, try to extract from SPACE_ID if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): username = os.environ.get("SPACE_ID").split("/")[0] os.environ["HF_USERNAME"] = username logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") else: # Try to load from .env file if not in a Space try: from dotenv import load_dotenv # First check the current directory env_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".env") if os.path.exists(env_path): load_dotenv(env_path) logging.info(f"Loaded environment variables from {env_path}") logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") else: # Try the shared directory as fallback shared_env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") if os.path.exists(shared_env_path): load_dotenv(shared_env_path) logging.info(f"Loaded environment variables from {shared_env_path}") logging.info(f"HF_TOKEN loaded from shared .env file: {bool(os.environ.get('HF_TOKEN'))}") logging.info(f"HF_USERNAME loaded from shared .env file: {bool(os.environ.get('HF_USERNAME'))}") logging.info(f"HF_SPACE_NAME loaded from shared .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") else: logging.warning(f"No .env file found in current or shared directory") except ImportError: logging.warning("python-dotenv not installed, not loading from .env file") if not os.environ.get("HF_TOKEN"): logger.warning("HF_TOKEN is not set. Pushing to Hugging Face Hub will not work.") if not os.environ.get("HF_USERNAME"): logger.warning("HF_USERNAME is not set. Using default username.") if not os.environ.get("HF_SPACE_NAME"): logger.warning("HF_SPACE_NAME is not set. Using default space name.") # Set HF_TOKEN for huggingface_hub if os.environ.get("HF_TOKEN"): os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") def load_configs(base_path): """Load configuration from transformers_config.json file.""" # Using a single consolidated config file config_file = base_path try: with open(config_file, "r") as f: config = json.load(f) logger.info(f"Loaded configuration from {config_file}") return config except Exception as e: logger.error(f"Error loading {config_file}: {e}") raise def parse_args(): """ Parse command line arguments for the training script. Returns: argparse.Namespace: The parsed command line arguments """ parser = argparse.ArgumentParser(description="Run training for language models") parser.add_argument( "--config_file", type=str, default=None, help="Path to the configuration file (default: transformers_config.json in script directory)" ) parser.add_argument( "--seed", type=int, default=None, help="Random seed for reproducibility (default: based on current time)" ) parser.add_argument( "--log_level", type=str, choices=["debug", "info", "warning", "error", "critical"], default="info", help="Logging level (default: info)" ) return parser.parse_args() def load_model_and_tokenizer(config): """ Load the model and tokenizer according to the configuration. Args: config (dict): Complete configuration dictionary Returns: tuple: (model, tokenizer) - The loaded model and tokenizer """ # Extract model configuration model_config = get_config_value(config, "model", {}) model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit") use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True) trust_remote_code = get_config_value(model_config, "trust_remote_code", True) model_revision = get_config_value(config, "model_revision", "main") # Unsloth configuration unsloth_config = get_config_value(config, "unsloth", {}) unsloth_enabled = get_config_value(unsloth_config, "enabled", True) # Tokenizer configuration tokenizer_config = get_config_value(config, "tokenizer", {}) max_seq_length = min( get_config_value(tokenizer_config, "max_seq_length", 2048), 4096 # Maximum supported by most models ) add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True) chat_template = get_config_value(tokenizer_config, "chat_template", None) padding_side = get_config_value(tokenizer_config, "padding_side", "right") # Check for flash attention use_flash_attention = get_config_value(config, "use_flash_attention", False) flash_attention_available = False try: import flash_attn flash_attention_available = True log_info(f"Flash Attention detected (version: {flash_attn.__version__})") except ImportError: if use_flash_attention: log_info("Flash Attention requested but not available") log_info(f"Loading model: {model_name} (revision: {model_revision})") log_info(f"Max sequence length: {max_seq_length}") try: if unsloth_enabled and unsloth_available: log_info("Using Unsloth for 4-bit quantized model and LoRA") # Load using Unsloth from unsloth import FastLanguageModel model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, dtype=get_config_value(config, "torch_dtype", "bfloat16"), revision=model_revision, trust_remote_code=trust_remote_code, use_flash_attention_2=use_flash_attention and flash_attention_available ) # Configure tokenizer settings tokenizer.padding_side = padding_side if add_eos_token and tokenizer.eos_token is None: log_info("Setting EOS token") tokenizer.add_special_tokens({"eos_token": ""}) # Set chat template if specified if chat_template: log_info(f"Setting chat template: {chat_template}") if hasattr(tokenizer, "chat_template"): tokenizer.chat_template = chat_template else: log_info("Tokenizer does not support chat templates, using default formatting") # Apply LoRA lora_r = get_config_value(unsloth_config, "r", 16) lora_alpha = get_config_value(unsloth_config, "alpha", 32) lora_dropout = get_config_value(unsloth_config, "dropout", 0) target_modules = get_config_value(unsloth_config, "target_modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") model = FastLanguageModel.get_peft_model( model, r=lora_r, target_modules=target_modules, lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias="none", use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True), random_state=0, max_seq_length=max_seq_length, modules_to_save=None ) if use_flash_attention and flash_attention_available: log_info("šŸš€ Using Flash Attention for faster training") elif use_flash_attention and not flash_attention_available: log_info("āš ļø Flash Attention requested but not available - using standard attention") else: # Standard HuggingFace loading log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)") from transformers import AutoModelForCausalLM, AutoTokenizer # Check if flash attention should be enabled in config use_attn_implementation = None if use_flash_attention and flash_attention_available: use_attn_implementation = "flash_attention_2" log_info("šŸš€ Using Flash Attention for faster training") # Load tokenizer first tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=trust_remote_code, use_fast=use_fast_tokenizer, revision=model_revision, padding_side=padding_side ) # Configure tokenizer settings if add_eos_token and tokenizer.eos_token is None: log_info("Setting EOS token") tokenizer.add_special_tokens({"eos_token": ""}) # Set chat template if specified if chat_template: log_info(f"Setting chat template: {chat_template}") if hasattr(tokenizer, "chat_template"): tokenizer.chat_template = chat_template else: log_info("Tokenizer does not support chat templates, using default formatting") # Now load model with updated tokenizer model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=trust_remote_code, revision=model_revision, torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16, device_map="auto" if CUDA_AVAILABLE else None, attn_implementation=use_attn_implementation ) # Apply PEFT/LoRA if enabled but using standard loading if peft_available and get_config_value(unsloth_config, "enabled", True): log_info("Applying standard PEFT/LoRA configuration") from peft import LoraConfig, get_peft_model lora_r = get_config_value(unsloth_config, "r", 16) lora_alpha = get_config_value(unsloth_config, "alpha", 32) lora_dropout = get_config_value(unsloth_config, "dropout", 0) target_modules = get_config_value(unsloth_config, "target_modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") lora_config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) # Print model summary log_info(f"Model loaded successfully: {model.__class__.__name__}") if hasattr(model, "print_trainable_parameters"): model.print_trainable_parameters() else: total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})") return model, tokenizer except Exception as e: log_info(f"Error loading model: {str(e)}") traceback.print_exc() return None, None def load_dataset_with_mapping(config): """ Load dataset from Hugging Face or local files and apply necessary transformations. Args: config (dict): Dataset configuration dictionary Returns: Dataset: The loaded and processed dataset """ # Extract dataset configuration dataset_info = get_config_value(config, "dataset", {}) dataset_name = get_config_value(dataset_info, "name", None) dataset_split = get_config_value(dataset_info, "split", "train") # Data formatting configuration formatting_config = get_config_value(config, "data_formatting", {}) if not dataset_name: raise ValueError("Dataset name not specified in config") log_info(f"Loading dataset: {dataset_name} (split: {dataset_split})") try: # Load dataset from Hugging Face or local path from datasets import load_dataset # Check if it's a local path or Hugging Face dataset if os.path.exists(dataset_name) or os.path.exists(os.path.join(os.getcwd(), dataset_name)): log_info(f"Loading dataset from local path: {dataset_name}") # Local dataset - check if it's a directory or file if os.path.isdir(dataset_name): # Directory - look for data files dataset = load_dataset( "json", data_files={"train": os.path.join(dataset_name, "*.json")}, split=dataset_split ) else: # Single file dataset = load_dataset( "json", data_files={"train": dataset_name}, split=dataset_split ) else: # Hugging Face dataset log_info(f"Loading dataset from Hugging Face: {dataset_name}") dataset = load_dataset(dataset_name, split=dataset_split) log_info(f"Dataset loaded with {len(dataset)} examples") # Check if dataset contains required fields required_fields = ["conversations"] missing_fields = [field for field in required_fields if field not in dataset.column_names] if missing_fields: log_info(f"WARNING: Dataset missing required fields: {missing_fields}") log_info("Attempting to map dataset structure to required format") # Implement conversion logic based on dataset structure if "messages" in dataset.column_names: log_info("Converting 'messages' field to 'conversations' format") dataset = dataset.map( lambda x: {"conversations": x["messages"]}, remove_columns=["messages"] ) elif "text" in dataset.column_names: log_info("Converting plain text to conversations format") dataset = dataset.map( lambda x: {"conversations": [{"role": "user", "content": x["text"]}]}, remove_columns=["text"] ) else: raise ValueError(f"Cannot convert dataset format - missing required fields and no conversion path available") # Log dataset info log_info(f"Dataset has {len(dataset)} examples and columns: {dataset.column_names}") # Show a few examples for verification for i in range(min(3, len(dataset))): example = dataset[i] log_info(f"Example {i}:") for key, value in example.items(): if key == "conversations": log_info(f" conversations: {len(value)} messages") # Show first message only to avoid cluttering logs if value and len(value) > 0: first_msg = value[0] if isinstance(first_msg, dict) and "content" in first_msg: content = first_msg["content"] log_info(f" First message: {content[:50]}..." if len(content) > 50 else f" First message: {content}") else: log_info(f" {key}: {value}") return dataset except Exception as e: log_info(f"Error loading dataset: {str(e)}") traceback.print_exc() return None def format_phi_chat(messages, dataset_config): """Format messages according to phi-4's chat template and dataset config. Only formats the conversation structure, preserves the actual content.""" formatted_chat = "" # Get role templates from config roles = dataset_config.get("data_formatting", {}).get("roles", { "system": "System: {content}\n\n", "human": "Human: {content}\n\n", "assistant": "Assistant: {content}\n\n" }) # Handle each message in the conversation for message in messages: if not isinstance(message, dict) or "content" not in message: logger.warning(f"Skipping invalid message format: {message}") continue content = message.get("content", "") # Don't strip() - preserve exact content # Skip empty content if not content: continue # Only add role prefixes based on position/content if "[RESEARCH INTRODUCTION]" in content: # System message template = roles.get("system", "System: {content}\n\n") formatted_chat = template.format(content=content) + formatted_chat else: # Alternate between human and assistant for regular conversation turns # In phi-4 format, human messages come first, followed by assistant responses if len(formatted_chat.split("Human:")) == len(formatted_chat.split("Assistant:")): # If equal numbers of Human and Assistant messages, next is Human template = roles.get("human", "Human: {content}\n\n") else: # Otherwise, next is Assistant template = roles.get("assistant", "Assistant: {content}\n\n") formatted_chat += template.format(content=content) return formatted_chat class SimpleDataCollator: def __init__(self, tokenizer, dataset_config): self.tokenizer = tokenizer self.max_seq_length = min(dataset_config.get("max_seq_length", 2048), tokenizer.model_max_length) self.stats = { "processed": 0, "skipped": 0, "total_tokens": 0 } logger.info(f"Initialized SimpleDataCollator with max_seq_length={self.max_seq_length}") def __call__(self, features): # Initialize tensors on CPU to save GPU memory batch = { "input_ids": [], "attention_mask": [], "labels": [] } for feature in features: paper_id = feature.get("article_id", "unknown") prompt_num = feature.get("prompt_number", 0) conversations = feature.get("conversations", []) if not conversations: logger.warning(f"No conversations for paper_id {paper_id}, prompt {prompt_num}") self.stats["skipped"] += 1 continue # Get the content directly content = conversations[0].get("content", "") if not content: logger.warning(f"Empty content for paper_id {paper_id}, prompt {prompt_num}") self.stats["skipped"] += 1 continue # Process the content string by tokenizing it if isinstance(content, str): # Tokenize the content string input_ids = self.tokenizer.encode(content, add_special_tokens=True) else: # If somehow the content is already tokenized (not a string), use it directly input_ids = content # Truncate if needed if len(input_ids) > self.max_seq_length: input_ids = input_ids[:self.max_seq_length] logger.warning(f"Truncated sequence for paper_id {paper_id}, prompt {prompt_num}") # Create attention mask (1s for all tokens) attention_mask = [1] * len(input_ids) # Add to batch batch["input_ids"].append(input_ids) batch["attention_mask"].append(attention_mask) batch["labels"].append(input_ids.copy()) # For causal LM, labels = input_ids self.stats["processed"] += 1 self.stats["total_tokens"] += len(input_ids) # Log statistics periodically if self.stats["processed"] % 100 == 0: avg_tokens = self.stats["total_tokens"] / max(1, self.stats["processed"]) logger.info(f"Data collation stats: processed={self.stats['processed']}, " f"skipped={self.stats['skipped']}, avg_tokens={avg_tokens:.1f}") # Convert to tensors or pad sequences (PyTorch will handle) if batch["input_ids"]: # Pad sequences to max length in batch using the tokenizer batch = self.tokenizer.pad( batch, padding="max_length", max_length=self.max_seq_length, return_tensors="pt" ) return batch else: # Return empty batch if no valid examples return {k: [] for k in batch} def log_gpu_memory_usage(step=None, frequency=50, clear_cache_threshold=0.9, label=None): """ Log GPU memory usage statistics with optional cache clearing Args: step: Current training step (if None, logs regardless of frequency) frequency: How often to log when step is provided clear_cache_threshold: Fraction of memory used that triggers cache clearing (0-1) label: Optional label for the log message (e.g., "Initial", "Error", "Step") """ if not CUDA_AVAILABLE: return # Only log every 'frequency' steps if step is provided if step is not None and frequency > 0 and step % frequency != 0: return # Get memory usage for each GPU memory_info = [] for i in range(NUM_GPUS): allocated = torch.cuda.memory_allocated(i) / (1024 ** 2) # MB reserved = torch.cuda.memory_reserved(i) / (1024 ** 2) # MB max_mem = torch.cuda.max_memory_allocated(i) / (1024 ** 2) # MB # Calculate percentage of reserved memory that's allocated usage_percent = (allocated / reserved) * 100 if reserved > 0 else 0 memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB ({usage_percent:.1f}%, max: {max_mem:.1f}MB)") # Automatically clear cache if over threshold if clear_cache_threshold > 0 and reserved > 0 and (allocated / reserved) > clear_cache_threshold: log_info(f"Clearing CUDA cache for GPU {i} - high utilization ({allocated:.1f}/{reserved:.1f}MB)") with torch.cuda.device(i): torch.cuda.empty_cache() prefix = f"{label} " if label else "" log_info(f"{prefix}GPU Memory: {', '.join(memory_info)}") class LoggingCallback(TrainerCallback): """ Custom callback for logging training progress and metrics. Provides detailed information about training status, GPU memory usage, and model performance. """ def __init__(self, model=None, dataset=None): # Ensure we have TrainerCallback try: super().__init__() except Exception as e: # Try to import directly if initial import failed try: from transformers.trainer_callback import TrainerCallback self.__class__.__bases__ = (TrainerCallback,) super().__init__() log_info("Successfully imported TrainerCallback directly") except ImportError as ie: log_info(f"āŒ Error: Could not import TrainerCallback: {str(ie)}") log_info("Please ensure transformers is properly installed") raise self.training_started = time.time() self.last_log_time = time.time() self.last_step_time = None self.step_durations = [] self.best_loss = float('inf') self.model = model self.dataset = dataset def on_train_begin(self, args, state, control, **kwargs): """Called at the beginning of training""" try: log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") # Log model info if available if self.model is not None: total_params = sum(p.numel() for p in self.model.parameters()) trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) log_info(f"Model parameters: {total_params/1e6:.2f}M total, {trainable_params/1e6:.2f}M trainable") # Log dataset info if available if self.dataset is not None: log_info(f"Dataset size: {len(self.dataset)} examples") # Log important training parameters for visibility total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS total_steps = int(len(self.dataset or []) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs) log_info(f"Training plan: {len(self.dataset or [])} examples over {args.num_train_epochs} epochs ā‰ˆ {total_steps} steps") log_info(f"Batch size: {args.per_device_train_batch_size} Ɨ {args.gradient_accumulation_steps} steps Ɨ {NUM_GPUS} GPUs = {total_batch_size} total") # Log initial GPU memory usage with label log_gpu_memory_usage(label="Initial") except Exception as e: logger.warning(f"Error logging training begin statistics: {str(e)}") def on_step_end(self, args, state, control, **kwargs): """Called at the end of each step""" try: if state.global_step == 1 or state.global_step % args.logging_steps == 0: # Track step timing current_time = time.time() if self.last_step_time: step_duration = current_time - self.last_step_time self.step_durations.append(step_duration) # Keep only last 100 steps for averaging if len(self.step_durations) > 100: self.step_durations.pop(0) avg_step_time = sum(self.step_durations) / len(self.step_durations) log_info(f"Step {state.global_step}: {step_duration:.2f}s (avg: {avg_step_time:.2f}s)") self.last_step_time = current_time # Log GPU memory usage with step number log_gpu_memory_usage(state.global_step, args.logging_steps) # Log loss if state.log_history: latest_logs = state.log_history[-1] if state.log_history else {} if "loss" in latest_logs: loss = latest_logs["loss"] log_info(f"Step {state.global_step} loss: {loss:.4f}") # Track best loss if loss < self.best_loss: self.best_loss = loss log_info(f"New best loss: {loss:.4f}") except Exception as e: logger.warning(f"Error logging step end statistics: {str(e)}") def on_train_end(self, args, state, control, **kwargs): """Called at the end of training""" try: # Calculate training duration training_time = time.time() - self.training_started hours, remainder = divmod(training_time, 3600) minutes, seconds = divmod(remainder, 60) log_info(f"=== Training completed at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") log_info(f"Training duration: {int(hours)}h {int(minutes)}m {int(seconds)}s") log_info(f"Final step: {state.global_step}") log_info(f"Best loss: {self.best_loss:.4f}") # Log final GPU memory usage log_gpu_memory_usage(label="Final") except Exception as e: logger.warning(f"Error logging training end statistics: {str(e)}") # Other callback methods with proper error handling def on_save(self, args, state, control, **kwargs): """Called when a checkpoint is saved""" try: log_info(f"Saving checkpoint at step {state.global_step}") except Exception as e: logger.warning(f"Error in on_save: {str(e)}") def on_log(self, args, state, control, **kwargs): """Called when a log is created""" pass def on_evaluate(self, args, state, control, **kwargs): """Called when evaluation is performed""" pass # Only implement the methods we actually need, remove the others def on_prediction_step(self, args, state, control, **kwargs): """Called when prediction is performed""" pass def on_save_model(self, args, state, control, **kwargs): """Called when model is saved""" try: # Log memory usage after saving log_gpu_memory_usage(label=f"Save at step {state.global_step}") except Exception as e: logger.warning(f"Error in on_save_model: {str(e)}") def on_epoch_end(self, args, state, control, **kwargs): """Called at the end of an epoch""" try: epoch = state.epoch log_info(f"Completed epoch {epoch:.2f}") log_gpu_memory_usage(label=f"Epoch {epoch:.2f}") except Exception as e: logger.warning(f"Error in on_epoch_end: {str(e)}") def on_step_begin(self, args, state, control, **kwargs): """Called at the beginning of a step""" pass def install_flash_attention(): """ Attempt to install Flash Attention for improved performance. Returns True if installation was successful, False otherwise. """ log_info("Attempting to install Flash Attention...") # Check for CUDA before attempting installation if not CUDA_AVAILABLE: log_info("āŒ Cannot install Flash Attention: CUDA not available") return False try: # Check CUDA version to determine correct installation command cuda_version = torch.version.cuda if cuda_version is None: log_info("āŒ Cannot determine CUDA version for Flash Attention installation") return False import subprocess # Use --no-build-isolation for better compatibility install_cmd = [ sys.executable, "-m", "pip", "install", "flash-attn", "--no-build-isolation" ] log_info(f"Running: {' '.join(install_cmd)}") result = subprocess.run( install_cmd, capture_output=True, text=True, check=False ) if result.returncode == 0: log_info("āœ… Flash Attention installed successfully!") # Attempt to import to verify installation try: import flash_attn log_info(f"āœ… Flash Attention version {flash_attn.__version__} is now available") return True except ImportError: log_info("āš ļø Flash Attention installed but import failed") return False else: log_info(f"āŒ Flash Attention installation failed with error: {result.stderr}") return False except Exception as e: log_info(f"āŒ Error installing Flash Attention: {str(e)}") return False def check_dependencies(): """ Check for required and optional dependencies, ensuring proper versions and import order. Returns True if all required dependencies are present, False otherwise. """ # Define required packages with versions and descriptions required_packages = { "unsloth": {"version": ">=2024.3", "feature": "fast 4-bit quantization and LoRA"}, "transformers": {"version": ">=4.38.0", "feature": "core model functionality"}, "peft": {"version": ">=0.9.0", "feature": "parameter-efficient fine-tuning"}, "accelerate": {"version": ">=0.27.0", "feature": "multi-GPU training"} } # Optional packages that enhance functionality optional_packages = { "flash_attn": {"feature": "faster attention computation"}, "bitsandbytes": {"feature": "quantization support"}, "optimum": {"feature": "model optimization"}, "wandb": {"feature": "experiment tracking"} } # Store results missing_packages = [] package_versions = {} order_issues = [] missing_optional = [] # Check required packages log_info("Checking required dependencies...") for package, info in required_packages.items(): version_req = info["version"] feature = info["feature"] try: # Special handling for packages we've already checked if package == "unsloth" and not unsloth_available: missing_packages.append(f"{package}{version_req}") log_info(f"āŒ {package} - {feature} MISSING") continue elif package == "peft" and not peft_available: missing_packages.append(f"{package}{version_req}") log_info(f"āŒ {package} - {feature} MISSING") continue # Try to import and get version module = __import__(package) version = getattr(module, "__version__", "unknown") package_versions[package] = version log_info(f"āœ… {package} v{version} - {feature}") except ImportError: missing_packages.append(f"{package}{version_req}") log_info(f"āŒ {package} - {feature} MISSING") # Check optional packages log_info("\nChecking optional dependencies...") for package, info in optional_packages.items(): feature = info["feature"] try: __import__(package) log_info(f"āœ… {package} - {feature} available") except ImportError: log_info(f"āš ļø {package} - {feature} not available") missing_optional.append(package) # Check import order for optimal performance if "transformers" in package_versions and "unsloth" in package_versions: try: import sys modules = list(sys.modules.keys()) transformers_idx = modules.index("transformers") unsloth_idx = modules.index("unsloth") if transformers_idx < unsloth_idx: order_issue = "āš ļø For optimal performance, import unsloth before transformers" order_issues.append(order_issue) log_info(order_issue) log_info("This might cause performance issues but won't prevent training") else: log_info("āœ… Import order: unsloth before transformers (optimal)") except (ValueError, IndexError) as e: log_info(f"āš ļø Could not verify import order: {str(e)}") # Try to install missing optional packages if "flash_attn" in missing_optional and CUDA_AVAILABLE: log_info("\nFlash Attention is missing but would improve performance.") install_result = install_flash_attention() if install_result: missing_optional.remove("flash_attn") # Report missing required packages if missing_packages: log_info("\nāŒ Critical dependencies missing:") for pkg in missing_packages: log_info(f" - {pkg}") log_info("Please install missing dependencies with:") log_info(f" pip install {' '.join(missing_packages)}") return False log_info("\nāœ… All required dependencies satisfied!") return True def get_config_value(config, path, default=None): """ Safely get a nested value from a config dictionary using a dot-separated path. Args: config: The configuration dictionary path: Dot-separated path to the value (e.g., "training.optimizer.lr") default: Default value to return if path doesn't exist Returns: The value at the specified path or the default value """ if not config: return default parts = path.split('.') current = config for part in parts: if isinstance(current, dict) and part in current: current = current[part] else: return default return current def update_huggingface_space(): """Update the Hugging Face Space with the current code.""" log_info("Updating Hugging Face Space...") update_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "update_space.py") if not os.path.exists(update_script): logger.warning(f"Update space script not found at {update_script}") return False try: import subprocess # Explicitly set space_name to ensure we're targeting the right Space result = subprocess.run( [sys.executable, update_script, "--force", "--space_name", "phi4training"], capture_output=True, text=True, check=False ) if result.returncode == 0: log_info("Hugging Face Space updated successfully!") log_info(f"Space URL: https://huggingface.co/spaces/George-API/phi4training") return True else: logger.error(f"Failed to update Hugging Face Space: {result.stderr}") return False except Exception as e: logger.error(f"Error updating Hugging Face Space: {str(e)}") return False def validate_huggingface_credentials(): """Validate Hugging Face credentials to ensure they work correctly.""" if not os.environ.get("HF_TOKEN"): logger.warning("HF_TOKEN not found. Skipping Hugging Face credentials validation.") return False try: # Import here to avoid requiring huggingface_hub if not needed from huggingface_hub import HfApi, login # Try to login with the token login(token=os.environ.get("HF_TOKEN")) # Check if we can access the API api = HfApi() username = os.environ.get("HF_USERNAME", "George-API") space_name = os.environ.get("HF_SPACE_NAME", "phi4training") # Try to get whoami info user_info = api.whoami() logger.info(f"Successfully authenticated with Hugging Face as {user_info['name']}") # Check if we're using the expected Space expected_space_id = "George-API/phi4training" actual_space_id = f"{username}/{space_name}" if actual_space_id != expected_space_id: logger.warning(f"Using Space '{actual_space_id}' instead of the expected '{expected_space_id}'") logger.warning(f"Make sure this is intentional. To use the correct Space, update your .env file.") else: logger.info(f"Confirmed using Space: {expected_space_id}") # Check if the space exists try: space_id = f"{username}/{space_name}" space_info = api.space_info(repo_id=space_id) logger.info(f"Space {space_id} is accessible at: https://huggingface.co/spaces/{space_id}") return True except Exception as e: logger.warning(f"Could not access Space {username}/{space_name}: {str(e)}") logger.warning("Space updating may not work correctly") return False except ImportError: logger.warning("huggingface_hub not installed. Cannot validate Hugging Face credentials.") return False except Exception as e: logger.warning(f"Error validating Hugging Face credentials: {str(e)}") return False def setup_environment(args): """ Set up the training environment including logging, seed, and configurations. Args: args: Command line arguments Returns: tuple: (transformers_config, seed) - The loaded configuration and random seed """ # Load environment variables first load_env_variables() # Set random seed for reproducibility seed = args.seed if args.seed is not None else int(time.time()) % 10000 set_seed(seed) log_info(f"Using random seed: {seed}") # Load configuration base_path = os.path.dirname(os.path.abspath(__file__)) config_file = args.config_file or os.path.join(base_path, "transformers_config.json") if not os.path.exists(config_file): raise FileNotFoundError(f"Config file not found: {config_file}") log_info(f"Loading configuration from {config_file}") transformers_config = load_configs(config_file) # Set up hardware environment variables if CUDA is available if CUDA_AVAILABLE: memory_fraction = get_config_value(transformers_config, "hardware.system_settings.cuda_memory_fraction", 0.75) if memory_fraction < 1.0: os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True" log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128") # Check dependencies and install optional ones if needed if not check_dependencies(): raise RuntimeError("Critical dependencies missing") # Check if flash attention was successfully installed flash_attention_available = False try: import flash_attn flash_attention_available = True log_info(f"Flash Attention will be used (version: {flash_attn.__version__})") # Update config to use flash attention if "use_flash_attention" not in transformers_config: transformers_config["use_flash_attention"] = True except ImportError: log_info("Flash Attention not available, will use standard attention mechanism") return transformers_config, seed def setup_model_and_tokenizer(config): """ Load and configure the model and tokenizer. Args: config (dict): Complete configuration dictionary Returns: tuple: (model, tokenizer) - The loaded model and tokenizer """ # Extract model configuration model_config = get_config_value(config, "model", {}) model_name = get_config_value(model_config, "name", "unsloth/phi-4-unsloth-bnb-4bit") use_fast_tokenizer = get_config_value(model_config, "use_fast_tokenizer", True) trust_remote_code = get_config_value(model_config, "trust_remote_code", True) model_revision = get_config_value(config, "model_revision", "main") # Detect if model is already pre-quantized (includes '4bit', 'bnb', or 'int4' in name) is_prequantized = any(q in model_name.lower() for q in ['4bit', 'bnb', 'int4', 'quant']) if is_prequantized: log_info("āš ļø Detected pre-quantized model. No additional quantization will be applied.") # Unsloth configuration unsloth_config = get_config_value(config, "unsloth", {}) unsloth_enabled = get_config_value(unsloth_config, "enabled", True) # Tokenizer configuration tokenizer_config = get_config_value(config, "tokenizer", {}) max_seq_length = min( get_config_value(tokenizer_config, "max_seq_length", 2048), 4096 # Maximum supported by most models ) add_eos_token = get_config_value(tokenizer_config, "add_eos_token", True) chat_template = get_config_value(tokenizer_config, "chat_template", None) padding_side = get_config_value(tokenizer_config, "padding_side", "right") # Check for flash attention use_flash_attention = get_config_value(config, "use_flash_attention", False) flash_attention_available = False try: import flash_attn flash_attention_available = True log_info(f"Flash Attention detected (version: {flash_attn.__version__})") except ImportError: if use_flash_attention: log_info("Flash Attention requested but not available") log_info(f"Loading model: {model_name} (revision: {model_revision})") log_info(f"Max sequence length: {max_seq_length}") try: if unsloth_enabled and unsloth_available: log_info("Using Unsloth for LoRA fine-tuning") if is_prequantized: log_info("Using pre-quantized model - no additional quantization will be applied") else: log_info("Using 4-bit quantization for efficient training") # Load using Unsloth from unsloth import FastLanguageModel model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, dtype=get_config_value(config, "torch_dtype", "bfloat16"), revision=model_revision, trust_remote_code=trust_remote_code, use_flash_attention_2=use_flash_attention and flash_attention_available ) # Configure tokenizer settings tokenizer.padding_side = padding_side if add_eos_token and tokenizer.eos_token is None: log_info("Setting EOS token") tokenizer.add_special_tokens({"eos_token": ""}) # Set chat template if specified if chat_template: log_info(f"Setting chat template: {chat_template}") if hasattr(tokenizer, "chat_template"): tokenizer.chat_template = chat_template else: log_info("Tokenizer does not support chat templates, using default formatting") # Apply LoRA lora_r = get_config_value(unsloth_config, "r", 16) lora_alpha = get_config_value(unsloth_config, "alpha", 32) lora_dropout = get_config_value(unsloth_config, "dropout", 0) target_modules = get_config_value(unsloth_config, "target_modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") model = FastLanguageModel.get_peft_model( model, r=lora_r, target_modules=target_modules, lora_alpha=lora_alpha, lora_dropout=lora_dropout, bias="none", use_gradient_checkpointing=get_config_value(config, "training.gradient_checkpointing", True), random_state=0, max_seq_length=max_seq_length, modules_to_save=None ) if use_flash_attention and flash_attention_available: log_info("šŸš€ Using Flash Attention for faster training") elif use_flash_attention and not flash_attention_available: log_info("āš ļø Flash Attention requested but not available - using standard attention") else: # Standard HuggingFace loading log_info("Using standard HuggingFace model loading (Unsloth not available or disabled)") from transformers import AutoModelForCausalLM, AutoTokenizer # Check if flash attention should be enabled in config use_attn_implementation = None if use_flash_attention and flash_attention_available: use_attn_implementation = "flash_attention_2" log_info("šŸš€ Using Flash Attention for faster training") # Load tokenizer first tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=trust_remote_code, use_fast=use_fast_tokenizer, revision=model_revision, padding_side=padding_side ) # Configure tokenizer settings if add_eos_token and tokenizer.eos_token is None: log_info("Setting EOS token") tokenizer.add_special_tokens({"eos_token": ""}) # Set chat template if specified if chat_template: log_info(f"Setting chat template: {chat_template}") if hasattr(tokenizer, "chat_template"): tokenizer.chat_template = chat_template else: log_info("Tokenizer does not support chat templates, using default formatting") # Only apply quantization config if model is not already pre-quantized quantization_config = None if not is_prequantized and CUDA_AVAILABLE: try: from transformers import BitsAndBytesConfig log_info("Using 4-bit quantization (BitsAndBytes) for efficient training") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True ) except ImportError: log_info("BitsAndBytes not available - quantization disabled") # Now load model with updated tokenizer model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=trust_remote_code, revision=model_revision, torch_dtype=torch.bfloat16 if get_config_value(config, "torch_dtype", "bfloat16") == "bfloat16" else torch.float16, device_map="auto" if CUDA_AVAILABLE else None, attn_implementation=use_attn_implementation, quantization_config=quantization_config ) # Apply PEFT/LoRA if enabled but using standard loading if peft_available and get_config_value(unsloth_config, "enabled", True): log_info("Applying standard PEFT/LoRA configuration") from peft import LoraConfig, get_peft_model lora_r = get_config_value(unsloth_config, "r", 16) lora_alpha = get_config_value(unsloth_config, "alpha", 32) lora_dropout = get_config_value(unsloth_config, "dropout", 0) target_modules = get_config_value(unsloth_config, "target_modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]) log_info(f"Applying LoRA with r={lora_r}, alpha={lora_alpha}, dropout={lora_dropout}") lora_config = LoraConfig( r=lora_r, lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=lora_dropout, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) # Print model summary log_info(f"Model loaded successfully: {model.__class__.__name__}") if hasattr(model, "print_trainable_parameters"): model.print_trainable_parameters() else: total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) log_info(f"Model has {total_params:,} parameters, {trainable_params:,} trainable ({trainable_params/total_params:.2%})") return model, tokenizer except Exception as e: log_info(f"Error loading model: {str(e)}") traceback.print_exc() return None, None def setup_dataset_and_collator(config, tokenizer): """ Load and configure the dataset and data collator. Args: config: Complete configuration dictionary tokenizer: The tokenizer for the data collator Returns: tuple: (dataset, data_collator) - The loaded dataset and configured data collator """ dataset_config = get_config_value(config, "dataset", {}) log_info("Loading dataset...") dataset = load_dataset_with_mapping(dataset_config) # Validate dataset if dataset is None: raise ValueError("Dataset is None! Cannot proceed with training.") if not hasattr(dataset, '__len__') or len(dataset) == 0: raise ValueError("Dataset is empty! Cannot proceed with training.") log_info(f"Dataset loaded with {len(dataset)} examples") # Create data collator data_collator = SimpleDataCollator(tokenizer, dataset_config) return dataset, data_collator def create_training_arguments(config, dataset): """ Create and configure training arguments for the Trainer. Args: config: Complete configuration dictionary dataset: The dataset to determine total steps Returns: TrainingArguments: Configured training arguments """ # Extract configuration sections training_config = get_config_value(config, "training", {}) hardware_config = get_config_value(config, "hardware", {}) huggingface_config = get_config_value(config, "huggingface_hub", {}) distributed_config = get_config_value(config, "distributed_training", {}) # Extract key training parameters per_device_batch_size = get_config_value(training_config, "per_device_train_batch_size", 4) gradient_accumulation_steps = get_config_value(training_config, "gradient_accumulation_steps", 8) learning_rate = get_config_value(training_config, "learning_rate", 2e-5) num_train_epochs = get_config_value(training_config, "num_train_epochs", 3) # Extract hardware settings dataloader_workers = get_config_value(hardware_config, "system_settings.dataloader_num_workers", get_config_value(distributed_config, "dataloader_num_workers", 2)) pin_memory = get_config_value(hardware_config, "system_settings.dataloader_pin_memory", True) # BF16/FP16 settings - ensure only one is enabled use_bf16 = get_config_value(training_config, "bf16", False) use_fp16 = get_config_value(training_config, "fp16", False) if not use_bf16 else False # Configure distributed training fsdp_config = get_config_value(distributed_config, "fsdp_config", {}) fsdp_enabled = get_config_value(fsdp_config, "enabled", False) ddp_config = get_config_value(distributed_config, "ddp_config", {}) ddp_find_unused_parameters = get_config_value(ddp_config, "find_unused_parameters", False) # Set up FSDP args if enabled fsdp_args = None if fsdp_enabled and NUM_GPUS > 1: from accelerate import FullyShardedDataParallelPlugin from torch.distributed.fsdp.fully_sharded_data_parallel import ( FullOptimStateDictConfig, FullStateDictConfig ) fsdp_plugin = FullyShardedDataParallelPlugin( sharding_strategy=get_config_value(fsdp_config, "sharding_strategy", "FULL_SHARD"), mixed_precision_policy=get_config_value(fsdp_config, "mixed_precision", "BF16"), state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), ) fsdp_args = { "fsdp": fsdp_plugin, "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer", "PhiDecoderLayer"] } # Create and return training arguments training_args = TrainingArguments( output_dir=get_config_value(config, "checkpointing.output_dir", "./results"), overwrite_output_dir=True, num_train_epochs=num_train_epochs, per_device_train_batch_size=per_device_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=learning_rate, weight_decay=get_config_value(training_config, "weight_decay", 0.01), max_grad_norm=get_config_value(training_config, "max_grad_norm", 1.0), warmup_ratio=get_config_value(training_config, "warmup_ratio", 0.03), lr_scheduler_type=get_config_value(training_config, "lr_scheduler_type", "cosine"), logging_steps=get_config_value(training_config, "logging_steps", 10), save_strategy=get_config_value(config, "checkpointing.save_strategy", "steps"), save_steps=get_config_value(config, "checkpointing.save_steps", 500), save_total_limit=get_config_value(config, "checkpointing.save_total_limit", 3), bf16=use_bf16, fp16=use_fp16, push_to_hub=get_config_value(huggingface_config, "push_to_hub", False), hub_model_id=get_config_value(huggingface_config, "hub_model_id", None), hub_strategy=get_config_value(huggingface_config, "hub_strategy", "every_save"), hub_private_repo=get_config_value(huggingface_config, "hub_private_repo", True), gradient_checkpointing=get_config_value(training_config, "gradient_checkpointing", True), dataloader_pin_memory=pin_memory, optim=get_config_value(training_config, "optim", "adamw_torch"), ddp_find_unused_parameters=ddp_find_unused_parameters, dataloader_drop_last=False, dataloader_num_workers=dataloader_workers, no_cuda=False if CUDA_AVAILABLE else True, **({} if fsdp_args is None else fsdp_args) ) log_info("Training arguments created successfully") return training_args def configure_custom_dataloader(trainer, dataset, config, training_args): """ Configure a custom dataloader for the trainer if needed. Args: trainer: The Trainer instance to configure dataset: The dataset to use config: Complete configuration dictionary training_args: The training arguments Returns: None (modifies trainer in-place) """ dataset_config = get_config_value(config, "dataset", {}) # Check if we need a custom dataloader if get_config_value(dataset_config, "data_loading.sequential_processing", True): log_info("Using custom sequential dataloader") # Create sequential sampler to maintain dataset order sequential_sampler = torch.utils.data.SequentialSampler(dataset) log_info("Sequential sampler created") # Define custom dataloader getter def custom_get_train_dataloader(): """Create a custom dataloader that maintains dataset order""" # Get configuration values batch_size = training_args.per_device_train_batch_size drop_last = get_config_value(dataset_config, "data_loading.drop_last", False) num_workers = training_args.dataloader_num_workers pin_memory = training_args.dataloader_pin_memory prefetch_factor = get_config_value(dataset_config, "data_loading.prefetch_factor", 2) persistent_workers = get_config_value(dataset_config, "data_loading.persistent_workers", False) # Create DataLoader with sequential sampler return DataLoader( dataset, batch_size=batch_size, sampler=sequential_sampler, collate_fn=trainer.data_collator, drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory, prefetch_factor=prefetch_factor if num_workers > 0 else None, persistent_workers=persistent_workers if num_workers > 0 else False, ) # Override the default dataloader trainer.get_train_dataloader = custom_get_train_dataloader def run_training(trainer, tokenizer, training_args): """ Run the training process and handle model saving. Args: trainer: Configured Trainer instance tokenizer: The tokenizer to save with the model training_args: Training arguments Returns: int: 0 for success, 1 for failure """ log_info("Starting training...") trainer.train() log_info("Training complete! Saving final model...") trainer.save_model() tokenizer.save_pretrained(training_args.output_dir) # Push to Hub if configured if training_args.push_to_hub: log_info(f"Pushing model to Hugging Face Hub: {training_args.hub_model_id}") trainer.push_to_hub() log_info("Training completed successfully!") return 0 def main(): """ Main entry point for the training script. Returns: int: 0 for success, non-zero for failure """ # Set up logging logger.info("Starting training process") try: # Verify critical imports are available if not transformers_available: log_info("āŒ Error: transformers library not available. Please install it with: pip install transformers") return 1 # Check for required classes for required_class in ["Trainer", "TrainingArguments", "TrainerCallback"]: if not hasattr(transformers, required_class): log_info(f"āŒ Error: {required_class} not found in transformers. Please update transformers.") return 1 # Check for potential import order issue and warn early if "transformers" in sys.modules and "unsloth" in sys.modules: if list(sys.modules.keys()).index("transformers") < list(sys.modules.keys()).index("unsloth"): log_info("āš ļø Warning: transformers was imported before unsloth. This may affect performance.") log_info(" For optimal performance in future runs, import unsloth first.") # Parse command line arguments args = parse_args() # Set up environment and load configuration transformers_config, seed = setup_environment(args) # Load model and tokenizer try: model, tokenizer = setup_model_and_tokenizer(transformers_config) except Exception as e: logger.error(f"Error setting up model: {str(e)}") return 1 # Load dataset and create data collator try: dataset, data_collator = setup_dataset_and_collator(transformers_config, tokenizer) except Exception as e: logger.error(f"Error setting up dataset: {str(e)}") return 1 # Configure training arguments try: training_args = create_training_arguments(transformers_config, dataset) except Exception as e: logger.error(f"Error configuring training arguments: {str(e)}") return 1 # Initialize trainer with callbacks log_info("Initializing Trainer") trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=data_collator, callbacks=[LoggingCallback(model=model, dataset=dataset)], ) # Configure custom dataloader if needed try: configure_custom_dataloader(trainer, dataset, transformers_config, training_args) except Exception as e: logger.error(f"Error configuring custom dataloader: {str(e)}") return 1 # Run training process try: return run_training(trainer, tokenizer, training_args) except Exception as e: logger.error(f"Training failed with error: {str(e)}") # Log GPU memory for debugging log_gpu_memory_usage(label="Error") # Print full stack trace traceback.print_exc() return 1 except Exception as e: logger.error(f"Error in main function: {str(e)}") traceback.print_exc() return 1 if __name__ == "__main__": sys.exit(main())