# src/plotting.py import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import matplotlib.colors as mcolors from colorsys import rgb_to_hls, hls_to_rgb from collections import defaultdict import numpy as np import pandas as pd from config import LANGUAGE_NAMES def create_leaderboard_plot(leaderboard_df: pd.DataFrame, metric: str = 'quality_score') -> plt.Figure: """Create a horizontal bar chart showing model rankings.""" fig, ax = plt.subplots(figsize=(12, 8)) # Sort by the selected metric (descending) df_sorted = leaderboard_df.sort_values(metric, ascending=True) # Create color palette colors = plt.cm.viridis(np.linspace(0, 1, len(df_sorted))) # Create horizontal bar chart bars = ax.barh(range(len(df_sorted)), df_sorted[metric], color=colors) # Customize the plot ax.set_yticks(range(len(df_sorted))) ax.set_yticklabels(df_sorted['model_display_name']) ax.set_xlabel(f'{metric.replace("_", " ").title()} Score') ax.set_title(f'Model Leaderboard - {metric.replace("_", " ").title()}', fontsize=16, pad=20) # Add value labels on bars for i, (bar, value) in enumerate(zip(bars, df_sorted[metric])): ax.text(value + 0.001, bar.get_y() + bar.get_height()/2, f'{value:.3f}', ha='left', va='center', fontweight='bold') # Add grid for better readability ax.grid(axis='x', linestyle='--', alpha=0.7) ax.set_axisbelow(True) # Set x-axis limits with some padding max_val = df_sorted[metric].max() ax.set_xlim(0, max_val * 1.15) plt.tight_layout() return fig def create_detailed_comparison_plot(metrics_data: dict, model_names: list) -> plt.Figure: """Create detailed comparison plot similar to the original evaluation script.""" # Filter metrics_data to only include models in model_names filtered_metrics = {name: metrics_data[name] for name in model_names if name in metrics_data} if not filtered_metrics: # Create empty plot if no data fig, ax = plt.subplots(figsize=(10, 6)) ax.text(0.5, 0.5, 'No data available for comparison', ha='center', va='center', transform=ax.transAxes, fontsize=16) ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.axis('off') return fig return plot_translation_metric_comparison(filtered_metrics, metric='bleu') def plot_translation_metric_comparison(metrics_by_model: dict, metric: str = 'bleu') -> plt.Figure: """ Creates a grouped bar chart comparing a selected metric across translation models. Adapted from the original plotting code. """ # Split language pairs into xx_to_eng and eng_to_xx categories first_model_data = list(metrics_by_model.values())[0] xx_to_eng = [key for key in first_model_data.keys() if key.endswith('_to_eng') and key != 'averages'] eng_to_xx = [key for key in first_model_data.keys() if key.startswith('eng_to_') and key != 'averages'] # Function to create nice labels def format_label(label): if label.startswith("eng_to_"): source, target = "English", label.replace("eng_to_", "") target = LANGUAGE_NAMES.get(target, target) else: source, target = label.replace("_to_eng", ""), "English" source = LANGUAGE_NAMES.get(source, source) return f"{source}→{target}" # Extract metric values for each category def extract_metric_values(model_metrics, pairs, metric_name): return [model_metrics.get(pair, {}).get(metric_name, 0.0) for pair in pairs] xx_to_eng_data = { model_name: extract_metric_values(model_data, xx_to_eng, metric) for model_name, model_data in metrics_by_model.items() } eng_to_xx_data = { model_name: extract_metric_values(model_data, eng_to_xx, metric) for model_name, model_data in metrics_by_model.items() } averages_data = { model_name: [model_data.get("averages", {}).get(metric, 0.0)] for model_name, model_data in metrics_by_model.items() } # Set up plot with custom grid fig = plt.figure(figsize=(18, 12)) # Increased height for better spacing # Create a GridSpec with 1 row and 5 columns gs = gridspec.GridSpec(1, 5) # Colors for the models model_names = list(metrics_by_model.keys()) family_base_colors = { 'gemma': '#3274A1', 'nllb': '#7f7f7f', 'qwen': '#E1812C', 'google': '#3A923A', 'other': '#D62728', } # Identify the family for each model def get_family(model_name): model_lower = model_name.lower() if 'gemma' in model_lower: return 'gemma' elif 'qwen' in model_lower: return 'qwen' elif 'nllb' in model_lower: return 'nllb' elif 'google' in model_lower or model_name == 'google-translate': return 'google' else: return 'other' # Count how many models belong to each family family_counts = defaultdict(int) for model in model_names: family = get_family(model) family_counts[family] += 1 # Generate slightly varied lightness within each family colors = [] family_indices = defaultdict(int) for model in model_names: family = get_family(model) base_rgb = mcolors.to_rgb(family_base_colors[family]) h, l, s = rgb_to_hls(*base_rgb) index = family_indices[family] count = family_counts[family] # Vary lightness: from 0.35 to 0.65 if count == 1: new_l = l # Keep original for single models else: new_l = 0.65 - 0.3 * (index / max(count - 1, 1)) varied_rgb = hls_to_rgb(h, new_l, s) hex_color = mcolors.to_hex(varied_rgb) colors.append(hex_color) family_indices[family] += 1 bar_width = 0.2 opacity = 0.8 # Positions for the bars xx_to_eng_indices = np.arange(len(xx_to_eng)) eng_to_xx_indices = np.arange(len(eng_to_xx)) avg_index = np.array([0]) # Determine y-axis limits based on metric if metric in ['chrf', 'len_ratio']: y_max = 1.1 elif metric in ['cer', 'wer']: y_max = 1.0 elif metric == 'bleu': y_max = 65 # Increased from 55 to accommodate high scores elif metric in ['rouge1', 'rouge2', 'rougeL']: y_max = 1.0 elif metric == 'quality_score': y_max = 0.65 else: # Auto-scale based on data all_values = [] for data in [xx_to_eng_data, eng_to_xx_data, averages_data]: for model_data in data.values(): all_values.extend(model_data) y_max = max(all_values) * 1.1 if all_values else 1.0 # Format metric name for display metric_display = metric.upper() if metric in ['bleu', 'chrf', 'cer', 'wer'] else metric.replace('_', ' ').title() # Create bars for xx_to_eng (using first 2 columns) if xx_to_eng: ax1 = plt.subplot(gs[0, 0:2]) for i, (model_name, color) in enumerate(zip(model_names, colors)): if model_name in xx_to_eng_data: ax1.bar(xx_to_eng_indices + i*bar_width, xx_to_eng_data[model_name], bar_width, alpha=opacity, color=color, label=model_name) ax1.set_xlabel('Translation Direction') ax1.set_ylabel(f'{metric_display} Score') ax1.set_title(f'XX→English {metric_display} Performance') ax1.set_xticks(xx_to_eng_indices + bar_width) ax1.set_xticklabels([format_label(label) for label in xx_to_eng], rotation=45, ha='right') ax1.set_ylim(0, y_max) ax1.grid(axis='y', linestyle='--', alpha=0.7) # Create bars for eng_to_xx (using next 2 columns) if eng_to_xx: ax2 = plt.subplot(gs[0, 2:4]) for i, (model_name, color) in enumerate(zip(model_names, colors)): if model_name in eng_to_xx_data: ax2.bar(eng_to_xx_indices + i*bar_width, eng_to_xx_data[model_name], bar_width, alpha=opacity, color=color, label=model_name) ax2.set_xlabel('Translation Direction') ax2.set_ylabel(f'{metric_display} Score') ax2.set_title(f'English→XX {metric_display} Performance') ax2.set_xticks(eng_to_xx_indices + bar_width) ax2.set_xticklabels([format_label(label) for label in eng_to_xx], rotation=45, ha='right') ax2.set_ylim(0, y_max) ax2.grid(axis='y', linestyle='--', alpha=0.7) # Create bars for averages (using last column) ax3 = plt.subplot(gs[0, 4]) for i, (model_name, color) in enumerate(zip(model_names, colors)): if model_name in averages_data: ax3.bar(avg_index + i*bar_width, averages_data[model_name], bar_width, alpha=opacity, color=color, label=model_name) ax3.set_xlabel('Overall') ax3.set_ylabel(f'{metric_display} Score') ax3.set_title(f'Average {metric_display}') ax3.set_xticks(avg_index + bar_width) ax3.set_xticklabels(['Average']) ax3.set_ylim(0, y_max) ax3.grid(axis='y', linestyle='--', alpha=0.7) ax3.legend() # Add note for metrics where lower is better if metric in ['cer', 'wer']: plt.figtext(0.5, 0.01, "Note: Lower values indicate better performance for this metric", ha='center', fontsize=12, style='italic') # Add an overall title and adjust layout model_list = ' vs '.join(model_names) plt.suptitle(f'{metric_display} Score Comparison: {model_list}', fontsize=16, y=0.98) plt.tight_layout(rect=[0, 0.02, 1, 0.95]) return fig def create_summary_metrics_plot(leaderboard_df: pd.DataFrame) -> plt.Figure: """Create a summary plot showing multiple metrics for top models.""" if leaderboard_df.empty: fig, ax = plt.subplots(figsize=(10, 6)) ax.text(0.5, 0.5, 'No data available', ha='center', va='center', transform=ax.transAxes, fontsize=16) return fig # Select top 5 models by quality score top_models = leaderboard_df.nlargest(5, 'quality_score') # Metrics to display metrics = ['bleu', 'chrf', 'quality_score'] metric_labels = ['BLEU', 'ChrF', 'Quality Score'] fig, axes = plt.subplots(1, 3, figsize=(15, 6)) for i, (metric, label) in enumerate(zip(metrics, metric_labels)): ax = axes[i] # Sort by current metric sorted_models = top_models.sort_values(metric, ascending=True) # Create horizontal bar chart bars = ax.barh(range(len(sorted_models)), sorted_models[metric], color=plt.cm.viridis(np.linspace(0, 1, len(sorted_models)))) ax.set_yticks(range(len(sorted_models))) ax.set_yticklabels(sorted_models['model_display_name']) ax.set_xlabel(f'{label} Score') ax.set_title(f'Top Models - {label}') ax.grid(axis='x', linestyle='--', alpha=0.7) # Add value labels for j, (bar, value) in enumerate(zip(bars, sorted_models[metric])): ax.text(value + value*0.01, bar.get_y() + bar.get_height()/2, f'{value:.3f}', ha='left', va='center', fontsize=10) plt.tight_layout() return fig