""" Main training script for Code Comment Quality Classifier """ import os import argparse import logging from pathlib import Path from transformers import ( Trainer, TrainingArguments, EarlyStoppingCallback ) from src import ( load_config, prepare_datasets_for_training, create_model, get_model_size, get_trainable_params, compute_metrics_factory ) def setup_logging(config: dict) -> None: """Setup logging configuration.""" log_config = config.get('logging', {}) log_level = getattr(logging, log_config.get('level', 'INFO')) log_file = log_config.get('log_file', './results/training.log') # Create log directory if needed log_dir = os.path.dirname(log_file) if log_dir: os.makedirs(log_dir, exist_ok=True) # Configure logging logging.basicConfig( level=log_level, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_file), logging.StreamHandler() ] ) def main(config_path: str = "config.yaml"): """ Main training function. Args: config_path: Path to configuration file """ print("=" * 60) print("Code Comment Quality Classifier - Training") print("=" * 60) # Load configuration print("\n[1/7] Loading configuration...") config = load_config(config_path) print(f"✓ Configuration loaded from {config_path}") # Validate configuration from src.validation import validate_config config_errors = validate_config(config) if config_errors: print("\n✗ Configuration validation errors:") for error in config_errors: print(f" - {error}") raise ValueError("Invalid configuration. Please fix the errors above.") # Setup logging setup_logging(config) logging.info("Starting training process") # Prepare datasets print("\n[2/7] Preparing datasets...") tokenized_datasets, label2id, id2label, tokenizer = prepare_datasets_for_training(config_path) print(f"✓ Train samples: {len(tokenized_datasets['train'])}") print(f"✓ Validation samples: {len(tokenized_datasets['validation'])}") print(f"✓ Test samples: {len(tokenized_datasets['test'])}") logging.info(f"Dataset sizes - Train: {len(tokenized_datasets['train'])}, " f"Val: {len(tokenized_datasets['validation'])}, " f"Test: {len(tokenized_datasets['test'])}") # Create model print("\n[3/7] Loading model...") dropout = config['model'].get('dropout') model = create_model( model_name=config['model']['name'], num_labels=config['model']['num_labels'], label2id=label2id, id2label=id2label, dropout=dropout ) model_size = get_model_size(model) params_info = get_trainable_params(model) print(f"✓ Model: {config['model']['name']}") print(f"✓ Total Parameters: {model_size:.2f}M") print(f"✓ Trainable Parameters: {params_info['trainable'] / 1e6:.2f}M") logging.info(f"Model: {config['model']['name']}, Size: {model_size:.2f}M parameters") # Setup training arguments print("\n[4/7] Setting up training...") output_dir = config['training']['output_dir'] os.makedirs(output_dir, exist_ok=True) training_args = TrainingArguments( output_dir=output_dir, num_train_epochs=config['training']['num_train_epochs'], per_device_train_batch_size=config['training']['per_device_train_batch_size'], per_device_eval_batch_size=config['training']['per_device_eval_batch_size'], gradient_accumulation_steps=config['training'].get('gradient_accumulation_steps', 1), learning_rate=config['training']['learning_rate'], lr_scheduler_type=config['training'].get('lr_scheduler_type', 'linear'), weight_decay=config['training']['weight_decay'], warmup_steps=config['training'].get('warmup_steps'), warmup_ratio=config['training'].get('warmup_ratio'), logging_dir=os.path.join(output_dir, 'logs'), logging_steps=config['training']['logging_steps'], eval_steps=config['training']['eval_steps'], save_steps=config['training']['save_steps'], save_total_limit=config['training'].get('save_total_limit', 3), eval_strategy=config['training']['evaluation_strategy'], save_strategy=config['training']['save_strategy'], load_best_model_at_end=config['training']['load_best_model_at_end'], metric_for_best_model=config['training']['metric_for_best_model'], greater_is_better=config['training'].get('greater_is_better', True), seed=config['training']['seed'], fp16=config['training'].get('fp16', False), dataloader_num_workers=config['training'].get('dataloader_num_workers', 4), dataloader_pin_memory=config['training'].get('dataloader_pin_memory', True), remove_unused_columns=config['training'].get('remove_unused_columns', True), report_to=config['training'].get('report_to', ['none']), push_to_hub=False, ) # Create compute_metrics function with label mapping compute_metrics_fn = compute_metrics_factory(id2label) # Setup callbacks callbacks = [] if config['training'].get('early_stopping_patience'): early_stopping = EarlyStoppingCallback( early_stopping_patience=config['training']['early_stopping_patience'], early_stopping_threshold=config['training'].get('early_stopping_threshold', 0.0) ) callbacks.append(early_stopping) logging.info(f"Early stopping enabled with patience={config['training']['early_stopping_patience']}") # Create trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets['train'], eval_dataset=tokenized_datasets['validation'], tokenizer=tokenizer, compute_metrics=compute_metrics_fn, callbacks=callbacks ) print("✓ Trainer initialized") logging.info("Trainer initialized with all configurations") # Train model print("\n[5/7] Training model...") print("-" * 60) logging.info("Starting training") train_result = trainer.train() logging.info(f"Training completed. Train loss: {train_result.training_loss:.4f}") # Save final model print("\n[6/7] Saving model...") final_model_path = os.path.join(output_dir, 'final_model') trainer.save_model(final_model_path) tokenizer.save_pretrained(final_model_path) print(f"✓ Model saved to {final_model_path}") logging.info(f"Model saved to {final_model_path}") # Evaluate on test set print("\n[7/7] Evaluating on test set...") print("=" * 60) print("Final Evaluation on Test Set") print("=" * 60) test_results = trainer.evaluate(tokenized_datasets['test'], metric_key_prefix='test') print("\nTest Results:") for key, value in sorted(test_results.items()): if isinstance(value, float): print(f" {key}: {value:.4f}") logging.info("Test evaluation completed") print("\n" + "=" * 60) print("Training Complete! 🎉") print("=" * 60) print(f"\nModel location: {final_model_path}") print("\nNext steps:") print("1. Run evaluation: python scripts/evaluate.py") print("2. Test inference: python inference.py") print("3. Upload to Hub: python scripts/upload_to_hub.py") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train Code Comment Quality Classifier") parser.add_argument( "--config", type=str, default="config.yaml", help="Path to configuration file" ) args = parser.parse_args() main(args.config)