diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
import os
import json
import pandas as pd
@@ -14,16 +13,15 @@ import time
from huggingface_hub import hf_hub_download
import psutil
import gc
-import atexit
-# 翻译表 (Unchanged)
+# 翻译表
SUBJECT_TRANS = {
"代数": "Algebra",
"数论": "Number Theory",
"几何": "Geometry",
"组合": "Combinatorics"
}
-# MODEL_TRANS (Unchanged)
+
MODEL_TRANS = {
"acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B",
"deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B",
@@ -53,172 +51,109 @@ MODEL_TRANS = {
"gemini-2.5-pro-exp-03-25": "Gemini 2.5 Pro Exp 0325",
"o3-mini-high": "OpenAI o3-mini (high)",
"qwen3-0.6b": "Qwen3-0.6B"
+ # 添加更多模型映射
}
-# Matplotlib Config (Unchanged)
+# Configure matplotlib for better display
plt.style.use('ggplot')
mpl.rcParams['figure.figsize'] = (10, 6)
mpl.rcParams['font.size'] = 10
-# Constants (Unchanged)
+# Constants
DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"]
-# Global DB Instance
+# 全局数据库实例
db = None
class ModelDatabase:
- """Database access class - Optimized for disk-based access"""
+ """Database access class"""
+
def __init__(self, db_path):
- """Initialize database connection directly to the disk file."""
+ """Initialize database connection"""
self.db_path = db_path
- self.conn = None
+ # Use connection pool pattern to avoid too many connections
+ self.conn = sqlite3.connect(db_path, check_same_thread=False, isolation_level=None, timeout=60)
+ self.conn.execute("PRAGMA journal_mode = WAL") # Use Write-Ahead Logging for better performance
+ self.conn.execute("PRAGMA synchronous = NORMAL") # Reduce synchronization overhead
+ self.conn.execute("PRAGMA cache_size = -8000") # 8MB cache (比原来大4倍)
+ self.conn.execute("PRAGMA temp_store = MEMORY") # 临时表存储在内存中
+ self.conn.execute("PRAGMA mmap_size = 8589934592") # 尝试使用8GB内存映射
+ self.conn.row_factory = sqlite3.Row
+
+ # 创建索引以加速查询
+ self._ensure_indices()
+
+ # 初始化模型名称映射
+ self.model_display_to_real = {}
+ self.comp_model_display_to_real = {}
+
+ # 初始化缓存
self._cache = {}
self._problem_cache = {}
self._response_cache = {}
- self.model_display_to_real = {}
- self.comp_model_display_to_real = {}
-
- try:
- print(f"Connecting to database file: {db_path}")
- if not os.path.exists(db_path):
- raise FileNotFoundError(f"Database file not found at {db_path}")
-
- # Connect directly to the database file
- # Increased timeout for potentially slower disk operations
- self.conn = sqlite3.connect(db_path, check_same_thread=False, timeout=120)
- self.conn.row_factory = sqlite3.Row
-
- # --- Apply PRAGMAs optimized for disk access ---
- print("Applying PRAGMAs for disk-based access...")
- # WAL mode generally provides better concurrency and performance
- self.conn.execute("PRAGMA journal_mode = WAL")
- # NORMAL synchronous is a good balance of safety and speed
- self.conn.execute("PRAGMA synchronous = NORMAL")
- # Allocate a cache size in KiB (e.g., 1GB = -1048576, 2GB = -2097152)
- # Adjust based on available RAM (10GB total limit)
- cache_size_kib = -1048576 # Start with 1GB cache
- print(f"Setting cache_size to {cache_size_kib} KiB")
- self.conn.execute(f"PRAGMA cache_size = {cache_size_kib}")
- # Keep temporary storage in memory if possible
- self.conn.execute("PRAGMA temp_store = MEMORY")
- # Avoid setting mmap_size explicitly when DB >> RAM initially
- # self.conn.execute("PRAGMA mmap_size = XXXXXX") # Experiment later if needed
-
- # Ensure indices exist (critical for disk performance)
- print("Ensuring indices exist...")
- start_index = time.time()
- self._ensure_indices()
- end_index = time.time()
- # Index check/creation might be very fast if they already exist
- print(f"Index check/creation completed in {end_index - start_index:.2f} seconds.")
-
- print("Database connection established successfully.")
-
- except sqlite3.Error as e:
- print(f"SQLite error during database initialization: {e}")
- if self.conn: self.conn.close(); self.conn = None
- raise
- except FileNotFoundError as e:
- print(f"Error: {e}")
- raise
- except Exception as e:
- print(f"Unexpected error during database initialization: {e}")
- if self.conn: self.conn.close(); self.conn = None
- raise
-
- if not self.conn:
- raise RuntimeError("Failed to establish database connection.")
-
-
+
def _ensure_indices(self):
- """Ensure necessary indices exist on the database connection."""
- if not self.conn:
- print("Error: Connection not established. Cannot ensure indices.")
- return
+ """确保数据库有必要的索引"""
try:
cursor = self.conn.cursor()
- print("Checking/Creating index: idx_responses_model_dataset")
+ # 添加最常用查询的索引
cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_model_dataset ON responses(model_name, dataset)")
- print("Checking/Creating index: idx_responses_unique_id")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_unique_id ON responses(unique_id)")
- print("Checking/Creating index: idx_problems_unique_id")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_unique_id ON problems(unique_id)")
- print("Checking/Creating index: idx_problems_subject")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_subject ON problems(subject)")
- # Analyze is important for the query planner, especially on disk
- print("Running ANALYZE (might take time on large DB)...")
- cursor.execute("ANALYZE")
- self.conn.commit()
- print("Indices checked/created and table analyzed.")
- except sqlite3.Error as e:
- print(f"Warning: Could not create or analyze indices: {e}")
- # Attempt rollback
- try: self.conn.rollback()
- except sqlite3.Error as rb_e: print(f"Rollback attempt failed: {rb_e}")
-
- # --- Methods performing queries are adjusted to use INDEXED BY hints ---
-
+ cursor.execute("ANALYZE") # 分析表以优化查询计划
+ except Exception as e:
+ pass
+
def get_available_models(self):
- # (No change needed here, simple query, uses cache)
- if not self.conn: return []
- if hasattr(self, '_models_cache') and self._models_cache is not None:
+ """Get list of all available models"""
+ # 缓存在实例变量中
+ if hasattr(self, '_models_cache') and self._models_cache:
return self._models_cache
+
try:
cursor = self.conn.cursor()
cursor.execute("SELECT DISTINCT model_name FROM responses ORDER BY model_name")
models = [row['model_name'] for row in cursor.fetchall()]
- self._models_cache = models
+ self._models_cache = models # 存储到实例缓存
return models
- except sqlite3.Error as e:
- print(f"Database error in get_available_models: {e}")
+ except sqlite3.OperationalError:
return []
-
+
def get_available_datasets(self):
- # (No change needed here, simple query, uses cache)
- if not self.conn: return DATASETS
- if hasattr(self, '_datasets_cache') and self._datasets_cache is not None:
+ """Get list of all available datasets"""
+ # 缓存在实例变量中
+ if hasattr(self, '_datasets_cache') and self._datasets_cache:
return self._datasets_cache
+
try:
cursor = self.conn.cursor()
cursor.execute("SELECT DISTINCT dataset FROM responses ORDER BY dataset")
datasets = [row['dataset'].upper() for row in cursor.fetchall()]
- self._datasets_cache = datasets
+ self._datasets_cache = datasets # 存储到实例缓存
return datasets
- except sqlite3.Error as e:
- print(f"Database error in get_available_datasets: {e}")
+ except sqlite3.OperationalError:
return DATASETS
-
+
def get_model_statistics(self, model_name, dataset):
- """Get statistics, using INDEXED BY hints for disk access."""
- if not self.conn: return [["Database Error", "No connection"]]
+ """Get statistics for a model on a specific dataset"""
if hasattr(model_name, 'value'): model_name = model_name.value
if hasattr(dataset, 'value'): dataset = dataset.value
- if not model_name or not dataset: return [["Input Error", "Missing model or dataset"]]
-
+
cache_key = f"stats_{model_name}_{dataset}"
+ if not hasattr(self, '_cache'): self._cache = {}
if cache_key in self._cache: return self._cache[cache_key]
-
- stats_data = []
+
+ cursor = self.conn.cursor()
try:
- cursor = self.conn.cursor()
- # Query 1: Overall accuracy - Use index hint
+ # 优化查询1: 整体准确率 - 使用索引提示加速
cursor.execute("""
SELECT COUNT(*) as total_samples, AVG(correctness) as accuracy
- FROM responses INDEXED BY idx_responses_model_dataset
+ FROM responses INDEXED BY idx_responses_model_dataset
WHERE model_name = ? AND dataset = ?
""", (model_name, dataset.lower()))
overall_stats = cursor.fetchone()
-
- if overall_stats and overall_stats['accuracy'] is not None:
- stats_data.append(["Overall Acc.", f"{overall_stats['accuracy']:.2%}"])
- elif overall_stats and overall_stats['total_samples'] == 0:
- stats_data.append(["Overall Acc.", "No Samples"])
- else:
- stats_data.append(["Overall Acc.", "N/A"])
-
- # Query 2: Per-subject statistics - Join still needed, rely on indices
- # (Adding explicit index hints on joins can sometimes be complex/less effective)
- # Rely on ANALYZE and standard indices (idx_responses_unique_id, idx_problems_unique_id, idx_problems_subject)
+
+ # 优化查询2: 按学科统计 - 避免子查询和复杂JOIN
cursor.execute("""
SELECT p.subject, COUNT(r.id) as sample_count, AVG(r.correctness) as accuracy
FROM responses r JOIN problems p ON r.unique_id = p.unique_id
@@ -226,616 +161,1471 @@ class ModelDatabase:
GROUP BY p.subject ORDER BY p.subject
""", (model_name, dataset.lower()))
subject_stats_rows = cursor.fetchall()
+
+ stats_data = []
+ if overall_stats and overall_stats['accuracy'] is not None:
+ stats_data.append(["Overall Acc.", f"{overall_stats['accuracy']:.2%}"])
+ else:
+ stats_data.append(["Overall Acc.", "N/A"])
for subject_row in subject_stats_rows:
acc_val = f"{subject_row['accuracy']:.2%}" if subject_row['accuracy'] is not None else "N/A"
subject_name = subject_row['subject']
+ # 使用翻译表翻译科目名称
translated_subject = SUBJECT_TRANS.get(subject_name, subject_name)
stats_data.append([f"{translated_subject} Acc.", acc_val])
-
+
self._cache[cache_key] = stats_data
return stats_data
- except sqlite3.Error as e:
- print(f"Database error in get_model_statistics({model_name}, {dataset}): {e}")
- return [["Database Error", f"Query failed: {e}"]]
-
+ except sqlite3.OperationalError:
+ return [["Database Error", "No data available"]]
+
def get_all_model_accuracies(self, dataset):
- """Get all accuracies, using INDEXED BY hint."""
- if not self.conn: return []
+ """获取所有模型在特定数据集上的准确率 (优化版本)"""
if hasattr(dataset, 'value'): dataset = dataset.value
- if not dataset: return []
-
cache_key = f"all_accuracies_{dataset}"
+ if not hasattr(self, '_cache'): self._cache = {}
if cache_key in self._cache: return self._cache[cache_key]
-
try:
cursor = self.conn.cursor()
- # Use index hint for potentially faster filtering/grouping
+ # 使用索引提示加速查询
cursor.execute("""
SELECT model_name, AVG(correctness) as accuracy
FROM responses INDEXED BY idx_responses_model_dataset
WHERE dataset = ? GROUP BY model_name ORDER BY accuracy DESC
""", (dataset.lower(),))
- results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall() if row['accuracy'] is not None]
+ results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall()]
self._cache[cache_key] = results
return results
- except sqlite3.Error as e:
- print(f"Database error in get_all_model_accuracies({dataset}): {e}")
+ except sqlite3.OperationalError:
return []
def get_problems_by_model_dataset(self, model_name, dataset):
- """Get problems, using INDEXED BY hint for the primary table."""
- if not self.conn: return []
+ """获取模型在特定数据集上的所有问题 (优化版本)"""
if hasattr(model_name, 'value'): model_name = model_name.value
if hasattr(dataset, 'value'): dataset = dataset.value
- if not model_name or not dataset: return []
-
cache_key = f"problems_{model_name}_{dataset}"
+ if not hasattr(self, '_cache'): self._cache = {}
if cache_key in self._cache: return self._cache[cache_key]
-
+
+ cursor = self.conn.cursor()
try:
- cursor = self.conn.cursor()
- # Add index hint to the 'responses' table scan
+ # 优化查询:使用索引提示和优化JOIN策略
cursor.execute("""
- SELECT r.unique_id, p.problem, COALESCE(AVG(r.correctness), 0.0) as accuracy
+ SELECT DISTINCT r.unique_id, p.problem, AVG(r.correctness) as accuracy
FROM responses r INDEXED BY idx_responses_model_dataset
- JOIN problems p ON r.unique_id = p.unique_id
+ JOIN problems p INDEXED BY idx_problems_unique_id ON r.unique_id = p.unique_id
WHERE r.model_name = ? AND r.dataset = ?
- GROUP BY r.unique_id, p.problem ORDER BY r.unique_id
+ GROUP BY r.unique_id ORDER BY r.unique_id
""", (model_name, dataset.lower()))
- results = [(row['unique_id'], row['accuracy'], row['problem']) for row in cursor.fetchall()]
-
- # Sorting in Python
- id_extractor = re.compile(r'\d+')
- def get_sort_key(problem_tuple):
- match = id_extractor.search(problem_tuple[0])
- return int(match.group(0)) if match else 0
- sorted_results = sorted(results, key=get_sort_key)
-
+ results = [(row['unique_id'], row['accuracy'] if row['accuracy'] is not None else 0.0, row['problem']) for row in cursor.fetchall()]
+
+ # Sort by the integer part of unique_id
+ sorted_results = sorted(results, key=lambda x: int(re.search(r'\d+', x[0]).group(0)) if re.search(r'\d+', x[0]) else 0)
self._cache[cache_key] = sorted_results
return sorted_results
- except sqlite3.Error as e:
- print(f"Database error in get_problems_by_model_dataset({model_name}, {dataset}): {e}")
+ except sqlite3.OperationalError:
return []
- except Exception as e:
- print(f"Error processing/sorting problems for {model_name}, {dataset}: {e}")
- return []
-
def get_problem_data(self, model_name, dataset, problem_id):
- """Get problem/responses, relying on automatic index usage (hints less common here)."""
- # (This method's logic relies heavily on primary key lookups or specific filters,
- # where SQLite is usually good at picking the right index (idx_problems_unique_id, idx_responses_unique_id).
- # Adding hints here is less likely to be necessary unless performance proves otherwise.)
- if not self.conn: return None, None
+ """获取问题和响应数据 (采用局部缓存策略)"""
if hasattr(model_name, 'value'): model_name = model_name.value
if hasattr(dataset, 'value'): dataset = dataset.value
if hasattr(problem_id, 'value'): problem_id = problem_id.value
- if not dataset or not problem_id: return None, None
-
+
+ # 问题数据缓存 - 问题数据通常不会变化,可长期缓存
problem_cache_key = f"problem_{problem_id}"
- problem = self._problem_cache.get(problem_cache_key)
-
- if problem is None:
+ if problem_cache_key in self._problem_cache:
+ problem = self._problem_cache[problem_cache_key]
+ else:
+ if not self.conn:
+ return None, None
+
try:
cursor = self.conn.cursor()
- # Uses idx_problems_unique_id
cursor.execute("SELECT * FROM problems WHERE unique_id = ?", (problem_id,))
- problem_row = cursor.fetchone()
- if problem_row:
- problem = dict(problem_row)
- self._problem_cache[problem_cache_key] = problem
- else:
- print(f"Problem not found in DB: {problem_id}")
- return None, None
- except sqlite3.Error as e:
- print(f"Database error fetching problem {problem_id}: {e}")
- return None, None
-
- if problem is None: return None, None
-
- responses = None
+ problem = cursor.fetchone()
+ if problem:
+ # 转为字典存储,避免SQLite连接依赖
+ self._problem_cache[problem_cache_key] = dict(problem)
+ problem = self._problem_cache[problem_cache_key]
+ except Exception:
+ return None, None
+
+ if not problem:
+ return None, None
+
+ # 响应数据缓存 - 更细粒度的缓存键
if model_name:
resp_cache_key = f"responses_{model_name}_{dataset}_{problem_id}"
if resp_cache_key in self._response_cache:
- responses = self._response_cache[resp_cache_key]
- else:
- try:
- cursor = self.conn.cursor()
- # Uses idx_responses_model_dataset & idx_responses_unique_id composite/scan
- cursor.execute("""
- SELECT * FROM responses -- INDEXED BY ??? (can be complex)
- WHERE model_name = ? AND dataset = ? AND unique_id = ?
- ORDER BY response_id
- """, (model_name, dataset.lower(), problem_id))
- response_rows = cursor.fetchall()
- responses = [dict(r) for r in response_rows] if response_rows else []
+ return problem, self._response_cache[resp_cache_key]
+
+ if not self.conn:
+ return problem, None
+
+ # 获取特定模型的响应
+ try:
+ cursor = self.conn.cursor()
+ cursor.execute("""
+ SELECT * FROM responses
+ WHERE model_name = ? AND dataset = ? AND unique_id = ?
+ ORDER BY response_id
+ """, (model_name, dataset.lower(), problem_id))
+ responses = cursor.fetchall()
+
+ # 转换为字典列表存储
+ if responses:
+ responses = [dict(r) for r in responses]
self._response_cache[resp_cache_key] = responses
- except sqlite3.Error as e:
- print(f"DB error fetching responses for model {model_name}, dataset {dataset}, problem {problem_id}: {e}")
- responses = None
- else: # Fetch all responses for the problem
+ return problem, responses
+ except Exception:
+ return problem, None
+ else:
+ # 获取所有模型对此问题的响应
resp_cache_key = f"all_responses_{dataset}_{problem_id}"
if resp_cache_key in self._response_cache:
- responses = self._response_cache[resp_cache_key]
- else:
- try:
- cursor = self.conn.cursor()
- # Uses idx_responses_model_dataset or idx_responses_unique_id scan
- cursor.execute("""
- SELECT * FROM responses -- INDEXED BY ???
- WHERE dataset = ? AND unique_id = ?
- ORDER BY model_name, response_id
- """, (dataset.lower(), problem_id))
- response_rows = cursor.fetchall()
- responses = [dict(r) for r in response_rows] if response_rows else []
+ return problem, self._response_cache[resp_cache_key]
+
+ if not self.conn:
+ return problem, None
+
+ try:
+ cursor = self.conn.cursor()
+ cursor.execute("""
+ SELECT * FROM responses
+ WHERE dataset = ? AND unique_id = ?
+ ORDER BY model_name, response_id
+ """, (dataset.lower(), problem_id))
+ responses = cursor.fetchall()
+
+ # 转换为字典列表存储
+ if responses:
+ responses = [dict(r) for r in responses]
self._response_cache[resp_cache_key] = responses
- except sqlite3.Error as e:
- print(f"DB error fetching all responses for dataset {dataset}, problem {problem_id}: {e}")
- responses = None
-
- return problem, responses
-
+ return problem, responses
+ except Exception:
+ return problem, None
def get_model_responses(self, selected_models, dataset, problem_id):
- """Get responses for multiple models, bulk query preferred."""
- # (Keeping the bulk query logic using IN clause is still good for disk access)
- if not self.conn: return None, {}
+ """获取多个模型对特定问题的响应(优化版本)"""
if hasattr(dataset, 'value'): dataset = dataset.value
if hasattr(problem_id, 'value'): problem_id = problem_id.value
- if not selected_models or not dataset or not problem_id:
+ if not selected_models or not dataset or not problem_id:
return None, {}
+ # 获取问题数据 - 可共享缓存
problem, _ = self.get_problem_data(None, dataset, problem_id)
- if not problem:
- print(f"Problem data not found for {problem_id} in get_model_responses")
+ if not problem:
return None, {}
-
+
model_responses_data = {}
- real_model_names_map = {}
- real_names_list = []
for model_display in selected_models:
model_display_val = model_display.value if hasattr(model_display, 'value') else model_display
- real_name = self.comp_model_display_to_real.get(model_display_val) or self.model_display_to_real.get(model_display_val)
- if not real_name:
- raw_name_part = model_display_val.split(" (")[0]
- for db_name, display_lookup in MODEL_TRANS.items():
- if display_lookup == raw_name_part: real_name = db_name; break
- if not real_name: real_name = raw_name_part
- print(f"Warning: Using fallback lookup/parsing for model name: '{model_display_val}' -> '{real_name}'.")
-
- if real_name:
- real_model_names_map[model_display_val] = real_name
- if real_name not in real_names_list: real_names_list.append(real_name)
-
- if not real_names_list:
- print("No valid real model names found to query.")
- return problem, {}
-
- # Optimized: Fetch all relevant responses in a single query
- try:
- cursor = self.conn.cursor()
- placeholders = ','.join('?' * len(real_names_list))
- # Rely on index idx_responses_model_dataset for the IN clause + other filters
- query = f"""
- SELECT * FROM responses -- INDEXED BY ??? (idx_responses_model_dataset likely used)
- WHERE model_name IN ({placeholders}) AND dataset = ? AND unique_id = ?
- ORDER BY model_name, correctness DESC, response_id
- """
- params = real_names_list + [dataset.lower(), problem_id]
- cursor.execute(query, params)
- all_fetched_responses = cursor.fetchall()
-
- responses_by_real_model = {}
- for resp_row in all_fetched_responses:
- resp_dict = dict(resp_row)
- model = resp_dict['model_name']
- if model not in responses_by_real_model: # Only store the first (best) one
- responses_by_real_model[model] = resp_dict
-
- for display_name, real_name in real_model_names_map.items():
- model_responses_data[display_name] = responses_by_real_model.get(real_name)
-
- except sqlite3.Error as e:
- print(f"Database error in bulk get_model_responses: {e}. Falling back.")
- # Fallback to individual fetches (uses cache)
- for display_name, real_name in real_model_names_map.items():
- _ , responses_for_model = self.get_problem_data(real_name, dataset, problem_id)
- if responses_for_model:
- correct_resp = next((r for r in responses_for_model if r.get('correctness') == 1), None)
- model_responses_data[display_name] = correct_resp if correct_resp else responses_for_model[0]
- else: model_responses_data[display_name] = None
-
+ # 从显示名称中获取真实模型名称
+ model = self.comp_model_display_to_real.get(model_display_val, model_display_val)
+
+ _, responses_for_model = self.get_problem_data(model, dataset, problem_id)
+ if responses_for_model:
+ # 尝试找到正确的响应,否则使用第一个
+ correct_resp = next((r for r in responses_for_model if r['correctness'] == 1), None)
+ model_responses_data[model_display_val] = correct_resp if correct_resp else responses_for_model[0]
+ else:
+ model_responses_data[model_display_val] = None
+
return problem, model_responses_data
-
def clear_cache(self, section=None):
- # (Unchanged - Python cache clearing is still relevant)
- print(f"Clearing cache section: {section if section else 'All'}")
- cleared_something = False
+ """清除指定部分或全部缓存"""
if section == 'main' or section is None:
- if self._cache:
- count = len(self._cache)
- self._cache = {}
- print(f"Cleared main cache ({count} items).")
- cleared_something = True
+ self._cache = {}
if section == 'problem' or section is None:
- if self._problem_cache:
- count = len(self._problem_cache)
- self._problem_cache = {}
- print(f"Cleared problem cache ({count} items).")
- cleared_something = True
+ self._problem_cache = {}
if section == 'response' or section is None:
- if self._response_cache:
- count = len(self._response_cache)
- self._response_cache = {}
- print(f"Cleared response cache ({count} items).")
- cleared_something = True
+ self._response_cache = {}
if section == 'models' or section is None:
- if hasattr(self, '_models_cache') and self._models_cache is not None:
+ if hasattr(self, '_models_cache'):
self._models_cache = None
- print("Cleared models list cache.")
- cleared_something = True
- if hasattr(self, '_datasets_cache') and self._datasets_cache is not None:
+ if hasattr(self, '_datasets_cache'):
self._datasets_cache = None
- print("Cleared datasets list cache.")
- cleared_something = True
- if cleared_something: print("Running garbage collection..."); gc.collect()
- else: print("Cache section(s) already empty or invalid section specified.")
-
+
def close(self):
- # (Unchanged - closing the disk connection)
- print("Closing database connection...")
+ """关闭数据库连接并释放资源"""
if hasattr(self, 'conn') and self.conn:
try:
- # Maybe run optimize before closing large WAL file? Might take time.
- # print("Running PRAGMA optimize...")
- # self.conn.execute("PRAGMA optimize;")
self.conn.close()
- self.conn = None
- print("Database connection closed.")
- except sqlite3.Error as e:
- print(f"Error closing database connection: {e}")
- else: print("Database connection already closed or never established.")
+ except Exception:
+ pass
+
+ # 清理所有缓存
self.clear_cache()
-
-# --- Helper functions (format_*, get_color_*, etc.) ---
-# (These remain unchanged as they don't depend on the DB access method)
def format_latex(text):
if text is None: return ""
+ # Process the text for proper LaTeX rendering with KaTeX
+ # KaTeX requires LaTeX backslashes to be preserved
+ # Only replace newlines with HTML breaks
text = text.replace('\n', '
')
+ # Wrap in a span that KaTeX can detect and render
return f'{text}'
def format_markdown_with_math(text):
if text is None: return ""
+
+ # Don't add HTML tags or do special processing for LaTeX - let Gradio handle it
+ # Just clean up basic issues that might affect rendering
+
+ # Convert newlines for markdown
text = text.replace('\r\n', '\n').replace('\r', '\n')
+
+ # Return the cleaned text for Gradio's markdown component to render
return text
def get_gradient_color(accuracy, color_map='RdYlGn'):
- if accuracy is None or not isinstance(accuracy, (int, float)) or not (0.0 <= accuracy <= 1.0):
- return "#808080"
+ if accuracy is None or not isinstance(accuracy, (int, float)):
+ return "#505050" # Default for missing or invalid accuracy
try:
+ # 使用更深的颜色映射
cmap = plt.colormaps.get_cmap(color_map)
- power_adjust = 0.7
- rgba = cmap(accuracy ** power_adjust)
- hex_color = mpl.colors.rgb2hex(rgba)
+ rgba = cmap(float(accuracy))
+
+ # 确保颜色足够深以与白色文本形成对比
+ r, g, b, a = rgba
+ # 降低颜色亮度,确保文本可读性
+ r = r * 0.7
+ g = g * 0.7
+ b = b * 0.7
+
+ # 转回十六进制
+ hex_color = mpl.colors.rgb2hex((r, g, b, a))
return hex_color
- except Exception as e:
- print(f"Error generating gradient color for accuracy {accuracy}: {e}")
- return "#808080"
-
-def get_contrasting_text_color(bg_color_hex):
- try:
- if not bg_color_hex or not bg_color_hex.startswith('#') or len(bg_color_hex) != 7:
- return "#000000"
- r = int(bg_color_hex[1:3], 16); g = int(bg_color_hex[3:5], 16); b = int(bg_color_hex[5:7], 16)
- rgb = [val / 255.0 for val in (r, g, b)]
- rgb_corrected = [((val / 12.92) if val <= 0.03928 else ((val + 0.055) / 1.055) ** 2.4) for val in rgb]
- luminance = 0.2126 * rgb_corrected[0] + 0.7152 * rgb_corrected[1] + 0.0722 * rgb_corrected[2]
- return "#000000" if luminance > 0.22 else "#FFFFFF"
- except Exception as e:
- print(f"Error calculating contrasting color for {bg_color_hex}: {e}")
- return "#000000"
+ except Exception:
+ return "#505050"
+
+def get_contrasting_text_color(bg_color):
+ """计算最佳对比文本颜色"""
+ # 如果背景是十六进制格式,转换为RGB
+ if bg_color.startswith('#'):
+ r = int(bg_color[1:3], 16)
+ g = int(bg_color[3:5], 16)
+ b = int(bg_color[5:7], 16)
+ else:
+ # 未知格式默认返回黑色
+ return "#000"
+
+ # 计算YIQ亮度值 - 更精确地表示人眼对亮度的感知
+ yiq = (r * 299 + g * 587 + b * 114) / 1000
+
+ # 黄色检测 - 黄色通常R和G高,B低
+ is_yellow = r > 200 and g > 200 and b < 150
+
+ # 浅绿色检测 - 通常G高,R中等,B低
+ is_light_green = g > 200 and r > 100 and r < 180 and b < 150
+
+ # 米色/浅棕色检测 - R高,G中高,B低
+ is_beige = r > 220 and g > 160 and g < 220 and b < 160
+
+ # 强制这些特定颜色使用黑色文本
+ if is_yellow or is_light_green or is_beige:
+ return "#000"
+
+ # 其他颜色根据亮度决定
+ return "#000" if yiq > 160 else "#fff"
def format_sample_metadata(sample, show_correctness=True):
- if sample is None: return "
DB Conn Error
"), gr.State([]) + # Ensure we're using the actual values from Gradio State objects model_name = current_model_state.value if hasattr(current_model_state, 'value') else current_model_state dataset_name = current_dataset_state.value if hasattr(current_dataset_state, 'value') else current_dataset_state - problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state - original_problem_id = problem_id - if not dataset_name: return "Select dataset.", "N/A", gr.HTML(""), gr.State([]) - if not problem_id: return "Enter problem ID.", "N/A", gr.HTML(""), gr.State([]) - if not model_name and mode == 'default': return "Select model.", "N/A", gr.HTML(""), gr.State([]) - if problem_id.isdigit(): - parts = dataset_name.split('-'); - if len(parts) == 2: language, difficulty = parts; problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}"; print(f"Reconstructed ID: {problem_id}") - else: print(f"Warning: Could not reconstruct ID from {original_problem_id} and {dataset_name}") + problem_id = problem_id_from_js.value if hasattr(problem_id_from_js, 'value') else problem_id_from_js + + # 处理纯数字输入,构建完整unique_id + if problem_id and problem_id.isdigit(): + # 构建格式:OlymMATH-HARD-0-EN 或类似格式 + # 从dataset_name (例如 "EN-HARD") 解析语言和难度 + parts = dataset_name.split('-') + if len(parts) == 2: # 确保格式正确 (例如 "EN-HARD") + language, difficulty = parts + # 构建完整ID + problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" + + if not problem_id or not dataset_name: + error_message = f"Missing data: problem_id='{problem_id}', dataset='{dataset_name}'" + return "Please fill in all the fields.", "No answer available.", "", gr.State([]) + + # For comparison mode, we might not have a model selected yet + if not model_name and mode == 'comparison': + try: + # Just get the problem data without model-specific responses + problem_data, _ = db.get_problem_data(None, dataset_name, problem_id) + + if not problem_data: + error_message = f"Problem data not found: problem_id='{problem_id}', dataset='{dataset_name}'" + return f"Problem not found: {problem_id}. Please check the ID and try again.", "No answer available.", "", gr.State([]) + + problem_dict = dict(problem_data) + # Process problem and answer text for Markdown rendering + problem_content = format_markdown_with_math(problem_dict.get('problem', '')) + + # 将答案中的双美元符号替换为单美元符号 + answer_text = problem_dict.get('answer', '') + # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行 + answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) + + # 检查答案是否已经包含美元符号,如果没有则添加 + if '$' not in answer_text and answer_text.strip(): + answer_text = f"${answer_text}$" + + answer_content = format_markdown_with_math(answer_text) + + # For comparison without model, we don't have samples to display + return problem_content, answer_content, "", gr.State([]) + except Exception as e: + error_message = f"Database error: {str(e)}" + return f"Database error occurred. Please try again.", "No answer available.", "", gr.State([]) + + # The regular flow for model-specific data + if not model_name: + error_message = f"Missing data: model='{model_name}'" + return "Please fill in all the fields.", "No answer available.", "", gr.State([]) + + # The problem_id from JS should be the full unique_id. No reconstruction needed normally. try: problem_data, responses_data = db.get_problem_data(model_name, dataset_name, problem_id) - if not problem_data: return f"Problem '{original_problem_id}' not found.", "N/A", gr.HTML(f"ID '{original_problem_id}' not found.
"), gr.State([]) - problem_dict = dict(problem_data) if problem_data else {} - problem_content = format_markdown_with_math(problem_dict.get('problem', 'N/A')) - answer_text = problem_dict.get('answer', 'N/A') - answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) - if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*(') and answer_text != 'N/A': answer_text = f"${answer_text}$" - answer_content = format_markdown_with_math(answer_text) - if responses_data is None: samples_grid_html = gr.HTML("Error fetching responses.
"); samples_data_for_state = gr.State([]) - elif not responses_data: samples_grid_html = gr.HTML("No responses found.
"); samples_data_for_state = gr.State([]) - else: - samples_data = responses_data - samples_data_for_state = gr.State(samples_data) - correct_count = sum(1 for r in samples_data if r.get('correctness') == 1) - total_samples = len(samples_data); accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 - displayed_samples = samples_data[:64]; actual_display_count = len(displayed_samples) - samples_per_row = 16 if mode.startswith('comparison') else 32 - num_rows = math.ceil(actual_display_count / samples_per_row); grid_html_content = "" - js_mode = "'comparison_left'" if mode == 'comparison_left' else "'comparison_right'" if mode == 'comparison_right' else "'default'" - for row_idx in range(num_rows): - grid_html_content += f'{error_msg}
"), gr.State([]) - - -# --- UI Creation function (create_problem_grid_html, create_ui) --- -# (These remain unchanged as they interact with the DB class interface and handlers) + actual_idx = i + 2*samples_per_row + correctness = resp.get('correctness', 0) + bg_color = get_gradient_color(1.0 if correctness else 0.0) + + samples_grid_html += f""" +DB Error.
"), None - if not selected_model_formatted or not selected_dataset: return gr.DataFrame(value=[]), gr.HTML("Select model and dataset."), None - real_model_name = None; - if mode == 'comparison': real_model_name = db.comp_model_display_to_real.get(selected_model_formatted) - if not real_model_name: real_model_name = db.model_display_to_real.get(selected_model_formatted) - if not real_model_name: - raw_name_part = selected_model_formatted.split(" (")[0]; - for db_name, display_lookup in MODEL_TRANS.items(): - if display_lookup == raw_name_part: real_model_name = db_name; break - if not real_model_name: real_model_name = raw_name_part - print(f"Warning: Using fallback lookup for model name: '{selected_model_formatted}' -> '{real_model_name}'") - if not real_model_name: print(f"Error: Could not determine real model name for '{selected_model_formatted}'"); return gr.DataFrame(value=[]), gr.HTML("Error.
"), None - stats_data = db.get_model_statistics(real_model_name, selected_dataset); problem_list = db.get_problems_by_model_dataset(real_model_name, selected_dataset); grid_html = create_problem_grid_html(problem_list, mode=mode) - return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), real_model_name - def clear_problem_outputs(): return "Select problem.", "N/A", gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" - def clear_comparison_side_outputs(): return gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" - # Single Model Event Connections - dataset_radio_single.change(fn=update_available_models_for_dropdowns, inputs=[dataset_radio_single], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]).then(lambda ds: (gr.DataFrame(value=[]), gr.HTML("Select model."), None, ds, "", *clear_problem_outputs(), ""), inputs=[dataset_radio_single], outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_id_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output, sample_index_state, problem_id_input_display]) - model_dropdown.change(fn=update_problem_grid_and_stats, inputs=[model_dropdown, current_dataset_state], outputs=[model_stats_df, problem_grid_html_output, current_model_state]).then(lambda: ("", *clear_problem_outputs(), ""), inputs=[], outputs=[problem_id_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output, sample_index_state, problem_id_input_display]) - problem_id_input_display.submit(fn=lambda x: x, inputs=[problem_id_input_display], outputs=[problem_id_state]) - problem_id_state.change(fn=handle_problem_select, inputs=[problem_id_state, current_model_state, current_dataset_state], outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state]).then(fn=handle_first_sample, inputs=[current_samples_data_state], outputs=[sample_metadata_output, sample_response_output]).then(lambda: "0", inputs=[], outputs=[sample_index_state]).then(fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[problem_id_state], outputs=[problem_id_input_display]) - sample_index_state.change(fn=handle_sample_select, inputs=[sample_index_state, current_samples_data_state], outputs=[sample_metadata_output, sample_response_output]) - # Comparison Tab Event Connections - comp_dataset_radio.change(fn=lambda ds: ds, inputs=[comp_dataset_radio], outputs=[comp_dataset_state]).then(fn=update_available_models_for_dropdowns, inputs=[comp_dataset_state], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]).then(lambda: (None, None, "Select models and problem.", "Select problem.", gr.HTML("Select model."), gr.HTML("Select model."), *clear_comparison_side_outputs(), *clear_comparison_side_outputs(), "", ""), inputs=[], outputs=[comp_model_state_left, comp_model_state_right, comp_problem_markdown_output, comp_answer_markdown_output, comp_problem_grid_html_output_left, comp_problem_grid_html_output_right, comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left, comp_sample_index_state_left, comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right, comp_sample_index_state_right, comp_problem_id_state, comp_problem_id_input_display]) - comp_model_dropdown_left.change(fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), inputs=[comp_model_dropdown_left, comp_dataset_state], outputs=[dummy_state, comp_problem_grid_html_output_left, comp_model_state_left]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_left]) - comp_model_dropdown_right.change(fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), inputs=[comp_model_dropdown_right, comp_dataset_state], outputs=[dummy_state, comp_problem_grid_html_output_right, comp_model_state_right]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_right]) - comp_problem_id_input_display.submit(fn=lambda x: x, inputs=[comp_problem_id_input_display], outputs=[comp_problem_id_state]) - comp_problem_id_state.change(fn=handle_comparison_problem_update, inputs=[comp_problem_id_state, comp_dataset_state], outputs=[comp_problem_markdown_output, comp_answer_markdown_output]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_left]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_right]).then(fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[comp_problem_id_state], outputs=[comp_problem_id_input_display]) - comp_sample_index_state_left.change(fn=handle_sample_select, inputs=[comp_sample_index_state_left, comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]) - comp_sample_index_state_right.change(fn=handle_sample_select, inputs=[comp_sample_index_state_right, comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]) - # Initial Load - demo.load(fn=update_available_models_for_dropdowns, inputs=[current_dataset_state], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]) - return demo + if not selected_model_formatted or not selected_dataset: + # Return empty/default values for all outputs, including the state + return gr.DataFrame(value=[]), gr.HTML("