Snaseem2026's picture
Upload train.py with huggingface_hub
e4495a9 verified
"""
Main training script for Code Comment Quality Classifier
"""
import os
import argparse
import logging
from pathlib import Path
from transformers import (
Trainer,
TrainingArguments,
EarlyStoppingCallback
)
from src import (
load_config,
prepare_datasets_for_training,
create_model,
get_model_size,
get_trainable_params,
compute_metrics_factory
)
def setup_logging(config: dict) -> None:
"""Setup logging configuration."""
log_config = config.get('logging', {})
log_level = getattr(logging, log_config.get('level', 'INFO'))
log_file = log_config.get('log_file', './results/training.log')
# Create log directory if needed
log_dir = os.path.dirname(log_file)
if log_dir:
os.makedirs(log_dir, exist_ok=True)
# Configure logging
logging.basicConfig(
level=log_level,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
def main(config_path: str = "config.yaml"):
"""
Main training function.
Args:
config_path: Path to configuration file
"""
print("=" * 60)
print("Code Comment Quality Classifier - Training")
print("=" * 60)
# Load configuration
print("\n[1/7] Loading configuration...")
config = load_config(config_path)
print(f"✓ Configuration loaded from {config_path}")
# Validate configuration
from src.validation import validate_config
config_errors = validate_config(config)
if config_errors:
print("\n✗ Configuration validation errors:")
for error in config_errors:
print(f" - {error}")
raise ValueError("Invalid configuration. Please fix the errors above.")
# Setup logging
setup_logging(config)
logging.info("Starting training process")
# Prepare datasets
print("\n[2/7] Preparing datasets...")
tokenized_datasets, label2id, id2label, tokenizer = prepare_datasets_for_training(config_path)
print(f"✓ Train samples: {len(tokenized_datasets['train'])}")
print(f"✓ Validation samples: {len(tokenized_datasets['validation'])}")
print(f"✓ Test samples: {len(tokenized_datasets['test'])}")
logging.info(f"Dataset sizes - Train: {len(tokenized_datasets['train'])}, "
f"Val: {len(tokenized_datasets['validation'])}, "
f"Test: {len(tokenized_datasets['test'])}")
# Create model
print("\n[3/7] Loading model...")
dropout = config['model'].get('dropout')
model = create_model(
model_name=config['model']['name'],
num_labels=config['model']['num_labels'],
label2id=label2id,
id2label=id2label,
dropout=dropout
)
model_size = get_model_size(model)
params_info = get_trainable_params(model)
print(f"✓ Model: {config['model']['name']}")
print(f"✓ Total Parameters: {model_size:.2f}M")
print(f"✓ Trainable Parameters: {params_info['trainable'] / 1e6:.2f}M")
logging.info(f"Model: {config['model']['name']}, Size: {model_size:.2f}M parameters")
# Setup training arguments
print("\n[4/7] Setting up training...")
output_dir = config['training']['output_dir']
os.makedirs(output_dir, exist_ok=True)
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=config['training']['num_train_epochs'],
per_device_train_batch_size=config['training']['per_device_train_batch_size'],
per_device_eval_batch_size=config['training']['per_device_eval_batch_size'],
gradient_accumulation_steps=config['training'].get('gradient_accumulation_steps', 1),
learning_rate=config['training']['learning_rate'],
lr_scheduler_type=config['training'].get('lr_scheduler_type', 'linear'),
weight_decay=config['training']['weight_decay'],
warmup_steps=config['training'].get('warmup_steps'),
warmup_ratio=config['training'].get('warmup_ratio'),
logging_dir=os.path.join(output_dir, 'logs'),
logging_steps=config['training']['logging_steps'],
eval_steps=config['training']['eval_steps'],
save_steps=config['training']['save_steps'],
save_total_limit=config['training'].get('save_total_limit', 3),
eval_strategy=config['training']['evaluation_strategy'],
save_strategy=config['training']['save_strategy'],
load_best_model_at_end=config['training']['load_best_model_at_end'],
metric_for_best_model=config['training']['metric_for_best_model'],
greater_is_better=config['training'].get('greater_is_better', True),
seed=config['training']['seed'],
fp16=config['training'].get('fp16', False),
dataloader_num_workers=config['training'].get('dataloader_num_workers', 4),
dataloader_pin_memory=config['training'].get('dataloader_pin_memory', True),
remove_unused_columns=config['training'].get('remove_unused_columns', True),
report_to=config['training'].get('report_to', ['none']),
push_to_hub=False,
)
# Create compute_metrics function with label mapping
compute_metrics_fn = compute_metrics_factory(id2label)
# Setup callbacks
callbacks = []
if config['training'].get('early_stopping_patience'):
early_stopping = EarlyStoppingCallback(
early_stopping_patience=config['training']['early_stopping_patience'],
early_stopping_threshold=config['training'].get('early_stopping_threshold', 0.0)
)
callbacks.append(early_stopping)
logging.info(f"Early stopping enabled with patience={config['training']['early_stopping_patience']}")
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['validation'],
tokenizer=tokenizer,
compute_metrics=compute_metrics_fn,
callbacks=callbacks
)
print("✓ Trainer initialized")
logging.info("Trainer initialized with all configurations")
# Train model
print("\n[5/7] Training model...")
print("-" * 60)
logging.info("Starting training")
train_result = trainer.train()
logging.info(f"Training completed. Train loss: {train_result.training_loss:.4f}")
# Save final model
print("\n[6/7] Saving model...")
final_model_path = os.path.join(output_dir, 'final_model')
trainer.save_model(final_model_path)
tokenizer.save_pretrained(final_model_path)
print(f"✓ Model saved to {final_model_path}")
logging.info(f"Model saved to {final_model_path}")
# Evaluate on test set
print("\n[7/7] Evaluating on test set...")
print("=" * 60)
print("Final Evaluation on Test Set")
print("=" * 60)
test_results = trainer.evaluate(tokenized_datasets['test'], metric_key_prefix='test')
print("\nTest Results:")
for key, value in sorted(test_results.items()):
if isinstance(value, float):
print(f" {key}: {value:.4f}")
logging.info("Test evaluation completed")
print("\n" + "=" * 60)
print("Training Complete! 🎉")
print("=" * 60)
print(f"\nModel location: {final_model_path}")
print("\nNext steps:")
print("1. Run evaluation: python scripts/evaluate.py")
print("2. Test inference: python inference.py")
print("3. Upload to Hub: python scripts/upload_to_hub.py")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train Code Comment Quality Classifier")
parser.add_argument(
"--config",
type=str,
default="config.yaml",
help="Path to configuration file"
)
args = parser.parse_args()
main(args.config)