diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- +# <<< Keep all existing imports >>> import os import json import pandas as pd @@ -13,7 +15,9 @@ import time from huggingface_hub import hf_hub_download import psutil import gc +import atexit # Import atexit +# <<< Keep SUBJECT_TRANS and MODEL_TRANS dictionaries >>> # 翻译表 SUBJECT_TRANS = { "代数": "Algebra", @@ -21,7 +25,7 @@ SUBJECT_TRANS = { "几何": "Geometry", "组合": "Combinatorics" } - +# MODEL_TRANS MODEL_TRANS = { "acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B", "deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B", @@ -51,109 +55,206 @@ MODEL_TRANS = { "gemini-2.5-pro-exp-03-25": "Gemini 2.5 Pro Exp 0325", "o3-mini-high": "OpenAI o3-mini (high)", "qwen3-0.6b": "Qwen3-0.6B" - # 添加更多模型映射 } -# Configure matplotlib for better display +# <<< Keep Matplotlib configuration >>> plt.style.use('ggplot') mpl.rcParams['figure.figsize'] = (10, 6) mpl.rcParams['font.size'] = 10 -# Constants +# <<< Keep Constants >>> DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] -# 全局数据库实例 +# Global database instance db = None class ModelDatabase: - """Database access class""" - + """Database access class - Optimized to use in-memory database""" def __init__(self, db_path): - """Initialize database connection""" + """Initialize database connection by copying disk DB to memory.""" self.db_path = db_path - # Use connection pool pattern to avoid too many connections - self.conn = sqlite3.connect(db_path, check_same_thread=False, isolation_level=None, timeout=60) - self.conn.execute("PRAGMA journal_mode = WAL") # Use Write-Ahead Logging for better performance - self.conn.execute("PRAGMA synchronous = NORMAL") # Reduce synchronization overhead - self.conn.execute("PRAGMA cache_size = -8000") # 8MB cache (比原来大4倍) - self.conn.execute("PRAGMA temp_store = MEMORY") # 临时表存储在内存中 - self.conn.execute("PRAGMA mmap_size = 8589934592") # 尝试使用8GB内存映射 - self.conn.row_factory = sqlite3.Row - - # 创建索引以加速查询 - self._ensure_indices() - - # 初始化模型名称映射 - self.model_display_to_real = {} - self.comp_model_display_to_real = {} - - # 初始化缓存 + self.conn = None self._cache = {} self._problem_cache = {} self._response_cache = {} - + self.model_display_to_real = {} + self.comp_model_display_to_real = {} + + disk_conn = None + try: + # 1. Connect to the source disk database in read-only mode + print(f"Connecting to source database (read-only): {db_path}") + # Ensure the file exists before trying to connect + if not os.path.exists(db_path): + raise FileNotFoundError(f"Database file not found at {db_path}") + disk_conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True, check_same_thread=False, timeout=120) # Increased timeout + print("Applying PRAGMAs to source connection for backup performance...") + disk_conn.execute("PRAGMA journal_mode = OFF") + disk_conn.execute("PRAGMA synchronous = OFF") + # Use a larger cache for reading from disk, e.g., 2GB = -2097152 KiB + disk_conn.execute("PRAGMA cache_size = -2097152") + disk_conn.execute("PRAGMA temp_store = MEMORY") + disk_conn.execute("PRAGMA locking_mode = EXCLUSIVE") # Prevent interference during backup + + # 2. Connect to the target in-memory database + print("Creating in-memory database...") + # Increase timeout for potential long operations on the in-memory DB too + self.conn = sqlite3.connect(':memory:', check_same_thread=False, timeout=120) + self.conn.row_factory = sqlite3.Row # Use Row factory for dict-like access + + # 3. Backup data from disk to memory + print("Starting database backup from disk to memory (this may take a while)...") + start_backup = time.time() + # Use a context manager for the destination connection to handle commits/rollbacks + with self.conn: + disk_conn.backup(self.conn) + end_backup = time.time() + print(f"Database backup completed in {end_backup - start_backup:.2f} seconds.") + + # 4. Apply PRAGMAs suitable for the in-memory database + print("Applying PRAGMAs to in-memory database...") + # temp_store=MEMORY is default for :memory:, but explicit is fine + self.conn.execute("PRAGMA temp_store = MEMORY") + # cache_size might still help slightly, but OS caching is dominant. Can be omitted. + # self.conn.execute("PRAGMA cache_size = -4194304") # e.g., 4GB cache within RAM + + # 5. Ensure indices exist on the in-memory database *after* data loading + print("Creating indices on in-memory database...") + start_index = time.time() + self._ensure_indices() # This now operates on self.conn (the memory DB) + end_index = time.time() + print(f"Index creation completed in {end_index - start_index:.2f} seconds.") + + except sqlite3.Error as e: + print(f"SQLite error during database initialization: {e}") + if self.conn: + self.conn.close() + self.conn = None + raise # Re-raise the exception to signal failure + except FileNotFoundError as e: + print(f"Error: {e}") + raise + except Exception as e: + print(f"Unexpected error during database initialization: {e}") + if self.conn: + self.conn.close() + self.conn = None + raise + finally: + # 6. Close the disk connection, it's no longer needed + if disk_conn: + disk_conn.close() + print("Closed connection to disk database.") + + if self.conn: + print("In-memory database initialized successfully.") + else: + print("Error: In-memory database connection failed.") + raise RuntimeError("Failed to establish in-memory database connection.") + + def _ensure_indices(self): - """确保数据库有必要的索引""" + """Ensure necessary indices exist on the database connection (self.conn).""" + if not self.conn: + print("Error: Connection not established. Cannot ensure indices.") + return try: cursor = self.conn.cursor() - # 添加最常用查询的索引 + print("Creating index: idx_responses_model_dataset") cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_model_dataset ON responses(model_name, dataset)") + print("Creating index: idx_responses_unique_id") cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_unique_id ON responses(unique_id)") + print("Creating index: idx_problems_unique_id") cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_unique_id ON problems(unique_id)") - cursor.execute("ANALYZE") # 分析表以优化查询计划 - except Exception as e: - pass - + print("Creating index: idx_problems_subject") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_subject ON problems(subject)") + # Analyze the tables after creating indices for optimal query plans + print("Running ANALYZE...") + cursor.execute("ANALYZE") + self.conn.commit() # Commit index creation and analysis + print("Indices created and table analyzed successfully.") + except sqlite3.Error as e: + # Log error but don't necessarily crash the app + print(f"Warning: Could not create or analyze indices: {e}") + # Attempt rollback if something failed partially + try: + self.conn.rollback() + except sqlite3.Error as rb_e: + print(f"Rollback attempt failed after index error: {rb_e}") + # Depending on severity, you might want to raise e here + + # <<< Methods get_available_models, get_available_datasets, get_model_statistics, >>> + # <<< get_all_model_accuracies, get_problems_by_model_dataset, get_problem_data, >>> + # <<< get_model_responses, clear_cache are modified to: >>> + # <<< 1. Remove INDEXED BY hints >>> + # <<< 2. Add checks for self.conn existence >>> + # <<< 3. Improve error logging >>> + def get_available_models(self): """Get list of all available models""" - # 缓存在实例变量中 - if hasattr(self, '_models_cache') and self._models_cache: + if not self.conn: return [] + # Check cache first + if hasattr(self, '_models_cache') and self._models_cache is not None: return self._models_cache - try: cursor = self.conn.cursor() + # Query without explicit index hints cursor.execute("SELECT DISTINCT model_name FROM responses ORDER BY model_name") models = [row['model_name'] for row in cursor.fetchall()] - self._models_cache = models # 存储到实例缓存 + self._models_cache = models # Store in instance cache return models - except sqlite3.OperationalError: - return [] - + except sqlite3.Error as e: + print(f"Database error in get_available_models: {e}") + return [] # Return empty list on error + def get_available_datasets(self): """Get list of all available datasets""" - # 缓存在实例变量中 - if hasattr(self, '_datasets_cache') and self._datasets_cache: + if not self.conn: return DATASETS # Fallback if connection failed + # Check cache first + if hasattr(self, '_datasets_cache') and self._datasets_cache is not None: return self._datasets_cache - try: cursor = self.conn.cursor() + # Query without explicit index hints cursor.execute("SELECT DISTINCT dataset FROM responses ORDER BY dataset") + # Ensure uppercase consistency datasets = [row['dataset'].upper() for row in cursor.fetchall()] - self._datasets_cache = datasets # 存储到实例缓存 + self._datasets_cache = datasets # Store in instance cache return datasets - except sqlite3.OperationalError: - return DATASETS - + except sqlite3.Error as e: + print(f"Database error in get_available_datasets: {e}") + return DATASETS # Fallback on error + def get_model_statistics(self, model_name, dataset): """Get statistics for a model on a specific dataset""" + if not self.conn: return [["Database Error", "No connection"]] + # Sanitize inputs if hasattr(model_name, 'value'): model_name = model_name.value if hasattr(dataset, 'value'): dataset = dataset.value - + if not model_name or not dataset: return [["Input Error", "Missing model or dataset"]] + cache_key = f"stats_{model_name}_{dataset}" - if not hasattr(self, '_cache'): self._cache = {} if cache_key in self._cache: return self._cache[cache_key] - - cursor = self.conn.cursor() + + stats_data = [] try: - # 优化查询1: 整体准确率 - 使用索引提示加速 + cursor = self.conn.cursor() + # Query 1: Overall accuracy - No INDEXED BY hint cursor.execute(""" SELECT COUNT(*) as total_samples, AVG(correctness) as accuracy - FROM responses INDEXED BY idx_responses_model_dataset + FROM responses WHERE model_name = ? AND dataset = ? """, (model_name, dataset.lower())) overall_stats = cursor.fetchone() - - # 优化查询2: 按学科统计 - 避免子查询和复杂JOIN + + if overall_stats and overall_stats['accuracy'] is not None: + stats_data.append(["Overall Acc.", f"{overall_stats['accuracy']:.2%}"]) + elif overall_stats and overall_stats['total_samples'] == 0: + stats_data.append(["Overall Acc.", "No Samples"]) + else: + stats_data.append(["Overall Acc.", "N/A"]) + + # Query 2: Per-subject statistics - No INDEXED BY hint cursor.execute(""" SELECT p.subject, COUNT(r.id) as sample_count, AVG(r.correctness) as accuracy FROM responses r JOIN problems p ON r.unique_id = p.unique_id @@ -161,1471 +262,1488 @@ class ModelDatabase: GROUP BY p.subject ORDER BY p.subject """, (model_name, dataset.lower())) subject_stats_rows = cursor.fetchall() - - stats_data = [] - if overall_stats and overall_stats['accuracy'] is not None: - stats_data.append(["Overall Acc.", f"{overall_stats['accuracy']:.2%}"]) - else: - stats_data.append(["Overall Acc.", "N/A"]) for subject_row in subject_stats_rows: acc_val = f"{subject_row['accuracy']:.2%}" if subject_row['accuracy'] is not None else "N/A" subject_name = subject_row['subject'] - # 使用翻译表翻译科目名称 translated_subject = SUBJECT_TRANS.get(subject_name, subject_name) stats_data.append([f"{translated_subject} Acc.", acc_val]) - - self._cache[cache_key] = stats_data + + self._cache[cache_key] = stats_data # Cache the result return stats_data - except sqlite3.OperationalError: - return [["Database Error", "No data available"]] - + except sqlite3.Error as e: + print(f"Database error in get_model_statistics({model_name}, {dataset}): {e}") + # Return partial data if overall stats succeeded but subject failed? Or just error. + return [["Database Error", f"Query failed: {e}"]] + def get_all_model_accuracies(self, dataset): - """获取所有模型在特定数据集上的准确率 (优化版本)""" + """获取所有模型在特定数据集上的准确率""" + if not self.conn: return [] if hasattr(dataset, 'value'): dataset = dataset.value + if not dataset: return [] + cache_key = f"all_accuracies_{dataset}" - if not hasattr(self, '_cache'): self._cache = {} if cache_key in self._cache: return self._cache[cache_key] + try: cursor = self.conn.cursor() - # 使用索引提示加速查询 + # No INDEXED BY hint needed, rely on idx_responses_model_dataset cursor.execute(""" SELECT model_name, AVG(correctness) as accuracy - FROM responses INDEXED BY idx_responses_model_dataset + FROM responses WHERE dataset = ? GROUP BY model_name ORDER BY accuracy DESC """, (dataset.lower(),)) - results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall()] - self._cache[cache_key] = results + # Fetchall directly into list comprehension + results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall() if row['accuracy'] is not None] + self._cache[cache_key] = results # Cache result return results - except sqlite3.OperationalError: + except sqlite3.Error as e: + print(f"Database error in get_all_model_accuracies({dataset}): {e}") return [] def get_problems_by_model_dataset(self, model_name, dataset): - """获取模型在特定数据集上的所有问题 (优化版本)""" + """获取模型在特定数据集上的所有问题""" + if not self.conn: return [] if hasattr(model_name, 'value'): model_name = model_name.value if hasattr(dataset, 'value'): dataset = dataset.value + if not model_name or not dataset: return [] + cache_key = f"problems_{model_name}_{dataset}" - if not hasattr(self, '_cache'): self._cache = {} if cache_key in self._cache: return self._cache[cache_key] - - cursor = self.conn.cursor() + try: - # 优化查询:使用索引提示和优化JOIN策略 + cursor = self.conn.cursor() + # No INDEXED BY hint, rely on indices on responses and problems tables + # Ensure AVG returns 0 if no correct responses, not NULL -> COALESCE(AVG(r.correctness), 0.0) cursor.execute(""" - SELECT DISTINCT r.unique_id, p.problem, AVG(r.correctness) as accuracy - FROM responses r INDEXED BY idx_responses_model_dataset - JOIN problems p INDEXED BY idx_problems_unique_id ON r.unique_id = p.unique_id + SELECT r.unique_id, p.problem, COALESCE(AVG(r.correctness), 0.0) as accuracy + FROM responses r + JOIN problems p ON r.unique_id = p.unique_id WHERE r.model_name = ? AND r.dataset = ? - GROUP BY r.unique_id ORDER BY r.unique_id + GROUP BY r.unique_id, p.problem ORDER BY r.unique_id """, (model_name, dataset.lower())) - results = [(row['unique_id'], row['accuracy'] if row['accuracy'] is not None else 0.0, row['problem']) for row in cursor.fetchall()] - - # Sort by the integer part of unique_id - sorted_results = sorted(results, key=lambda x: int(re.search(r'\d+', x[0]).group(0)) if re.search(r'\d+', x[0]) else 0) - self._cache[cache_key] = sorted_results + # Fetchall directly + results = [(row['unique_id'], row['accuracy'], row['problem']) for row in cursor.fetchall()] + + # Sort in Python - pre-compile regex for slight speedup + id_extractor = re.compile(r'\d+') + def get_sort_key(problem_tuple): + match = id_extractor.search(problem_tuple[0]) # problem_tuple[0] is unique_id + # Handle cases where ID might not have numbers gracefully + return int(match.group(0)) if match else 0 + + # Sort the results list using the defined key + sorted_results = sorted(results, key=get_sort_key) + + self._cache[cache_key] = sorted_results # Cache the sorted list return sorted_results - except sqlite3.OperationalError: + except sqlite3.Error as e: + print(f"Database error in get_problems_by_model_dataset({model_name}, {dataset}): {e}") return [] + except Exception as e: # Catch potential errors during sorting + print(f"Error processing/sorting problems for {model_name}, {dataset}: {e}") + return [] + def get_problem_data(self, model_name, dataset, problem_id): - """获取问题和响应数据 (采用局部缓存策略)""" + """获取问题和响应数据 (using in-memory DB and cache)""" + if not self.conn: return None, None + # Sanitize inputs if hasattr(model_name, 'value'): model_name = model_name.value if hasattr(dataset, 'value'): dataset = dataset.value if hasattr(problem_id, 'value'): problem_id = problem_id.value - - # 问题数据缓存 - 问题数据通常不会变化,可长期缓存 + if not dataset or not problem_id: return None, None # Need dataset and problem_id + + # Problem data cache check problem_cache_key = f"problem_{problem_id}" - if problem_cache_key in self._problem_cache: - problem = self._problem_cache[problem_cache_key] - else: - if not self.conn: - return None, None - + problem = self._problem_cache.get(problem_cache_key) + + if problem is None: # Not in cache, fetch from DB try: cursor = self.conn.cursor() + # Query uses index idx_problems_unique_id automatically cursor.execute("SELECT * FROM problems WHERE unique_id = ?", (problem_id,)) - problem = cursor.fetchone() - if problem: - # 转为字典存储,避免SQLite连接依赖 - self._problem_cache[problem_cache_key] = dict(problem) - problem = self._problem_cache[problem_cache_key] - except Exception: - return None, None - - if not problem: - return None, None - - # 响应数据缓存 - 更细粒度的缓存键 - if model_name: + problem_row = cursor.fetchone() + if problem_row: + problem = dict(problem_row) # Convert to dict for caching + self._problem_cache[problem_cache_key] = problem + else: + print(f"Problem not found in DB: {problem_id}") + return None, None # Problem ID does not exist in the database + except sqlite3.Error as e: + print(f"Database error fetching problem {problem_id}: {e}") + return None, None # Return None if problem fetch fails + + # If problem is still None here, it wasn't found in DB or cache + if problem is None: + return None, None + + # --- Response data fetching --- + responses = None + if model_name: # Fetch for a specific model resp_cache_key = f"responses_{model_name}_{dataset}_{problem_id}" if resp_cache_key in self._response_cache: - return problem, self._response_cache[resp_cache_key] - - if not self.conn: - return problem, None - - # 获取特定模型的响应 - try: - cursor = self.conn.cursor() - cursor.execute(""" - SELECT * FROM responses - WHERE model_name = ? AND dataset = ? AND unique_id = ? - ORDER BY response_id - """, (model_name, dataset.lower(), problem_id)) - responses = cursor.fetchall() - - # 转换为字典列表存储 - if responses: - responses = [dict(r) for r in responses] - self._response_cache[resp_cache_key] = responses - return problem, responses - except Exception: - return problem, None - else: - # 获取所有模型对此问题的响应 + responses = self._response_cache[resp_cache_key] + else: + try: + cursor = self.conn.cursor() + # Query uses indices idx_responses_model_dataset and idx_responses_unique_id + cursor.execute(""" + SELECT * FROM responses + WHERE model_name = ? AND dataset = ? AND unique_id = ? + ORDER BY response_id + """, (model_name, dataset.lower(), problem_id)) + response_rows = cursor.fetchall() + # Convert rows to list of dicts for easier handling and caching + responses = [dict(r) for r in response_rows] if response_rows else [] + self._response_cache[resp_cache_key] = responses # Cache the result (even if empty) + except sqlite3.Error as e: + print(f"DB error fetching responses for model {model_name}, dataset {dataset}, problem {problem_id}: {e}") + responses = None # Indicate error fetching responses + else: # Fetch for all models for this problem resp_cache_key = f"all_responses_{dataset}_{problem_id}" if resp_cache_key in self._response_cache: - return problem, self._response_cache[resp_cache_key] - - if not self.conn: - return problem, None - - try: - cursor = self.conn.cursor() - cursor.execute(""" - SELECT * FROM responses - WHERE dataset = ? AND unique_id = ? - ORDER BY model_name, response_id - """, (dataset.lower(), problem_id)) - responses = cursor.fetchall() - - # 转换为字典列表存储 - if responses: - responses = [dict(r) for r in responses] - self._response_cache[resp_cache_key] = responses - return problem, responses - except Exception: - return problem, None + responses = self._response_cache[resp_cache_key] + else: + try: + cursor = self.conn.cursor() + # Query uses indices idx_responses_dataset and idx_responses_unique_id + # Need to create idx_responses_dataset if not exists, or rely on model_dataset index scan + # Let's add CREATE INDEX IF NOT EXISTS idx_responses_dataset ON responses(dataset); in _ensure_indices + # --> Added index idx_responses_model_dataset which covers (dataset, unique_id) lookups too. + cursor.execute(""" + SELECT * FROM responses + WHERE dataset = ? AND unique_id = ? + ORDER BY model_name, response_id + """, (dataset.lower(), problem_id)) + response_rows = cursor.fetchall() + responses = [dict(r) for r in response_rows] if response_rows else [] + self._response_cache[resp_cache_key] = responses # Cache the result + except sqlite3.Error as e: + print(f"DB error fetching all responses for dataset {dataset}, problem {problem_id}: {e}") + responses = None # Indicate error + + return problem, responses + def get_model_responses(self, selected_models, dataset, problem_id): - """获取多个模型对特定问题的响应(优化版本)""" + """获取多个模型对特定问题的响应 (optimized for in-memory)""" + if not self.conn: return None, {} + # Sanitize inputs if hasattr(dataset, 'value'): dataset = dataset.value if hasattr(problem_id, 'value'): problem_id = problem_id.value - if not selected_models or not dataset or not problem_id: + if not selected_models or not dataset or not problem_id: return None, {} - # 获取问题数据 - 可共享缓存 + # Get problem data first (uses cache/fast in-memory lookup) problem, _ = self.get_problem_data(None, dataset, problem_id) - if not problem: + if not problem: + print(f"Problem data not found for {problem_id} in get_model_responses") return None, {} - + model_responses_data = {} + # Get the *real* model names from the display names using the stored map + real_model_names_map = {} # Map display name -> real name + real_names_list = [] for model_display in selected_models: model_display_val = model_display.value if hasattr(model_display, 'value') else model_display - # 从显示名称中获取真实模型名称 - model = self.comp_model_display_to_real.get(model_display_val, model_display_val) - - _, responses_for_model = self.get_problem_data(model, dataset, problem_id) - if responses_for_model: - # 尝试找到正确的响应,否则使用第一个 - correct_resp = next((r for r in responses_for_model if r['correctness'] == 1), None) - model_responses_data[model_display_val] = correct_resp if correct_resp else responses_for_model[0] - else: - model_responses_data[model_display_val] = None - + # Use comp_model_display_to_real if available, otherwise model_display_to_real + real_name = self.comp_model_display_to_real.get(model_display_val) or self.model_display_to_real.get(model_display_val) + # Fallback if map lookup fails (try parsing) + if not real_name: + raw_name_part = model_display_val.split(" (")[0] + # Reverse lookup MODEL_TRANS + for db_name, display_lookup in MODEL_TRANS.items(): + if display_lookup == raw_name_part: + real_name = db_name + break + if not real_name: # If still not found, assume display name *is* real name (less accuracy suffix) + real_name = raw_name_part + print(f"Warning: Could not map display name '{model_display_val}' to real name via maps. Using inferred '{real_name}'.") + + if real_name: # Ensure we have a name to query + real_model_names_map[model_display_val] = real_name + if real_name not in real_names_list: # Avoid duplicates in IN clause + real_names_list.append(real_name) + + if not real_names_list: + print("No valid real model names found to query.") + return problem, {} # Return problem data but empty responses + + # Optimized: Fetch all relevant responses in a single query + try: + cursor = self.conn.cursor() + placeholders = ','.join('?' * len(real_names_list)) + query = f""" + SELECT * FROM responses + WHERE model_name IN ({placeholders}) AND dataset = ? AND unique_id = ? + ORDER BY model_name, correctness DESC, response_id -- Prioritize correct responses + """ + params = real_names_list + [dataset.lower(), problem_id] + cursor.execute(query, params) + all_fetched_responses = cursor.fetchall() + + # Group responses by *real* model name, keeping only the best (correct first, then by ID) + responses_by_real_model = {} + for resp_row in all_fetched_responses: + resp_dict = dict(resp_row) + model = resp_dict['model_name'] + if model not in responses_by_real_model: # Only store the first one encountered (due to ORDER BY) + responses_by_real_model[model] = resp_dict + + # Populate the result dictionary using display names as keys + for display_name, real_name in real_model_names_map.items(): + model_responses_data[display_name] = responses_by_real_model.get(real_name) # Will be None if no response found + + except sqlite3.Error as e: + print(f"Database error in bulk get_model_responses: {e}. Falling back to individual fetches.") + # Fallback to individual fetching using get_problem_data (which uses cache) + for display_name, real_name in real_model_names_map.items(): + _ , responses_for_model = self.get_problem_data(real_name, dataset, problem_id) + if responses_for_model: + # Find correct one first, otherwise take first response + correct_resp = next((r for r in responses_for_model if r.get('correctness') == 1), None) + model_responses_data[display_name] = correct_resp if correct_resp else responses_for_model[0] + else: + model_responses_data[display_name] = None + return problem, model_responses_data + def clear_cache(self, section=None): - """清除指定部分或全部缓存""" + """Clear specified cache sections.""" + print(f"Clearing cache section: {section if section else 'All'}") + cleared_something = False if section == 'main' or section is None: - self._cache = {} + if self._cache: + count = len(self._cache) + self._cache = {} + print(f"Cleared main cache ({count} items).") + cleared_something = True if section == 'problem' or section is None: - self._problem_cache = {} + if self._problem_cache: + count = len(self._problem_cache) + self._problem_cache = {} + print(f"Cleared problem cache ({count} items).") + cleared_something = True if section == 'response' or section is None: - self._response_cache = {} + if self._response_cache: + count = len(self._response_cache) + self._response_cache = {} + print(f"Cleared response cache ({count} items).") + cleared_something = True + # Clear model/dataset list caches if section == 'models' or section is None: - if hasattr(self, '_models_cache'): + if hasattr(self, '_models_cache') and self._models_cache is not None: self._models_cache = None - if hasattr(self, '_datasets_cache'): + print("Cleared models list cache.") + cleared_something = True + if hasattr(self, '_datasets_cache') and self._datasets_cache is not None: self._datasets_cache = None - + print("Cleared datasets list cache.") + cleared_something = True + + if cleared_something: + print("Running garbage collection...") + gc.collect() # Explicitly trigger garbage collection + else: + print("Cache section(s) already empty or invalid section specified.") + + def close(self): - """关闭数据库连接并释放资源""" + """Close the database connection.""" + print("Closing database connection...") if hasattr(self, 'conn') and self.conn: try: + # Optional: Backup in-memory changes to disk if needed (not in this scenario) + # Optional: Run final pragmas like optimize before closing if desired + # self.conn.execute("PRAGMA optimize;") self.conn.close() - except Exception: - pass - - # 清理所有缓存 + self.conn = None # Ensure the attribute is None after closing + print("In-memory database connection closed.") + except sqlite3.Error as e: + print(f"Error closing database connection: {e}") + else: + print("Database connection already closed or never established.") + # Clear caches on close as well self.clear_cache() + +# <<< Keep helper functions: format_latex, format_markdown_with_math, >>> +# <<< get_gradient_color, get_contrasting_text_color, format_sample_metadata, >>> +# <<< format_sample_response >>> def format_latex(text): if text is None: return "" - # Process the text for proper LaTeX rendering with KaTeX - # KaTeX requires LaTeX backslashes to be preserved - # Only replace newlines with HTML breaks text = text.replace('\n', '
') - # Wrap in a span that KaTeX can detect and render return f'{text}' def format_markdown_with_math(text): if text is None: return "" - - # Don't add HTML tags or do special processing for LaTeX - let Gradio handle it - # Just clean up basic issues that might affect rendering - - # Convert newlines for markdown text = text.replace('\r\n', '\n').replace('\r', '\n') - - # Return the cleaned text for Gradio's markdown component to render + # Ensure math delimiters are properly handled by Gradio's Markdown component + # No need for complex regex if Gradio handles $, $$, \(, \), \[, \] return text def get_gradient_color(accuracy, color_map='RdYlGn'): - if accuracy is None or not isinstance(accuracy, (int, float)): - return "#505050" # Default for missing or invalid accuracy + if accuracy is None or not isinstance(accuracy, (int, float)) or not (0.0 <= accuracy <= 1.0): + return "#808080" # Use gray for invalid/missing accuracy try: - # 使用更深的颜色映射 + # Use the specified colormap cmap = plt.colormaps.get_cmap(color_map) - rgba = cmap(float(accuracy)) - - # 确保颜色足够深以与白色文本形成对比 - r, g, b, a = rgba - # 降低颜色亮度,确保文本可读性 - r = r * 0.7 - g = g * 0.7 - b = b * 0.7 - - # 转回十六进制 - hex_color = mpl.colors.rgb2hex((r, g, b, a)) + # Apply a power transform to make colors darker/more distinct, especially greens + # Power < 1 darkens low values more, Power > 1 darkens high values more + power_adjust = 0.7 + rgba = cmap(accuracy ** power_adjust) + # Convert RGBA to Hex + hex_color = mpl.colors.rgb2hex(rgba) return hex_color - except Exception: - return "#505050" - -def get_contrasting_text_color(bg_color): - """计算最佳对比文本颜色""" - # 如果背景是十六进制格式,转换为RGB - if bg_color.startswith('#'): - r = int(bg_color[1:3], 16) - g = int(bg_color[3:5], 16) - b = int(bg_color[5:7], 16) - else: - # 未知格式默认返回黑色 - return "#000" - - # 计算YIQ亮度值 - 更精确地表示人眼对亮度的感知 - yiq = (r * 299 + g * 587 + b * 114) / 1000 - - # 黄色检测 - 黄色通常R和G高,B低 - is_yellow = r > 200 and g > 200 and b < 150 - - # 浅绿色检测 - 通常G高,R中等,B低 - is_light_green = g > 200 and r > 100 and r < 180 and b < 150 - - # 米色/浅棕色检测 - R高,G中高,B低 - is_beige = r > 220 and g > 160 and g < 220 and b < 160 - - # 强制这些特定颜色使用黑色文本 - if is_yellow or is_light_green or is_beige: - return "#000" - - # 其他颜色根据亮度决定 - return "#000" if yiq > 160 else "#fff" + except Exception as e: + print(f"Error generating gradient color for accuracy {accuracy}: {e}") + return "#808080" # Fallback gray + +def get_contrasting_text_color(bg_color_hex): + """Calculate contrasting text color (black or white) for a given hex background.""" + try: + if not bg_color_hex or not bg_color_hex.startswith('#') or len(bg_color_hex) != 7: + return "#000000" # Default to black for invalid input + + # Convert hex to RGB + r = int(bg_color_hex[1:3], 16) + g = int(bg_color_hex[3:5], 16) + b = int(bg_color_hex[5:7], 16) + + # Calculate luminance using the WCAG formula (more accurate than YIQ for accessibility) + # Normalize RGB values to 0-1 range + rgb = [val / 255.0 for val in (r, g, b)] + # Apply gamma correction approximation + rgb_corrected = [((val / 12.92) if val <= 0.03928 else ((val + 0.055) / 1.055) ** 2.4) for val in rgb] + # Calculate relative luminance + luminance = 0.2126 * rgb_corrected[0] + 0.7152 * rgb_corrected[1] + 0.0722 * rgb_corrected[2] + + # WCAG contrast ratio threshold is complex, but a luminance threshold works well for black/white text + # Threshold of 0.179 is often cited, but empirical testing might be needed + # Let's use a slightly higher threshold towards white text for better readability on mid-tones + return "#000000" if luminance > 0.22 else "#FFFFFF" # Black text on lighter backgrounds, White on darker + + except Exception as e: + print(f"Error calculating contrasting color for {bg_color_hex}: {e}") + return "#000000" # Default to black on error + def format_sample_metadata(sample, show_correctness=True): - """生成样本元数据的HTML格式显示""" - if sample is None: return "" - sample_dict = dict(sample) if hasattr(sample, 'keys') else sample if isinstance(sample, dict) else {} - if not sample_dict: return "No sample data" + """Generates HTML for sample metadata display.""" + if sample is None: return "
No sample data provided.
" + # Ensure sample is a dictionary + sample_dict = dict(sample) if hasattr(sample, 'keys') else {} + if not sample_dict: return "
Empty sample data.
" - # 提取所需信息 + # Use .get() with defaults for safety extracted = sample_dict.get('extracted', '') - correctness = sample_dict.get('correctness', 0) - correctness_label = "✓ Correct" if correctness else "✗ Incorrect" - correctness_color = "var(--color-green)" if correctness else "var(--color-red)" - - # 获取token信息 - output_tokens = sample_dict.get('output_tokens', None) - reasoning_tokens = sample_dict.get('reasoning_tokens', None) - - # 创建元数据HTML - html = f"
" - - # 创建信息行 + correctness = sample_dict.get('correctness', None) # Keep None distinct from False/0 + output_tokens = sample_dict.get('output_tokens') + reasoning_tokens = sample_dict.get('reasoning_tokens') + + # Correctness display + if correctness == 1: + correctness_label = "✓ Correct" + correctness_color = "var(--color-acc-green, #28a745)" # Use CSS variable with fallback + elif correctness == 0: + correctness_label = "✗ Incorrect" + correctness_color = "var(--color-acc-red, #dc3545)" + else: # None or other values + correctness_label = "? Unknown" + correctness_color = "var(--color-acc-grey, #6c757d)" + + # Build HTML using f-string for clarity + html = f""" +
+ " + html += f"" + + if extracted: + # Basic escaping for extracted answer to prevent HTML injection if it contains < > + extracted_safe = extracted.replace('<', '<').replace('>', '>') + # Wrap in $ for math rendering by Gradio/MathJax if appropriate, else just display + # Heuristic: Render as math if it looks like a number, fraction, or simple expression + if re.match(r'^-?\d+(\.\d+)?(/(-?\d+(\.\d+)?))?$', extracted_safe.strip()): + extracted_display = f"${extracted_safe}$" + else: + extracted_display = extracted_safe # Display as is if complex + html += f"" + + if output_tokens is not None: + html += f"" + + if reasoning_tokens is not None: + html += f"" + + html += """ +
+
+ + """ return html + def format_sample_response(sample): - """生成样本响应的Markdown格式显示""" - if sample is None: return "" - sample_dict = dict(sample) if hasattr(sample, 'keys') else sample if isinstance(sample, dict) else {} - if not sample_dict: return "No sample data" - - # 获取响应内容 + """Generates Markdown-compatible string for the sample response.""" + if sample is None: return "No response data." + sample_dict = dict(sample) if hasattr(sample, 'keys') else {} + if not sample_dict: return "Empty response data." + response = sample_dict.get('response', '') - - # 转义特殊标签以防止被解析为HTML - # 替换标签 - response = response.replace("", "<think>") - response = response.replace("", "</think>") - - # 替换其他可能的特殊标签 - response = response.replace("", "<reasoning>") - response = response.replace("", "</reasoning>") - response = response.replace("", "<answer>") - response = response.replace("", "</answer>") - + if not response: return "*(Empty Response)*" + + # Escape HTML tags that might interfere with Markdown rendering + # Focus on < > & characters. + response = response.replace('&', '&') + response = response.replace('<', '<') + response = response.replace('>', '>') + + # Gradio's Markdown component should handle LaTeX delimiters like $...$, $$...$$ + # No need for manual replacement here if delimiters are set correctly in gr.Markdown return response -def handle_sample_select(sample_number, samples_data): - # 确保从Gradio State对象中提取实际值 - if hasattr(samples_data, 'value'): - samples_list = samples_data.value - else: - samples_list = samples_data - - # 确保样本编号是整数 + +# <<< Keep handler functions: handle_sample_select, handle_first_sample, >>> +# <<< handle_comparison_problem_update, handle_problem_select >>> +# <<< Ensure they use the global db instance and handle potential None values >>> +# <<< And use the updated format_ functions >>> + +def handle_sample_select(sample_number_str, samples_data_state): + """Handles selection of a specific sample index.""" + # Extract list from state + samples_list = samples_data_state if isinstance(samples_data_state, list) else (samples_data_state.value if hasattr(samples_data_state, 'value') else []) + + if not samples_list or not isinstance(samples_list, list): + err_msg = "**Error:** No sample data available or invalid format." + return err_msg, "" # Return error for metadata, empty for response + try: - sample_idx = int(sample_number) - except ValueError: - return "Error: Sample number must be an integer.", "" - - # 确保样本数据存在且为非空列表 - if not samples_list or not isinstance(samples_list, list) or len(samples_list) == 0: - return "No sample data available. Please select a problem first.", "" - - # 检查索引是否在有效范围内,如果不在范围内,显示错误消息 - if sample_idx < 0: - err_msg = f"**Error:** Sample number {sample_idx} is out of range. Valid range is 0 to {len(samples_list) - 1}." + sample_idx = int(sample_number_str) # Convert input string to int + except (ValueError, TypeError): + err_msg = f"**Error:** Invalid sample number '{sample_number_str}'. Must be an integer." return err_msg, "" - - if sample_idx >= len(samples_list): - err_msg = f"**Error:** Sample number {sample_idx} is out of range. Valid range is 0 to {len(samples_list) - 1}." + + if not (0 <= sample_idx < len(samples_list)): + err_msg = f"**Error:** Sample index {sample_idx} out of range (0 to {len(samples_list) - 1})." return err_msg, "" - - # 获取所选样本的数据 + try: - sample = samples_list[sample_idx] - formatted_metadata = format_sample_metadata(sample) - formatted_response = format_sample_response(sample) + selected_sample = samples_list[sample_idx] + # Ensure the selected sample is a dict before formatting + if not isinstance(selected_sample, dict): + selected_sample = dict(selected_sample) if hasattr(selected_sample, 'keys') else {} + + formatted_metadata = format_sample_metadata(selected_sample) + formatted_response = format_sample_response(selected_sample) return formatted_metadata, formatted_response except Exception as e: + print(f"Error formatting sample {sample_idx}: {e}") err_msg = f"**Error displaying sample {sample_idx}:** {str(e)}" - return err_msg, "" + # Return error message in metadata, keep response empty + return f"
{err_msg}
", "" -def handle_first_sample(samples_data): - """处理并显示第一个样本(索引0)""" - # 确保从Gradio State对象中提取实际值 - if hasattr(samples_data, 'value'): - samples_list = samples_data.value +def handle_first_sample(samples_data_state): + """Handles displaying the first sample (index 0) from the state.""" + # Delegate to handle_sample_select with index 0 + # Provide default empty display if no samples + samples_list = samples_data_state if isinstance(samples_data_state, list) else (samples_data_state.value if hasattr(samples_data_state, 'value') else []) + if not samples_list: + return format_sample_metadata(None), format_sample_response(None) # Display "No data" messages else: - samples_list = samples_data - - # 检查样本数据是否存在 - if not samples_list or not isinstance(samples_list, list) or len(samples_list) == 0: - return "No sample data available. Please select the problem and dataset first.", "" - - # 直接获取第一个样本,避免错误处理逻辑 - try: - sample = samples_list[0] - formatted_metadata = format_sample_metadata(sample) - formatted_response = format_sample_response(sample) - return formatted_metadata, formatted_response - except Exception as e: - err_msg = f"**Error displaying first sample:** {str(e)}" - return err_msg, "" + # Use the main handler to display sample 0 + return handle_sample_select("0", samples_data_state) -def handle_comparison_problem_update(problem_id, dataset_state): - """处理比较页面的问题更新,仅更新问题和答案内容,不需要模型""" +def handle_comparison_problem_update(problem_id_state, dataset_state): + """Updates only the Problem/Answer display in the comparison tab.""" global db - # 确保从Gradio State对象中提取实际值 + if not db or not db.conn: return "Database not initialized.", "Error" + dataset_name = dataset_state.value if hasattr(dataset_state, 'value') else dataset_state - problem_id_value = problem_id.value if hasattr(problem_id, 'value') else problem_id - - if not problem_id_value or not dataset_name: - return "Please select a dataset and enter a problem ID.", "No answer available." - - # 处理纯数字输入,构建完整unique_id - if problem_id_value and problem_id_value.isdigit(): - # 构建格式:OlymMATH-HARD-0-EN 或类似格式 + problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state + + # Allow entering just the number part of the ID + if problem_id and problem_id.isdigit() and dataset_name: parts = dataset_name.split('-') - if len(parts) == 2: # 确保格式正确 (例如 "EN-HARD") + if len(parts) == 2: language, difficulty = parts - # 构建完整ID - problem_id_value = f"OlymMATH-{difficulty}-{problem_id_value}-{language}" - + problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" + else: + print(f"Warning: Cannot reconstruct full ID from number '{problem_id}' and dataset '{dataset_name}'") + # Proceed with the entered value, might fail if not a full ID + + if not problem_id or not dataset_name: + return "Please select dataset and enter problem ID.", "N/A" + try: - # 只获取问题数据,不获取特定模型的响应 - problem_data, _ = db.get_problem_data(None, dataset_name, problem_id_value) - + # Fetch only problem data, no responses needed here + problem_data, _ = db.get_problem_data(None, dataset_name, problem_id) + if not problem_data: - return f"Problem not found: {problem_id_value}. Please check the ID and try again.", "No answer available." - - problem_dict = dict(problem_data) - # Use format_markdown_with_math for proper rendering - problem_content = format_markdown_with_math(problem_dict.get('problem', '')) - - # 将答案中的双美元符号替换为单美元符号 - answer_text = problem_dict.get('answer', '') - # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行 + # Check if just the number was entered and failed reconstruction + if problem_id.isdigit(): + return f"Problem number {problem_id} not found for {dataset_name}. Enter full ID or check dataset.", "N/A" + else: + return f"Problem ID '{problem_id}' not found for {dataset_name}.", "N/A" + + # Ensure problem_data is a dictionary + problem_dict = dict(problem_data) if problem_data else {} + + # Format problem statement for Markdown rendering + problem_content = format_markdown_with_math(problem_dict.get('problem', '*(Problem text not available)*')) + + # Format answer, handling LaTeX and ensuring $...$ + answer_text = problem_dict.get('answer', '*(Answer not available)*') + # Simplify $$...$$ to $...$ answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) - - # 检查答案是否已经包含美元符号,如果没有则添加 - if '$' not in answer_text and answer_text.strip(): + # Add $...$ if missing (basic check) + if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*('): answer_text = f"${answer_text}$" - answer_content = format_markdown_with_math(answer_text) - + return problem_content, answer_content except Exception as e: - return f"Error: {str(e)}", "No answer available." + print(f"Error in handle_comparison_problem_update for {problem_id}, {dataset_name}: {e}") + return f"Error fetching problem details: {e}", "Error" + -def handle_problem_select(problem_id_from_js, current_model_state, current_dataset_state, mode='default'): +def handle_problem_select(problem_id_state, current_model_state, current_dataset_state, mode='default'): + """Handles problem selection, fetching details, responses, and generating sample grid.""" global db - # Ensure we're using the actual values from Gradio State objects + if not db or not db.conn: + return "DB Error.", "DB Error.", gr.HTML("

Database Connection Error

"), gr.State([]) + + # --- Get values from Gradio State objects --- model_name = current_model_state.value if hasattr(current_model_state, 'value') else current_model_state dataset_name = current_dataset_state.value if hasattr(current_dataset_state, 'value') else current_dataset_state - problem_id = problem_id_from_js.value if hasattr(problem_id_from_js, 'value') else problem_id_from_js + problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state + + # --- Input Validation --- + if not dataset_name: + return "Please select a dataset first.", "N/A", gr.HTML(""), gr.State([]) + if not problem_id: + return "Problem ID is missing.", "N/A", gr.HTML(""), gr.State([]) + # Model name is required unless in comparison mode initial load + if not model_name and mode != 'comparison_initial_problem_load': + # In single mode, model is always required here + if mode == 'default': + return "Please select a model first.", "N/A", gr.HTML(""), gr.State([]) + # In comparison mode, if model state is None, means user selected problem before model + # We can still show problem/answer, but no samples for that side yet. + # Let's handle this by fetching problem/answer but returning empty samples for the specific call. + pass # Allow proceeding without model name in comparison mode to fetch problem/answer - # 处理纯数字输入,构建完整unique_id - if problem_id and problem_id.isdigit(): - # 构建格式:OlymMATH-HARD-0-EN 或类似格式 - # 从dataset_name (例如 "EN-HARD") 解析语言和难度 + # --- Reconstruct full problem ID if only number is entered --- + original_problem_id = problem_id # Keep original for messages + if problem_id.isdigit(): parts = dataset_name.split('-') - if len(parts) == 2: # 确保格式正确 (例如 "EN-HARD") + if len(parts) == 2: language, difficulty = parts - # 构建完整ID problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" + print(f"Reconstructed problem ID: {problem_id}") + else: + # Cannot reconstruct, use the entered value but it might fail + print(f"Warning: Could not reconstruct full ID from number '{problem_id}' and dataset '{dataset_name}'") - if not problem_id or not dataset_name: - error_message = f"Missing data: problem_id='{problem_id}', dataset='{dataset_name}'" - return "Please fill in all the fields.", "No answer available.", "", gr.State([]) - - # For comparison mode, we might not have a model selected yet - if not model_name and mode == 'comparison': - try: - # Just get the problem data without model-specific responses - problem_data, _ = db.get_problem_data(None, dataset_name, problem_id) - - if not problem_data: - error_message = f"Problem data not found: problem_id='{problem_id}', dataset='{dataset_name}'" - return f"Problem not found: {problem_id}. Please check the ID and try again.", "No answer available.", "", gr.State([]) - - problem_dict = dict(problem_data) - # Process problem and answer text for Markdown rendering - problem_content = format_markdown_with_math(problem_dict.get('problem', '')) - - # 将答案中的双美元符号替换为单美元符号 - answer_text = problem_dict.get('answer', '') - # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行 - answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) - - # 检查答案是否已经包含美元符号,如果没有则添加 - if '$' not in answer_text and answer_text.strip(): - answer_text = f"${answer_text}$" - - answer_content = format_markdown_with_math(answer_text) - - # For comparison without model, we don't have samples to display - return problem_content, answer_content, "", gr.State([]) - except Exception as e: - error_message = f"Database error: {str(e)}" - return f"Database error occurred. Please try again.", "No answer available.", "", gr.State([]) - - # The regular flow for model-specific data - if not model_name: - error_message = f"Missing data: model='{model_name}'" - return "Please fill in all the fields.", "No answer available.", "", gr.State([]) - # The problem_id from JS should be the full unique_id. No reconstruction needed normally. + # --- Fetch Data --- try: + # Fetch problem details and responses for the specific model (if provided) problem_data, responses_data = db.get_problem_data(model_name, dataset_name, problem_id) - + if not problem_data: - error_message = f"Problem data not found: problem_id='{problem_id}', model='{model_name}', dataset='{dataset_name}'" - return f"Problem not found: {problem_id}. Please check the ID and try again.", "No answer available.", "", gr.State([]) - except Exception as e: - error_message = f"Database error: {str(e)}" - return f"Database error occurred. Please try again.", "No answer available.", "", gr.State([]) - - problem_dict = dict(problem_data) - problem_display_num = re.search(r'\d+', problem_id).group(0) if re.search(r'\d+', problem_id) else problem_id - - # Process problem and answer text for Markdown rendering - problem_content = format_markdown_with_math(problem_dict.get('problem', '')) - - # 将答案中的双美元符号替换为单美元符号 - answer_text = problem_dict.get('answer', '') - # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行 - answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) - - # 检查答案是否已经包含美元符号,如果没有则添加 - if '$' not in answer_text and answer_text.strip(): - answer_text = f"${answer_text}$" - - answer_content = format_markdown_with_math(answer_text) - - # Rest of the function remains the same - if not responses_data: - samples_grid_html = "
No samples available for this problem.
" - # 返回空的样本数据状态 - return problem_content, answer_content, samples_grid_html, gr.State([]) - else: - # 准备所有样本数据,用于后续处理 - samples_data = [] - for i, resp in enumerate(responses_data): - resp_dict = dict(resp) - samples_data.append(resp_dict) - - # 计算正确率 - correct_count = sum(1 for r in samples_data if r['correctness']) - total_samples = len(samples_data) - accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 - - # 创建样本网格显示 (最多显示 64 个样本) - displayed_samples = samples_data[:64] - actual_display_count = len(displayed_samples) - - # 根据模式确定每行的样本数 - samples_per_row = 16 if mode == 'comparison' else 32 - - # 第一行: 样本 0-samples_per_row - samples_grid_html = f'
' - - for i, resp in enumerate(displayed_samples[:samples_per_row]): - correctness = resp.get('correctness', 0) - bg_color = get_gradient_color(1.0 if correctness else 0.0) - - # 移除点击事件和data属性,只保留纯显示 - samples_grid_html += f""" -
- {i} -
- """ - - # 如果少于samples_per_row个样本,填充剩余空间 - for i in range(min(actual_display_count, samples_per_row), samples_per_row): - samples_grid_html += f""" -
- """ - - samples_grid_html += '
' - - # 如果有更多样本,显示第二行 - if actual_display_count > samples_per_row: - row_samples = displayed_samples[samples_per_row:2*samples_per_row] - samples_grid_html += f'
' - - for i, resp in enumerate(row_samples): - actual_idx = i + samples_per_row - correctness = resp.get('correctness', 0) - bg_color = get_gradient_color(1.0 if correctness else 0.0) - - samples_grid_html += f""" -
- {actual_idx} -
- """ - - # 填充剩余空间 - for i in range(len(row_samples), samples_per_row): - samples_grid_html += f""" -
- """ - - samples_grid_html += '
' - - # 第三行和第四行 - 允许所有模式显示完整的64个样本 - if actual_display_count > 2*samples_per_row: - # 第三行 - row_samples = displayed_samples[2*samples_per_row:3*samples_per_row] - if row_samples: - samples_grid_html += f'
' - + return f"Problem ID '{original_problem_id}' not found for dataset '{dataset_name}'.", "N/A", gr.HTML(f"

Problem ID '{original_problem_id}' not found.

"), gr.State([]) + + # Ensure problem_data is a dict + problem_dict = dict(problem_data) if problem_data else {} + + # --- Format Problem and Answer --- + problem_content = format_markdown_with_math(problem_dict.get('problem', '*(Problem text not available)*')) + answer_text = problem_dict.get('answer', '*(Answer not available)*') + answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) + if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*('): + answer_text = f"${answer_text}$" + answer_content = format_markdown_with_math(answer_text) + + # --- Handle Responses and Generate Sample Grid --- + if responses_data is None: # Indicates a DB error occurred fetching responses + samples_grid_html = gr.HTML("

Error fetching model responses.

") + samples_data_for_state = gr.State([]) # Empty state on error + elif not responses_data: # Empty list means no responses found + samples_grid_html = gr.HTML("

No responses found for this model on this problem.

") + samples_data_for_state = gr.State([]) # Empty state + else: + # responses_data should already be a list of dicts from get_problem_data + samples_data = responses_data # Use directly + samples_data_for_state = gr.State(samples_data) # Store the list in state + + correct_count = sum(1 for r in samples_data if r.get('correctness') == 1) + total_samples = len(samples_data) + accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 + + # --- Generate Sample Grid HTML (with onclick) --- + displayed_samples = samples_data[:64] # Limit display + actual_display_count = len(displayed_samples) + # Determine grid columns based on mode + samples_per_row = 16 if mode.startswith('comparison') else 32 + num_rows = math.ceil(actual_display_count / samples_per_row) + grid_html_content = "" + + # Determine the correct JS function call based on mode + js_mode = "'default'" # Default for single model tab + if mode == 'comparison_left': js_mode = "'comparison_left'" + elif mode == 'comparison_right': js_mode = "'comparison_right'" + + for row_idx in range(num_rows): + grid_html_content += f'
' + start_idx = row_idx * samples_per_row + end_idx = min(start_idx + samples_per_row, actual_display_count) + row_samples = displayed_samples[start_idx:end_idx] + for i, resp in enumerate(row_samples): - actual_idx = i + 2*samples_per_row - correctness = resp.get('correctness', 0) - bg_color = get_gradient_color(1.0 if correctness else 0.0) - - samples_grid_html += f""" -
- {actual_idx} -
- """ - - # 填充剩余空间 - for i in range(len(row_samples), samples_per_row): - samples_grid_html += f""" -
+ actual_idx = start_idx + i + correctness = resp.get('correctness', None) # Handle None correctness + # Get background color based on correctness (1=green, 0=red, None=grey) + if correctness == 1: bg_color = get_gradient_color(1.0) + elif correctness == 0: bg_color = get_gradient_color(0.0) + else: bg_color = "#808080" # Grey for unknown + text_color = get_contrasting_text_color(bg_color) + + # Add onclick event to call the JavaScript handler + grid_html_content += f""" + """ - - samples_grid_html += '
' - - # 第四行 - if actual_display_count > 3*samples_per_row: - row_samples = displayed_samples[3*samples_per_row:4*samples_per_row] - if row_samples: - samples_grid_html += f'
' - - for i, resp in enumerate(row_samples): - actual_idx = i + 3*samples_per_row - correctness = resp.get('correctness', 0) - bg_color = get_gradient_color(1.0 if correctness else 0.0) - - samples_grid_html += f""" -
- {actual_idx} -
- """ - - # 填充剩余空间 - for i in range(len(row_samples), samples_per_row): - samples_grid_html += f""" -
- """ - - samples_grid_html += '
' - - # 组合HTML内容 - final_html = f""" -
-

Samples {actual_display_count} - Model Accuracy: {correct_count}/{actual_display_count} = {accuracy_on_problem:.1%}

- {samples_grid_html} -
- """ - - # 获取第一个样本作为初始样本 - if samples_data: - # 这样样本会在选择问题后立即显示 - return problem_content, answer_content, final_html, gr.State(samples_data) - else: - return problem_content, answer_content, final_html, gr.State([]) + # Fill remaining columns in the row with gray placeholders if needed + for _ in range(len(row_samples), samples_per_row): + grid_html_content += "
" + grid_html_content += '
' + # Add filler rows if less than the max were generated (e.g., 4 for comparison, 2 for single) + max_rows = 4 if mode.startswith('comparison') else 2 + for _ in range(num_rows, max_rows): + grid_html_content += f'
' + for _ in range(samples_per_row): + grid_html_content += "
" + grid_html_content += '
' + + + # Assemble the final HTML for the samples section + samples_grid_html = gr.HTML(f""" +
+

Samples ({actual_display_count} shown) – Model Accuracy: {correct_count}/{total_samples} ({accuracy_on_problem:.1%})

+
{grid_html_content}
+
+ + """) + + + # Return all results + return problem_content, answer_content, samples_grid_html, samples_data_for_state + + except Exception as e: + print(f"Unexpected error in handle_problem_select for {problem_id}, model {model_name}, dataset {dataset_name}: {e}") + # Log the full traceback for debugging if possible + import traceback + traceback.print_exc() + error_msg = f"**Internal Error processing problem {original_problem_id}:** {str(e)}" + return error_msg, "Error", gr.HTML(f"

{error_msg}

"), gr.State([]) + + +# <<< Keep create_problem_grid_html, modified to use onclick >>> def create_problem_grid_html(problems, mode='default'): - """Create HTML for problem grid buttons. The JS function will be defined globally.""" + """Create HTML for problem grid buttons with onclick handlers.""" if not problems: - return "
No problems found for this model/dataset. Please select a model and dataset.
" + return "
No problems found for this model/dataset.
" html_buttons = "" + # Sort problems based on the numeric part of the ID try: - sorted_problems = sorted( - [(str(p[0]), float(p[1]) if p[1] is not None else 0.0, p[2]) for p in problems], - key=lambda x: int(re.search(r'\d+', x[0]).group(0)) if re.search(r'\d+', x[0]) else 0 - ) + id_extractor = re.compile(r'\d+') + def get_sort_key(p): + # p is expected to be a tuple/list like (unique_id, accuracy, problem_text) + match = id_extractor.search(str(p[0])) + return int(match.group(0)) if match else 0 + + # Ensure problems is a list of tuples/lists before sorting + if isinstance(problems, list) and all(isinstance(p, (list, tuple)) and len(p) >= 2 for p in problems): + # Convert accuracy to float, handle None or potential errors + processed_problems = [] + for p in problems: + try: + pid = str(p[0]) + acc = float(p[1]) if p[1] is not None else 0.0 + processed_problems.append((pid, acc)) + except (IndexError, TypeError, ValueError) as conv_err: + print(f"Skipping problem entry due to conversion error: {p} - {conv_err}") + sorted_problems = sorted(processed_problems, key=get_sort_key) + else: + print(f"Warning: Problem data format unexpected in create_problem_grid_html (mode={mode}). Skipping sort.") + # Attempt to process anyway if possible, otherwise return error message + if isinstance(problems, list): + processed_problems = [] + for p in problems: + try: + pid = str(p[0]) + # Try to get accuracy, default to 0.0 if fails + acc = 0.0 + if len(p) > 1: + try: + acc = float(p[1]) if p[1] is not None else 0.0 + except (TypeError, ValueError): pass # Keep acc as 0.0 + processed_problems.append((pid, acc)) + except (IndexError, TypeError, ValueError): + print(f"Skipping invalid problem entry: {p}") + sorted_problems = processed_problems # No sort if format is wrong initially + else: + return "
Error: Invalid problem data format.
" + except Exception as e: - return f"
Error displaying problems. Check logs. {e}
" + print(f"Error sorting/processing problems in create_problem_grid_html (mode={mode}): {e}") + return f"
Error displaying problems: {e}
" + + id_extractor = re.compile(r'\d+') # Re-use compiled regex + # Determine JS mode argument based on Python mode + js_mode_arg = "'comparison'" if mode == 'comparison' else "'default'" - for pid, accuracy, _ in sorted_problems: - match = re.search(r'\d+', pid) - num_display = match.group(0) if match else pid + for pid, accuracy in sorted_problems: + match = id_extractor.search(pid) + num_display = match.group(0) if match else pid[:6] # Fallback display if no number found acc_pct = int(accuracy * 100) - - # 获取背景颜色 bg_color = get_gradient_color(accuracy) - # 统一使用白色文本,添加!important确保不被覆盖 - text_color = "#ffffff" + text_color = get_contrasting_text_color(bg_color) # Calculate contrast + # Add onclick event calling the global JavaScript function + # Escape pid if it could contain quotes or special chars (unlikely but safer) + escaped_pid = pid.replace("'", "\\'") html_buttons += f""" -
-
{num_display}
-
{acc_pct}%
-
+ """ - - # 添加自定义样式强制文本颜色为白色 - custom_style = "" - # 根据模式设置每行显示的列数 + grid_cols = 20 if mode == 'comparison' else 10 - grid_html = f"{custom_style}
{html_buttons}
" + # Add CSS for the button layout within the grid item + grid_html = f""" +
+ {html_buttons} +
+ + """ return grid_html -def create_ui(db_path): + +# <<< Keep create_ui, modified for hidden state and JS >>> +def create_ui(db_instance): + """Creates the Gradio UI application.""" global db - db = ModelDatabase(db_path) - - AVAILABLE_DATASETS = db.get_available_datasets() + db = db_instance # Use the passed-in, initialized DB instance + + if not db or not db.conn: + print("Error: Database instance is not valid in create_ui.") + with gr.Blocks() as error_demo: + gr.Markdown("# Error: Database Initialization Failed\nThe application cannot start. Check logs for details.") + return error_demo + + AVAILABLE_DATASETS = db.get_available_datasets() # Fetch datasets from the initialized DB if not AVAILABLE_DATASETS: + print("Warning: No datasets found in the database. Using fallback list.") AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] # Fallback - - # Add MathJax support to the CSS + + # --- CSS --- (Add styles for metadata, sample grid buttons if not already present) custom_css = """ - .padding.svelte-phx28p { padding: unset !important; } - body, .gradio-container { font-family: sans-serif; font-size: 0.95em; line-height: 1.6; } - .sample-btn { transition: all 0.15s ease-in-out; } - .sample-btn:hover { transform: translateY(-1px); box-shadow: 0 2px 5px rgba(0,0,0,0.1); } - .problem-grid-container { overflow-y: auto; } - .math-content { overflow-x: auto; padding: 5px; } - .sample-response { overflow-y: clip !important; max-height: none !important; height: auto !important; } - h1, h2, h3, h4, h5 { margin-top: 0.8em; margin-bottom: 0.4em; color: var(--color-text); } - .gradio-tabs > div[role='tablist'] button { font-size: 0.9em; padding: 8px 12px; } - .gr-dropdown select { font-size: 0.9em; } - .gr-radio label span { font-size: 0.9em; } - .gr-checkboxgroup label span { font-size: 0.9em; } - .gr-button { font-size: 0.9em; padding: 8px 12px; } - .gr-dataframe table { font-size:0.85em; } - .gr-markdown { font-size: 1em; } - - /* 适应深色模式的样式 */ - .dark-mode-compatible { - background-color: var(--background-fill-primary); - color: var(--color-text); - border-color: var(--border-color-primary); - } - .dark-mode-bg-secondary { - background-color: var(--background-fill-secondary); - } + body, .gradio-container { font-family: sans-serif; font-size: 0.98em; line-height: 1.6; } /* Slightly larger base font */ + .gradio-tabs > div[role='tablist'] button { font-size: 0.95em; padding: 8px 14px; } + .gr-dropdown select { font-size: 0.95em; } + .gr-radio label span { font-size: 0.95em; } + .gr-checkboxgroup label span { font-size: 0.95em; } + .gr-button { font-size: 0.95em; padding: 8px 14px; } + .gr-dataframe table { font-size:0.9em; } /* Slightly larger dataframe font */ + .gr-markdown { font-size: 1.0em; line-height: 1.6; } /* Ensure markdown line height is good */ + /* Dark mode adjustments */ + .dark .dark-mode-compatible { background-color: var(--neutral-800); color: var(--neutral-100); border-color: var(--neutral-700); } + .dark .dark-mode-bg-secondary { background-color: var(--neutral-900); } + .dark .dataframe-container table { color: var(--neutral-100); border-color: var(--neutral-600); } + .dark .dataframe-container th { background-color: var(--neutral-700); } + /* Math rendering adjustments */ + .math-inline, .math-display { font-size: 105%; } /* Slightly smaller math */ + /* Custom classes for layout */ + .compact-row { margin-bottom: 5px !important; padding: 0 !important; } + /* Ensure hidden inputs don't take up space */ + .hidden-state > .svelte-kit-component { display: none !important; } + """ + # --- JavaScript for Button Clicks --- + javascript = """ + function handleProblemClick(problemId, mode) { + console.log(`Problem clicked: ${problemId}, Mode: ${mode}`); + // Determine the target hidden input element ID based on the mode + let targetInputId = (mode === 'comparison') ? 'comp-problem-id-state' : 'problem-id-state'; + let targetInputElement = document.getElementById(targetInputId); - /* DataTable深色模式样式 */ - .dataframe-container { - //padding: 12px; - //border-radius: 8px; - //margin-top: 10px; + if (targetInputElement) { + // Gradio Textbox often wraps the actual input (textarea) + let actualInput = targetInputElement.querySelector('textarea'); + if (actualInput) { + console.log(`Updating ${targetInputId} value to: ${problemId}`); + // Set the value of the hidden input + actualInput.value = problemId; + // Dispatch 'input' and 'change' events to notify Gradio backend + actualInput.dispatchEvent(new Event('input', { bubbles: true })); + actualInput.dispatchEvent(new Event('change', { bubbles: true })); + console.log(`Events dispatched for ${targetInputId}`); + } else { + console.error(`Could not find textarea within #${targetInputId}`); + } + } else { + console.error(`Target input element #${targetInputId} not found.`); + } } - /* MathJax Styles for Gradio's Built-in LaTeX */ - .math-inline, .math-display { - font-size: 110%; - } - .math-container p { - margin: 0.5em 0; - } + function handleSampleClick(sampleIndex, mode) { + console.log(`Sample clicked: ${sampleIndex}, Mode: ${mode}`); + let targetInputId; + // Determine target based on which sample grid was clicked + if (mode === 'comparison_left') { + targetInputId = 'comp-sample-index-state-left'; + } else if (mode === 'comparison_right') { + targetInputId = 'comp-sample-index-state-right'; + } else { // Default single model mode + targetInputId = 'sample-index-state'; + } - /* Markdown content styles */ - .gr-markdown strong { - font-weight: bold; - } - .gr-markdown em { - font-style: italic; - } - .gr-markdown ul, .gr-markdown ol { - padding-left: 2em; - margin: 0.5em 0; - } - .gr-markdown blockquote { - border-left: 3px solid #ccc; - margin: 0.5em 0; - padding-left: 1em; - color: #666; - } - .gr-markdown pre, .gr-markdown code { - background-color: rgba(0,0,0,0.05); - padding: 2px 4px; - border-radius: 3px; - font-family: monospace; - } - .gr-markdown table { - border-collapse: collapse; - margin: 0.5em 0; - } - .gr-markdown th, .gr-markdown td { - border: 1px solid #ddd; - padding: 4px 8px; + let targetInputElement = document.getElementById(targetInputId); + if (targetInputElement) { + let actualInput = targetInputElement.querySelector('textarea'); // Assuming Textbox uses textarea + if (actualInput) { + console.log(`Updating ${targetInputId} value to: ${sampleIndex}`); + actualInput.value = sampleIndex; // Set the index + // Dispatch events + actualInput.dispatchEvent(new Event('input', { bubbles: true })); + actualInput.dispatchEvent(new Event('change', { bubbles: true })); + console.log(`Events dispatched for ${targetInputId}`); + } else { + console.error(`Could not find textarea within #${targetInputId}`); + } + } else { + console.error(`Target input element #${targetInputId} not found.`); + } } + + // Make functions globally available for onclick handlers + window.handleProblemClick = handleProblemClick; + window.handleSampleClick = handleSampleClick; """ - - with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo: - # Remove KaTeX loading script since we're using Gradio's native Markdown with LaTeX + with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky), head=f"", title="Model Performance Analyzer") as demo: + + # --- Define Hidden State Components --- + # These hold the actual state triggered by UI interactions (clicks, dropdowns) + # Single Model Tab States current_dataset_state = gr.State(value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else "") - current_model_state = gr.State(value=None) - comparison_data_state = gr.State(value={}) - # 添加当前样本状态 - current_sample_state = gr.State(value="0") - # 添加当前问题的样本数据状态 - current_samples_data_state = gr.State(value=[]) - - # 为Comparison标签页添加独立状态 + current_model_state = gr.State(value=None) # Holds the *real* model name + problem_id_state = gr.Textbox(elem_id="problem-id-state", visible=False, label="Selected Problem ID") # Updated by JS problem click + sample_index_state = gr.Textbox(value="0", elem_id="sample-index-state", visible=False, label="Selected Sample Index") # Updated by JS sample click + current_samples_data_state = gr.State(value=[]) # Holds the list of dicts for the current problem's samples + + # Comparison Tab States comp_dataset_state = gr.State(value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else "") + comp_problem_id_state = gr.Textbox(elem_id="comp-problem-id-state", visible=False, label="Selected Comparison Problem ID") # Updated by JS problem click + # Left Side comp_model_state_left = gr.State(value=None) - comp_sample_state_left = gr.State(value="0") + comp_sample_index_state_left = gr.Textbox(value="0", elem_id="comp-sample-index-state-left", visible=False, label="Selected Left Sample Index") comp_samples_data_state_left = gr.State(value=[]) + # Right Side comp_model_state_right = gr.State(value=None) - comp_sample_state_right = gr.State(value="0") + comp_sample_index_state_right = gr.Textbox(value="0", elem_id="comp-sample-index-state-right", visible=False, label="Selected Right Sample Index") comp_samples_data_state_right = gr.State(value=[]) - - # 创建占位符State组件替代None + # Dummy state for outputs we need to provide but don't use dummy_state = gr.State(value=None) + # --- UI Layout --- with gr.Tabs(): + # == Single Model Analysis Tab == with gr.TabItem("Single Model Analysis"): - with gr.Row(variant='compact'): - with gr.Column(scale=1, min_width=280): - + with gr.Row(): + # Column 1: Controls & Stats + with gr.Column(scale=1, min_width=300): + gr.Markdown("### Controls") dataset_radio_single = gr.Radio( - choices=AVAILABLE_DATASETS, - value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else None, - label="Select Dataset", - interactive=True + choices=AVAILABLE_DATASETS, value=current_dataset_state.value, # Use initial state value + label="Select Dataset", interactive=True ) - model_dropdown = gr.Dropdown( - choices=[], # Populated by callback - label="Select Model", - interactive=True - ) - - problem_state_input = gr.Textbox( - value="", - elem_id="problem-state-input", - visible=True, - label="Enter Problem ID (0 - 99, acc. below)", - container=True, - interactive=True, - every=0.5 + choices=[], label="Select Model (Name + Acc%)", interactive=True ) - - #gr.Markdown("#### Problem Grid") - problem_grid_html_output = gr.HTML( - value="
Select model and dataset to see problems.
", - elem_id="problem-grid-container" + # Problem ID Input (for manual entry or display - could be combined with state?) + # Let's keep it visible for manual entry as an alternative to clicking grid + problem_id_input_display = gr.Textbox( + label="Enter Problem ID (e.g., 42) or click grid", + placeholder="Enter number or full ID", + interactive=True, + # elem_id="problem_id_input_display" # Use different ID than state ) - gr.Markdown("#### Model Statistics") - model_stats_df = gr.DataFrame( - headers=["Metric", "Value"], - wrap=True, - elem_classes="dataframe-container dark-mode-compatible dark-mode-bg-secondary" - ) - - with gr.Column(scale=3, min_width=400): + + gr.Markdown("### Problem Grid (Click to Select)") + problem_grid_html_output = gr.HTML("Select model and dataset.") # Populated by callback + + gr.Markdown("### Model Statistics") + model_stats_df = gr.DataFrame(headers=["Metric", "Value"], wrap=True) + + # Column 2: Problem Details & Samples + with gr.Column(scale=3, min_width=500): + gr.Markdown("### Problem Details") with gr.Tabs(): with gr.TabItem("Problem Statement"): - problem_markdown_output = gr.Markdown( - "Please fill in all the fields.", - latex_delimiters=[ - {"left": "$", "right": "$", "display": False}, - {"left": "$$", "right": "$$", "display": True}, - {"left": "\\(", "right": "\\)", "display": False}, - {"left": "\\[", "right": "\\]", "display": True} - ] - ) + problem_markdown_output = gr.Markdown("Select a problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) with gr.TabItem("Reference Answer"): - answer_markdown_output = gr.Markdown( - "No answer available.", - latex_delimiters=[ - {"left": "$", "right": "$", "display": False}, - {"left": "$$", "right": "$$", "display": True}, - {"left": "\\(", "right": "\\)", "display": False}, - {"left": "\\[", "right": "\\]", "display": True} - ] - ) - - # 样本网格 + answer_markdown_output = gr.Markdown("Select a problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) + + gr.Markdown("### Model Responses (Click Grid Below to Select)") + # Sample Grid HTML (generated by handle_problem_select) samples_grid_output = gr.HTML("") - - # 在样本网格下方添加样本选择输入框 - with gr.Row(): - # 样本选择输入框 - sample_number_input = gr.Textbox( - value="0", - elem_id="sample-number-input", - visible=True, - label="Enter Sample Number (0 - 63)", - container=True, - interactive=True, - every=0.5 - ) - - # 样本内容显示区域 - 使用HTML和Markdown组件分别显示元数据和响应内容 - sample_metadata_output = gr.HTML( - value="
Select a problem first to view samples.
", - elem_classes="sample-metadata dark-mode-bg-secondary", - elem_id="sample-metadata-area" - ) - - sample_response_output = gr.Markdown( - value="Select a problem first to view samples.", - elem_classes="sample-response dark-mode-bg-secondary", - elem_id="sample-response-area", - latex_delimiters=[ - {"left": "$", "right": "$", "display": False}, - {"left": "$$", "right": "$$", "display": True}, - {"left": "\\(", "right": "\\)", "display": False}, - {"left": "\\[", "right": "\\]", "display": True} - ] - ) - + + # Sample Display Area + gr.Markdown("#### Selected Sample Details") + sample_metadata_output = gr.HTML("Select a sample from the grid above.") + sample_response_output = gr.Markdown("*(Response will appear here)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) + + + # == Model Comparison Tab == with gr.TabItem("Model Comparison"): - # 共享部分 + # Row 1: Shared Controls with gr.Row(variant='compact'): - comp_dataset_radio = gr.Radio( - choices=AVAILABLE_DATASETS, - value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else None, - label="Select Dataset", - interactive=True - ) - - comp_problem_state_input = gr.Textbox( - value="", - elem_id="comp-problem-state-input", - visible=True, - label="Enter Problem ID (0 - 99, acc. below)", - container=True, - interactive=True, - every=0.5 - ) - - # 移动的共享问题和答案显示到这里 + comp_dataset_radio = gr.Radio( + choices=AVAILABLE_DATASETS, value=comp_dataset_state.value, + label="Select Dataset", interactive=True, scale=1 + ) + # Visible Problem ID input for comparison tab + comp_problem_id_input_display = gr.Textbox( + label="Enter Problem ID (e.g., 42) or click grid", + placeholder="Enter number or full ID", + interactive=True, scale=1 + ) + + # Row 2: Shared Problem/Answer Display with gr.Row(variant='compact'): - with gr.Column(scale=1): + with gr.Column(scale=1): + gr.Markdown("### Problem Details") with gr.Tabs(): with gr.TabItem("Problem Statement"): - comp_problem_markdown_output = gr.Markdown( - "Please select models and problem.", - latex_delimiters=[ - {"left": "$", "right": "$", "display": False}, - {"left": "$$", "right": "$$", "display": True}, - {"left": "\\(", "right": "\\)", "display": False}, - {"left": "\\[", "right": "\\]", "display": True} - ] - ) + comp_problem_markdown_output = gr.Markdown("Select models and problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) with gr.TabItem("Reference Answer"): - comp_answer_markdown_output = gr.Markdown( - "No answer available.", - latex_delimiters=[ - {"left": "$", "right": "$", "display": False}, - {"left": "$$", "right": "$$", "display": True}, - {"left": "\\(", "right": "\\)", "display": False}, - {"left": "\\[", "right": "\\]", "display": True} - ] - ) - - # 左右两部分模型比较 - with gr.Row(variant='compact'): - # 左侧模型 - with gr.Column(scale=1): - comp_model_dropdown_left = gr.Dropdown( - choices=[], # Populated by callback - label="Select Model 1", - interactive=True - ) - - gr.Markdown("#### Problem Grid") - comp_problem_grid_html_output_left = gr.HTML( - value="
Select model and dataset to see problems.
", - elem_id="comp-problem-grid-container-left" - ) - - # 样本网格和选择器 + comp_answer_markdown_output = gr.Markdown("Select problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) + + # Row 3: Left vs Right Columns + with gr.Row(variant='compact', equal_height=False): # Allow columns to size independently + # Left Column + with gr.Column(scale=1, min_width=400): + gr.Markdown("### Model 1") + comp_model_dropdown_left = gr.Dropdown(choices=[], label="Select Model 1", interactive=True) + gr.Markdown("#### Problem Grid (Model 1 - Click to Select Problem)") + comp_problem_grid_html_output_left = gr.HTML("Select model 1.") + gr.Markdown("#### Model 1 Responses (Click Grid Below)") comp_samples_grid_output_left = gr.HTML("") - - with gr.Row(): - comp_sample_number_input_left = gr.Textbox( - value="0", - elem_id="comp-sample-number-input-left", - visible=True, - label="Enter Sample Number (0 - 63)", - container=True, - interactive=True, - every=0.5 - ) - - # 样本内容显示区域 - 使用HTML和Markdown组件分别显示元数据和响应内容 - comp_sample_metadata_output_left = gr.HTML( - value="
Select a problem first to view samples.
", - elem_classes="sample-metadata dark-mode-bg-secondary", - elem_id="comp-sample-metadata-area-left" - ) - - comp_sample_response_output_left = gr.Markdown( - value="Select a problem first to view samples.", - elem_classes="sample-response dark-mode-bg-secondary", - elem_id="comp-sample-response-area-left", - latex_delimiters=[ - {"left": "$", "right": "$", "display": False}, - {"left": "$$", "right": "$$", "display": True}, - {"left": "\\(", "right": "\\)", "display": False}, - {"left": "\\[", "right": "\\]", "display": True} - ] - ) - - # 右侧模型 - with gr.Column(scale=1): - comp_model_dropdown_right = gr.Dropdown( - choices=[], # Populated by callback - label="Select Model 2", - interactive=True - ) - - gr.Markdown("#### Problem Grid") - comp_problem_grid_html_output_right = gr.HTML( - value="
Select model and dataset to see problems.
", - elem_id="comp-problem-grid-container-right" - ) - - # 样本网格和选择器 + gr.Markdown("##### Selected Sample (Model 1)") + comp_sample_metadata_output_left = gr.HTML("Select sample.") + comp_sample_response_output_left = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) + + # Right Column + with gr.Column(scale=1, min_width=400): + gr.Markdown("### Model 2") + comp_model_dropdown_right = gr.Dropdown(choices=[], label="Select Model 2", interactive=True) + gr.Markdown("#### Problem Grid (Model 2 - Click to Select Problem)") + comp_problem_grid_html_output_right = gr.HTML("Select model 2.") + gr.Markdown("#### Model 2 Responses (Click Grid Below)") comp_samples_grid_output_right = gr.HTML("") - - with gr.Row(): - comp_sample_number_input_right = gr.Textbox( - value="0", - elem_id="comp-sample-number-input-right", - visible=True, - label="Enter Sample Number (0 - 63)", - container=True, - interactive=True, - every=0.5 - ) - - # 样本内容显示区域 - 使用HTML和Markdown组件分别显示元数据和响应内容 - comp_sample_metadata_output_right = gr.HTML( - value="
Select a problem first to view samples.
", - elem_classes="sample-metadata dark-mode-bg-secondary", - elem_id="comp-sample-metadata-area-right" - ) - - comp_sample_response_output_right = gr.Markdown( - value="Select a problem first to view samples.", - elem_classes="sample-response dark-mode-bg-secondary", - elem_id="comp-sample-response-area-right", - latex_delimiters=[ - {"left": "$", "right": "$", "display": False}, - {"left": "$$", "right": "$$", "display": True}, - {"left": "\\(", "right": "\\)", "display": False}, - {"left": "\\[", "right": "\\]", "display": True} - ] - ) - - # --- Event Handlers --- + gr.Markdown("##### Selected Sample (Model 2)") + comp_sample_metadata_output_right = gr.HTML("Select sample.") + comp_sample_response_output_right = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) + + + # --- Event Handlers --- + def update_available_models_for_dropdowns(selected_dataset): - # This function can be used to update model lists if they are dataset-dependent - # For now, assume get_available_models() gets all models irrespective of dataset for dropdown population - all_models = db.get_available_models() - # For single model tab, format with accuracy on the selected dataset - single_model_options = [] - model_to_display_map = {} # 映射用于存储真实模型名称到显示名称的映射 - + # Fetches all models, gets accuracies for the selected dataset, formats choices + if not db or not db.conn: + print("Error: DB not available in update_available_models_for_dropdowns") + # Return empty updates for all dropdowns + return gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[]) + + all_models = db.get_available_models() # Uses cache/in-memory DB + model_acc_map = {} if selected_dataset and all_models: + # Fetch accuracies (uses cache/in-memory DB) model_accs = db.get_all_model_accuracies(selected_dataset) model_acc_map = {name: acc for name, acc in model_accs} - single_model_options = [] - for name in all_models: - # 使用MODEL_TRANS映射模型名称 - display_name = MODEL_TRANS.get(name, name) - acc_display = f" ({model_acc_map.get(name, 0):.1%})" if model_acc_map.get(name) is not None else "" - display_text = f"{display_name}{acc_display}" - single_model_options.append(display_text) - model_to_display_map[display_text] = name # 存储映射关系 - else: - for name in all_models: - display_name = MODEL_TRANS.get(name, name) - single_model_options.append(display_name) - model_to_display_map[display_name] = name - - # 将映射存储到全局数据库对象中以便后续使用 - db.model_display_to_real = model_to_display_map - - # For comparison tab, also use formatted model names with accuracy - comp_model_choices = single_model_options # 使用和单模型相同的选项,包含准确率 - db.comp_model_display_to_real = model_to_display_map # 使用相同的映射 - - return gr.Dropdown(choices=single_model_options if single_model_options else [], value=None), \ - gr.Dropdown(choices=comp_model_choices if comp_model_choices else [], value=None) + + display_options = [] + # Clear previous maps before repopulating + db.model_display_to_real = {} + db.comp_model_display_to_real = {} + + # Sort models by accuracy (descending), handle missing accuracy (treat as -1) + sorted_models = sorted(all_models, key=lambda m: model_acc_map.get(m, -1), reverse=True) + + for name in sorted_models: + display_name = MODEL_TRANS.get(name, name) # Use translation map + acc = model_acc_map.get(name) + # Format accuracy nicely, handle None + acc_display = f" ({acc:.1%})" if acc is not None else " (N/A)" + display_text = f"{display_name}{acc_display}" + display_options.append(display_text) + # Store mapping from formatted name back to real name for both contexts + db.model_display_to_real[display_text] = name + db.comp_model_display_to_real[display_text] = name + + # Return updates for all three dropdowns + # Use label argument to set labels dynamically if needed, or keep static + return gr.Dropdown(choices=display_options, value=None, label="Select Model (Name + Acc%)", interactive=True), \ + gr.Dropdown(choices=display_options, value=None, label="Select Model 1", interactive=True), \ + gr.Dropdown(choices=display_options, value=None, label="Select Model 2", interactive=True) + def update_problem_grid_and_stats(selected_model_formatted, selected_dataset, mode='default'): + # Fetches stats/problems, returns stats DF, grid HTML, and *real* model name state + if not db or not db.conn: + print("Error: DB not available in update_problem_grid_and_stats") + return gr.DataFrame(value=[]), gr.HTML("

DB Error.

"), None if not selected_model_formatted or not selected_dataset: - # Return empty/default values for all outputs, including the state - return gr.DataFrame(value=[]), gr.HTML("
Please select a model and dataset first.
"), None - - # 从映射中获取真实模型名称 - model_name = db.model_display_to_real.get(selected_model_formatted, selected_model_formatted) - # 如果找不到确切匹配,可能是因为准确率等动态内容导致,尝试前缀匹配 - if model_name == selected_model_formatted: - for display_name, real_name in db.model_display_to_real.items(): - if selected_model_formatted.startswith(display_name.split(" (")[0]): - model_name = real_name + return gr.DataFrame(value=[]), gr.HTML("Select model and dataset."), None + + # Use the appropriate map (comp or default) to get the real model name + real_model_name = None + if mode == 'comparison': + real_model_name = db.comp_model_display_to_real.get(selected_model_formatted) + if not real_model_name: # Fallback to default map or parsing if comp map missed + real_model_name = db.model_display_to_real.get(selected_model_formatted) + + # If still not found via map, try parsing (less reliable) + if not real_model_name: + raw_name_part = selected_model_formatted.split(" (")[0] + for db_name, display_lookup in MODEL_TRANS.items(): + if display_lookup == raw_name_part: + real_model_name = db_name break - - stats_data = db.get_model_statistics(model_name, selected_dataset) - problem_list = db.get_problems_by_model_dataset(model_name, selected_dataset) + if not real_model_name: real_model_name = raw_name_part # Assume it's the real name + print(f"Warning: Using fallback lookup/parsing for model name: '{selected_model_formatted}' -> '{real_model_name}'") + + if not real_model_name: # Final check if name resolution failed + print(f"Error: Could not determine real model name for '{selected_model_formatted}'") + return gr.DataFrame(value=[]), gr.HTML("

Internal error resolving model name.

"), None + + # Fetch data using the resolved real model name + stats_data = db.get_model_statistics(real_model_name, selected_dataset) + problem_list = db.get_problems_by_model_dataset(real_model_name, selected_dataset) grid_html = create_problem_grid_html(problem_list, mode=mode) - - # Correctly return the actual value for the current_model_state output - return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), model_name - # Single Model Tab interactions + # Return stats DF, grid HTML, and the *real* model name for state update + return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), real_model_name + + + # Helper function to clear problem/sample displays + def clear_problem_outputs(): + return "Select a problem.", "N/A", gr.HTML(""), gr.State([]), "Select a sample.", "*(Response)*", "0" + + # Helper function to clear comparison problem/sample displays for one side + def clear_comparison_side_outputs(): + # Doesn't clear the main problem/answer display, only the side's samples + return gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" + + + # --- Single Model Tab Event Connections --- + + # Dataset selection changes model dropdown options and clears everything else dataset_radio_single.change( fn=update_available_models_for_dropdowns, inputs=[dataset_radio_single], - outputs=[model_dropdown, comp_model_dropdown_left] - ).then( - lambda ds: (gr.DataFrame(value=[]), gr.HTML("
Select a model.
"), gr.State(value=None), ds, ""), # 清空所有输出,包括problem_state_input + outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] # Update all model lists + ).then( # Chain: Reset dependent states and UI components + lambda ds: (gr.DataFrame(value=[]), gr.HTML("Select model."), None, ds, "", # Clear stats, grid, model state, problem ID state + *clear_problem_outputs(), # Clear problem/answer/samples/state/metadata/response/index state + ""), # Clear problem input display inputs=[dataset_radio_single], - outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_state_input] - ).then( - # 重置Sample Number为0 - fn=lambda: "0", - inputs=[], - outputs=[sample_number_input] - ).then( - lambda: ("Please fill in all the fields.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", ""), - inputs=[], - outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output] + outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_id_state, + problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, + sample_metadata_output, sample_response_output, sample_index_state, + problem_id_input_display] # Also clear the visible input box ) - - # Initial population of model dropdowns based on default dataset - demo.load( - fn=update_available_models_for_dropdowns, - inputs=[current_dataset_state], # Uses initial value of state - outputs=[model_dropdown, comp_model_dropdown_left] - ).then( - lambda ds_val: (gr.DataFrame(value=[]), gr.HTML("
Select a model.
"), ds_val), # Also update dataset state for single tab - inputs=[current_dataset_state], - outputs=[model_stats_df, problem_grid_html_output, current_dataset_state] - ).then( - lambda: ("Please fill in all the fields.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", ""), - inputs=[], - outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output] - ).then( - # 重置Sample Number为0 - fn=lambda: "0", - inputs=[], - outputs=[sample_number_input] + + # Model selection updates stats, grid, and clears problem/sample display + model_dropdown.change( + fn=update_problem_grid_and_stats, # Mode defaults to 'default' + inputs=[model_dropdown, current_dataset_state], + outputs=[model_stats_df, problem_grid_html_output, current_model_state] # Update stats, grid, REAL model state + ).then( # Chain: Clear problem-specific outputs + lambda: ("", # Clear problem ID state + *clear_problem_outputs(), # Clear problem/answer/samples/state/metadata/response/index state + ""), # Clear problem input display + inputs=[], + outputs=[problem_id_state, + problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, + sample_metadata_output, sample_response_output, sample_index_state, + problem_id_input_display] ) - # ==== 比较页面事件处理 ==== - # 初始化两侧模型下拉列表 - demo.load( - fn=update_available_models_for_dropdowns, - inputs=[comp_dataset_state], - outputs=[model_dropdown, comp_model_dropdown_left] - ).then( - fn=update_available_models_for_dropdowns, - inputs=[comp_dataset_state], - outputs=[model_dropdown, comp_model_dropdown_right] + # --- Problem Selection Handling (Single Tab) --- + # Option 1: User types into the visible input box + problem_id_input_display.submit( # Trigger on Enter press + # Copy value from visible input to hidden state input to trigger main handler + fn=lambda x: x, + inputs=[problem_id_input_display], + outputs=[problem_id_state] # This change triggers the next handler + ) + # Option 2: User clicks problem grid (JS updates hidden problem_id_state) + # Main handler triggered by change in hidden problem_id_state + problem_id_state.change( + fn=handle_problem_select, # Fetches problem, answer, sample grid, sample data state + inputs=[problem_id_state, current_model_state, current_dataset_state], # Mode defaults to 'default' + outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] + ).then( # Chain: Display the first sample after data is loaded + fn=handle_first_sample, + inputs=[current_samples_data_state], + outputs=[sample_metadata_output, sample_response_output] + ).then( # Chain: Reset sample index state to 0 + lambda: "0", inputs=[], outputs=[sample_index_state] + ).then( # Chain: Update the visible input box to reflect the selected ID (useful if clicked) + fn=lambda x: x.value if hasattr(x,'value') else x, # Get value from state object + inputs=[problem_id_state], + outputs=[problem_id_input_display] ) - - # 数据集改变事件 + + + # --- Sample Selection Handling (Single Tab) --- + # Triggered by change in hidden sample_index_state (updated by JS sample click) + sample_index_state.change( + fn=handle_sample_select, + inputs=[sample_index_state, current_samples_data_state], + outputs=[sample_metadata_output, sample_response_output] + ) + + + # --- Comparison Tab Event Connections --- + + # Dataset change updates dropdowns and clears everything comp_dataset_radio.change( - fn=lambda ds: ds, + fn=lambda ds: ds, # Update comparison dataset state first inputs=[comp_dataset_radio], outputs=[comp_dataset_state] ).then( fn=update_available_models_for_dropdowns, inputs=[comp_dataset_state], - outputs=[model_dropdown, comp_model_dropdown_left] - ).then( - fn=update_available_models_for_dropdowns, - inputs=[comp_dataset_state], - outputs=[model_dropdown, comp_model_dropdown_right] - ).then( - lambda: ("Please select a dataset and enter a problem ID.", "No answer available."), - inputs=[], - outputs=[comp_problem_markdown_output, comp_answer_markdown_output] - ) - - # 为比较页面的问题ID添加单独的更新逻辑 - comp_problem_state_input.change( - fn=handle_comparison_problem_update, - inputs=[comp_problem_state_input, comp_dataset_state], - outputs=[comp_problem_markdown_output, comp_answer_markdown_output] + outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] + ).then( # Clear everything dependent on dataset/models/problem + lambda: (None, None, # Clear model states + "Select models and problem.", "Select problem.", # Clear problem/answer display + gr.HTML("Select model."), gr.HTML("Select model."), # Clear grids + *clear_comparison_side_outputs(), # Clear left samples + *clear_comparison_side_outputs(), # Clear right samples + "", # Clear comp problem ID state + ""), # Clear comp problem input display + inputs=[], + outputs=[comp_model_state_left, comp_model_state_right, + comp_problem_markdown_output, comp_answer_markdown_output, + comp_problem_grid_html_output_left, comp_problem_grid_html_output_right, + # Left side sample outputs + state + comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left, comp_sample_index_state_left, + # Right side sample outputs + state + comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right, comp_sample_index_state_right, + comp_problem_id_state, comp_problem_id_input_display] ) - - # 创建包装函数,预设模式参数 - def update_problem_grid_comparison(model, dataset): - return update_problem_grid_and_stats(model, dataset, mode='comparison') - - # 问题选择的包装函数 - def handle_problem_select_comparison(problem_id, model_state, dataset_state): - return handle_problem_select(problem_id, model_state, dataset_state, mode='comparison') - - # 修改model_dropdown的处理函数,以重新查询当前问题响应 - 比较页面左侧 - def update_model_and_requery_problem_left(model_dropdown_value, current_dataset, current_problem_id): - # 首先更新模型统计和问题网格 - _, grid_html, new_model_state = update_problem_grid_comparison(model_dropdown_value, current_dataset) - - # 如果有选择的问题ID,重新查询它的响应 - if current_problem_id: - problem_content, answer_content, samples_grid_html, new_samples_data = handle_problem_select_comparison(current_problem_id, new_model_state, current_dataset) - - # 获取第一个样本的内容 - first_metadata, first_response = handle_first_sample(new_samples_data) - - return grid_html, new_model_state, problem_content, answer_content, samples_grid_html, new_samples_data, first_metadata, first_response - else: - # 没有问题ID,只返回更新的模型状态 - return grid_html, new_model_state, "Please enter a problem ID.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", "" - - # 修改model_dropdown的处理函数,以重新查询当前问题响应 - 比较页面右侧 - def update_model_and_requery_problem_right(model_dropdown_value, current_dataset, current_problem_id): - # 首先更新模型统计和问题网格 - _, grid_html, new_model_state = update_problem_grid_comparison(model_dropdown_value, current_dataset) - - # 如果有选择的问题ID,重新查询它的响应 - if current_problem_id: - # 对于右侧,我们不需要更新问题和答案内容 - _, _, samples_grid_html, new_samples_data = handle_problem_select_comparison(current_problem_id, new_model_state, current_dataset) - - # 获取第一个样本的内容 - first_metadata, first_response = handle_first_sample(new_samples_data) - - return grid_html, new_model_state, samples_grid_html, new_samples_data, first_metadata, first_response - else: - # 没有问题ID,只返回更新的模型状态 - return grid_html, new_model_state, "", gr.State([]), "
Select a problem first to view samples.
", "" - # 左侧模型选择事件 + # Left model selection comp_model_dropdown_left.change( - fn=update_model_and_requery_problem_left, - inputs=[comp_model_dropdown_left, comp_dataset_state, comp_problem_state_input], - outputs=[comp_problem_grid_html_output_left, comp_model_state_left, comp_problem_markdown_output, comp_answer_markdown_output, comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left] + # 1. Update grid and model state for left side + fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), + inputs=[comp_model_dropdown_left, comp_dataset_state], + # Output to dummy state for stats DF, then grid HTML, then REAL model state + outputs=[dummy_state, comp_problem_grid_html_output_left, comp_model_state_left] ).then( - # 重置Sample Number为0 - fn=lambda: "0", - inputs=[], - outputs=[comp_sample_number_input_left] - ) - - # 右侧模型选择事件 - comp_model_dropdown_right.change( - fn=update_model_and_requery_problem_right, - inputs=[comp_model_dropdown_right, comp_dataset_state, comp_problem_state_input], - outputs=[comp_problem_grid_html_output_right, comp_model_state_right, comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right] + # 2. If a problem is already selected, refetch left samples + # Use 'comparison_left' mode for handle_problem_select + fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), + inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], + # Outputs for handle_problem_select: problem, answer, sample_grid, sample_state + # We only care about sample_grid and sample_state here + outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left] ).then( - # 重置Sample Number为0 - fn=lambda: "0", - inputs=[], - outputs=[comp_sample_number_input_right] - ) - - # 左侧样本选择 - comp_sample_number_input_left.change( - fn=handle_sample_select, - inputs=[comp_sample_number_input_left, comp_samples_data_state_left], - outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] - ) - - # 右侧样本选择 - comp_sample_number_input_right.change( - fn=handle_sample_select, - inputs=[comp_sample_number_input_right, comp_samples_data_state_right], - outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] - ) - - # 为比较页面问题选择事件添加处理 - comp_problem_state_input.change( - fn=handle_problem_select_comparison, - inputs=[comp_problem_state_input, comp_model_state_left, comp_dataset_state], - outputs=[comp_problem_markdown_output, comp_answer_markdown_output, comp_samples_grid_output_left, comp_samples_data_state_left] - ).then( - # 重置Sample Number为0 - fn=lambda: "0", - inputs=[], - outputs=[comp_sample_number_input_left] + # 3. Display first sample for left side (if data exists) + fn=handle_first_sample, + inputs=[comp_samples_data_state_left], + outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] ).then( - fn=handle_first_sample, - inputs=[comp_samples_data_state_left], - outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] + # 4. Reset left sample index state + lambda: "0", inputs=[], outputs=[comp_sample_index_state_left] ) - - # 问题选择事件 - 右侧模型 - comp_problem_state_input.change( - fn=handle_problem_select_comparison, - inputs=[comp_problem_state_input, comp_model_state_right, comp_dataset_state], + + # Right model selection (mirrors left side logic) + comp_model_dropdown_right.change( + fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), + inputs=[comp_model_dropdown_right, comp_dataset_state], + outputs=[dummy_state, comp_problem_grid_html_output_right, comp_model_state_right] + ).then( + # Use 'comparison_right' mode for handle_problem_select + fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), + inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right] ).then( - # 重置Sample Number为0 - fn=lambda: "0", - inputs=[], - outputs=[comp_sample_number_input_right] + fn=handle_first_sample, + inputs=[comp_samples_data_state_right], + outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] ).then( - fn=handle_first_sample, - inputs=[comp_samples_data_state_right], - outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] + lambda: "0", inputs=[], outputs=[comp_sample_index_state_right] ) - # This is the crucial link: problem_state_input is changed by user, triggers this Python callback. - problem_state_input.change( - fn=handle_problem_select, - inputs=[problem_state_input, current_model_state, current_dataset_state], - outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] + # --- Problem Selection Handling (Comparison Tab) --- + # Option 1: User types into the visible input box + comp_problem_id_input_display.submit( + fn=lambda x: x, inputs=[comp_problem_id_input_display], outputs=[comp_problem_id_state] + ) + # Option 2: User clicks problem grid (JS updates hidden comp_problem_id_state) + # Main handler triggered by change in hidden comp_problem_id_state + comp_problem_id_state.change( + # 1. Update main problem/answer display (doesn't need model info) + fn=handle_comparison_problem_update, + inputs=[comp_problem_id_state, comp_dataset_state], + outputs=[comp_problem_markdown_output, comp_answer_markdown_output] ).then( - # 重置Sample Number为0 - fn=lambda: "0", - inputs=[], - outputs=[sample_number_input] + # 2. Update left samples (if left model is selected) + fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if model_state.value else ("","",gr.HTML(""), gr.State([])), + inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], + outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left] ).then( - fn=handle_first_sample, - inputs=[current_samples_data_state], - outputs=[sample_metadata_output, sample_response_output] - ) - - # Also listen for direct input event which may be more reliable than change - problem_state_input.input( - fn=handle_problem_select, - inputs=[problem_state_input, current_model_state, current_dataset_state], - outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] + # 3. Display first left sample + fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] ).then( - # 重置Sample Number为0 - fn=lambda: "0", - inputs=[], - outputs=[sample_number_input] + # 4. Reset left sample index + lambda: "0", inputs=[], outputs=[comp_sample_index_state_left] ).then( - fn=handle_first_sample, - inputs=[current_samples_data_state], - outputs=[sample_metadata_output, sample_response_output] + # 5. Update right samples (if right model is selected) + fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if model_state.value else ("","",gr.HTML(""), gr.State([])), + inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], + outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right] + ).then( + # 6. Display first right sample + fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] + ).then( + # 7. Reset right sample index + lambda: "0", inputs=[], outputs=[comp_sample_index_state_right] + ).then( + # 8. Update the visible input box to reflect the selected ID + fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[comp_problem_id_state], outputs=[comp_problem_id_input_display] ) - # 添加样本编号的事件处理 - sample_number_input.change( + + # --- Sample Selection Handling (Comparison Tab) --- + # Left sample selection (triggered by JS updating hidden state) + comp_sample_index_state_left.change( fn=handle_sample_select, - inputs=[sample_number_input, current_samples_data_state], - outputs=[sample_metadata_output, sample_response_output] + inputs=[comp_sample_index_state_left, comp_samples_data_state_left], + outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] ) - - sample_number_input.input( + # Right sample selection (triggered by JS updating hidden state) + comp_sample_index_state_right.change( fn=handle_sample_select, - inputs=[sample_number_input, current_samples_data_state], - outputs=[sample_metadata_output, sample_response_output] + inputs=[comp_sample_index_state_right, comp_samples_data_state_right], + outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] ) - - # 修改model_dropdown.change处理函数,以重新查询当前问题响应 - def update_model_and_requery_problem(model_dropdown_value, current_dataset, current_problem_id): - # 首先更新模型统计和问题网格 - stats_df, grid_html, new_model_state = update_problem_grid_and_stats(model_dropdown_value, current_dataset) - - # 如果有选择的问题ID,重新查询它的响应 - if current_problem_id: - problem_content, answer_content, samples_grid_html, new_samples_data = handle_problem_select(current_problem_id, new_model_state, current_dataset) - - # 获取第一个样本的内容 - first_metadata, first_response = handle_first_sample(new_samples_data) - - return stats_df, grid_html, new_model_state, problem_content, answer_content, samples_grid_html, new_samples_data, first_metadata, first_response - else: - # 没有问题ID,只返回更新的模型状态 - return stats_df, grid_html, new_model_state, "Please fill in all the fields.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", "" - model_dropdown.change( - fn=update_model_and_requery_problem, - inputs=[model_dropdown, current_dataset_state, problem_state_input], - outputs=[model_stats_df, problem_grid_html_output, current_model_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output] - ).then( - # 重置Sample Number为0 - fn=lambda: "0", - inputs=[], - outputs=[sample_number_input] + # --- Initial Population on App Load --- + # Populate model dropdowns based on the initial dataset state + demo.load( + fn=update_available_models_for_dropdowns, + inputs=[current_dataset_state], # Use initial state for single tab dataset + outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] ) return demo +# <<< Keep monitor_memory_usage function >>> def monitor_memory_usage(): - """监控内存使用情况并在必要时释放缓存""" + """Monitors memory usage and clears caches if thresholds are exceeded.""" global db - try: process = psutil.Process(os.getpid()) + # Use resident set size (RSS) as a measure of physical RAM usage memory_info = process.memory_info() - memory_usage_mb = memory_info.rss / 1024 / 1024 - - # 如果内存使用超过12GB (激进设置),清理缓存 - if memory_usage_mb > 12000: # 12GB - if db: - db.clear_cache('response') # 优先清理响应缓存 - gc.collect() - # 如果内存使用超过14GB,更激进地清理 - if memory_usage_mb > 14000: # 14GB - if db: - db.clear_cache() # 清理所有缓存 - gc.collect() - - return f"Memory: {memory_usage_mb:.1f} MB" + memory_usage_mb = memory_info.rss / (1024 * 1024) + # Get total system memory for context + total_memory_gb = psutil.virtual_memory().total / (1024**3) + print(f"[Memory Monitor] Usage: {memory_usage_mb:.1f} MB / {total_memory_gb:.1f} GB available.") + + # Define thresholds based on 18GB total available RAM + # Threshold 1: ~70% = 12.6 GB + threshold_1_mb = 18 * 1024 * 0.70 + # Threshold 2: ~85% = 15.3 GB + threshold_2_mb = 18 * 1024 * 0.85 + + if db and db.conn: # Only clear caches if DB is active + if memory_usage_mb > threshold_2_mb: + print(f"[Memory Monitor] CRITICAL: Usage ({memory_usage_mb:.1f} MB) exceeds {threshold_2_mb:.1f} MB. Clearing ALL caches.") + db.clear_cache() # Clear all Python-level caches + elif memory_usage_mb > threshold_1_mb: + print(f"[Memory Monitor] WARNING: Usage ({memory_usage_mb:.1f} MB) exceeds {threshold_1_mb:.1f} MB. Clearing response cache.") + db.clear_cache('response') # Clear less critical response cache first + else: + print("[Memory Monitor] DB not active, skipping cache check.") + + + # Return status string (optional, could be used in UI if needed) + return f"Memory OK: {memory_usage_mb:.1f} MB" + except Exception as e: + print(f"[Memory Monitor] Error: {e}") return "Memory monitor error" -# 修改主函数以使用优化策略 +# <<< Keep __main__ block, ensuring DB initialization happens before UI creation >>> if __name__ == "__main__": - DB_PATH = "data.db" - - # 检查数据库文件是否存在,如果不存在则从 Hugging Face 下载 + DB_FILE_NAME = "data.db" + # Determine expected path (e.g., in the current directory) + DB_PATH = os.path.abspath(DB_FILE_NAME) + + # --- Database Download/Check --- if not os.path.exists(DB_PATH): + print(f"{DB_PATH} not found. Attempting to download from Hugging Face Hub...") try: - # 从环境变量获取 HF_TOKEN + # Attempt to get token from environment or local file hf_token = os.environ.get("HF_TOKEN") if not hf_token: - raise ValueError("HF_TOKEN environment variable is not set") - - # 从 Hugging Face 下载数据库文件 - DB_PATH = hf_hub_download( + token_path = Path.home() / ".huggingface" / "token" + if token_path.exists(): + print("Using token from ~/.huggingface/token") + hf_token = token_path.read_text().strip() + # Optional: Add interactive prompt for token if needed + # else: hf_token = input("Enter your Hugging Face token: ").strip() + + if not hf_token: + raise ValueError("Hugging Face token not found (set HF_TOKEN env var or place in ~/.huggingface/token).") + + print("Downloading data.db (this might take time)...") + # Download directly to the expected DB_PATH location's directory + downloaded_path = hf_hub_download( repo_id="CoderBak/OlymMATH-data", - filename="data.db", + filename=DB_FILE_NAME, repo_type="dataset", - token=hf_token + token=hf_token, + local_dir=os.path.dirname(DB_PATH), # Download to target directory + local_dir_use_symlinks=False # Force copy, avoid symlink issues ) + # Verify download path matches expected path + if os.path.abspath(downloaded_path) != DB_PATH: + print(f"Warning: Downloaded path '{downloaded_path}' differs from expected path '{DB_PATH}'.") + # Optionally, attempt to move/rename or update DB_PATH + # For simplicity, assume download worked if no exception + DB_PATH = os.path.abspath(downloaded_path) + + print(f"Database downloaded successfully to: {DB_PATH}") except Exception as e: - # 创建一个显示错误信息的简单 Gradio 应用 + print(f"Error downloading database: {e}") + # Display error UI and exit with gr.Blocks() as error_demo: - gr.Markdown(f"# Error: Database Download Failed\n{str(e)}\nPlease ensure HF_TOKEN is set correctly and try again.") - error_demo.launch(server_name="0.0.0.0") + gr.Markdown(f"# Error: Database Download Failed\n`{str(e)}`\nEnsure `{DB_FILE_NAME}` exists in the script's directory, or provide a valid Hugging Face token and network access.") + error_demo.launch(server_name="0.0.0.0", show_error=True) exit(1) - - if os.path.exists(DB_PATH): - # 创建UI并启动 - db = ModelDatabase(DB_PATH) - - # 添加清理函数 - def cleanup(): - global db - if db: - db.close() - - # 注册清理函数 - import atexit - atexit.register(cleanup) - - # 创建UI - main_demo = create_ui(DB_PATH) - - # 使用兼容的启动参数 - main_demo.launch( - server_name="0.0.0.0", - share=False, - inbrowser=False - ) - else: - # 创建一个显示错误信息的简单 Gradio 应用 - with gr.Blocks() as error_demo: - gr.Markdown(f"# Error: Database Not Found\nCould not find `{DB_PATH}`. Please ensure the database file is correctly placed and accessible.") - error_demo.launch(server_name="0.0.0.0") \ No newline at end of file + + # --- Database Initialization (Loads into Memory) --- + print(f"Initializing ModelDatabase from: {DB_PATH} (loading data into memory)...") + start_init = time.time() + try: + # Instantiate the database class - this performs the load + db = ModelDatabase(DB_PATH) # db becomes the global instance + except Exception as e: + print(f"Fatal Error during ModelDatabase initialization: {e}") + import traceback + traceback.print_exc() # Print full traceback for debugging + # Display error UI and exit + with gr.Blocks() as error_demo: + gr.Markdown(f"# Error: Database Initialization Failed\n`{str(e)}`\nCheck database file integrity, permissions, and available memory ({psutil.virtual_memory().available / (1024**3):.1f} GB free).") + error_demo.launch(server_name="0.0.0.0", show_error=True) + exit(1) + end_init = time.time() + print(f"ModelDatabase initialized in {end_init - start_init:.2f} seconds.") + # Initial memory check after load + monitor_memory_usage() + + # --- Cleanup Registration --- + def cleanup(): + global db + print("\nRunning cleanup before exit...") + if db: + db.close() # Ensures connection is closed properly + print("Cleanup finished.") + atexit.register(cleanup) # Register cleanup to run on script exit + + # --- Create and Launch UI --- + print("Creating Gradio UI...") + # Pass the initialized (in-memory) db instance to the UI creation function + main_demo = create_ui(db) + + # --- Optional: Periodic Memory Monitor --- + # Use Gradio's `every` parameter on a dummy component if you want checks tied to UI activity, + # or run a separate thread (carefully). For simplicity, we'll rely on checks during requests (if implemented) + # or manual monitoring based on the initial check. + + print("Launching Gradio application...") + # Use queue() for better performance under load + main_demo.queue().launch( + server_name="0.0.0.0", # Listen on all interfaces + share=os.environ.get("GRADIO_SHARE", False), # Allow sharing via env var if needed + inbrowser=False, # Don't automatically open browser + show_error=True, # Show Python errors in browser console + # Optional: Increase workers if needed, but be mindful of memory/CPU + # max_threads=4 + ) + + # Server runs here until interrupted (Ctrl+C) + print("Application is running. Press Ctrl+C to stop.") + # Keep the main thread alive if needed (though launch() usually blocks) + # while True: time.sleep(1) \ No newline at end of file