diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,3 +1,5 @@
+# -*- coding: utf-8 -*-
+# <<< Keep all existing imports >>>
import os
import json
import pandas as pd
@@ -13,7 +15,9 @@ import time
from huggingface_hub import hf_hub_download
import psutil
import gc
+import atexit # Import atexit
+# <<< Keep SUBJECT_TRANS and MODEL_TRANS dictionaries >>>
# 翻译表
SUBJECT_TRANS = {
"代数": "Algebra",
@@ -21,7 +25,7 @@ SUBJECT_TRANS = {
"几何": "Geometry",
"组合": "Combinatorics"
}
-
+# MODEL_TRANS
MODEL_TRANS = {
"acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B",
"deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B",
@@ -51,109 +55,206 @@ MODEL_TRANS = {
"gemini-2.5-pro-exp-03-25": "Gemini 2.5 Pro Exp 0325",
"o3-mini-high": "OpenAI o3-mini (high)",
"qwen3-0.6b": "Qwen3-0.6B"
- # 添加更多模型映射
}
-# Configure matplotlib for better display
+# <<< Keep Matplotlib configuration >>>
plt.style.use('ggplot')
mpl.rcParams['figure.figsize'] = (10, 6)
mpl.rcParams['font.size'] = 10
-# Constants
+# <<< Keep Constants >>>
DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"]
-# 全局数据库实例
+# Global database instance
db = None
class ModelDatabase:
- """Database access class"""
-
+ """Database access class - Optimized to use in-memory database"""
def __init__(self, db_path):
- """Initialize database connection"""
+ """Initialize database connection by copying disk DB to memory."""
self.db_path = db_path
- # Use connection pool pattern to avoid too many connections
- self.conn = sqlite3.connect(db_path, check_same_thread=False, isolation_level=None, timeout=60)
- self.conn.execute("PRAGMA journal_mode = WAL") # Use Write-Ahead Logging for better performance
- self.conn.execute("PRAGMA synchronous = NORMAL") # Reduce synchronization overhead
- self.conn.execute("PRAGMA cache_size = -8000") # 8MB cache (比原来大4倍)
- self.conn.execute("PRAGMA temp_store = MEMORY") # 临时表存储在内存中
- self.conn.execute("PRAGMA mmap_size = 8589934592") # 尝试使用8GB内存映射
- self.conn.row_factory = sqlite3.Row
-
- # 创建索引以加速查询
- self._ensure_indices()
-
- # 初始化模型名称映射
- self.model_display_to_real = {}
- self.comp_model_display_to_real = {}
-
- # 初始化缓存
+ self.conn = None
self._cache = {}
self._problem_cache = {}
self._response_cache = {}
-
+ self.model_display_to_real = {}
+ self.comp_model_display_to_real = {}
+
+ disk_conn = None
+ try:
+ # 1. Connect to the source disk database in read-only mode
+ print(f"Connecting to source database (read-only): {db_path}")
+ # Ensure the file exists before trying to connect
+ if not os.path.exists(db_path):
+ raise FileNotFoundError(f"Database file not found at {db_path}")
+ disk_conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True, check_same_thread=False, timeout=120) # Increased timeout
+ print("Applying PRAGMAs to source connection for backup performance...")
+ disk_conn.execute("PRAGMA journal_mode = OFF")
+ disk_conn.execute("PRAGMA synchronous = OFF")
+ # Use a larger cache for reading from disk, e.g., 2GB = -2097152 KiB
+ disk_conn.execute("PRAGMA cache_size = -2097152")
+ disk_conn.execute("PRAGMA temp_store = MEMORY")
+ disk_conn.execute("PRAGMA locking_mode = EXCLUSIVE") # Prevent interference during backup
+
+ # 2. Connect to the target in-memory database
+ print("Creating in-memory database...")
+ # Increase timeout for potential long operations on the in-memory DB too
+ self.conn = sqlite3.connect(':memory:', check_same_thread=False, timeout=120)
+ self.conn.row_factory = sqlite3.Row # Use Row factory for dict-like access
+
+ # 3. Backup data from disk to memory
+ print("Starting database backup from disk to memory (this may take a while)...")
+ start_backup = time.time()
+ # Use a context manager for the destination connection to handle commits/rollbacks
+ with self.conn:
+ disk_conn.backup(self.conn)
+ end_backup = time.time()
+ print(f"Database backup completed in {end_backup - start_backup:.2f} seconds.")
+
+ # 4. Apply PRAGMAs suitable for the in-memory database
+ print("Applying PRAGMAs to in-memory database...")
+ # temp_store=MEMORY is default for :memory:, but explicit is fine
+ self.conn.execute("PRAGMA temp_store = MEMORY")
+ # cache_size might still help slightly, but OS caching is dominant. Can be omitted.
+ # self.conn.execute("PRAGMA cache_size = -4194304") # e.g., 4GB cache within RAM
+
+ # 5. Ensure indices exist on the in-memory database *after* data loading
+ print("Creating indices on in-memory database...")
+ start_index = time.time()
+ self._ensure_indices() # This now operates on self.conn (the memory DB)
+ end_index = time.time()
+ print(f"Index creation completed in {end_index - start_index:.2f} seconds.")
+
+ except sqlite3.Error as e:
+ print(f"SQLite error during database initialization: {e}")
+ if self.conn:
+ self.conn.close()
+ self.conn = None
+ raise # Re-raise the exception to signal failure
+ except FileNotFoundError as e:
+ print(f"Error: {e}")
+ raise
+ except Exception as e:
+ print(f"Unexpected error during database initialization: {e}")
+ if self.conn:
+ self.conn.close()
+ self.conn = None
+ raise
+ finally:
+ # 6. Close the disk connection, it's no longer needed
+ if disk_conn:
+ disk_conn.close()
+ print("Closed connection to disk database.")
+
+ if self.conn:
+ print("In-memory database initialized successfully.")
+ else:
+ print("Error: In-memory database connection failed.")
+ raise RuntimeError("Failed to establish in-memory database connection.")
+
+
def _ensure_indices(self):
- """确保数据库有必要的索引"""
+ """Ensure necessary indices exist on the database connection (self.conn)."""
+ if not self.conn:
+ print("Error: Connection not established. Cannot ensure indices.")
+ return
try:
cursor = self.conn.cursor()
- # 添加最常用查询的索引
+ print("Creating index: idx_responses_model_dataset")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_model_dataset ON responses(model_name, dataset)")
+ print("Creating index: idx_responses_unique_id")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_unique_id ON responses(unique_id)")
+ print("Creating index: idx_problems_unique_id")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_unique_id ON problems(unique_id)")
- cursor.execute("ANALYZE") # 分析表以优化查询计划
- except Exception as e:
- pass
-
+ print("Creating index: idx_problems_subject")
+ cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_subject ON problems(subject)")
+ # Analyze the tables after creating indices for optimal query plans
+ print("Running ANALYZE...")
+ cursor.execute("ANALYZE")
+ self.conn.commit() # Commit index creation and analysis
+ print("Indices created and table analyzed successfully.")
+ except sqlite3.Error as e:
+ # Log error but don't necessarily crash the app
+ print(f"Warning: Could not create or analyze indices: {e}")
+ # Attempt rollback if something failed partially
+ try:
+ self.conn.rollback()
+ except sqlite3.Error as rb_e:
+ print(f"Rollback attempt failed after index error: {rb_e}")
+ # Depending on severity, you might want to raise e here
+
+ # <<< Methods get_available_models, get_available_datasets, get_model_statistics, >>>
+ # <<< get_all_model_accuracies, get_problems_by_model_dataset, get_problem_data, >>>
+ # <<< get_model_responses, clear_cache are modified to: >>>
+ # <<< 1. Remove INDEXED BY hints >>>
+ # <<< 2. Add checks for self.conn existence >>>
+ # <<< 3. Improve error logging >>>
+
def get_available_models(self):
"""Get list of all available models"""
- # 缓存在实例变量中
- if hasattr(self, '_models_cache') and self._models_cache:
+ if not self.conn: return []
+ # Check cache first
+ if hasattr(self, '_models_cache') and self._models_cache is not None:
return self._models_cache
-
try:
cursor = self.conn.cursor()
+ # Query without explicit index hints
cursor.execute("SELECT DISTINCT model_name FROM responses ORDER BY model_name")
models = [row['model_name'] for row in cursor.fetchall()]
- self._models_cache = models # 存储到实例缓存
+ self._models_cache = models # Store in instance cache
return models
- except sqlite3.OperationalError:
- return []
-
+ except sqlite3.Error as e:
+ print(f"Database error in get_available_models: {e}")
+ return [] # Return empty list on error
+
def get_available_datasets(self):
"""Get list of all available datasets"""
- # 缓存在实例变量中
- if hasattr(self, '_datasets_cache') and self._datasets_cache:
+ if not self.conn: return DATASETS # Fallback if connection failed
+ # Check cache first
+ if hasattr(self, '_datasets_cache') and self._datasets_cache is not None:
return self._datasets_cache
-
try:
cursor = self.conn.cursor()
+ # Query without explicit index hints
cursor.execute("SELECT DISTINCT dataset FROM responses ORDER BY dataset")
+ # Ensure uppercase consistency
datasets = [row['dataset'].upper() for row in cursor.fetchall()]
- self._datasets_cache = datasets # 存储到实例缓存
+ self._datasets_cache = datasets # Store in instance cache
return datasets
- except sqlite3.OperationalError:
- return DATASETS
-
+ except sqlite3.Error as e:
+ print(f"Database error in get_available_datasets: {e}")
+ return DATASETS # Fallback on error
+
def get_model_statistics(self, model_name, dataset):
"""Get statistics for a model on a specific dataset"""
+ if not self.conn: return [["Database Error", "No connection"]]
+ # Sanitize inputs
if hasattr(model_name, 'value'): model_name = model_name.value
if hasattr(dataset, 'value'): dataset = dataset.value
-
+ if not model_name or not dataset: return [["Input Error", "Missing model or dataset"]]
+
cache_key = f"stats_{model_name}_{dataset}"
- if not hasattr(self, '_cache'): self._cache = {}
if cache_key in self._cache: return self._cache[cache_key]
-
- cursor = self.conn.cursor()
+
+ stats_data = []
try:
- # 优化查询1: 整体准确率 - 使用索引提示加速
+ cursor = self.conn.cursor()
+ # Query 1: Overall accuracy - No INDEXED BY hint
cursor.execute("""
SELECT COUNT(*) as total_samples, AVG(correctness) as accuracy
- FROM responses INDEXED BY idx_responses_model_dataset
+ FROM responses
WHERE model_name = ? AND dataset = ?
""", (model_name, dataset.lower()))
overall_stats = cursor.fetchone()
-
- # 优化查询2: 按学科统计 - 避免子查询和复杂JOIN
+
+ if overall_stats and overall_stats['accuracy'] is not None:
+ stats_data.append(["Overall Acc.", f"{overall_stats['accuracy']:.2%}"])
+ elif overall_stats and overall_stats['total_samples'] == 0:
+ stats_data.append(["Overall Acc.", "No Samples"])
+ else:
+ stats_data.append(["Overall Acc.", "N/A"])
+
+ # Query 2: Per-subject statistics - No INDEXED BY hint
cursor.execute("""
SELECT p.subject, COUNT(r.id) as sample_count, AVG(r.correctness) as accuracy
FROM responses r JOIN problems p ON r.unique_id = p.unique_id
@@ -161,1471 +262,1488 @@ class ModelDatabase:
GROUP BY p.subject ORDER BY p.subject
""", (model_name, dataset.lower()))
subject_stats_rows = cursor.fetchall()
-
- stats_data = []
- if overall_stats and overall_stats['accuracy'] is not None:
- stats_data.append(["Overall Acc.", f"{overall_stats['accuracy']:.2%}"])
- else:
- stats_data.append(["Overall Acc.", "N/A"])
for subject_row in subject_stats_rows:
acc_val = f"{subject_row['accuracy']:.2%}" if subject_row['accuracy'] is not None else "N/A"
subject_name = subject_row['subject']
- # 使用翻译表翻译科目名称
translated_subject = SUBJECT_TRANS.get(subject_name, subject_name)
stats_data.append([f"{translated_subject} Acc.", acc_val])
-
- self._cache[cache_key] = stats_data
+
+ self._cache[cache_key] = stats_data # Cache the result
return stats_data
- except sqlite3.OperationalError:
- return [["Database Error", "No data available"]]
-
+ except sqlite3.Error as e:
+ print(f"Database error in get_model_statistics({model_name}, {dataset}): {e}")
+ # Return partial data if overall stats succeeded but subject failed? Or just error.
+ return [["Database Error", f"Query failed: {e}"]]
+
def get_all_model_accuracies(self, dataset):
- """获取所有模型在特定数据集上的准确率 (优化版本)"""
+ """获取所有模型在特定数据集上的准确率"""
+ if not self.conn: return []
if hasattr(dataset, 'value'): dataset = dataset.value
+ if not dataset: return []
+
cache_key = f"all_accuracies_{dataset}"
- if not hasattr(self, '_cache'): self._cache = {}
if cache_key in self._cache: return self._cache[cache_key]
+
try:
cursor = self.conn.cursor()
- # 使用索引提示加速查询
+ # No INDEXED BY hint needed, rely on idx_responses_model_dataset
cursor.execute("""
SELECT model_name, AVG(correctness) as accuracy
- FROM responses INDEXED BY idx_responses_model_dataset
+ FROM responses
WHERE dataset = ? GROUP BY model_name ORDER BY accuracy DESC
""", (dataset.lower(),))
- results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall()]
- self._cache[cache_key] = results
+ # Fetchall directly into list comprehension
+ results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall() if row['accuracy'] is not None]
+ self._cache[cache_key] = results # Cache result
return results
- except sqlite3.OperationalError:
+ except sqlite3.Error as e:
+ print(f"Database error in get_all_model_accuracies({dataset}): {e}")
return []
def get_problems_by_model_dataset(self, model_name, dataset):
- """获取模型在特定数据集上的所有问题 (优化版本)"""
+ """获取模型在特定数据集上的所有问题"""
+ if not self.conn: return []
if hasattr(model_name, 'value'): model_name = model_name.value
if hasattr(dataset, 'value'): dataset = dataset.value
+ if not model_name or not dataset: return []
+
cache_key = f"problems_{model_name}_{dataset}"
- if not hasattr(self, '_cache'): self._cache = {}
if cache_key in self._cache: return self._cache[cache_key]
-
- cursor = self.conn.cursor()
+
try:
- # 优化查询:使用索引提示和优化JOIN策略
+ cursor = self.conn.cursor()
+ # No INDEXED BY hint, rely on indices on responses and problems tables
+ # Ensure AVG returns 0 if no correct responses, not NULL -> COALESCE(AVG(r.correctness), 0.0)
cursor.execute("""
- SELECT DISTINCT r.unique_id, p.problem, AVG(r.correctness) as accuracy
- FROM responses r INDEXED BY idx_responses_model_dataset
- JOIN problems p INDEXED BY idx_problems_unique_id ON r.unique_id = p.unique_id
+ SELECT r.unique_id, p.problem, COALESCE(AVG(r.correctness), 0.0) as accuracy
+ FROM responses r
+ JOIN problems p ON r.unique_id = p.unique_id
WHERE r.model_name = ? AND r.dataset = ?
- GROUP BY r.unique_id ORDER BY r.unique_id
+ GROUP BY r.unique_id, p.problem ORDER BY r.unique_id
""", (model_name, dataset.lower()))
- results = [(row['unique_id'], row['accuracy'] if row['accuracy'] is not None else 0.0, row['problem']) for row in cursor.fetchall()]
-
- # Sort by the integer part of unique_id
- sorted_results = sorted(results, key=lambda x: int(re.search(r'\d+', x[0]).group(0)) if re.search(r'\d+', x[0]) else 0)
- self._cache[cache_key] = sorted_results
+ # Fetchall directly
+ results = [(row['unique_id'], row['accuracy'], row['problem']) for row in cursor.fetchall()]
+
+ # Sort in Python - pre-compile regex for slight speedup
+ id_extractor = re.compile(r'\d+')
+ def get_sort_key(problem_tuple):
+ match = id_extractor.search(problem_tuple[0]) # problem_tuple[0] is unique_id
+ # Handle cases where ID might not have numbers gracefully
+ return int(match.group(0)) if match else 0
+
+ # Sort the results list using the defined key
+ sorted_results = sorted(results, key=get_sort_key)
+
+ self._cache[cache_key] = sorted_results # Cache the sorted list
return sorted_results
- except sqlite3.OperationalError:
+ except sqlite3.Error as e:
+ print(f"Database error in get_problems_by_model_dataset({model_name}, {dataset}): {e}")
return []
+ except Exception as e: # Catch potential errors during sorting
+ print(f"Error processing/sorting problems for {model_name}, {dataset}: {e}")
+ return []
+
def get_problem_data(self, model_name, dataset, problem_id):
- """获取问题和响应数据 (采用局部缓存策略)"""
+ """获取问题和响应数据 (using in-memory DB and cache)"""
+ if not self.conn: return None, None
+ # Sanitize inputs
if hasattr(model_name, 'value'): model_name = model_name.value
if hasattr(dataset, 'value'): dataset = dataset.value
if hasattr(problem_id, 'value'): problem_id = problem_id.value
-
- # 问题数据缓存 - 问题数据通常不会变化,可长期缓存
+ if not dataset or not problem_id: return None, None # Need dataset and problem_id
+
+ # Problem data cache check
problem_cache_key = f"problem_{problem_id}"
- if problem_cache_key in self._problem_cache:
- problem = self._problem_cache[problem_cache_key]
- else:
- if not self.conn:
- return None, None
-
+ problem = self._problem_cache.get(problem_cache_key)
+
+ if problem is None: # Not in cache, fetch from DB
try:
cursor = self.conn.cursor()
+ # Query uses index idx_problems_unique_id automatically
cursor.execute("SELECT * FROM problems WHERE unique_id = ?", (problem_id,))
- problem = cursor.fetchone()
- if problem:
- # 转为字典存储,避免SQLite连接依赖
- self._problem_cache[problem_cache_key] = dict(problem)
- problem = self._problem_cache[problem_cache_key]
- except Exception:
- return None, None
-
- if not problem:
- return None, None
-
- # 响应数据缓存 - 更细粒度的缓存键
- if model_name:
+ problem_row = cursor.fetchone()
+ if problem_row:
+ problem = dict(problem_row) # Convert to dict for caching
+ self._problem_cache[problem_cache_key] = problem
+ else:
+ print(f"Problem not found in DB: {problem_id}")
+ return None, None # Problem ID does not exist in the database
+ except sqlite3.Error as e:
+ print(f"Database error fetching problem {problem_id}: {e}")
+ return None, None # Return None if problem fetch fails
+
+ # If problem is still None here, it wasn't found in DB or cache
+ if problem is None:
+ return None, None
+
+ # --- Response data fetching ---
+ responses = None
+ if model_name: # Fetch for a specific model
resp_cache_key = f"responses_{model_name}_{dataset}_{problem_id}"
if resp_cache_key in self._response_cache:
- return problem, self._response_cache[resp_cache_key]
-
- if not self.conn:
- return problem, None
-
- # 获取特定模型的响应
- try:
- cursor = self.conn.cursor()
- cursor.execute("""
- SELECT * FROM responses
- WHERE model_name = ? AND dataset = ? AND unique_id = ?
- ORDER BY response_id
- """, (model_name, dataset.lower(), problem_id))
- responses = cursor.fetchall()
-
- # 转换为字典列表存储
- if responses:
- responses = [dict(r) for r in responses]
- self._response_cache[resp_cache_key] = responses
- return problem, responses
- except Exception:
- return problem, None
- else:
- # 获取所有模型对此问题的响应
+ responses = self._response_cache[resp_cache_key]
+ else:
+ try:
+ cursor = self.conn.cursor()
+ # Query uses indices idx_responses_model_dataset and idx_responses_unique_id
+ cursor.execute("""
+ SELECT * FROM responses
+ WHERE model_name = ? AND dataset = ? AND unique_id = ?
+ ORDER BY response_id
+ """, (model_name, dataset.lower(), problem_id))
+ response_rows = cursor.fetchall()
+ # Convert rows to list of dicts for easier handling and caching
+ responses = [dict(r) for r in response_rows] if response_rows else []
+ self._response_cache[resp_cache_key] = responses # Cache the result (even if empty)
+ except sqlite3.Error as e:
+ print(f"DB error fetching responses for model {model_name}, dataset {dataset}, problem {problem_id}: {e}")
+ responses = None # Indicate error fetching responses
+ else: # Fetch for all models for this problem
resp_cache_key = f"all_responses_{dataset}_{problem_id}"
if resp_cache_key in self._response_cache:
- return problem, self._response_cache[resp_cache_key]
-
- if not self.conn:
- return problem, None
-
- try:
- cursor = self.conn.cursor()
- cursor.execute("""
- SELECT * FROM responses
- WHERE dataset = ? AND unique_id = ?
- ORDER BY model_name, response_id
- """, (dataset.lower(), problem_id))
- responses = cursor.fetchall()
-
- # 转换为字典列表存储
- if responses:
- responses = [dict(r) for r in responses]
- self._response_cache[resp_cache_key] = responses
- return problem, responses
- except Exception:
- return problem, None
+ responses = self._response_cache[resp_cache_key]
+ else:
+ try:
+ cursor = self.conn.cursor()
+ # Query uses indices idx_responses_dataset and idx_responses_unique_id
+ # Need to create idx_responses_dataset if not exists, or rely on model_dataset index scan
+ # Let's add CREATE INDEX IF NOT EXISTS idx_responses_dataset ON responses(dataset); in _ensure_indices
+ # --> Added index idx_responses_model_dataset which covers (dataset, unique_id) lookups too.
+ cursor.execute("""
+ SELECT * FROM responses
+ WHERE dataset = ? AND unique_id = ?
+ ORDER BY model_name, response_id
+ """, (dataset.lower(), problem_id))
+ response_rows = cursor.fetchall()
+ responses = [dict(r) for r in response_rows] if response_rows else []
+ self._response_cache[resp_cache_key] = responses # Cache the result
+ except sqlite3.Error as e:
+ print(f"DB error fetching all responses for dataset {dataset}, problem {problem_id}: {e}")
+ responses = None # Indicate error
+
+ return problem, responses
+
def get_model_responses(self, selected_models, dataset, problem_id):
- """获取多个模型对特定问题的响应(优化版本)"""
+ """获取多个模型对特定问题的响应 (optimized for in-memory)"""
+ if not self.conn: return None, {}
+ # Sanitize inputs
if hasattr(dataset, 'value'): dataset = dataset.value
if hasattr(problem_id, 'value'): problem_id = problem_id.value
- if not selected_models or not dataset or not problem_id:
+ if not selected_models or not dataset or not problem_id:
return None, {}
- # 获取问题数据 - 可共享缓存
+ # Get problem data first (uses cache/fast in-memory lookup)
problem, _ = self.get_problem_data(None, dataset, problem_id)
- if not problem:
+ if not problem:
+ print(f"Problem data not found for {problem_id} in get_model_responses")
return None, {}
-
+
model_responses_data = {}
+ # Get the *real* model names from the display names using the stored map
+ real_model_names_map = {} # Map display name -> real name
+ real_names_list = []
for model_display in selected_models:
model_display_val = model_display.value if hasattr(model_display, 'value') else model_display
- # 从显示名称中获取真实模型名称
- model = self.comp_model_display_to_real.get(model_display_val, model_display_val)
-
- _, responses_for_model = self.get_problem_data(model, dataset, problem_id)
- if responses_for_model:
- # 尝试找到正确的响应,否则使用第一个
- correct_resp = next((r for r in responses_for_model if r['correctness'] == 1), None)
- model_responses_data[model_display_val] = correct_resp if correct_resp else responses_for_model[0]
- else:
- model_responses_data[model_display_val] = None
-
+ # Use comp_model_display_to_real if available, otherwise model_display_to_real
+ real_name = self.comp_model_display_to_real.get(model_display_val) or self.model_display_to_real.get(model_display_val)
+ # Fallback if map lookup fails (try parsing)
+ if not real_name:
+ raw_name_part = model_display_val.split(" (")[0]
+ # Reverse lookup MODEL_TRANS
+ for db_name, display_lookup in MODEL_TRANS.items():
+ if display_lookup == raw_name_part:
+ real_name = db_name
+ break
+ if not real_name: # If still not found, assume display name *is* real name (less accuracy suffix)
+ real_name = raw_name_part
+ print(f"Warning: Could not map display name '{model_display_val}' to real name via maps. Using inferred '{real_name}'.")
+
+ if real_name: # Ensure we have a name to query
+ real_model_names_map[model_display_val] = real_name
+ if real_name not in real_names_list: # Avoid duplicates in IN clause
+ real_names_list.append(real_name)
+
+ if not real_names_list:
+ print("No valid real model names found to query.")
+ return problem, {} # Return problem data but empty responses
+
+ # Optimized: Fetch all relevant responses in a single query
+ try:
+ cursor = self.conn.cursor()
+ placeholders = ','.join('?' * len(real_names_list))
+ query = f"""
+ SELECT * FROM responses
+ WHERE model_name IN ({placeholders}) AND dataset = ? AND unique_id = ?
+ ORDER BY model_name, correctness DESC, response_id -- Prioritize correct responses
+ """
+ params = real_names_list + [dataset.lower(), problem_id]
+ cursor.execute(query, params)
+ all_fetched_responses = cursor.fetchall()
+
+ # Group responses by *real* model name, keeping only the best (correct first, then by ID)
+ responses_by_real_model = {}
+ for resp_row in all_fetched_responses:
+ resp_dict = dict(resp_row)
+ model = resp_dict['model_name']
+ if model not in responses_by_real_model: # Only store the first one encountered (due to ORDER BY)
+ responses_by_real_model[model] = resp_dict
+
+ # Populate the result dictionary using display names as keys
+ for display_name, real_name in real_model_names_map.items():
+ model_responses_data[display_name] = responses_by_real_model.get(real_name) # Will be None if no response found
+
+ except sqlite3.Error as e:
+ print(f"Database error in bulk get_model_responses: {e}. Falling back to individual fetches.")
+ # Fallback to individual fetching using get_problem_data (which uses cache)
+ for display_name, real_name in real_model_names_map.items():
+ _ , responses_for_model = self.get_problem_data(real_name, dataset, problem_id)
+ if responses_for_model:
+ # Find correct one first, otherwise take first response
+ correct_resp = next((r for r in responses_for_model if r.get('correctness') == 1), None)
+ model_responses_data[display_name] = correct_resp if correct_resp else responses_for_model[0]
+ else:
+ model_responses_data[display_name] = None
+
return problem, model_responses_data
+
def clear_cache(self, section=None):
- """清除指定部分或全部缓存"""
+ """Clear specified cache sections."""
+ print(f"Clearing cache section: {section if section else 'All'}")
+ cleared_something = False
if section == 'main' or section is None:
- self._cache = {}
+ if self._cache:
+ count = len(self._cache)
+ self._cache = {}
+ print(f"Cleared main cache ({count} items).")
+ cleared_something = True
if section == 'problem' or section is None:
- self._problem_cache = {}
+ if self._problem_cache:
+ count = len(self._problem_cache)
+ self._problem_cache = {}
+ print(f"Cleared problem cache ({count} items).")
+ cleared_something = True
if section == 'response' or section is None:
- self._response_cache = {}
+ if self._response_cache:
+ count = len(self._response_cache)
+ self._response_cache = {}
+ print(f"Cleared response cache ({count} items).")
+ cleared_something = True
+ # Clear model/dataset list caches
if section == 'models' or section is None:
- if hasattr(self, '_models_cache'):
+ if hasattr(self, '_models_cache') and self._models_cache is not None:
self._models_cache = None
- if hasattr(self, '_datasets_cache'):
+ print("Cleared models list cache.")
+ cleared_something = True
+ if hasattr(self, '_datasets_cache') and self._datasets_cache is not None:
self._datasets_cache = None
-
+ print("Cleared datasets list cache.")
+ cleared_something = True
+
+ if cleared_something:
+ print("Running garbage collection...")
+ gc.collect() # Explicitly trigger garbage collection
+ else:
+ print("Cache section(s) already empty or invalid section specified.")
+
+
def close(self):
- """关闭数据库连接并释放资源"""
+ """Close the database connection."""
+ print("Closing database connection...")
if hasattr(self, 'conn') and self.conn:
try:
+ # Optional: Backup in-memory changes to disk if needed (not in this scenario)
+ # Optional: Run final pragmas like optimize before closing if desired
+ # self.conn.execute("PRAGMA optimize;")
self.conn.close()
- except Exception:
- pass
-
- # 清理所有缓存
+ self.conn = None # Ensure the attribute is None after closing
+ print("In-memory database connection closed.")
+ except sqlite3.Error as e:
+ print(f"Error closing database connection: {e}")
+ else:
+ print("Database connection already closed or never established.")
+ # Clear caches on close as well
self.clear_cache()
+
+# <<< Keep helper functions: format_latex, format_markdown_with_math, >>>
+# <<< get_gradient_color, get_contrasting_text_color, format_sample_metadata, >>>
+# <<< format_sample_response >>>
def format_latex(text):
if text is None: return ""
- # Process the text for proper LaTeX rendering with KaTeX
- # KaTeX requires LaTeX backslashes to be preserved
- # Only replace newlines with HTML breaks
text = text.replace('\n', ' ')
- # Wrap in a span that KaTeX can detect and render
return f'{text}'
def format_markdown_with_math(text):
if text is None: return ""
-
- # Don't add HTML tags or do special processing for LaTeX - let Gradio handle it
- # Just clean up basic issues that might affect rendering
-
- # Convert newlines for markdown
text = text.replace('\r\n', '\n').replace('\r', '\n')
-
- # Return the cleaned text for Gradio's markdown component to render
+ # Ensure math delimiters are properly handled by Gradio's Markdown component
+ # No need for complex regex if Gradio handles $, $$, \(, \), \[, \]
return text
def get_gradient_color(accuracy, color_map='RdYlGn'):
- if accuracy is None or not isinstance(accuracy, (int, float)):
- return "#505050" # Default for missing or invalid accuracy
+ if accuracy is None or not isinstance(accuracy, (int, float)) or not (0.0 <= accuracy <= 1.0):
+ return "#808080" # Use gray for invalid/missing accuracy
try:
- # 使用更深的颜色映射
+ # Use the specified colormap
cmap = plt.colormaps.get_cmap(color_map)
- rgba = cmap(float(accuracy))
-
- # 确保颜色足够深以与白色文本形成对比
- r, g, b, a = rgba
- # 降低颜色亮度,确保文本可读性
- r = r * 0.7
- g = g * 0.7
- b = b * 0.7
-
- # 转回十六进制
- hex_color = mpl.colors.rgb2hex((r, g, b, a))
+ # Apply a power transform to make colors darker/more distinct, especially greens
+ # Power < 1 darkens low values more, Power > 1 darkens high values more
+ power_adjust = 0.7
+ rgba = cmap(accuracy ** power_adjust)
+ # Convert RGBA to Hex
+ hex_color = mpl.colors.rgb2hex(rgba)
return hex_color
- except Exception:
- return "#505050"
-
-def get_contrasting_text_color(bg_color):
- """计算最佳对比文本颜色"""
- # 如果背景是十六进制格式,转换为RGB
- if bg_color.startswith('#'):
- r = int(bg_color[1:3], 16)
- g = int(bg_color[3:5], 16)
- b = int(bg_color[5:7], 16)
- else:
- # 未知格式默认返回黑色
- return "#000"
-
- # 计算YIQ亮度值 - 更精确地表示人眼对亮度的感知
- yiq = (r * 299 + g * 587 + b * 114) / 1000
-
- # 黄色检测 - 黄色通常R和G高,B低
- is_yellow = r > 200 and g > 200 and b < 150
-
- # 浅绿色检测 - 通常G高,R中等,B低
- is_light_green = g > 200 and r > 100 and r < 180 and b < 150
-
- # 米色/浅棕色检测 - R高,G中高,B低
- is_beige = r > 220 and g > 160 and g < 220 and b < 160
-
- # 强制这些特定颜色使用黑色文本
- if is_yellow or is_light_green or is_beige:
- return "#000"
-
- # 其他颜色根据亮度决定
- return "#000" if yiq > 160 else "#fff"
+ except Exception as e:
+ print(f"Error generating gradient color for accuracy {accuracy}: {e}")
+ return "#808080" # Fallback gray
+
+def get_contrasting_text_color(bg_color_hex):
+ """Calculate contrasting text color (black or white) for a given hex background."""
+ try:
+ if not bg_color_hex or not bg_color_hex.startswith('#') or len(bg_color_hex) != 7:
+ return "#000000" # Default to black for invalid input
+
+ # Convert hex to RGB
+ r = int(bg_color_hex[1:3], 16)
+ g = int(bg_color_hex[3:5], 16)
+ b = int(bg_color_hex[5:7], 16)
+
+ # Calculate luminance using the WCAG formula (more accurate than YIQ for accessibility)
+ # Normalize RGB values to 0-1 range
+ rgb = [val / 255.0 for val in (r, g, b)]
+ # Apply gamma correction approximation
+ rgb_corrected = [((val / 12.92) if val <= 0.03928 else ((val + 0.055) / 1.055) ** 2.4) for val in rgb]
+ # Calculate relative luminance
+ luminance = 0.2126 * rgb_corrected[0] + 0.7152 * rgb_corrected[1] + 0.0722 * rgb_corrected[2]
+
+ # WCAG contrast ratio threshold is complex, but a luminance threshold works well for black/white text
+ # Threshold of 0.179 is often cited, but empirical testing might be needed
+ # Let's use a slightly higher threshold towards white text for better readability on mid-tones
+ return "#000000" if luminance > 0.22 else "#FFFFFF" # Black text on lighter backgrounds, White on darker
+
+ except Exception as e:
+ print(f"Error calculating contrasting color for {bg_color_hex}: {e}")
+ return "#000000" # Default to black on error
+
def format_sample_metadata(sample, show_correctness=True):
- """生成样本元数据的HTML格式显示"""
- if sample is None: return ""
- sample_dict = dict(sample) if hasattr(sample, 'keys') else sample if isinstance(sample, dict) else {}
- if not sample_dict: return "No sample data"
+ """Generates HTML for sample metadata display."""
+ if sample is None: return "
No sample data provided.
"
+ # Ensure sample is a dictionary
+ sample_dict = dict(sample) if hasattr(sample, 'keys') else {}
+ if not sample_dict: return "
"
-
- # 创建信息行
+ correctness = sample_dict.get('correctness', None) # Keep None distinct from False/0
+ output_tokens = sample_dict.get('output_tokens')
+ reasoning_tokens = sample_dict.get('reasoning_tokens')
+
+ # Correctness display
+ if correctness == 1:
+ correctness_label = "✓ Correct"
+ correctness_color = "var(--color-acc-green, #28a745)" # Use CSS variable with fallback
+ elif correctness == 0:
+ correctness_label = "✗ Incorrect"
+ correctness_color = "var(--color-acc-red, #dc3545)"
+ else: # None or other values
+ correctness_label = "? Unknown"
+ correctness_color = "var(--color-acc-grey, #6c757d)"
+
+ # Build HTML using f-string for clarity
+ html = f"""
+
+
""" # Use flexbox for alignment
+
if show_correctness:
- html += f"
"
- # 正确性指示器
- html += f"{correctness_label}"
-
- # 提取的答案
- if extracted:
- html += f"Extracted: ${extracted}$"
-
- # 输出token数
- if output_tokens is not None:
- html += f"Output Tokens: {output_tokens}"
-
- # 推理token数 - 仅在可用时
- if reasoning_tokens is not None:
- html += f"Reasoning Tokens: {reasoning_tokens}"
-
- html += f"
"
-
- html += "
"
+ html += f"{correctness_label}"
+
+ if extracted:
+ # Basic escaping for extracted answer to prevent HTML injection if it contains < >
+ extracted_safe = extracted.replace('<', '<').replace('>', '>')
+ # Wrap in $ for math rendering by Gradio/MathJax if appropriate, else just display
+ # Heuristic: Render as math if it looks like a number, fraction, or simple expression
+ if re.match(r'^-?\d+(\.\d+)?(/(-?\d+(\.\d+)?))?$', extracted_safe.strip()):
+ extracted_display = f"${extracted_safe}$"
+ else:
+ extracted_display = extracted_safe # Display as is if complex
+ html += f"Extracted: {extracted_display}"
+
+ if output_tokens is not None:
+ html += f"Output Tokens: {output_tokens}"
+
+ if reasoning_tokens is not None:
+ html += f"Reasoning Tokens: {reasoning_tokens}"
+
+ html += """
+
+
+
+ """
return html
+
def format_sample_response(sample):
- """生成样本响应的Markdown格式显示"""
- if sample is None: return ""
- sample_dict = dict(sample) if hasattr(sample, 'keys') else sample if isinstance(sample, dict) else {}
- if not sample_dict: return "No sample data"
-
- # 获取响应内容
+ """Generates Markdown-compatible string for the sample response."""
+ if sample is None: return "No response data."
+ sample_dict = dict(sample) if hasattr(sample, 'keys') else {}
+ if not sample_dict: return "Empty response data."
+
response = sample_dict.get('response', '')
-
- # 转义特殊标签以防止被解析为HTML
- # 替换标签
- response = response.replace("", "<think>")
- response = response.replace("", "</think>")
-
- # 替换其他可能的特殊标签
- response = response.replace("", "<reasoning>")
- response = response.replace("", "</reasoning>")
- response = response.replace("", "<answer>")
- response = response.replace("", "</answer>")
-
+ if not response: return "*(Empty Response)*"
+
+ # Escape HTML tags that might interfere with Markdown rendering
+ # Focus on < > & characters.
+ response = response.replace('&', '&')
+ response = response.replace('<', '<')
+ response = response.replace('>', '>')
+
+ # Gradio's Markdown component should handle LaTeX delimiters like $...$, $$...$$
+ # No need for manual replacement here if delimiters are set correctly in gr.Markdown
return response
-def handle_sample_select(sample_number, samples_data):
- # 确保从Gradio State对象中提取实际值
- if hasattr(samples_data, 'value'):
- samples_list = samples_data.value
- else:
- samples_list = samples_data
-
- # 确保样本编号是整数
+
+# <<< Keep handler functions: handle_sample_select, handle_first_sample, >>>
+# <<< handle_comparison_problem_update, handle_problem_select >>>
+# <<< Ensure they use the global db instance and handle potential None values >>>
+# <<< And use the updated format_ functions >>>
+
+def handle_sample_select(sample_number_str, samples_data_state):
+ """Handles selection of a specific sample index."""
+ # Extract list from state
+ samples_list = samples_data_state if isinstance(samples_data_state, list) else (samples_data_state.value if hasattr(samples_data_state, 'value') else [])
+
+ if not samples_list or not isinstance(samples_list, list):
+ err_msg = "**Error:** No sample data available or invalid format."
+ return err_msg, "" # Return error for metadata, empty for response
+
try:
- sample_idx = int(sample_number)
- except ValueError:
- return "Error: Sample number must be an integer.", ""
-
- # 确保样本数据存在且为非空列表
- if not samples_list or not isinstance(samples_list, list) or len(samples_list) == 0:
- return "No sample data available. Please select a problem first.", ""
-
- # 检查索引是否在有效范围内,如果不在范围内,显示错误消息
- if sample_idx < 0:
- err_msg = f"**Error:** Sample number {sample_idx} is out of range. Valid range is 0 to {len(samples_list) - 1}."
+ sample_idx = int(sample_number_str) # Convert input string to int
+ except (ValueError, TypeError):
+ err_msg = f"**Error:** Invalid sample number '{sample_number_str}'. Must be an integer."
return err_msg, ""
-
- if sample_idx >= len(samples_list):
- err_msg = f"**Error:** Sample number {sample_idx} is out of range. Valid range is 0 to {len(samples_list) - 1}."
+
+ if not (0 <= sample_idx < len(samples_list)):
+ err_msg = f"**Error:** Sample index {sample_idx} out of range (0 to {len(samples_list) - 1})."
return err_msg, ""
-
- # 获取所选样本的数据
+
try:
- sample = samples_list[sample_idx]
- formatted_metadata = format_sample_metadata(sample)
- formatted_response = format_sample_response(sample)
+ selected_sample = samples_list[sample_idx]
+ # Ensure the selected sample is a dict before formatting
+ if not isinstance(selected_sample, dict):
+ selected_sample = dict(selected_sample) if hasattr(selected_sample, 'keys') else {}
+
+ formatted_metadata = format_sample_metadata(selected_sample)
+ formatted_response = format_sample_response(selected_sample)
return formatted_metadata, formatted_response
except Exception as e:
+ print(f"Error formatting sample {sample_idx}: {e}")
err_msg = f"**Error displaying sample {sample_idx}:** {str(e)}"
- return err_msg, ""
+ # Return error message in metadata, keep response empty
+ return f"
{err_msg}
", ""
-def handle_first_sample(samples_data):
- """处理并显示第一个样本(索引0)"""
- # 确保从Gradio State对象中提取实际值
- if hasattr(samples_data, 'value'):
- samples_list = samples_data.value
+def handle_first_sample(samples_data_state):
+ """Handles displaying the first sample (index 0) from the state."""
+ # Delegate to handle_sample_select with index 0
+ # Provide default empty display if no samples
+ samples_list = samples_data_state if isinstance(samples_data_state, list) else (samples_data_state.value if hasattr(samples_data_state, 'value') else [])
+ if not samples_list:
+ return format_sample_metadata(None), format_sample_response(None) # Display "No data" messages
else:
- samples_list = samples_data
-
- # 检查样本数据是否存在
- if not samples_list or not isinstance(samples_list, list) or len(samples_list) == 0:
- return "No sample data available. Please select the problem and dataset first.", ""
-
- # 直接获取第一个样本,避免错误处理逻辑
- try:
- sample = samples_list[0]
- formatted_metadata = format_sample_metadata(sample)
- formatted_response = format_sample_response(sample)
- return formatted_metadata, formatted_response
- except Exception as e:
- err_msg = f"**Error displaying first sample:** {str(e)}"
- return err_msg, ""
+ # Use the main handler to display sample 0
+ return handle_sample_select("0", samples_data_state)
-def handle_comparison_problem_update(problem_id, dataset_state):
- """处理比较页面的问题更新,仅更新问题和答案内容,不需要模型"""
+def handle_comparison_problem_update(problem_id_state, dataset_state):
+ """Updates only the Problem/Answer display in the comparison tab."""
global db
- # 确保从Gradio State对象中提取实际值
+ if not db or not db.conn: return "Database not initialized.", "Error"
+
dataset_name = dataset_state.value if hasattr(dataset_state, 'value') else dataset_state
- problem_id_value = problem_id.value if hasattr(problem_id, 'value') else problem_id
-
- if not problem_id_value or not dataset_name:
- return "Please select a dataset and enter a problem ID.", "No answer available."
-
- # 处理纯数字输入,构建完整unique_id
- if problem_id_value and problem_id_value.isdigit():
- # 构建格式:OlymMATH-HARD-0-EN 或类似格式
+ problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state
+
+ # Allow entering just the number part of the ID
+ if problem_id and problem_id.isdigit() and dataset_name:
parts = dataset_name.split('-')
- if len(parts) == 2: # 确保格式正确 (例如 "EN-HARD")
+ if len(parts) == 2:
language, difficulty = parts
- # 构建完整ID
- problem_id_value = f"OlymMATH-{difficulty}-{problem_id_value}-{language}"
-
+ problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}"
+ else:
+ print(f"Warning: Cannot reconstruct full ID from number '{problem_id}' and dataset '{dataset_name}'")
+ # Proceed with the entered value, might fail if not a full ID
+
+ if not problem_id or not dataset_name:
+ return "Please select dataset and enter problem ID.", "N/A"
+
try:
- # 只获取问题数据,不获取特定模型的响应
- problem_data, _ = db.get_problem_data(None, dataset_name, problem_id_value)
-
+ # Fetch only problem data, no responses needed here
+ problem_data, _ = db.get_problem_data(None, dataset_name, problem_id)
+
if not problem_data:
- return f"Problem not found: {problem_id_value}. Please check the ID and try again.", "No answer available."
-
- problem_dict = dict(problem_data)
- # Use format_markdown_with_math for proper rendering
- problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
-
- # 将答案中的双美元符号替换为单美元符号
- answer_text = problem_dict.get('answer', '')
- # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行
+ # Check if just the number was entered and failed reconstruction
+ if problem_id.isdigit():
+ return f"Problem number {problem_id} not found for {dataset_name}. Enter full ID or check dataset.", "N/A"
+ else:
+ return f"Problem ID '{problem_id}' not found for {dataset_name}.", "N/A"
+
+ # Ensure problem_data is a dictionary
+ problem_dict = dict(problem_data) if problem_data else {}
+
+ # Format problem statement for Markdown rendering
+ problem_content = format_markdown_with_math(problem_dict.get('problem', '*(Problem text not available)*'))
+
+ # Format answer, handling LaTeX and ensuring $...$
+ answer_text = problem_dict.get('answer', '*(Answer not available)*')
+ # Simplify $$...$$ to $...$
answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
-
- # 检查答案是否已经包含美元符号,如果没有则添加
- if '$' not in answer_text and answer_text.strip():
+ # Add $...$ if missing (basic check)
+ if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*('):
answer_text = f"${answer_text}$"
-
answer_content = format_markdown_with_math(answer_text)
-
+
return problem_content, answer_content
except Exception as e:
- return f"Error: {str(e)}", "No answer available."
+ print(f"Error in handle_comparison_problem_update for {problem_id}, {dataset_name}: {e}")
+ return f"Error fetching problem details: {e}", "Error"
+
-def handle_problem_select(problem_id_from_js, current_model_state, current_dataset_state, mode='default'):
+def handle_problem_select(problem_id_state, current_model_state, current_dataset_state, mode='default'):
+ """Handles problem selection, fetching details, responses, and generating sample grid."""
global db
- # Ensure we're using the actual values from Gradio State objects
+ if not db or not db.conn:
+ return "DB Error.", "DB Error.", gr.HTML("
Database Connection Error
"), gr.State([])
+
+ # --- Get values from Gradio State objects ---
model_name = current_model_state.value if hasattr(current_model_state, 'value') else current_model_state
dataset_name = current_dataset_state.value if hasattr(current_dataset_state, 'value') else current_dataset_state
- problem_id = problem_id_from_js.value if hasattr(problem_id_from_js, 'value') else problem_id_from_js
+ problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state
+
+ # --- Input Validation ---
+ if not dataset_name:
+ return "Please select a dataset first.", "N/A", gr.HTML(""), gr.State([])
+ if not problem_id:
+ return "Problem ID is missing.", "N/A", gr.HTML(""), gr.State([])
+ # Model name is required unless in comparison mode initial load
+ if not model_name and mode != 'comparison_initial_problem_load':
+ # In single mode, model is always required here
+ if mode == 'default':
+ return "Please select a model first.", "N/A", gr.HTML(""), gr.State([])
+ # In comparison mode, if model state is None, means user selected problem before model
+ # We can still show problem/answer, but no samples for that side yet.
+ # Let's handle this by fetching problem/answer but returning empty samples for the specific call.
+ pass # Allow proceeding without model name in comparison mode to fetch problem/answer
- # 处理纯数字输入,构建完整unique_id
- if problem_id and problem_id.isdigit():
- # 构建格式:OlymMATH-HARD-0-EN 或类似格式
- # 从dataset_name (例如 "EN-HARD") 解析语言和难度
+ # --- Reconstruct full problem ID if only number is entered ---
+ original_problem_id = problem_id # Keep original for messages
+ if problem_id.isdigit():
parts = dataset_name.split('-')
- if len(parts) == 2: # 确保格式正确 (例如 "EN-HARD")
+ if len(parts) == 2:
language, difficulty = parts
- # 构建完整ID
problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}"
+ print(f"Reconstructed problem ID: {problem_id}")
+ else:
+ # Cannot reconstruct, use the entered value but it might fail
+ print(f"Warning: Could not reconstruct full ID from number '{problem_id}' and dataset '{dataset_name}'")
- if not problem_id or not dataset_name:
- error_message = f"Missing data: problem_id='{problem_id}', dataset='{dataset_name}'"
- return "Please fill in all the fields.", "No answer available.", "", gr.State([])
-
- # For comparison mode, we might not have a model selected yet
- if not model_name and mode == 'comparison':
- try:
- # Just get the problem data without model-specific responses
- problem_data, _ = db.get_problem_data(None, dataset_name, problem_id)
-
- if not problem_data:
- error_message = f"Problem data not found: problem_id='{problem_id}', dataset='{dataset_name}'"
- return f"Problem not found: {problem_id}. Please check the ID and try again.", "No answer available.", "", gr.State([])
-
- problem_dict = dict(problem_data)
- # Process problem and answer text for Markdown rendering
- problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
-
- # 将答案中的双美元符号替换为单美元符号
- answer_text = problem_dict.get('answer', '')
- # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行
- answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
-
- # 检查答案是否已经包含美元符号,如果没有则添加
- if '$' not in answer_text and answer_text.strip():
- answer_text = f"${answer_text}$"
-
- answer_content = format_markdown_with_math(answer_text)
-
- # For comparison without model, we don't have samples to display
- return problem_content, answer_content, "", gr.State([])
- except Exception as e:
- error_message = f"Database error: {str(e)}"
- return f"Database error occurred. Please try again.", "No answer available.", "", gr.State([])
-
- # The regular flow for model-specific data
- if not model_name:
- error_message = f"Missing data: model='{model_name}'"
- return "Please fill in all the fields.", "No answer available.", "", gr.State([])
- # The problem_id from JS should be the full unique_id. No reconstruction needed normally.
+ # --- Fetch Data ---
try:
+ # Fetch problem details and responses for the specific model (if provided)
problem_data, responses_data = db.get_problem_data(model_name, dataset_name, problem_id)
-
+
if not problem_data:
- error_message = f"Problem data not found: problem_id='{problem_id}', model='{model_name}', dataset='{dataset_name}'"
- return f"Problem not found: {problem_id}. Please check the ID and try again.", "No answer available.", "", gr.State([])
- except Exception as e:
- error_message = f"Database error: {str(e)}"
- return f"Database error occurred. Please try again.", "No answer available.", "", gr.State([])
-
- problem_dict = dict(problem_data)
- problem_display_num = re.search(r'\d+', problem_id).group(0) if re.search(r'\d+', problem_id) else problem_id
-
- # Process problem and answer text for Markdown rendering
- problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
-
- # 将答案中的双美元符号替换为单美元符号
- answer_text = problem_dict.get('answer', '')
- # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行
- answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
-
- # 检查答案是否已经包含美元符号,如果没有则添加
- if '$' not in answer_text and answer_text.strip():
- answer_text = f"${answer_text}$"
-
- answer_content = format_markdown_with_math(answer_text)
-
- # Rest of the function remains the same
- if not responses_data:
- samples_grid_html = "
No samples available for this problem.
"
- # 返回空的样本数据状态
- return problem_content, answer_content, samples_grid_html, gr.State([])
- else:
- # 准备所有样本数据,用于后续处理
- samples_data = []
- for i, resp in enumerate(responses_data):
- resp_dict = dict(resp)
- samples_data.append(resp_dict)
-
- # 计算正确率
- correct_count = sum(1 for r in samples_data if r['correctness'])
- total_samples = len(samples_data)
- accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0
-
- # 创建样本网格显示 (最多显示 64 个样本)
- displayed_samples = samples_data[:64]
- actual_display_count = len(displayed_samples)
-
- # 根据模式确定每行的样本数
- samples_per_row = 16 if mode == 'comparison' else 32
-
- # 第一行: 样本 0-samples_per_row
- samples_grid_html = f'
'
-
- for i, resp in enumerate(displayed_samples[:samples_per_row]):
- correctness = resp.get('correctness', 0)
- bg_color = get_gradient_color(1.0 if correctness else 0.0)
-
- # 移除点击事件和data属性,只保留纯显示
- samples_grid_html += f"""
-
- {i}
-
- """
-
- # 如果少于samples_per_row个样本,填充剩余空间
- for i in range(min(actual_display_count, samples_per_row), samples_per_row):
- samples_grid_html += f"""
-
- """
-
- samples_grid_html += '
'
-
- # 如果有更多样本,显示第二行
- if actual_display_count > samples_per_row:
- row_samples = displayed_samples[samples_per_row:2*samples_per_row]
- samples_grid_html += f'
'
-
- for i, resp in enumerate(row_samples):
- actual_idx = i + samples_per_row
- correctness = resp.get('correctness', 0)
- bg_color = get_gradient_color(1.0 if correctness else 0.0)
-
- samples_grid_html += f"""
-
- {actual_idx}
-
- """
-
- # 填充剩余空间
- for i in range(len(row_samples), samples_per_row):
- samples_grid_html += f"""
-
- """
-
- samples_grid_html += '
'
-
- # 第三行和第四行 - 允许所有模式显示完整的64个样本
- if actual_display_count > 2*samples_per_row:
- # 第三行
- row_samples = displayed_samples[2*samples_per_row:3*samples_per_row]
- if row_samples:
- samples_grid_html += f'
'
-
+ return f"Problem ID '{original_problem_id}' not found for dataset '{dataset_name}'.", "N/A", gr.HTML(f"
Problem ID '{original_problem_id}' not found.
"), gr.State([])
+
+ # Ensure problem_data is a dict
+ problem_dict = dict(problem_data) if problem_data else {}
+
+ # --- Format Problem and Answer ---
+ problem_content = format_markdown_with_math(problem_dict.get('problem', '*(Problem text not available)*'))
+ answer_text = problem_dict.get('answer', '*(Answer not available)*')
+ answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
+ if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*('):
+ answer_text = f"${answer_text}$"
+ answer_content = format_markdown_with_math(answer_text)
+
+ # --- Handle Responses and Generate Sample Grid ---
+ if responses_data is None: # Indicates a DB error occurred fetching responses
+ samples_grid_html = gr.HTML("
Error fetching model responses.
")
+ samples_data_for_state = gr.State([]) # Empty state on error
+ elif not responses_data: # Empty list means no responses found
+ samples_grid_html = gr.HTML("
No responses found for this model on this problem.
")
+ samples_data_for_state = gr.State([]) # Empty state
+ else:
+ # responses_data should already be a list of dicts from get_problem_data
+ samples_data = responses_data # Use directly
+ samples_data_for_state = gr.State(samples_data) # Store the list in state
+
+ correct_count = sum(1 for r in samples_data if r.get('correctness') == 1)
+ total_samples = len(samples_data)
+ accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0
+
+ # --- Generate Sample Grid HTML (with onclick) ---
+ displayed_samples = samples_data[:64] # Limit display
+ actual_display_count = len(displayed_samples)
+ # Determine grid columns based on mode
+ samples_per_row = 16 if mode.startswith('comparison') else 32
+ num_rows = math.ceil(actual_display_count / samples_per_row)
+ grid_html_content = ""
+
+ # Determine the correct JS function call based on mode
+ js_mode = "'default'" # Default for single model tab
+ if mode == 'comparison_left': js_mode = "'comparison_left'"
+ elif mode == 'comparison_right': js_mode = "'comparison_right'"
+
+ for row_idx in range(num_rows):
+ grid_html_content += f'
'
+ start_idx = row_idx * samples_per_row
+ end_idx = min(start_idx + samples_per_row, actual_display_count)
+ row_samples = displayed_samples[start_idx:end_idx]
+
for i, resp in enumerate(row_samples):
- actual_idx = i + 2*samples_per_row
- correctness = resp.get('correctness', 0)
- bg_color = get_gradient_color(1.0 if correctness else 0.0)
-
- samples_grid_html += f"""
-
- {actual_idx}
-
- """
-
- # 填充剩余空间
- for i in range(len(row_samples), samples_per_row):
- samples_grid_html += f"""
-
+ actual_idx = start_idx + i
+ correctness = resp.get('correctness', None) # Handle None correctness
+ # Get background color based on correctness (1=green, 0=red, None=grey)
+ if correctness == 1: bg_color = get_gradient_color(1.0)
+ elif correctness == 0: bg_color = get_gradient_color(0.0)
+ else: bg_color = "#808080" # Grey for unknown
+ text_color = get_contrasting_text_color(bg_color)
+
+ # Add onclick event to call the JavaScript handler
+ grid_html_content += f"""
+
"""
-
- samples_grid_html += '
'
-
- # 第四行
- if actual_display_count > 3*samples_per_row:
- row_samples = displayed_samples[3*samples_per_row:4*samples_per_row]
- if row_samples:
- samples_grid_html += f'
'
-
- for i, resp in enumerate(row_samples):
- actual_idx = i + 3*samples_per_row
- correctness = resp.get('correctness', 0)
- bg_color = get_gradient_color(1.0 if correctness else 0.0)
-
- samples_grid_html += f"""
-
- {actual_idx}
-
- """
-
- # 填充剩余空间
- for i in range(len(row_samples), samples_per_row):
- samples_grid_html += f"""
-
- """
-
- samples_grid_html += '
'
-
- # 组合HTML内容
- final_html = f"""
-
-
Samples {actual_display_count} - Model Accuracy: {correct_count}/{actual_display_count} = {accuracy_on_problem:.1%}
- {samples_grid_html}
-
- """
-
- # 获取第一个样本作为初始样本
- if samples_data:
- # 这样样本会在选择问题后立即显示
- return problem_content, answer_content, final_html, gr.State(samples_data)
- else:
- return problem_content, answer_content, final_html, gr.State([])
+ # Fill remaining columns in the row with gray placeholders if needed
+ for _ in range(len(row_samples), samples_per_row):
+ grid_html_content += ""
+ grid_html_content += '
'
+ # Add filler rows if less than the max were generated (e.g., 4 for comparison, 2 for single)
+ max_rows = 4 if mode.startswith('comparison') else 2
+ for _ in range(num_rows, max_rows):
+ grid_html_content += f'
'
+ for _ in range(samples_per_row):
+ grid_html_content += ""
+ grid_html_content += '
'
+
+
+ # Assemble the final HTML for the samples section
+ samples_grid_html = gr.HTML(f"""
+
+
Samples ({actual_display_count} shown) – Model Accuracy: {correct_count}/{total_samples} ({accuracy_on_problem:.1%})
+
{grid_html_content}
+
+
+ """)
+
+
+ # Return all results
+ return problem_content, answer_content, samples_grid_html, samples_data_for_state
+
+ except Exception as e:
+ print(f"Unexpected error in handle_problem_select for {problem_id}, model {model_name}, dataset {dataset_name}: {e}")
+ # Log the full traceback for debugging if possible
+ import traceback
+ traceback.print_exc()
+ error_msg = f"**Internal Error processing problem {original_problem_id}:** {str(e)}"
+ return error_msg, "Error", gr.HTML(f"
{error_msg}
"), gr.State([])
+
+
+# <<< Keep create_problem_grid_html, modified to use onclick >>>
def create_problem_grid_html(problems, mode='default'):
- """Create HTML for problem grid buttons. The JS function will be defined globally."""
+ """Create HTML for problem grid buttons with onclick handlers."""
if not problems:
- return "
No problems found for this model/dataset. Please select a model and dataset.
"
+ return "
No problems found for this model/dataset.
"
html_buttons = ""
+ # Sort problems based on the numeric part of the ID
try:
- sorted_problems = sorted(
- [(str(p[0]), float(p[1]) if p[1] is not None else 0.0, p[2]) for p in problems],
- key=lambda x: int(re.search(r'\d+', x[0]).group(0)) if re.search(r'\d+', x[0]) else 0
- )
+ id_extractor = re.compile(r'\d+')
+ def get_sort_key(p):
+ # p is expected to be a tuple/list like (unique_id, accuracy, problem_text)
+ match = id_extractor.search(str(p[0]))
+ return int(match.group(0)) if match else 0
+
+ # Ensure problems is a list of tuples/lists before sorting
+ if isinstance(problems, list) and all(isinstance(p, (list, tuple)) and len(p) >= 2 for p in problems):
+ # Convert accuracy to float, handle None or potential errors
+ processed_problems = []
+ for p in problems:
+ try:
+ pid = str(p[0])
+ acc = float(p[1]) if p[1] is not None else 0.0
+ processed_problems.append((pid, acc))
+ except (IndexError, TypeError, ValueError) as conv_err:
+ print(f"Skipping problem entry due to conversion error: {p} - {conv_err}")
+ sorted_problems = sorted(processed_problems, key=get_sort_key)
+ else:
+ print(f"Warning: Problem data format unexpected in create_problem_grid_html (mode={mode}). Skipping sort.")
+ # Attempt to process anyway if possible, otherwise return error message
+ if isinstance(problems, list):
+ processed_problems = []
+ for p in problems:
+ try:
+ pid = str(p[0])
+ # Try to get accuracy, default to 0.0 if fails
+ acc = 0.0
+ if len(p) > 1:
+ try:
+ acc = float(p[1]) if p[1] is not None else 0.0
+ except (TypeError, ValueError): pass # Keep acc as 0.0
+ processed_problems.append((pid, acc))
+ except (IndexError, TypeError, ValueError):
+ print(f"Skipping invalid problem entry: {p}")
+ sorted_problems = processed_problems # No sort if format is wrong initially
+ else:
+ return "
",
- elem_classes="sample-metadata dark-mode-bg-secondary",
- elem_id="comp-sample-metadata-area-right"
- )
-
- comp_sample_response_output_right = gr.Markdown(
- value="Select a problem first to view samples.",
- elem_classes="sample-response dark-mode-bg-secondary",
- elem_id="comp-sample-response-area-right",
- latex_delimiters=[
- {"left": "$", "right": "$", "display": False},
- {"left": "$$", "right": "$$", "display": True},
- {"left": "\\(", "right": "\\)", "display": False},
- {"left": "\\[", "right": "\\]", "display": True}
- ]
- )
-
- # --- Event Handlers ---
+ gr.Markdown("##### Selected Sample (Model 2)")
+ comp_sample_metadata_output_right = gr.HTML("Select sample.")
+ comp_sample_response_output_right = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}])
+
+
+ # --- Event Handlers ---
+
def update_available_models_for_dropdowns(selected_dataset):
- # This function can be used to update model lists if they are dataset-dependent
- # For now, assume get_available_models() gets all models irrespective of dataset for dropdown population
- all_models = db.get_available_models()
- # For single model tab, format with accuracy on the selected dataset
- single_model_options = []
- model_to_display_map = {} # 映射用于存储真实模型名称到显示名称的映射
-
+ # Fetches all models, gets accuracies for the selected dataset, formats choices
+ if not db or not db.conn:
+ print("Error: DB not available in update_available_models_for_dropdowns")
+ # Return empty updates for all dropdowns
+ return gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[])
+
+ all_models = db.get_available_models() # Uses cache/in-memory DB
+ model_acc_map = {}
if selected_dataset and all_models:
+ # Fetch accuracies (uses cache/in-memory DB)
model_accs = db.get_all_model_accuracies(selected_dataset)
model_acc_map = {name: acc for name, acc in model_accs}
- single_model_options = []
- for name in all_models:
- # 使用MODEL_TRANS映射模型名称
- display_name = MODEL_TRANS.get(name, name)
- acc_display = f" ({model_acc_map.get(name, 0):.1%})" if model_acc_map.get(name) is not None else ""
- display_text = f"{display_name}{acc_display}"
- single_model_options.append(display_text)
- model_to_display_map[display_text] = name # 存储映射关系
- else:
- for name in all_models:
- display_name = MODEL_TRANS.get(name, name)
- single_model_options.append(display_name)
- model_to_display_map[display_name] = name
-
- # 将映射存储到全局数据库对象中以便后续使用
- db.model_display_to_real = model_to_display_map
-
- # For comparison tab, also use formatted model names with accuracy
- comp_model_choices = single_model_options # 使用和单模型相同的选项,包含准确率
- db.comp_model_display_to_real = model_to_display_map # 使用相同的映射
-
- return gr.Dropdown(choices=single_model_options if single_model_options else [], value=None), \
- gr.Dropdown(choices=comp_model_choices if comp_model_choices else [], value=None)
+
+ display_options = []
+ # Clear previous maps before repopulating
+ db.model_display_to_real = {}
+ db.comp_model_display_to_real = {}
+
+ # Sort models by accuracy (descending), handle missing accuracy (treat as -1)
+ sorted_models = sorted(all_models, key=lambda m: model_acc_map.get(m, -1), reverse=True)
+
+ for name in sorted_models:
+ display_name = MODEL_TRANS.get(name, name) # Use translation map
+ acc = model_acc_map.get(name)
+ # Format accuracy nicely, handle None
+ acc_display = f" ({acc:.1%})" if acc is not None else " (N/A)"
+ display_text = f"{display_name}{acc_display}"
+ display_options.append(display_text)
+ # Store mapping from formatted name back to real name for both contexts
+ db.model_display_to_real[display_text] = name
+ db.comp_model_display_to_real[display_text] = name
+
+ # Return updates for all three dropdowns
+ # Use label argument to set labels dynamically if needed, or keep static
+ return gr.Dropdown(choices=display_options, value=None, label="Select Model (Name + Acc%)", interactive=True), \
+ gr.Dropdown(choices=display_options, value=None, label="Select Model 1", interactive=True), \
+ gr.Dropdown(choices=display_options, value=None, label="Select Model 2", interactive=True)
+
def update_problem_grid_and_stats(selected_model_formatted, selected_dataset, mode='default'):
+ # Fetches stats/problems, returns stats DF, grid HTML, and *real* model name state
+ if not db or not db.conn:
+ print("Error: DB not available in update_problem_grid_and_stats")
+ return gr.DataFrame(value=[]), gr.HTML("
DB Error.
"), None
if not selected_model_formatted or not selected_dataset:
- # Return empty/default values for all outputs, including the state
- return gr.DataFrame(value=[]), gr.HTML("
Please select a model and dataset first.
"), None
-
- # 从映射中获取真实模型名称
- model_name = db.model_display_to_real.get(selected_model_formatted, selected_model_formatted)
- # 如果找不到确切匹配,可能是因为准确率等动态内容导致,尝试前缀匹配
- if model_name == selected_model_formatted:
- for display_name, real_name in db.model_display_to_real.items():
- if selected_model_formatted.startswith(display_name.split(" (")[0]):
- model_name = real_name
+ return gr.DataFrame(value=[]), gr.HTML("Select model and dataset."), None
+
+ # Use the appropriate map (comp or default) to get the real model name
+ real_model_name = None
+ if mode == 'comparison':
+ real_model_name = db.comp_model_display_to_real.get(selected_model_formatted)
+ if not real_model_name: # Fallback to default map or parsing if comp map missed
+ real_model_name = db.model_display_to_real.get(selected_model_formatted)
+
+ # If still not found via map, try parsing (less reliable)
+ if not real_model_name:
+ raw_name_part = selected_model_formatted.split(" (")[0]
+ for db_name, display_lookup in MODEL_TRANS.items():
+ if display_lookup == raw_name_part:
+ real_model_name = db_name
break
-
- stats_data = db.get_model_statistics(model_name, selected_dataset)
- problem_list = db.get_problems_by_model_dataset(model_name, selected_dataset)
+ if not real_model_name: real_model_name = raw_name_part # Assume it's the real name
+ print(f"Warning: Using fallback lookup/parsing for model name: '{selected_model_formatted}' -> '{real_model_name}'")
+
+ if not real_model_name: # Final check if name resolution failed
+ print(f"Error: Could not determine real model name for '{selected_model_formatted}'")
+ return gr.DataFrame(value=[]), gr.HTML("
Internal error resolving model name.
"), None
+
+ # Fetch data using the resolved real model name
+ stats_data = db.get_model_statistics(real_model_name, selected_dataset)
+ problem_list = db.get_problems_by_model_dataset(real_model_name, selected_dataset)
grid_html = create_problem_grid_html(problem_list, mode=mode)
-
- # Correctly return the actual value for the current_model_state output
- return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), model_name
- # Single Model Tab interactions
+ # Return stats DF, grid HTML, and the *real* model name for state update
+ return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), real_model_name
+
+
+ # Helper function to clear problem/sample displays
+ def clear_problem_outputs():
+ return "Select a problem.", "N/A", gr.HTML(""), gr.State([]), "Select a sample.", "*(Response)*", "0"
+
+ # Helper function to clear comparison problem/sample displays for one side
+ def clear_comparison_side_outputs():
+ # Doesn't clear the main problem/answer display, only the side's samples
+ return gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0"
+
+
+ # --- Single Model Tab Event Connections ---
+
+ # Dataset selection changes model dropdown options and clears everything else
dataset_radio_single.change(
fn=update_available_models_for_dropdowns,
inputs=[dataset_radio_single],
- outputs=[model_dropdown, comp_model_dropdown_left]
- ).then(
- lambda ds: (gr.DataFrame(value=[]), gr.HTML("
Select a model.
"), gr.State(value=None), ds, ""), # 清空所有输出,包括problem_state_input
+ outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] # Update all model lists
+ ).then( # Chain: Reset dependent states and UI components
+ lambda ds: (gr.DataFrame(value=[]), gr.HTML("Select model."), None, ds, "", # Clear stats, grid, model state, problem ID state
+ *clear_problem_outputs(), # Clear problem/answer/samples/state/metadata/response/index state
+ ""), # Clear problem input display
inputs=[dataset_radio_single],
- outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_state_input]
- ).then(
- # 重置Sample Number为0
- fn=lambda: "0",
- inputs=[],
- outputs=[sample_number_input]
- ).then(
- lambda: ("Please fill in all the fields.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", ""),
- inputs=[],
- outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output]
+ outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_id_state,
+ problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state,
+ sample_metadata_output, sample_response_output, sample_index_state,
+ problem_id_input_display] # Also clear the visible input box
)
-
- # Initial population of model dropdowns based on default dataset
- demo.load(
- fn=update_available_models_for_dropdowns,
- inputs=[current_dataset_state], # Uses initial value of state
- outputs=[model_dropdown, comp_model_dropdown_left]
- ).then(
- lambda ds_val: (gr.DataFrame(value=[]), gr.HTML("
Select a model.
"), ds_val), # Also update dataset state for single tab
- inputs=[current_dataset_state],
- outputs=[model_stats_df, problem_grid_html_output, current_dataset_state]
- ).then(
- lambda: ("Please fill in all the fields.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", ""),
- inputs=[],
- outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output]
- ).then(
- # 重置Sample Number为0
- fn=lambda: "0",
- inputs=[],
- outputs=[sample_number_input]
+
+ # Model selection updates stats, grid, and clears problem/sample display
+ model_dropdown.change(
+ fn=update_problem_grid_and_stats, # Mode defaults to 'default'
+ inputs=[model_dropdown, current_dataset_state],
+ outputs=[model_stats_df, problem_grid_html_output, current_model_state] # Update stats, grid, REAL model state
+ ).then( # Chain: Clear problem-specific outputs
+ lambda: ("", # Clear problem ID state
+ *clear_problem_outputs(), # Clear problem/answer/samples/state/metadata/response/index state
+ ""), # Clear problem input display
+ inputs=[],
+ outputs=[problem_id_state,
+ problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state,
+ sample_metadata_output, sample_response_output, sample_index_state,
+ problem_id_input_display]
)
- # ==== 比较页面事件处理 ====
- # 初始化两侧模型下拉列表
- demo.load(
- fn=update_available_models_for_dropdowns,
- inputs=[comp_dataset_state],
- outputs=[model_dropdown, comp_model_dropdown_left]
- ).then(
- fn=update_available_models_for_dropdowns,
- inputs=[comp_dataset_state],
- outputs=[model_dropdown, comp_model_dropdown_right]
+ # --- Problem Selection Handling (Single Tab) ---
+ # Option 1: User types into the visible input box
+ problem_id_input_display.submit( # Trigger on Enter press
+ # Copy value from visible input to hidden state input to trigger main handler
+ fn=lambda x: x,
+ inputs=[problem_id_input_display],
+ outputs=[problem_id_state] # This change triggers the next handler
+ )
+ # Option 2: User clicks problem grid (JS updates hidden problem_id_state)
+ # Main handler triggered by change in hidden problem_id_state
+ problem_id_state.change(
+ fn=handle_problem_select, # Fetches problem, answer, sample grid, sample data state
+ inputs=[problem_id_state, current_model_state, current_dataset_state], # Mode defaults to 'default'
+ outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state]
+ ).then( # Chain: Display the first sample after data is loaded
+ fn=handle_first_sample,
+ inputs=[current_samples_data_state],
+ outputs=[sample_metadata_output, sample_response_output]
+ ).then( # Chain: Reset sample index state to 0
+ lambda: "0", inputs=[], outputs=[sample_index_state]
+ ).then( # Chain: Update the visible input box to reflect the selected ID (useful if clicked)
+ fn=lambda x: x.value if hasattr(x,'value') else x, # Get value from state object
+ inputs=[problem_id_state],
+ outputs=[problem_id_input_display]
)
-
- # 数据集改变事件
+
+
+ # --- Sample Selection Handling (Single Tab) ---
+ # Triggered by change in hidden sample_index_state (updated by JS sample click)
+ sample_index_state.change(
+ fn=handle_sample_select,
+ inputs=[sample_index_state, current_samples_data_state],
+ outputs=[sample_metadata_output, sample_response_output]
+ )
+
+
+ # --- Comparison Tab Event Connections ---
+
+ # Dataset change updates dropdowns and clears everything
comp_dataset_radio.change(
- fn=lambda ds: ds,
+ fn=lambda ds: ds, # Update comparison dataset state first
inputs=[comp_dataset_radio],
outputs=[comp_dataset_state]
).then(
fn=update_available_models_for_dropdowns,
inputs=[comp_dataset_state],
- outputs=[model_dropdown, comp_model_dropdown_left]
- ).then(
- fn=update_available_models_for_dropdowns,
- inputs=[comp_dataset_state],
- outputs=[model_dropdown, comp_model_dropdown_right]
- ).then(
- lambda: ("Please select a dataset and enter a problem ID.", "No answer available."),
- inputs=[],
- outputs=[comp_problem_markdown_output, comp_answer_markdown_output]
- )
-
- # 为比较页面的问题ID添加单独的更新逻辑
- comp_problem_state_input.change(
- fn=handle_comparison_problem_update,
- inputs=[comp_problem_state_input, comp_dataset_state],
- outputs=[comp_problem_markdown_output, comp_answer_markdown_output]
+ outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]
+ ).then( # Clear everything dependent on dataset/models/problem
+ lambda: (None, None, # Clear model states
+ "Select models and problem.", "Select problem.", # Clear problem/answer display
+ gr.HTML("Select model."), gr.HTML("Select model."), # Clear grids
+ *clear_comparison_side_outputs(), # Clear left samples
+ *clear_comparison_side_outputs(), # Clear right samples
+ "", # Clear comp problem ID state
+ ""), # Clear comp problem input display
+ inputs=[],
+ outputs=[comp_model_state_left, comp_model_state_right,
+ comp_problem_markdown_output, comp_answer_markdown_output,
+ comp_problem_grid_html_output_left, comp_problem_grid_html_output_right,
+ # Left side sample outputs + state
+ comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left, comp_sample_index_state_left,
+ # Right side sample outputs + state
+ comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right, comp_sample_index_state_right,
+ comp_problem_id_state, comp_problem_id_input_display]
)
-
- # 创建包装函数,预设模式参数
- def update_problem_grid_comparison(model, dataset):
- return update_problem_grid_and_stats(model, dataset, mode='comparison')
-
- # 问题选择的包装函数
- def handle_problem_select_comparison(problem_id, model_state, dataset_state):
- return handle_problem_select(problem_id, model_state, dataset_state, mode='comparison')
-
- # 修改model_dropdown的处理函数,以重新查询当前问题响应 - 比较页面左侧
- def update_model_and_requery_problem_left(model_dropdown_value, current_dataset, current_problem_id):
- # 首先更新模型统计和问题网格
- _, grid_html, new_model_state = update_problem_grid_comparison(model_dropdown_value, current_dataset)
-
- # 如果有选择的问题ID,重新查询它的响应
- if current_problem_id:
- problem_content, answer_content, samples_grid_html, new_samples_data = handle_problem_select_comparison(current_problem_id, new_model_state, current_dataset)
-
- # 获取第一个样本的内容
- first_metadata, first_response = handle_first_sample(new_samples_data)
-
- return grid_html, new_model_state, problem_content, answer_content, samples_grid_html, new_samples_data, first_metadata, first_response
- else:
- # 没有问题ID,只返回更新的模型状态
- return grid_html, new_model_state, "Please enter a problem ID.", "No answer available.", "", gr.State([]), "