Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| import os | |
| import sys | |
| import json | |
| import argparse | |
| import logging | |
| from datetime import datetime | |
| import time | |
| import warnings | |
| import torch | |
| from importlib.util import find_spec | |
| # Global variables for hardware detection | |
| CUDA_AVAILABLE = torch.cuda.is_available() | |
| NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0 | |
| DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu" | |
| # Import Unsloth first, before other ML imports | |
| try: | |
| from unsloth import FastLanguageModel | |
| from unsloth.chat_templates import get_chat_template | |
| unsloth_available = True | |
| except ImportError: | |
| unsloth_available = False | |
| logger = logging.getLogger(__name__) | |
| logger.warning("Unsloth not available. Please install with: pip install unsloth") | |
| from datasets import load_dataset | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| TrainingArguments, | |
| Trainer, | |
| TrainerCallback, | |
| set_seed, | |
| BitsAndBytesConfig | |
| ) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Set other loggers to WARNING to reduce noise and ensure our logs are visible | |
| logging.getLogger("transformers").setLevel(logging.WARNING) | |
| logging.getLogger("datasets").setLevel(logging.WARNING) | |
| logging.getLogger("accelerate").setLevel(logging.WARNING) | |
| logging.getLogger("torch").setLevel(logging.WARNING) | |
| logging.getLogger("bitsandbytes").setLevel(logging.WARNING) | |
| # Check availability of libraries | |
| peft_available = find_spec("peft") is not None | |
| # Define a clean logging function for HF Space compatibility | |
| def log_info(message): | |
| """Log information in a format compatible with Hugging Face Spaces""" | |
| # Just use the logger, but ensure consistent formatting | |
| logger.info(message) | |
| # Also ensure output is flushed immediately for streaming | |
| sys.stdout.flush() | |
| # Check for BitsAndBytes | |
| try: | |
| from transformers import BitsAndBytesConfig | |
| bitsandbytes_available = True | |
| except ImportError: | |
| bitsandbytes_available = False | |
| logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") | |
| # Check for PEFT | |
| try: | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| peft_available = True | |
| except ImportError: | |
| peft_available = False | |
| logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") | |
| def load_env_variables(): | |
| """Load environment variables from system, .env file, or Hugging Face Space variables.""" | |
| # Check if we're running in a Hugging Face Space | |
| if os.environ.get("SPACE_ID"): | |
| logging.info("Running in Hugging Face Space") | |
| # Log the presence of variables (without revealing values) | |
| logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") | |
| # If username is not set, try to extract from SPACE_ID | |
| if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): | |
| username = os.environ.get("SPACE_ID").split("/")[0] | |
| os.environ["HF_USERNAME"] = username | |
| logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") | |
| else: | |
| # Try to load from .env file if not in a Space | |
| try: | |
| from dotenv import load_dotenv | |
| # Updated path to .env file in the new directory structure | |
| env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") | |
| if os.path.exists(env_path): | |
| load_dotenv(env_path) | |
| logging.info(f"Loaded environment variables from {env_path}") | |
| logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") | |
| logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") | |
| logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") | |
| else: | |
| logging.warning(f"No .env file found at {env_path}") | |
| except ImportError: | |
| logging.warning("python-dotenv not installed, not loading from .env file") | |
| if not os.environ.get("HF_USERNAME"): | |
| logger.warning("HF_USERNAME is not set. Using default username.") | |
| if not os.environ.get("HF_SPACE_NAME"): | |
| logger.warning("HF_SPACE_NAME is not set. Using default space name.") | |
| # Set HF_TOKEN for huggingface_hub | |
| if os.environ.get("HF_TOKEN"): | |
| os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") | |
| def load_configs(base_path): | |
| """Load configuration from transformers_config.json file.""" | |
| # Using a single consolidated config file | |
| config_file = base_path | |
| try: | |
| with open(config_file, "r") as f: | |
| config = json.load(f) | |
| logger.info(f"Loaded configuration from {config_file}") | |
| return config | |
| except Exception as e: | |
| logger.error(f"Error loading {config_file}: {e}") | |
| raise | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset") | |
| parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file") | |
| return parser.parse_args() | |
| def load_model_and_tokenizer(config): | |
| """Load model and tokenizer with proper error handling and optimizations.""" | |
| try: | |
| if not unsloth_available: | |
| logger.error("Unsloth is required for training with pre-quantized model") | |
| logger.error("Please ensure unsloth is in requirements.txt") | |
| raise ImportError("Unsloth is required for this training setup") | |
| # Get model name correctly from config | |
| model_name = config.get("model_name") or config.get("model", {}).get("name") | |
| logger.info(f"Loading model: {model_name}") | |
| if not model_name: | |
| raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.") | |
| logger.info("Using Unsloth optimizations with pre-quantized model") | |
| # Check for flash attention | |
| use_flash_attention = config.get("use_flash_attention", True) | |
| if use_flash_attention and not find_spec("flash_attn"): | |
| logger.warning("flash-attn not found. Will continue without flash attention.") | |
| logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation") | |
| use_flash_attention = False | |
| # First detect if we have a GPU | |
| if torch.cuda.is_available(): | |
| gpu_count = torch.cuda.device_count() | |
| logger.info(f"CUDA available, found {gpu_count} GPU(s)") | |
| # Log GPU info | |
| for i in range(gpu_count): | |
| logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}") | |
| logger.info(f"Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB") | |
| # Create an optimized device map for better balance | |
| if gpu_count > 1: | |
| logger.info(f"Creating balanced device map for {gpu_count} GPUs") | |
| # Use auto mapping but with memory tracking | |
| device_map = "auto" | |
| # Set max memory for better balancing | |
| max_memory = {i: f"{int(torch.cuda.get_device_properties(i).total_memory * 0.85 / 1024**3)}GiB" for i in range(gpu_count)} | |
| logger.info(f"Max memory settings: {max_memory}") | |
| else: | |
| device_map = "auto" | |
| max_memory = None | |
| else: | |
| logger.warning("No CUDA available, falling back to CPU") | |
| device_map = {"": "cpu"} # Force CPU placement | |
| max_memory = None | |
| # Set default dtype for better numerics | |
| if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: | |
| # Use bfloat16 for Ampere or newer | |
| dtype = torch.bfloat16 | |
| logger.info("Using bfloat16 precision (Ampere+ GPU)") | |
| elif torch.cuda.is_available(): | |
| # Use float16 for older GPUs | |
| dtype = torch.float16 | |
| logger.info("Using float16 precision (pre-Ampere GPU)") | |
| else: | |
| # CPU, use default dtype | |
| dtype = None | |
| logger.info("Using default precision (CPU)") | |
| # Load model with proper error handling for out-of-memory | |
| try: | |
| # Improved memory settings for multi-GPU setup | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), | |
| dtype=dtype, | |
| device_map=device_map, | |
| max_memory=max_memory, | |
| # Don't explicitly use flash attention config here, let Unsloth handle it | |
| ) | |
| except RuntimeError as e: | |
| if "CUDA out of memory" in str(e): | |
| logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.") | |
| raise | |
| else: | |
| # Try again with CPU placement to see if it's a memory issue | |
| logger.warning(f"Error loading model on default device: {str(e)}") | |
| logger.warning("Attempting to load with device_map='cpu' and no specific dtype") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), | |
| dtype=None, | |
| device_map={"": "cpu"}, | |
| ) | |
| logger.warning("Model loaded on CPU. Training will be very slow.") | |
| # Ensure model and optimizer init is on the same device | |
| logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}") | |
| # Apply Unsloth's training optimizations with config parameters | |
| unsloth_config = config.get("unsloth", {}) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=unsloth_config.get("r", 32), | |
| target_modules=unsloth_config.get("target_modules", | |
| ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]), | |
| lora_alpha=unsloth_config.get("alpha", 16), | |
| lora_dropout=unsloth_config.get("dropout", 0.05), | |
| bias="none", | |
| use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True), | |
| random_state=config.get("seed", 42), | |
| ) | |
| logger.info("Unsloth optimizations applied successfully") | |
| # Set up tokenizer settings | |
| chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template") | |
| if chat_template: | |
| try: | |
| template = get_chat_template("phi") | |
| tokenizer.chat_template = template | |
| logger.info("Set phi chat template") | |
| except Exception as e: | |
| logger.warning(f"Failed to set chat template: {str(e)}") | |
| # Ensure proper token settings | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}") | |
| return model, tokenizer | |
| except Exception as e: | |
| logger.error(f"Error in model/tokenizer loading: {str(e)}") | |
| logger.error("If missing dependencies, check the requirements.txt file") | |
| raise | |
| def load_dataset_with_mapping(dataset_config): | |
| """Load dataset and apply appropriate column mappings.""" | |
| try: | |
| # Load dataset | |
| dataset_name = dataset_config.get("dataset", {}).get("name", "") | |
| dataset_split = dataset_config.get("dataset", {}).get("split", "train") | |
| if not dataset_name: | |
| raise ValueError("Dataset name not provided in configuration") | |
| logger.info(f"Loading dataset {dataset_name}, split {dataset_split}") | |
| dataset = load_dataset(dataset_name, split=dataset_split) | |
| # Map columns if specified - with checks to avoid conflicts | |
| column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {}) | |
| if column_mapping: | |
| logger.info(f"Checking column mapping: {column_mapping}") | |
| # Only apply mappings for columns that need renaming and don't already exist | |
| safe_mappings = {} | |
| for target, source in column_mapping.items(): | |
| if source in dataset.column_names: | |
| # Skip if target already exists and is not the same as source | |
| if target in dataset.column_names and target != source: | |
| logger.warning(f"Cannot rename '{source}' to '{target}' - target column already exists") | |
| else: | |
| safe_mappings[source] = target | |
| # Apply safe renames | |
| if safe_mappings: | |
| logger.info(f"Applying safe column mapping: {safe_mappings}") | |
| for source, target in safe_mappings.items(): | |
| if source != target: # Only rename if names are different | |
| dataset = dataset.rename_column(source, target) | |
| # Add prompt_number field that increments based on original order | |
| def add_prompt_numbers(examples, indices): | |
| # Defensive check to ensure indices is not None and is iterable | |
| if indices is None: | |
| logger.warning("Warning: indices is None in add_prompt_numbers, using empty list") | |
| indices = [] | |
| elif isinstance(indices, int): | |
| # Handle case where indices is a single integer | |
| logger.warning(f"Warning: indices is an integer ({indices}) in add_prompt_numbers, converting to list") | |
| indices = [indices] | |
| # Ensure indices is always a list/iterable | |
| try: | |
| # Create a new field with the dataset index as the prompt number, starting at 1 | |
| examples["prompt_number"] = [idx + 1 for idx in indices] # Adding 1 to make it 1-indexed | |
| except TypeError: | |
| # Fallback for non-iterable types | |
| logger.warning(f"Warning: non-iterable indices in add_prompt_numbers: {type(indices)}, using default") | |
| examples["prompt_number"] = [1] * len(next(iter(examples.values()))) | |
| return examples | |
| # Add prompt numbers to the dataset based on original order | |
| logger.info("Adding prompt numbers based on original dataset order (starting at 1)") | |
| try: | |
| dataset = dataset.map( | |
| add_prompt_numbers, | |
| with_indices=True, | |
| desc="Adding prompt numbers" | |
| ) | |
| logger.info(f"Successfully added prompt_number field to dataset") | |
| except Exception as e: | |
| logger.error(f"Error adding prompt numbers: {e}") | |
| # Create a fallback implementation that doesn't rely on with_indices | |
| logger.info("Attempting fallback method for adding prompt numbers") | |
| def add_prompt_numbers_fallback(example, idx): | |
| example["prompt_number"] = idx + 1 | |
| return example | |
| # Process each example one by one with explicit indices | |
| updated_examples = [] | |
| for i, example in enumerate(dataset): | |
| updated_examples.append(add_prompt_numbers_fallback(dict(example), i)) | |
| # Create a new dataset with the updated examples | |
| from datasets import Dataset | |
| dataset = Dataset.from_list(updated_examples) | |
| logger.info(f"Successfully added prompt_number field using fallback method") | |
| # Rename 'id' to 'article_id' if it exists | |
| if 'id' in dataset.column_names and 'article_id' not in dataset.column_names: | |
| logger.info("Renaming 'id' column to 'article_id'") | |
| dataset = dataset.rename_column('id', 'article_id') | |
| # Reorder columns to make prompt_number first if it exists | |
| if 'prompt_number' in dataset.column_names: | |
| logger.info("Reordering columns to place prompt_number first") | |
| # Get current column names | |
| current_columns = dataset.column_names | |
| # Create new column order with prompt_number first | |
| new_column_order = ['prompt_number'] + [col for col in current_columns if col != 'prompt_number'] | |
| # Reorder columns | |
| dataset = dataset.select_columns(new_column_order) | |
| # Verify all new column names for logging | |
| logger.info(f"Dataset loaded successfully with {len(dataset)} examples") | |
| logger.info(f"Dataset columns: {dataset.column_names}") | |
| # Verify dataset is not empty | |
| if len(dataset) == 0: | |
| logger.error("Dataset is empty! This will cause errors during training.") | |
| raise ValueError("Empty dataset loaded") | |
| # Check for required columns | |
| required_columns = ['conversations'] | |
| for col in required_columns: | |
| if col not in dataset.column_names: | |
| logger.error(f"Required column '{col}' not found in dataset!") | |
| raise ValueError(f"Required column '{col}' missing from dataset") | |
| # Verify expected columns exist | |
| expected_columns = {"article_id", "conversations", "prompt_number"} | |
| missing_columns = expected_columns - set(dataset.column_names) | |
| if missing_columns: | |
| logger.warning(f"Some expected columns are missing: {missing_columns}") | |
| # If "conversations" is missing but "text" exists, attempt conversion | |
| if "conversations" not in dataset.column_names and "text" in dataset.column_names: | |
| logger.info("Converting 'text' field to 'conversations' format") | |
| def convert_text_to_conversations(example): | |
| # Check if text is already a list of conversation turns | |
| if isinstance(example.get("text"), list): | |
| example["conversations"] = example["text"] | |
| # Otherwise, create a simple conversation with the text as user message | |
| else: | |
| example["conversations"] = [ | |
| {"role": "user", "content": str(example.get("text", ""))} | |
| ] | |
| return example | |
| dataset = dataset.map(convert_text_to_conversations) | |
| logger.info("Successfully converted 'text' to 'conversations'") | |
| # Verify data ordering requirements | |
| processing_config = dataset_config.get("dataset", {}).get("processing", {}) | |
| data_loading_config = dataset_config.get("data_loading", {}) | |
| # Check if sorting is required | |
| sort_by_article_id = processing_config.get("sort_by_article_id", False) | |
| if sort_by_article_id and 'article_id' in dataset.column_names: | |
| logger.info("Sorting dataset by article_id as specified in config") | |
| dataset = dataset.sort("article_id") | |
| sorted_ids = [example['article_id'] for example in dataset.select(range(min(5, len(dataset))))] | |
| logger.info(f"First few article_ids after sorting: {sorted_ids}") | |
| # Flag consolidation - we only need one flag to control sequence preservation | |
| # Default to True to ensure safety | |
| preserve_sequence = processing_config.get("preserve_entry_sequence", True) | |
| shuffle_disabled = not data_loading_config.get("shuffle", False) | |
| if not preserve_sequence: | |
| logger.warning("CRITICAL: preserve_entry_sequence is set to False. This is NOT RECOMMENDED!") | |
| logger.warning("Data sequence integrity is essential for proper model training.") | |
| if not shuffle_disabled: | |
| logger.error("CRITICAL: shuffle is enabled in the dataset config!") | |
| logger.error("This will RANDOMIZE your dataset and break sequential order.") | |
| logger.error("Please set shuffle: false in your data_loading configuration.") | |
| # Actually enforce sequence preservation by raising an error | |
| raise ValueError("Dataset shuffling is enabled but preserve_entry_sequence is required. " + | |
| "Please disable shuffling in your configuration.") | |
| # Verify the IDs are in sequential order if they're numeric | |
| try: | |
| if len(dataset) > 1: | |
| # Check prompt numbers are sequential | |
| sample_indices = range(min(10, len(dataset))) | |
| sample_prompt_numbers = [] | |
| # Defensive collection of prompt numbers | |
| for i in sample_indices: | |
| try: | |
| if i < len(dataset) and "prompt_number" in dataset[i]: | |
| sample_prompt_numbers.append(dataset[i]["prompt_number"]) | |
| else: | |
| # If prompt_number doesn't exist, use index+1 as fallback | |
| sample_prompt_numbers.append(i + 1) | |
| logger.warning(f"Sample at index {i} missing prompt_number, using {i+1} as fallback") | |
| except Exception as e: | |
| logger.warning(f"Error accessing sample at index {i}: {e}") | |
| sample_prompt_numbers.append(i + 1) # Use fallback | |
| logger.info(f"Verifying sequential integrity with prompt numbers: {sample_prompt_numbers}") | |
| # Check if prompt numbers are sequential (1-indexed) | |
| if sample_prompt_numbers: | |
| is_sequential = all(sample_prompt_numbers[i] == i + 1 for i in range(len(sample_prompt_numbers))) | |
| if not is_sequential: | |
| logger.warning("WARNING: Prompt numbers are not in sequential order!") | |
| logger.warning("This may indicate that data sequence is not preserved.") | |
| else: | |
| logger.info("Prompt numbers verify that samples are in sequential order.") | |
| else: | |
| logger.warning("Could not verify sequential integrity: no prompt numbers collected") | |
| # Also check original IDs as a backup if numeric | |
| try: | |
| sample_examples = [] | |
| for i in sample_indices: | |
| try: | |
| if i < len(dataset): | |
| sample_examples.append(dataset[i]) | |
| except Exception as e: | |
| logger.warning(f"Error accessing dataset at index {i}: {e}") | |
| if sample_examples: | |
| id_field = 'article_id' if 'article_id' in dataset.column_names else 'id' | |
| if all(isinstance(example.get(id_field, ''), (int, str)) for example in sample_examples): | |
| sample_ids = [example.get(id_field, '') for example in sample_examples if id_field in example] | |
| if sample_ids and all(isinstance(id, int) or (isinstance(id, str) and id.isdigit()) for id in sample_ids): | |
| numeric_ids = [int(id) if isinstance(id, str) else id for id in sample_ids] | |
| if len(numeric_ids) > 1: | |
| is_ordered = all(numeric_ids[i] <= numeric_ids[i+1] for i in range(len(numeric_ids)-1)) | |
| if not is_ordered: | |
| logger.warning(f"WARNING: Sample {id_field}s are not in sequential order.") | |
| else: | |
| logger.info(f"Sample {id_field}s appear to be in sequential order.") | |
| except Exception as e: | |
| logger.warning(f"Error checking ID sequence: {e}") | |
| except Exception as e: | |
| logger.warning(f"Could not verify sequential integrity: {e}") | |
| # Log examples without printing full content - with defensive coding | |
| if "conversations" in dataset.column_names: | |
| try: | |
| # Safely get first few samples | |
| first_few_indices = range(min(5, len(dataset))) | |
| sample_prompt_numbers = [] | |
| sample_article_ids = [] | |
| for i in first_few_indices: | |
| try: | |
| example = dataset[i] | |
| if 'prompt_number' in example: | |
| sample_prompt_numbers.append(example['prompt_number']) | |
| if 'article_id' in example: | |
| sample_article_ids.append(example['article_id']) | |
| except Exception as e: | |
| logger.warning(f"Error accessing sample at index {i}: {e}") | |
| logger.info(f"First few samples - Prompt numbers: {sample_prompt_numbers}, Article IDs: {sample_article_ids}") | |
| # Log conversation structure without full content | |
| if len(dataset) > 0: | |
| try: | |
| sample_conv_structure = [] | |
| first_example = dataset[0] | |
| if 'conversations' in first_example and first_example['conversations'] is not None: | |
| for msg in first_example['conversations']: | |
| if isinstance(msg, dict): | |
| content = msg.get('content', '') | |
| preview = content[:50] + "..." if len(content) > 50 else content | |
| sample_conv_structure.append({ | |
| "role": msg.get('role', ''), | |
| "content_length": len(content), | |
| "preview": preview | |
| }) | |
| logger.info(f"Conversation structure: {sample_conv_structure}") | |
| except Exception as e: | |
| logger.warning(f"Error logging conversation structure: {e}") | |
| except Exception as e: | |
| logger.warning(f"Error logging sample examples: {e}") | |
| logger.info(f"Dataset loaded successfully with {len(dataset)} examples") | |
| logger.info(f"Dataset columns: {dataset.column_names}") | |
| # Verify dataset is not empty | |
| if len(dataset) == 0: | |
| logger.error("Dataset is empty! Cannot proceed with training.") | |
| return dataset | |
| # Check for required columns | |
| required_cols = ['conversations', 'prompt_number'] | |
| for col in required_cols: | |
| if col not in dataset.column_names: | |
| logger.error(f"Required column '{col}' missing from dataset. Cannot proceed with training.") | |
| return dataset | |
| # Validate at least one sample can be processed | |
| try: | |
| if len(dataset) > 0: | |
| sample = dataset[0] | |
| if 'conversations' not in sample or not sample['conversations']: | |
| logger.error("First sample has no conversations! Data format may be incorrect.") | |
| return dataset | |
| if not isinstance(sample['conversations'], list): | |
| logger.error(f"Conversations field should be a list but got {type(sample['conversations'])}") | |
| return dataset | |
| except Exception as e: | |
| logger.error(f"Error validating first sample: {e}") | |
| return dataset | |
| # Add metadata if specified | |
| metadata_config = dataset_config.get("data_formatting", {}).get("metadata_handling", {}) | |
| if metadata_config: | |
| include_article_id = metadata_config.get("include_article_id", False) | |
| include_prompt_number = metadata_config.get("include_prompt_number", False) | |
| metadata_format = metadata_config.get("metadata_format", "") | |
| if (include_article_id or include_prompt_number) and metadata_format: | |
| logger.info("Adding metadata to conversations") | |
| def add_metadata(example): | |
| if not example.get("conversations"): | |
| return example | |
| # Prepare metadata | |
| metadata = metadata_format | |
| if include_article_id and "article_id" in example: | |
| metadata = metadata.replace("{article_id}", str(example.get("article_id", ""))) | |
| if include_prompt_number and "prompt_number" in example: | |
| metadata = metadata.replace("{prompt_number}", str(example.get("prompt_number", ""))) | |
| # Add system message with metadata if not empty | |
| if metadata.strip(): | |
| if example["conversations"] and isinstance(example["conversations"], list): | |
| # Check if first message is already a system message | |
| if (isinstance(example["conversations"][0], dict) and | |
| example["conversations"][0].get("role") == "system"): | |
| # Append to existing system message | |
| example["conversations"][0]["content"] += f"\n\nMetadata: {metadata}" | |
| else: | |
| # Add new system message at the beginning | |
| example["conversations"].insert(0, { | |
| "role": "system", | |
| "content": f"Metadata: {metadata}" | |
| }) | |
| return example | |
| dataset = dataset.map(add_metadata) | |
| logger.info("Metadata added to conversations") | |
| return dataset | |
| except Exception as e: | |
| logger.error(f"Error loading dataset: {str(e)}") | |
| raise | |
| def format_phi_chat(messages, dataset_config): | |
| """Format messages according to phi-4's chat template and dataset config.""" | |
| formatted_chat = "" | |
| # Get role templates from config | |
| roles = dataset_config.get("data_formatting", {}).get("roles", { | |
| "system": "System: {content}\n\n", | |
| "human": "Human: {content}\n\n", | |
| "user": "Human: {content}\n\n", | |
| "assistant": "Assistant: {content}\n\n" | |
| }) | |
| # Handle research introduction metadata first | |
| metadata = next((msg for msg in messages if isinstance(msg, dict) and | |
| "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None) | |
| if metadata: | |
| system_template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat = system_template.format(content=metadata['content']) | |
| messages = [msg for msg in messages if msg != metadata] | |
| # Process remaining messages | |
| for message in messages: | |
| if not isinstance(message, dict) or "content" not in message: | |
| logger.warning(f"Skipping invalid message format: {message}") | |
| continue | |
| role = message.get("role", "").lower() | |
| content = message.get("content", "") | |
| # Format based on role | |
| if role == "human" or role == "user": | |
| template = roles.get("user", roles.get("human", "Human: {content}\n\n")) | |
| formatted_chat += template.format(content=content) | |
| elif role == "assistant" or role == "bot": | |
| template = roles.get("assistant", "Assistant: {content}\n\n") | |
| formatted_chat += template.format(content=content) | |
| elif role == "system": | |
| # For system messages, prepend them | |
| template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat = template.format(content=content) + formatted_chat | |
| else: | |
| # Default to system for unknown roles | |
| logger.warning(f"Unknown role '{role}' - treating as system message") | |
| template = roles.get("system", "System: {content}\n\n") | |
| formatted_chat += template.format(content=content) | |
| return formatted_chat.strip() | |
| class SimpleDataCollator: | |
| def __init__(self, tokenizer, dataset_config): | |
| self.tokenizer = tokenizer | |
| self.dataset_config = dataset_config | |
| self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0} | |
| self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 | |
| self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048) | |
| logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}") | |
| logger.info("Using exact dataset structure without reformatting") | |
| # Check if we're on GPU | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"SimpleDataCollator using device: {self.device}") | |
| def __call__(self, features): | |
| """Process examples preserving exact JSONL structure""" | |
| batch = {"input_ids": [], "attention_mask": [], "labels": []} | |
| for example in features: | |
| try: | |
| # Get ID | |
| paper_id = example.get("id", "") | |
| # Get conversations - these should already contain role and content | |
| conversations = example.get("conversations", []) | |
| if not conversations: | |
| self.stats["skipped"] += 1 | |
| continue | |
| # Directly use the conversations array as input to the model's chat template | |
| # This preserves the exact structure with roles and content as they are | |
| try: | |
| # Let tokenizer handle the content with the model's chat template | |
| inputs = self.tokenizer.apply_chat_template( | |
| conversations, | |
| return_tensors=None, | |
| add_generation_prompt=False | |
| ) | |
| except Exception as chat_error: | |
| # Fallback if apply_chat_template fails | |
| logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}") | |
| # Create a basic representation of the conversation | |
| conversation_text = "" | |
| for msg in conversations: | |
| if isinstance(msg, dict) and 'content' in msg: | |
| conversation_text += msg.get('content', '') + "\n\n" | |
| # Basic tokenization | |
| inputs = self.tokenizer( | |
| conversation_text, | |
| add_special_tokens=True, | |
| return_tensors=None | |
| ) | |
| # Apply length cap if needed (shouldn't be necessary for pre-audited data) | |
| if self.max_seq_length > 0 and len(inputs) > self.max_seq_length: | |
| logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})") | |
| inputs = inputs[:self.max_seq_length] | |
| # Create attention mask (1 for all tokens) | |
| attention_mask = [1] * len(inputs) | |
| if len(inputs) > 0: | |
| # For causal language modeling, labels are the same as inputs | |
| labels = inputs.copy() | |
| batch["input_ids"].append(inputs) | |
| batch["attention_mask"].append(attention_mask) | |
| batch["labels"].append(labels) | |
| self.stats["processed"] += 1 | |
| self.stats["total_tokens"] += len(inputs) | |
| # Debug logging for first few examples | |
| log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3) | |
| if self.stats["processed"] <= log_samples: | |
| logger.info(f"Example {self.stats['processed']}:") | |
| logger.info(f"Paper ID: {paper_id}") | |
| logger.info(f"Token count: {len(inputs)}") | |
| logger.info(f"Conversation entries: {len(conversations)}") | |
| else: | |
| self.stats["skipped"] += 1 | |
| except Exception as e: | |
| logger.warning(f"Error processing example: {str(e)[:100]}...") | |
| logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}") | |
| self.stats["skipped"] += 1 | |
| continue | |
| if not batch["input_ids"]: | |
| logger.warning("Empty batch, returning dummy tensors") | |
| return { | |
| "input_ids": torch.zeros((1, 1), dtype=torch.long), | |
| "attention_mask": torch.zeros((1, 1), dtype=torch.long), | |
| "labels": torch.zeros((1, 1), dtype=torch.long) | |
| } | |
| # Pad the batch | |
| max_length = max(len(ids) for ids in batch["input_ids"]) | |
| for i in range(len(batch["input_ids"])): | |
| padding_length = max_length - len(batch["input_ids"][i]) | |
| if padding_length > 0: | |
| batch["input_ids"][i].extend([self.pad_token_id] * padding_length) | |
| batch["attention_mask"][i].extend([0] * padding_length) | |
| batch["labels"][i].extend([-100] * padding_length) | |
| # Convert to tensors | |
| batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()} | |
| # Log stats periodically | |
| log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100) | |
| if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0: | |
| logger.info(f"Data collator stats: processed={self.stats['processed']}, " | |
| f"skipped={self.stats['skipped']}, " | |
| f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}") | |
| return batch | |
| class LoggingCallback(TrainerCallback): | |
| def __init__(self): | |
| super().__init__() | |
| self.training_started = time.time() | |
| self.last_log_time = time.time() | |
| self.last_step = 0 | |
| self.verify_sequence = None | |
| self.sequence_samples = None | |
| self.sample_indices = None | |
| def on_step_end(self, args, state, control, **kwargs): | |
| # Log every 50 steps or every 5 minutes, whichever comes first | |
| current_time = time.time() | |
| # Perform actual sequence integrity verification if enabled | |
| if self.verify_sequence is True and state.global_step % 100 == 0 and self.sequence_samples: | |
| try: | |
| # Get a batch of data without disturbing the training | |
| train_dataloader = trainer.get_train_dataloader() | |
| if train_dataloader is None: | |
| log_info("Warning: Could not get train dataloader for verification") | |
| else: | |
| batch_iterator = iter(train_dataloader) | |
| if batch_iterator is None: | |
| log_info("Warning: Could not get batch iterator for verification") | |
| else: | |
| try: | |
| batch = next(batch_iterator) | |
| if batch is None: | |
| log_info("Warning: Could not get batch for verification") | |
| elif 'input_ids' in batch and 'labels' in batch: | |
| log_info("Verifying data sequence integrity...") | |
| # Check if we can access some of our reference samples | |
| if not hasattr(trainer, 'train_dataset') or trainer.train_dataset is None: | |
| log_info("Warning: Train dataset is not available") | |
| else: | |
| # Get current samples defensively | |
| current_samples = [] | |
| current_indices = list(range(min(3, len(trainer.train_dataset)))) | |
| for idx in current_indices: | |
| try: | |
| if idx < len(trainer.train_dataset): | |
| current_samples.append(trainer.train_dataset[idx]) | |
| except Exception as e: | |
| log_info(f"Warning: Error accessing dataset at index {idx}: {e}") | |
| # Only proceed if we have samples to compare | |
| if current_samples and self.sequence_samples: | |
| # Compare current samples with our reference samples from training start | |
| is_sequence_maintained = True | |
| for i, (orig_idx, orig_sample) in enumerate(zip(self.sample_indices, self.sequence_samples)): | |
| # Check if sample index is valid | |
| if i < len(current_samples): | |
| current_sample = current_samples[i] | |
| # Compare prompt numbers if available | |
| if ('prompt_number' in orig_sample and | |
| 'prompt_number' in current_sample and | |
| orig_sample['prompt_number'] is not None and | |
| current_sample['prompt_number'] is not None): | |
| if orig_sample['prompt_number'] != current_sample['prompt_number']: | |
| log_info(f"WARNING: Sequence integrity compromised! Sample {i} prompt number changed from {orig_sample['prompt_number']} to {current_sample['prompt_number']}") | |
| is_sequence_maintained = False | |
| # Also compare IDs as a backup check | |
| elif ('article_id' in orig_sample and | |
| 'article_id' in current_sample and | |
| orig_sample['article_id'] is not None and | |
| current_sample['article_id'] is not None): | |
| if orig_sample['article_id'] != current_sample['article_id']: | |
| log_info(f"WARNING: Sequence integrity compromised! Sample {i} article_id changed from {orig_sample['article_id']} to {current_sample['article_id']}") | |
| is_sequence_maintained = False | |
| # Compare input fingerprints | |
| if ('conversations' in orig_sample and | |
| 'conversations' in current_sample and | |
| orig_sample['conversations'] is not None and | |
| current_sample['conversations'] is not None): | |
| orig_len = len(orig_sample['conversations']) | |
| curr_len = len(current_sample['conversations']) | |
| if orig_len != curr_len: | |
| log_info(f"WARNING: Sequence integrity compromised! Sample {i} conversation length changed from {orig_len} to {curr_len}") | |
| is_sequence_maintained = False | |
| if is_sequence_maintained: | |
| log_info("Data sequence integrity check: OK") | |
| else: | |
| log_info("CRITICAL WARNING: Data sequence integrity check FAILED!") | |
| else: | |
| log_info("Warning: Not enough samples available for sequence verification") | |
| except StopIteration: | |
| log_info("Warning: No batches available in the dataloader") | |
| except Exception as e: | |
| log_info(f"Warning: Error iterating through dataloader: {e}") | |
| except Exception as e: | |
| log_info(f"Warning: Couldn't verify sequence integrity: {e}") | |
| time_interval = current_time - self.last_log_time | |
| step_interval = state.global_step - self.last_step | |
| if step_interval >= 50 or time_interval >= 300: # 5 minutes = 300 seconds | |
| # Calculate throughput | |
| examples_per_second = step_interval * args.per_device_train_batch_size * args.gradient_accumulation_steps / max(time_interval, 1e-6) | |
| elapsed_total = time.strftime("%H:%M:%S", time.gmtime(current_time - self.training_started)) | |
| # Log progress | |
| log_info(f"Step: {state.global_step}, Loss: {state.log_history[-1]['loss']:.4f}, " | |
| f"Rate: {examples_per_second:.2f} examples/sec, Elapsed: {elapsed_total}") | |
| # Report memory usage if CUDA is available | |
| if CUDA_AVAILABLE: | |
| log_info(f"GPU Memory: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB allocated, " | |
| f"{torch.cuda.max_memory_reserved() / 1024**3:.2f} GB reserved") | |
| # Reset for next interval | |
| self.last_log_time = current_time | |
| self.last_step = state.global_step | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===") | |
| log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M") | |
| # Set up sequence verification with actual sample capturing | |
| try: | |
| self.verify_sequence = dataset_config.get("validation", {}).get("verify_sequence_integrity", False) | |
| if self.verify_sequence: | |
| log_info("Sequence integrity verification enabled during training") | |
| # Save actual samples for later verification | |
| if trainer and hasattr(trainer, 'train_dataset') and trainer.train_dataset is not None: | |
| # Get some reference samples from the beginning of the dataset defensively | |
| self.sample_indices = [] | |
| self.sequence_samples = [] | |
| max_samples = min(5, len(trainer.train_dataset)) | |
| for i in range(max_samples): | |
| try: | |
| if i < len(trainer.train_dataset): | |
| self.sample_indices.append(i) | |
| self.sequence_samples.append(trainer.train_dataset[i]) | |
| except Exception as e: | |
| log_info(f"Warning: Error capturing reference sample at index {i}: {e}") | |
| if self.sequence_samples: | |
| log_info(f"Captured {len(self.sequence_samples)} reference samples for sequence integrity verification") | |
| # Log sample prompt numbers for debugging | |
| sample_prompt_numbers = [] | |
| for s in self.sequence_samples: | |
| if isinstance(s, dict) and 'prompt_number' in s and s['prompt_number'] is not None: | |
| sample_prompt_numbers.append(s.get('prompt_number')) | |
| if sample_prompt_numbers: | |
| log_info(f"Reference sample prompt numbers: {sample_prompt_numbers}") | |
| else: | |
| log_info("Warning: No reference samples were captured") | |
| else: | |
| log_info("Warning: Could not capture reference samples - verification will be limited") | |
| except Exception as e: | |
| log_info(f"Warning: Could not set up sequence integrity verification: {e}") | |
| self.verify_sequence = False | |
| log_info("=== Training is starting ===") | |
| # Log important training parameters for visibility | |
| total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS | |
| log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total") | |
| log_info(f"Learning rate: {args.learning_rate}") | |
| log_info(f"Epochs: {args.num_train_epochs}") | |
| # Log memory information in compact format | |
| if CUDA_AVAILABLE: | |
| memory_info = [] | |
| for i in range(NUM_GPUS): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
| max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 | |
| memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)") | |
| log_info(f"Initial memory usage - {', '.join(memory_info)}") | |
| def on_train_end(self, args, state, control, **kwargs): | |
| training_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - self.training_started)) | |
| log_info(f"=== Training completed in {training_time} ===") | |
| # Log final memory usage | |
| if CUDA_AVAILABLE: | |
| for i in range(NUM_GPUS): | |
| max_mem = torch.cuda.max_memory_allocated(i) / 1024**3 # GB | |
| log_info(f"GPU {i} max memory: {max_mem:.2f} GB") | |
| # Clear GPU memory | |
| torch.cuda.empty_cache() | |
| log_info("GPU memory cleared") | |
| log_info(f"Total steps: {state.global_step}") | |
| log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}") | |
| def check_dependencies(): | |
| """Check if all required dependencies are installed.""" | |
| missing_packages = [] | |
| # Critical packages | |
| if not unsloth_available: | |
| missing_packages.append("unsloth>=2024.3") | |
| if not peft_available: | |
| missing_packages.append("peft>=0.9.0") | |
| # Optional packages - don't add to missing list, just log | |
| if find_spec("flash_attn"): | |
| logger.info("flash-attn found. Flash attention will be used for faster training.") | |
| else: | |
| logger.warning("flash-attn not found. Training will work but may be slower.") | |
| logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation") | |
| # If critical packages are missing, exit with instructions | |
| if missing_packages: | |
| logger.error("Critical dependencies missing:") | |
| for pkg in missing_packages: | |
| logger.error(f" - {pkg}") | |
| logger.error("Please ensure the space has these packages in requirements.txt") | |
| return False | |
| return True | |
| def main(): | |
| # Set up logging | |
| logger.info("Starting training process") | |
| # Parse arguments | |
| args = parse_args() | |
| # Load environment variables | |
| load_env_variables() | |
| # Load configuration | |
| try: | |
| transformers_config = load_configs(args.config) | |
| hardware_config = transformers_config.get("hardware", {}) | |
| dataset_config = transformers_config.get("dataset", {}) | |
| logger.info("Configuration loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading configuration: {e}") | |
| return 1 | |
| # Check dependencies | |
| if not check_dependencies(): | |
| logger.error("Aborting due to missing critical dependencies") | |
| return 1 | |
| # Check if we're in distributed mode | |
| is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1 | |
| if is_distributed: | |
| local_rank = int(os.environ.get("LOCAL_RANK", "0")) | |
| log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}") | |
| else: | |
| log_info("Running in non-distributed mode (single process)") | |
| # Set random seed for reproducibility | |
| seed = transformers_config.get("seed", 42) | |
| set_seed(seed) | |
| logger.info(f"Set random seed to {seed}") | |
| # Load model and tokenizer using the consolidated config | |
| model, tokenizer = load_model_and_tokenizer(transformers_config) | |
| # Empty CUDA cache to ensure clean state | |
| if CUDA_AVAILABLE: | |
| torch.cuda.empty_cache() | |
| log_info("Cleared CUDA cache") | |
| # Setup environment variable for CUDA memory allocation | |
| if CUDA_AVAILABLE: | |
| system_settings = hardware_config.get("system_settings", {}) | |
| cuda_memory_fraction = system_settings.get("cuda_memory_fraction", 0.85) | |
| if cuda_memory_fraction < 1.0: | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True" | |
| log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128") | |
| try: | |
| log_info("Loading dataset...") | |
| dataset = load_dataset_with_mapping(dataset_config) | |
| log_info(f"Dataset loaded with {len(dataset)} examples") | |
| # Minimal validation before proceeding | |
| if dataset is None or len(dataset) == 0: | |
| logger.error("Dataset is empty or None! Cannot proceed with training.") | |
| return 1 | |
| # Create data collator | |
| data_collator = SimpleDataCollator(tokenizer, dataset_config) | |
| # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence | |
| # First check hardware config, then transformers config | |
| use_bf16 = False | |
| use_fp16 = False | |
| # Check hardware config first | |
| hardware_precision = hardware_config.get("training_optimizations", {}).get("mixed_precision", "") | |
| if hardware_precision.lower() == "bf16": | |
| use_bf16 = True | |
| log_info("Using BF16 precision from hardware config") | |
| elif hardware_precision.lower() == "fp16": | |
| use_fp16 = True | |
| log_info("Using FP16 precision from hardware config") | |
| else: | |
| # Fall back to transformers config | |
| use_bf16 = transformers_config.get("bf16", False) or transformers_config.get("torch_dtype", "") == "bfloat16" | |
| use_fp16 = transformers_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set | |
| log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}") | |
| # Get per device batch size - from transformers config, but possibly overridden by hardware config | |
| per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16) | |
| gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3) | |
| # For multi-GPU setup, adjust for better balance | |
| if CUDA_AVAILABLE and NUM_GPUS > 1: | |
| log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs") | |
| # Set up FSDP for multi-GPU training if specified and in distributed mode | |
| fsdp_config = None | |
| if multi_gpu_strategy == "fsdp" and is_distributed and NUM_GPUS > 1: | |
| try: | |
| from torch.distributed.fsdp import ( | |
| FullyShardedDataParallel as FSDP, | |
| MixedPrecision, | |
| BackwardPrefetch, | |
| ShardingStrategy, | |
| CPUOffload, | |
| ) | |
| from torch.distributed.fsdp.wrap import ( | |
| transformer_auto_wrap_policy, | |
| enable_wrap, | |
| wrap, | |
| ) | |
| log_info("Using FSDP for distributed training") | |
| # Configure FSDP | |
| fsdp_config = { | |
| "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"], | |
| "fsdp_offload_params": False, | |
| "fsdp_backward_prefetch": "BACKWARD_PRE", | |
| "fsdp_min_num_params": 1e6, | |
| "fsdp_sharding_strategy": 1, # FULL_SHARD | |
| } | |
| if use_bf16 or use_fp16: | |
| precision_type = "bf16" if use_bf16 else "fp16" | |
| fsdp_config["fsdp_state_dict_type"] = "FULL_STATE_DICT" | |
| log_info(f"FSDP using mixed precision: {precision_type}") | |
| except ImportError: | |
| log_info("FSDP imports failed, falling back to standard DDP") | |
| fsdp_config = None | |
| elif multi_gpu_strategy == "fsdp" and not is_distributed: | |
| log_info("FSDP disabled: requires distributed environment (use torchrun or accelerate)") | |
| log_info("Using DataParallel for multi-GPU training instead") | |
| else: | |
| log_info(f"Using {multi_gpu_strategy} for multi-GPU training") | |
| # Get system settings from hardware config | |
| dataloader_workers = hardware_config.get("system_settings", {}).get("dataloader_num_workers", 2) | |
| pin_memory = hardware_config.get("system_settings", {}).get("dataloader_pin_memory", True) | |
| # Set up training arguments | |
| log_info("Setting up training arguments") | |
| training_args = TrainingArguments( | |
| output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"), | |
| num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3), | |
| per_device_train_batch_size=per_device_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| learning_rate=transformers_config.get("training", {}).get("learning_rate", 2e-5), | |
| weight_decay=transformers_config.get("training", {}).get("weight_decay", 0.01), | |
| warmup_ratio=transformers_config.get("training", {}).get("warmup_ratio", 0.05), | |
| lr_scheduler_type=transformers_config.get("training", {}).get("lr_scheduler_type", "cosine"), | |
| logging_steps=transformers_config.get("training", {}).get("logging_steps", 10), | |
| save_strategy=transformers_config.get("checkpointing", {}).get("save_strategy", "steps"), | |
| save_steps=transformers_config.get("checkpointing", {}).get("save_steps", 100), | |
| save_total_limit=transformers_config.get("checkpointing", {}).get("save_total_limit", 3), | |
| fp16=use_fp16, | |
| bf16=use_bf16, | |
| max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0), | |
| push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False), | |
| hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None), | |
| hub_token=os.environ.get("HF_TOKEN", None), | |
| report_to="tensorboard", | |
| remove_unused_columns=False, # Keep all columns | |
| gradient_checkpointing=transformers_config.get("training", {}).get("gradient_checkpointing", True), | |
| dataloader_pin_memory=pin_memory, | |
| optim=transformers_config.get("training", {}).get("optim", "adamw_torch"), | |
| ddp_find_unused_parameters=False, # Improve distributed training efficiency | |
| dataloader_drop_last=False, # Process all examples | |
| dataloader_num_workers=dataloader_workers, | |
| no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available | |
| # Only add FSDP if we're in distributed mode with FSDP strategy | |
| fsdp=fsdp_config if is_distributed and multi_gpu_strategy == "fsdp" else None, | |
| ) | |
| # Create sequential sampler to maintain original dataset order | |
| sequential_sampler = torch.utils.data.SequentialSampler(dataset) | |
| # Initialize trainer first | |
| log_info("Initializing Trainer") | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=dataset, # We'll override this with our custom dataloader | |
| data_collator=data_collator, | |
| callbacks=[LoggingCallback()], | |
| ) | |
| # Then override the get_train_dataloader method | |
| def custom_get_train_dataloader(): | |
| """Custom dataloader that preserves original dataset order""" | |
| log_info("Creating sequential dataloader to maintain original dataset order") | |
| # Verification of sequence preservation flags - consolidated | |
| data_loading_config = dataset_config.get("data_loading", {}) | |
| sequential_processing = data_loading_config.get("sequential_processing", True) | |
| shuffle_disabled = not data_loading_config.get("shuffle", False) | |
| if not sequential_processing: | |
| log_info("CRITICAL WARNING: sequential_processing flag is disabled! This may affect data order.") | |
| log_info("Data sequence integrity is essential - using sequential sampler regardless of flag.") | |
| # Force sequential processing regardless of flag | |
| if not shuffle_disabled: | |
| log_info("CRITICAL ERROR: Shuffle is not disabled! This will randomize data entry order!") | |
| # Actually handle the error rather than just logging it | |
| raise ValueError("Dataset shuffling is enabled but sequential processing is required. " + | |
| "Please disable shuffling in your configuration.") | |
| # Calculate batch size based on device availability | |
| if getattr(training_args, "no_cuda", False): | |
| batch_size = training_args.per_device_train_batch_size | |
| else: | |
| batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1) | |
| log_info(f"Using sequential sampler with batch size {batch_size}") | |
| # Return DataLoader with sequential sampler | |
| return torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| sampler=sequential_sampler, | |
| collate_fn=data_collator, | |
| drop_last=training_args.dataloader_drop_last, | |
| num_workers=training_args.dataloader_num_workers, | |
| pin_memory=training_args.dataloader_pin_memory, | |
| ) | |
| # Override the get_train_dataloader method | |
| trainer.get_train_dataloader = custom_get_train_dataloader | |
| # Start training | |
| log_info("=== Starting Training ===") | |
| try: | |
| # Empty cache again right before training | |
| if CUDA_AVAILABLE: | |
| torch.cuda.empty_cache() | |
| log_info("Cleared CUDA cache before training") | |
| # Display compact training info | |
| total_steps = int(len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps) * training_args.num_train_epochs) | |
| log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps") | |
| trainer.train() | |
| log_info("Training completed successfully!") | |
| # Save the final model | |
| log_info("Saving final model...") | |
| trainer.save_model() | |
| log_info(f"Model saved to {training_args.output_dir}") | |
| # Push to hub if enabled | |
| if transformers_config.get("huggingface_hub", {}).get("push_to_hub", False): | |
| hub_id = transformers_config.get("huggingface_hub", {}).get("hub_model_id", "model") | |
| log_info(f"Pushing model to Hugging Face Hub as {hub_id}...") | |
| trainer.push_to_hub() | |
| log_info("Model successfully pushed to Hub") | |
| return 0 | |
| except Exception as e: | |
| logger.error(f"Training failed with error: {str(e)}") | |
| # Log CUDA memory info if available in compact format | |
| if CUDA_AVAILABLE: | |
| memory_info = [] | |
| for i in range(NUM_GPUS): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**2 | |
| reserved = torch.cuda.memory_reserved(i) / 1024**2 | |
| max_mem = torch.cuda.max_memory_allocated(i) / 1024**2 | |
| memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)") | |
| logger.error(f"GPU memory at failure: {', '.join(memory_info)}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in main training loop: {str(e)}") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |