|
|
""" |
|
|
Main training script for Code Comment Quality Classifier |
|
|
""" |
|
|
import os |
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from transformers import ( |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
EarlyStoppingCallback |
|
|
) |
|
|
from src import ( |
|
|
load_config, |
|
|
prepare_datasets_for_training, |
|
|
create_model, |
|
|
get_model_size, |
|
|
get_trainable_params, |
|
|
compute_metrics_factory |
|
|
) |
|
|
|
|
|
|
|
|
def setup_logging(config: dict) -> None: |
|
|
"""Setup logging configuration.""" |
|
|
log_config = config.get('logging', {}) |
|
|
log_level = getattr(logging, log_config.get('level', 'INFO')) |
|
|
log_file = log_config.get('log_file', './results/training.log') |
|
|
|
|
|
|
|
|
log_dir = os.path.dirname(log_file) |
|
|
if log_dir: |
|
|
os.makedirs(log_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=log_level, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.FileHandler(log_file), |
|
|
logging.StreamHandler() |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def main(config_path: str = "config.yaml"): |
|
|
""" |
|
|
Main training function. |
|
|
|
|
|
Args: |
|
|
config_path: Path to configuration file |
|
|
""" |
|
|
print("=" * 60) |
|
|
print("Code Comment Quality Classifier - Training") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\n[1/7] Loading configuration...") |
|
|
config = load_config(config_path) |
|
|
print(f"✓ Configuration loaded from {config_path}") |
|
|
|
|
|
|
|
|
from src.validation import validate_config |
|
|
config_errors = validate_config(config) |
|
|
if config_errors: |
|
|
print("\n✗ Configuration validation errors:") |
|
|
for error in config_errors: |
|
|
print(f" - {error}") |
|
|
raise ValueError("Invalid configuration. Please fix the errors above.") |
|
|
|
|
|
|
|
|
setup_logging(config) |
|
|
logging.info("Starting training process") |
|
|
|
|
|
|
|
|
print("\n[2/7] Preparing datasets...") |
|
|
tokenized_datasets, label2id, id2label, tokenizer = prepare_datasets_for_training(config_path) |
|
|
print(f"✓ Train samples: {len(tokenized_datasets['train'])}") |
|
|
print(f"✓ Validation samples: {len(tokenized_datasets['validation'])}") |
|
|
print(f"✓ Test samples: {len(tokenized_datasets['test'])}") |
|
|
logging.info(f"Dataset sizes - Train: {len(tokenized_datasets['train'])}, " |
|
|
f"Val: {len(tokenized_datasets['validation'])}, " |
|
|
f"Test: {len(tokenized_datasets['test'])}") |
|
|
|
|
|
|
|
|
print("\n[3/7] Loading model...") |
|
|
dropout = config['model'].get('dropout') |
|
|
model = create_model( |
|
|
model_name=config['model']['name'], |
|
|
num_labels=config['model']['num_labels'], |
|
|
label2id=label2id, |
|
|
id2label=id2label, |
|
|
dropout=dropout |
|
|
) |
|
|
model_size = get_model_size(model) |
|
|
params_info = get_trainable_params(model) |
|
|
print(f"✓ Model: {config['model']['name']}") |
|
|
print(f"✓ Total Parameters: {model_size:.2f}M") |
|
|
print(f"✓ Trainable Parameters: {params_info['trainable'] / 1e6:.2f}M") |
|
|
logging.info(f"Model: {config['model']['name']}, Size: {model_size:.2f}M parameters") |
|
|
|
|
|
|
|
|
print("\n[4/7] Setting up training...") |
|
|
output_dir = config['training']['output_dir'] |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
num_train_epochs=config['training']['num_train_epochs'], |
|
|
per_device_train_batch_size=config['training']['per_device_train_batch_size'], |
|
|
per_device_eval_batch_size=config['training']['per_device_eval_batch_size'], |
|
|
gradient_accumulation_steps=config['training'].get('gradient_accumulation_steps', 1), |
|
|
learning_rate=config['training']['learning_rate'], |
|
|
lr_scheduler_type=config['training'].get('lr_scheduler_type', 'linear'), |
|
|
weight_decay=config['training']['weight_decay'], |
|
|
warmup_steps=config['training'].get('warmup_steps'), |
|
|
warmup_ratio=config['training'].get('warmup_ratio'), |
|
|
logging_dir=os.path.join(output_dir, 'logs'), |
|
|
logging_steps=config['training']['logging_steps'], |
|
|
eval_steps=config['training']['eval_steps'], |
|
|
save_steps=config['training']['save_steps'], |
|
|
save_total_limit=config['training'].get('save_total_limit', 3), |
|
|
eval_strategy=config['training']['evaluation_strategy'], |
|
|
save_strategy=config['training']['save_strategy'], |
|
|
load_best_model_at_end=config['training']['load_best_model_at_end'], |
|
|
metric_for_best_model=config['training']['metric_for_best_model'], |
|
|
greater_is_better=config['training'].get('greater_is_better', True), |
|
|
seed=config['training']['seed'], |
|
|
fp16=config['training'].get('fp16', False), |
|
|
dataloader_num_workers=config['training'].get('dataloader_num_workers', 4), |
|
|
dataloader_pin_memory=config['training'].get('dataloader_pin_memory', True), |
|
|
remove_unused_columns=config['training'].get('remove_unused_columns', True), |
|
|
report_to=config['training'].get('report_to', ['none']), |
|
|
push_to_hub=False, |
|
|
) |
|
|
|
|
|
|
|
|
compute_metrics_fn = compute_metrics_factory(id2label) |
|
|
|
|
|
|
|
|
callbacks = [] |
|
|
if config['training'].get('early_stopping_patience'): |
|
|
early_stopping = EarlyStoppingCallback( |
|
|
early_stopping_patience=config['training']['early_stopping_patience'], |
|
|
early_stopping_threshold=config['training'].get('early_stopping_threshold', 0.0) |
|
|
) |
|
|
callbacks.append(early_stopping) |
|
|
logging.info(f"Early stopping enabled with patience={config['training']['early_stopping_patience']}") |
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_datasets['train'], |
|
|
eval_dataset=tokenized_datasets['validation'], |
|
|
tokenizer=tokenizer, |
|
|
compute_metrics=compute_metrics_fn, |
|
|
callbacks=callbacks |
|
|
) |
|
|
|
|
|
print("✓ Trainer initialized") |
|
|
logging.info("Trainer initialized with all configurations") |
|
|
|
|
|
|
|
|
print("\n[5/7] Training model...") |
|
|
print("-" * 60) |
|
|
logging.info("Starting training") |
|
|
train_result = trainer.train() |
|
|
logging.info(f"Training completed. Train loss: {train_result.training_loss:.4f}") |
|
|
|
|
|
|
|
|
print("\n[6/7] Saving model...") |
|
|
final_model_path = os.path.join(output_dir, 'final_model') |
|
|
trainer.save_model(final_model_path) |
|
|
tokenizer.save_pretrained(final_model_path) |
|
|
print(f"✓ Model saved to {final_model_path}") |
|
|
logging.info(f"Model saved to {final_model_path}") |
|
|
|
|
|
|
|
|
print("\n[7/7] Evaluating on test set...") |
|
|
print("=" * 60) |
|
|
print("Final Evaluation on Test Set") |
|
|
print("=" * 60) |
|
|
test_results = trainer.evaluate(tokenized_datasets['test'], metric_key_prefix='test') |
|
|
|
|
|
print("\nTest Results:") |
|
|
for key, value in sorted(test_results.items()): |
|
|
if isinstance(value, float): |
|
|
print(f" {key}: {value:.4f}") |
|
|
logging.info("Test evaluation completed") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Training Complete! 🎉") |
|
|
print("=" * 60) |
|
|
print(f"\nModel location: {final_model_path}") |
|
|
print("\nNext steps:") |
|
|
print("1. Run evaluation: python scripts/evaluate.py") |
|
|
print("2. Test inference: python inference.py") |
|
|
print("3. Upload to Hub: python scripts/upload_to_hub.py") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Train Code Comment Quality Classifier") |
|
|
parser.add_argument( |
|
|
"--config", |
|
|
type=str, |
|
|
default="config.yaml", |
|
|
help="Path to configuration file" |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
main(args.config) |
|
|
|