#!/usr/bin/env python # coding=utf-8 import os import sys import json import argparse import logging from datetime import datetime import time # Import Unsloth first, before other ML imports try: from unsloth import FastLanguageModel from unsloth.chat_templates import get_chat_template unsloth_available = True except ImportError: unsloth_available = False logger = logging.getLogger(__name__) logger.warning("Unsloth not available. Please install with: pip install unsloth") import torch from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, TrainerCallback, set_seed, BitsAndBytesConfig ) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger(__name__) # Check for BitsAndBytes try: from transformers import BitsAndBytesConfig bitsandbytes_available = True except ImportError: bitsandbytes_available = False logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.") # Check for PEFT try: from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training peft_available = True except ImportError: peft_available = False logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.") def load_env_variables(): """Load environment variables from system, .env file, or Hugging Face Space variables.""" # Check if we're running in a Hugging Face Space if os.environ.get("SPACE_ID"): logging.info("Running in Hugging Face Space") # Log the presence of variables (without revealing values) logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}") logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}") # If username is not set, try to extract from SPACE_ID if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""): username = os.environ.get("SPACE_ID").split("/")[0] os.environ["HF_USERNAME"] = username logging.info(f"Set HF_USERNAME from SPACE_ID: {username}") else: # Try to load from .env file if not in a Space try: from dotenv import load_dotenv # Updated path to .env file in the new directory structure env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env") if os.path.exists(env_path): load_dotenv(env_path) logging.info(f"Loaded environment variables from {env_path}") logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}") logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}") logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}") else: logging.warning(f"No .env file found at {env_path}") except ImportError: logging.warning("python-dotenv not installed, not loading from .env file") if not os.environ.get("HF_USERNAME"): logger.warning("HF_USERNAME is not set. Using default username.") if not os.environ.get("HF_SPACE_NAME"): logger.warning("HF_SPACE_NAME is not set. Using default space name.") # Set HF_TOKEN for huggingface_hub if os.environ.get("HF_TOKEN"): os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN") def load_configs(base_path): """Load all configuration files.""" configs = {} # List of config files to load config_files = [ "transformers_config.json", "hardware_config.json", "dataset_config.json" ] for config_file in config_files: file_path = os.path.join(base_path, config_file) try: with open(file_path, "r") as f: config_name = config_file.replace("_config.json", "") configs[config_name] = json.load(f) logger.info(f"Loaded {config_name} configuration from {file_path}") except Exception as e: logger.error(f"Error loading {config_file}: {e}") raise return configs def parse_args(): parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset") parser.add_argument("--config_dir", type=str, default=".", help="Directory containing configuration files") return parser.parse_args() def load_model_and_tokenizer(config): """Load model and tokenizer with proper error handling and optimizations.""" try: if not unsloth_available: logger.error("Unsloth is required for training with pre-quantized model") logger.error("Please ensure unsloth is in requirements.txt") raise ImportError("Unsloth is required for this training setup") # Get model name correctly from nested config structure model_name = config.get("model", {}).get("name") or config.get("model_name_or_path") or config.get("model_name") logger.info(f"Loading model: {model_name}") if not model_name: raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.") logger.info("Using Unsloth optimizations with pre-quantized model") # Check for flash attention without importing it directly use_flash_attention = config.get("use_flash_attention", True) try: import flash_attn logger.info("Flash attention detected and will be used") except ImportError: use_flash_attention = False logger.warning("Flash attention not available, falling back to standard attention") model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048), dtype=None, # Let Unsloth choose optimal dtype device_map="auto", # Don't explicitly use flash attention config here, let Unsloth handle it ) # Apply Unsloth's training optimizations with config parameters unsloth_config = config.get("unsloth", {}) model = FastLanguageModel.get_peft_model( model, r=unsloth_config.get("r", 32), target_modules=unsloth_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]), lora_alpha=unsloth_config.get("alpha", 16), lora_dropout=unsloth_config.get("dropout", 0.05), bias="none", use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True), random_state=config.get("seed", 42), ) logger.info("Unsloth optimizations applied successfully") # Set up tokenizer settings chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template") if chat_template: try: template = get_chat_template("phi") tokenizer.chat_template = template logger.info("Set phi chat template") except Exception as e: logger.warning(f"Failed to set chat template: {str(e)}") # Ensure proper token settings if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}") return model, tokenizer except Exception as e: logger.error(f"Error in model/tokenizer loading: {str(e)}") logger.error("If missing dependencies, check the requirements.txt file") raise def load_dataset_with_mapping(dataset_config): """Load and prepare dataset with proper column mapping.""" try: # Load dataset dataset_name = dataset_config.get("dataset", {}).get("name", "") dataset_split = dataset_config.get("dataset", {}).get("split", "train") if not dataset_name: raise ValueError("Dataset name not provided in configuration") logger.info(f"Loading dataset {dataset_name}, split {dataset_split}") dataset = load_dataset(dataset_name, split=dataset_split) # Map columns if specified - with checks to avoid conflicts column_mapping = dataset_config.get("dataset", {}).get("column_mapping", {}) if column_mapping: logger.info(f"Checking column mapping: {column_mapping}") # Only apply mappings for columns that need renaming and don't already exist safe_mappings = {} for target, source in column_mapping.items(): if source in dataset.column_names: # Skip if target already exists and is not the same as source if target in dataset.column_names and target != source: logger.warning(f"Cannot rename '{source}' to '{target}' - target column already exists") else: safe_mappings[source] = target # Apply safe renames if safe_mappings: logger.info(f"Applying safe column mapping: {safe_mappings}") for source, target in safe_mappings.items(): if source != target: # Only rename if names are different dataset = dataset.rename_column(source, target) # Verify expected columns exist expected_columns = {"id", "conversations"} for col in expected_columns: if col not in dataset.column_names: # If "conversations" is missing but "text" exists, it might need conversion if col == "conversations" and "text" in dataset.column_names: logger.info("Converting 'text' field to 'conversations' format") def convert_text_to_conversations(example): # Check if text is already a list of conversation turns if isinstance(example.get("text"), list): return {"conversations": example["text"]} # Otherwise, create a simple conversation with the text as user message else: return { "conversations": [ {"role": "user", "content": str(example.get("text", ""))} ] } dataset = dataset.map(convert_text_to_conversations) else: logger.warning(f"Expected column '{col}' not found in dataset") # Sort dataset if required sort_by_id = dataset_config.get("dataset", {}).get("processing", {}).get("sort_by_id", False) if sort_by_id and "id" in dataset.column_names: logger.info("Sorting dataset by ID") dataset = dataset.sort("id") # Log the first few IDs to verify sorting sample_ids = [example['id'] for example in dataset.select(range(min(5, len(dataset))))] logger.info(f"First few IDs after sorting: {sample_ids}") # Log example of conversations structure to verify format if "conversations" in dataset.column_names: sample_conv = dataset["conversations"][0] if len(dataset) > 0 else [] logger.info(f"Example conversation structure: {sample_conv}") logger.info(f"Dataset loaded successfully with {len(dataset)} examples") logger.info(f"Dataset columns: {dataset.column_names}") return dataset except Exception as e: logger.error(f"Error loading dataset: {str(e)}") raise def format_phi_chat(messages, dataset_config): """Format messages according to phi-4's chat template and dataset config.""" formatted_chat = "" # Get role templates from config roles = dataset_config.get("data_formatting", {}).get("roles", { "system": "System: {content}\n\n", "human": "Human: {content}\n\n", "user": "Human: {content}\n\n", "assistant": "Assistant: {content}\n\n" }) # Handle research introduction metadata first metadata = next((msg for msg in messages if isinstance(msg, dict) and "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None) if metadata: system_template = roles.get("system", "System: {content}\n\n") formatted_chat = system_template.format(content=metadata['content']) messages = [msg for msg in messages if msg != metadata] # Process remaining messages for message in messages: if not isinstance(message, dict) or "content" not in message: logger.warning(f"Skipping invalid message format: {message}") continue role = message.get("role", "").lower() content = message.get("content", "") # Format based on role if role == "human" or role == "user": template = roles.get("user", roles.get("human", "Human: {content}\n\n")) formatted_chat += template.format(content=content) elif role == "assistant" or role == "bot": template = roles.get("assistant", "Assistant: {content}\n\n") formatted_chat += template.format(content=content) elif role == "system": # For system messages, prepend them template = roles.get("system", "System: {content}\n\n") formatted_chat = template.format(content=content) + formatted_chat else: # Default to system for unknown roles logger.warning(f"Unknown role '{role}' - treating as system message") template = roles.get("system", "System: {content}\n\n") formatted_chat += template.format(content=content) return formatted_chat.strip() class SimpleDataCollator: def __init__(self, tokenizer, dataset_config): self.tokenizer = tokenizer self.dataset_config = dataset_config self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0} self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 self.prompt_counter = 0 self.paper_counters = {} self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048) self.include_metadata = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_paper_id", True) self.include_chunk = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("include_chunk_number", True) self.metadata_format = dataset_config.get("data_formatting", {}).get("metadata_handling", {}).get("metadata_format", "Paper ID: {paper_id} | Chunk: {chunk_number}") self.roles = dataset_config.get("data_formatting", {}).get("roles", {}) logger.info(f"SimpleDataCollator initialized - using phi-4 chat format with max_seq_length={self.max_seq_length}") def normalize_conversation(self, conversation): """Normalize conversation format to ensure consistent structure.""" normalized = [] # Handle non-list or empty inputs if not isinstance(conversation, list): logger.warning(f"Conversation is not a list: {type(conversation)}") if hasattr(conversation, 'items'): # It's a dict-like object conversation = [conversation] else: return [] for turn in conversation: # Skip empty or None entries if not turn: continue # Handle string entries (convert to user message) if isinstance(turn, str): normalized.append({"role": "user", "content": turn}) continue # Handle dict-like entries if not isinstance(turn, dict) and hasattr(turn, 'get'): # Convert to dict turn = {k: turn.get(k) for k in ['role', 'content'] if hasattr(turn, 'get') and turn.get(k) is not None} # Ensure both role and content exist if not isinstance(turn, dict) or 'role' not in turn or 'content' not in turn: logger.warning(f"Skipping malformatted conversation turn: {turn}") continue # Normalize role field role = turn.get('role', '').lower() if role == 'user' or role == 'human': role = 'user' elif role == 'assistant' or role == 'bot': role = 'assistant' # Add normalized turn normalized.append({ "role": role, "content": str(turn.get('content', '')) }) return normalized def __call__(self, features): batch = {"input_ids": [], "attention_mask": [], "labels": []} for example in features: try: # Get ID and conversation fields paper_id = example.get("id", "") # Handle conversation field - could be under 'conversations' or 'text' conversation = example.get("conversations", example.get("text", [])) # Normalize conversation format conversation = self.normalize_conversation(conversation) if not conversation: self.stats["skipped"] += 1 continue # Track paper chunks if paper_id not in self.paper_counters: self.paper_counters[paper_id] = 0 self.paper_counters[paper_id] += 1 # Add metadata if configured if self.include_metadata: # Format metadata according to configured format metadata_content = self.metadata_format.format( paper_id=paper_id, chunk_number=self.paper_counters[paper_id] ) # Add as system message if not already in conversation if not any(msg.get("role") == "system" for msg in conversation): conversation = [{"role": "system", "content": metadata_content}] + conversation # Format conversation with research introduction and chunk info formatted_content = format_phi_chat(conversation, self.dataset_config) # Tokenize with the model's chat template inputs = self.tokenizer( formatted_content, add_special_tokens=True, truncation=True, max_length=self.max_seq_length, return_tensors=None, ) if len(inputs["input_ids"]) > 0: # For causal language modeling, labels are the same as inputs labels = inputs["input_ids"].copy() batch["input_ids"].append(inputs["input_ids"]) batch["attention_mask"].append(inputs["attention_mask"]) batch["labels"].append(labels) self.stats["processed"] += 1 self.stats["total_tokens"] += len(inputs["input_ids"]) # Debug logging for first few examples log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3) if self.stats["processed"] <= log_samples: logger.info(f"Example {self.stats['processed']} format:") logger.info(f"Paper ID: {paper_id} | Chunk: {self.paper_counters[paper_id]}") logger.info(f"Token count: {len(inputs['input_ids'])}") logger.info(f"Content preview:\n{formatted_content[:500]}...") logger.info(f"Conversation structure: {conversation[:2]}...") else: self.stats["skipped"] += 1 except Exception as e: logger.warning(f"Error processing example: {str(e)[:100]}...") logger.warning(f"Problematic example: {str(example)[:200]}...") self.stats["skipped"] += 1 continue if not batch["input_ids"]: logger.warning("Empty batch, returning dummy tensors") return { "input_ids": torch.zeros((1, 1), dtype=torch.long), "attention_mask": torch.zeros((1, 1), dtype=torch.long), "labels": torch.zeros((1, 1), dtype=torch.long) } # Pad the batch max_length = max(len(ids) for ids in batch["input_ids"]) for i in range(len(batch["input_ids"])): padding_length = max_length - len(batch["input_ids"][i]) if padding_length > 0: batch["input_ids"][i].extend([self.pad_token_id] * padding_length) batch["attention_mask"][i].extend([0] * padding_length) batch["labels"][i].extend([-100] * padding_length) # Convert to tensors batch = {k: torch.tensor(v) for k, v in batch.items()} # Log stats periodically log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100) if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0: logger.info(f"Data collator stats: processed={self.stats['processed']}, " f"skipped={self.stats['skipped']}, " f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}, " f"unique_papers={len(self.paper_counters)}") return batch def check_dependencies(): """Check if all required dependencies are installed.""" missing_packages = [] # Critical packages if not unsloth_available: missing_packages.append("unsloth>=2024.3") if not peft_available: missing_packages.append("peft>=0.9.0") # Optional packages - don't add to missing list, just log try: import flash_attn logger.info("flash-attn found. Flash attention will be used for faster training.") except ImportError: logger.warning("flash-attn not found. Training will work but may be slower.") # Don't add to missing packages since it's optional and can cause build issues # If critical packages are missing, exit with instructions if missing_packages: logger.error("Critical dependencies missing:") for pkg in missing_packages: logger.error(f" - {pkg}") logger.error("Please ensure the space has these packages in requirements.txt") return False return True def main(): # Set up logging logger.info("Starting training process") # Parse arguments args = parse_args() # Check dependencies if not check_dependencies(): logger.error("Aborting due to missing critical dependencies") return 1 # Load environment variables load_env_variables() # Load all configurations try: configs = load_configs(args.config_dir) # Extract specific configs if not configs: logger.error("Failed to load configurations") return 1 # Verify configurations exist if "transformers" not in configs: logger.error("transformers_config.json not found or invalid") return 1 if "hardware" not in configs: logger.warning("hardware_config.json not found. Using default hardware configuration.") if "dataset" not in configs: logger.error("dataset_config.json not found or invalid") return 1 # Validate model configuration model_config = configs["transformers"] if not model_config.get("model", {}).get("name") and not model_config.get("model_name_or_path") and not model_config.get("model_name"): logger.error("Model name not specified in configuration") logger.error("Please ensure 'name' is specified under 'model' in transformers_config.json") return 1 logger.info(f"Model name: {model_config.get('model', {}).get('name') or model_config.get('model_name_or_path') or model_config.get('model_name')}") logger.info("All configurations loaded successfully") # Extract specific configs model_config = configs["transformers"] hardware_config = configs.get("hardware", {}) dataset_config = configs["dataset"] # Apply hardware-specific settings if available if hardware_config: training_opts = hardware_config.get("training_optimizations", {}) per_device_batch_size = training_opts.get("per_device_batch_size") gradient_accumulation = training_opts.get("gradient_accumulation_steps") if per_device_batch_size and model_config.get("training"): model_config["training"]["per_device_train_batch_size"] = per_device_batch_size logger.info(f"Applied hardware-specific batch size: {per_device_batch_size}") if gradient_accumulation and model_config.get("training"): model_config["training"]["gradient_accumulation_steps"] = gradient_accumulation logger.info(f"Applied hardware-specific gradient accumulation: {gradient_accumulation}") # Apply memory optimizations memory_opts = training_opts.get("memory_optimizations", {}) if memory_opts.get("use_gradient_checkpointing") is not None and model_config.get("training"): model_config["training"]["gradient_checkpointing"] = memory_opts["use_gradient_checkpointing"] except Exception as e: logger.error(f"Error loading configurations: {e}") return 1 # Set random seed for reproducibility seed = model_config.get("seed", 42) set_seed(seed) logger.info(f"Set random seed to {seed}") try: model, tokenizer = load_model_and_tokenizer(model_config) logger.info("Model and tokenizer loaded successfully") # Load dataset with proper mapping try: dataset = load_dataset_with_mapping(dataset_config) logger.info("Dataset loaded and prepared successfully") except Exception as e: logger.error(f"Error loading dataset: {e}") return 1 # Create data collator data_collator = SimpleDataCollator(tokenizer, dataset_config) # Simple logging callback class LoggingCallback(TrainerCallback): def __init__(self): self.last_log_time = time.time() def on_step_end(self, args, state, control, **kwargs): # Log every 50 steps or every 5 minutes, whichever comes first current_time = time.time() if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300): logger.info(f"Step {state.global_step}: Loss {state.log_history[-1]['loss'] if state.log_history else 'N/A'}") self.last_log_time = current_time # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence use_bf16 = model_config.get("bf16", False) or model_config.get("torch_dtype", "") == "bfloat16" use_fp16 = model_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set logger.info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}") # Set up training arguments logger.info("Setting up training arguments") training_args = TrainingArguments( output_dir=model_config.get("output_dir", "./results") or model_config.get("checkpointing", {}).get("output_dir", "./results"), num_train_epochs=model_config.get("training", {}).get("num_train_epochs", 3), per_device_train_batch_size=model_config.get("training", {}).get("per_device_train_batch_size", 24), gradient_accumulation_steps=model_config.get("training", {}).get("gradient_accumulation_steps", 2), learning_rate=model_config.get("training", {}).get("learning_rate", 2e-5), weight_decay=model_config.get("training", {}).get("weight_decay", 0.01), warmup_ratio=model_config.get("training", {}).get("warmup_ratio", 0.05), lr_scheduler_type=model_config.get("training", {}).get("lr_scheduler_type", "cosine"), logging_steps=model_config.get("training", {}).get("logging_steps", 10), save_strategy=model_config.get("checkpointing", {}).get("save_strategy", "steps"), save_steps=model_config.get("checkpointing", {}).get("save_steps", 100), save_total_limit=model_config.get("checkpointing", {}).get("save_total_limit", 3), fp16=use_fp16, bf16=use_bf16, max_grad_norm=model_config.get("training", {}).get("max_grad_norm", 1.0), push_to_hub=model_config.get("huggingface_hub", {}).get("push_to_hub", False), hub_model_id=model_config.get("huggingface_hub", {}).get("hub_model_id", None), hub_token=os.environ.get("HF_TOKEN", None), report_to="tensorboard", remove_unused_columns=False, # Keep all columns gradient_checkpointing=model_config.get("training", {}).get("gradient_checkpointing", True), dataloader_pin_memory=False, # Reduce memory usage optim=model_config.get("training", {}).get("optim", "adamw_torch"), ddp_find_unused_parameters=False, # Improve distributed training efficiency dataloader_drop_last=False, # Process all examples dataloader_num_workers=4, # Sequential data loading ) # Create a sequential sampler to ensure dataset is processed in order logger.info("Creating sequential sampler to maintain dataset order") # Create trainer with callback logger.info("Creating trainer") # Check if we should resume from checkpoint resume_from_checkpoint = False output_dir = model_config.get("output_dir", "./results") if os.path.exists(output_dir): checkpoints = [folder for folder in os.listdir(output_dir) if folder.startswith("checkpoint-")] if checkpoints: latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[1])) resume_from_checkpoint = os.path.join(output_dir, latest_checkpoint) logger.info(f"Found checkpoint: {resume_from_checkpoint}. Training will resume from this point.") trainer = Trainer( model=model, args=training_args, train_dataset=dataset, data_collator=data_collator, callbacks=[LoggingCallback()] ) # Override the default data loader to disable shuffling # This is necessary because TrainingArguments doesn't have a direct shuffle parameter def get_train_dataloader_no_shuffle(): """Create a train DataLoader with shuffling disabled.""" logger.info("Creating train dataloader with sequential sampler (no shuffling)") # Create a sequential sampler to ensure dataset is processed in order train_sampler = torch.utils.data.SequentialSampler(dataset) return torch.utils.data.DataLoader( dataset, batch_size=training_args.per_device_train_batch_size, sampler=train_sampler, # Use sequential sampler instead of shuffle parameter collate_fn=data_collator, drop_last=False, num_workers=0, pin_memory=False ) # Replace the default data loader with our non-shuffling version trainer.get_train_dataloader = get_train_dataloader_no_shuffle # Start training logger.info("Starting training") logger.info(f"Processing with batch size = {training_args.per_device_train_batch_size}, each entry processed independently") # Create a lock file to indicate training is in progress lock_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "TRAINING_IN_PROGRESS.lock") with open(lock_file, "w") as f: f.write(f"Training started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"Expected completion: After {training_args.num_train_epochs} epochs\n") f.write("DO NOT UPDATE OR RESTART THIS SPACE UNTIL TRAINING COMPLETES\n") logger.info(f"Created lock file: {lock_file}") try: trainer.train(resume_from_checkpoint=resume_from_checkpoint) logger.info("Training completed successfully") # Save model if model_config.get("push_to_hub", False): logger.info(f"Pushing model to hub: {model_config.get('hub_model_id')}") trainer.push_to_hub() logger.info("Model pushed to hub successfully") else: logger.info(f"Saving model to {model_config.get('output_dir', './results')}") trainer.save_model() logger.info("Model saved successfully") except Exception as e: logger.error(f"Training failed with error: {str(e)}") raise finally: # Remove the lock file when training completes or fails if os.path.exists(lock_file): os.remove(lock_file) logger.info(f"Removed lock file: {lock_file}") return 0 except Exception as e: logger.error(f"Error in main training loop: {str(e)}") return 1 if __name__ == "__main__": sys.exit(main())