diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import os import json import pandas as pd @@ -14,16 +13,15 @@ import time from huggingface_hub import hf_hub_download import psutil import gc -import atexit -# 翻译表 (Unchanged) +# 翻译表 SUBJECT_TRANS = { "代数": "Algebra", "数论": "Number Theory", "几何": "Geometry", "组合": "Combinatorics" } -# MODEL_TRANS (Unchanged) + MODEL_TRANS = { "acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B", "deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B", @@ -53,172 +51,109 @@ MODEL_TRANS = { "gemini-2.5-pro-exp-03-25": "Gemini 2.5 Pro Exp 0325", "o3-mini-high": "OpenAI o3-mini (high)", "qwen3-0.6b": "Qwen3-0.6B" + # 添加更多模型映射 } -# Matplotlib Config (Unchanged) +# Configure matplotlib for better display plt.style.use('ggplot') mpl.rcParams['figure.figsize'] = (10, 6) mpl.rcParams['font.size'] = 10 -# Constants (Unchanged) +# Constants DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] -# Global DB Instance +# 全局数据库实例 db = None class ModelDatabase: - """Database access class - Optimized for disk-based access""" + """Database access class""" + def __init__(self, db_path): - """Initialize database connection directly to the disk file.""" + """Initialize database connection""" self.db_path = db_path - self.conn = None + # Use connection pool pattern to avoid too many connections + self.conn = sqlite3.connect(db_path, check_same_thread=False, isolation_level=None, timeout=60) + self.conn.execute("PRAGMA journal_mode = WAL") # Use Write-Ahead Logging for better performance + self.conn.execute("PRAGMA synchronous = NORMAL") # Reduce synchronization overhead + self.conn.execute("PRAGMA cache_size = -8000") # 8MB cache (比原来大4倍) + self.conn.execute("PRAGMA temp_store = MEMORY") # 临时表存储在内存中 + self.conn.execute("PRAGMA mmap_size = 8589934592") # 尝试使用8GB内存映射 + self.conn.row_factory = sqlite3.Row + + # 创建索引以加速查询 + self._ensure_indices() + + # 初始化模型名称映射 + self.model_display_to_real = {} + self.comp_model_display_to_real = {} + + # 初始化缓存 self._cache = {} self._problem_cache = {} self._response_cache = {} - self.model_display_to_real = {} - self.comp_model_display_to_real = {} - - try: - print(f"Connecting to database file: {db_path}") - if not os.path.exists(db_path): - raise FileNotFoundError(f"Database file not found at {db_path}") - - # Connect directly to the database file - # Increased timeout for potentially slower disk operations - self.conn = sqlite3.connect(db_path, check_same_thread=False, timeout=120) - self.conn.row_factory = sqlite3.Row - - # --- Apply PRAGMAs optimized for disk access --- - print("Applying PRAGMAs for disk-based access...") - # WAL mode generally provides better concurrency and performance - self.conn.execute("PRAGMA journal_mode = WAL") - # NORMAL synchronous is a good balance of safety and speed - self.conn.execute("PRAGMA synchronous = NORMAL") - # Allocate a cache size in KiB (e.g., 1GB = -1048576, 2GB = -2097152) - # Adjust based on available RAM (10GB total limit) - cache_size_kib = -1048576 # Start with 1GB cache - print(f"Setting cache_size to {cache_size_kib} KiB") - self.conn.execute(f"PRAGMA cache_size = {cache_size_kib}") - # Keep temporary storage in memory if possible - self.conn.execute("PRAGMA temp_store = MEMORY") - # Avoid setting mmap_size explicitly when DB >> RAM initially - # self.conn.execute("PRAGMA mmap_size = XXXXXX") # Experiment later if needed - - # Ensure indices exist (critical for disk performance) - print("Ensuring indices exist...") - start_index = time.time() - self._ensure_indices() - end_index = time.time() - # Index check/creation might be very fast if they already exist - print(f"Index check/creation completed in {end_index - start_index:.2f} seconds.") - - print("Database connection established successfully.") - - except sqlite3.Error as e: - print(f"SQLite error during database initialization: {e}") - if self.conn: self.conn.close(); self.conn = None - raise - except FileNotFoundError as e: - print(f"Error: {e}") - raise - except Exception as e: - print(f"Unexpected error during database initialization: {e}") - if self.conn: self.conn.close(); self.conn = None - raise - - if not self.conn: - raise RuntimeError("Failed to establish database connection.") - - + def _ensure_indices(self): - """Ensure necessary indices exist on the database connection.""" - if not self.conn: - print("Error: Connection not established. Cannot ensure indices.") - return + """确保数据库有必要的索引""" try: cursor = self.conn.cursor() - print("Checking/Creating index: idx_responses_model_dataset") + # 添加最常用查询的索引 cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_model_dataset ON responses(model_name, dataset)") - print("Checking/Creating index: idx_responses_unique_id") cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_unique_id ON responses(unique_id)") - print("Checking/Creating index: idx_problems_unique_id") cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_unique_id ON problems(unique_id)") - print("Checking/Creating index: idx_problems_subject") - cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_subject ON problems(subject)") - # Analyze is important for the query planner, especially on disk - print("Running ANALYZE (might take time on large DB)...") - cursor.execute("ANALYZE") - self.conn.commit() - print("Indices checked/created and table analyzed.") - except sqlite3.Error as e: - print(f"Warning: Could not create or analyze indices: {e}") - # Attempt rollback - try: self.conn.rollback() - except sqlite3.Error as rb_e: print(f"Rollback attempt failed: {rb_e}") - - # --- Methods performing queries are adjusted to use INDEXED BY hints --- - + cursor.execute("ANALYZE") # 分析表以优化查询计划 + except Exception as e: + pass + def get_available_models(self): - # (No change needed here, simple query, uses cache) - if not self.conn: return [] - if hasattr(self, '_models_cache') and self._models_cache is not None: + """Get list of all available models""" + # 缓存在实例变量中 + if hasattr(self, '_models_cache') and self._models_cache: return self._models_cache + try: cursor = self.conn.cursor() cursor.execute("SELECT DISTINCT model_name FROM responses ORDER BY model_name") models = [row['model_name'] for row in cursor.fetchall()] - self._models_cache = models + self._models_cache = models # 存储到实例缓存 return models - except sqlite3.Error as e: - print(f"Database error in get_available_models: {e}") + except sqlite3.OperationalError: return [] - + def get_available_datasets(self): - # (No change needed here, simple query, uses cache) - if not self.conn: return DATASETS - if hasattr(self, '_datasets_cache') and self._datasets_cache is not None: + """Get list of all available datasets""" + # 缓存在实例变量中 + if hasattr(self, '_datasets_cache') and self._datasets_cache: return self._datasets_cache + try: cursor = self.conn.cursor() cursor.execute("SELECT DISTINCT dataset FROM responses ORDER BY dataset") datasets = [row['dataset'].upper() for row in cursor.fetchall()] - self._datasets_cache = datasets + self._datasets_cache = datasets # 存储到实例缓存 return datasets - except sqlite3.Error as e: - print(f"Database error in get_available_datasets: {e}") + except sqlite3.OperationalError: return DATASETS - + def get_model_statistics(self, model_name, dataset): - """Get statistics, using INDEXED BY hints for disk access.""" - if not self.conn: return [["Database Error", "No connection"]] + """Get statistics for a model on a specific dataset""" if hasattr(model_name, 'value'): model_name = model_name.value if hasattr(dataset, 'value'): dataset = dataset.value - if not model_name or not dataset: return [["Input Error", "Missing model or dataset"]] - + cache_key = f"stats_{model_name}_{dataset}" + if not hasattr(self, '_cache'): self._cache = {} if cache_key in self._cache: return self._cache[cache_key] - - stats_data = [] + + cursor = self.conn.cursor() try: - cursor = self.conn.cursor() - # Query 1: Overall accuracy - Use index hint + # 优化查询1: 整体准确率 - 使用索引提示加速 cursor.execute(""" SELECT COUNT(*) as total_samples, AVG(correctness) as accuracy - FROM responses INDEXED BY idx_responses_model_dataset + FROM responses INDEXED BY idx_responses_model_dataset WHERE model_name = ? AND dataset = ? """, (model_name, dataset.lower())) overall_stats = cursor.fetchone() - - if overall_stats and overall_stats['accuracy'] is not None: - stats_data.append(["Overall Acc.", f"{overall_stats['accuracy']:.2%}"]) - elif overall_stats and overall_stats['total_samples'] == 0: - stats_data.append(["Overall Acc.", "No Samples"]) - else: - stats_data.append(["Overall Acc.", "N/A"]) - - # Query 2: Per-subject statistics - Join still needed, rely on indices - # (Adding explicit index hints on joins can sometimes be complex/less effective) - # Rely on ANALYZE and standard indices (idx_responses_unique_id, idx_problems_unique_id, idx_problems_subject) + + # 优化查询2: 按学科统计 - 避免子查询和复杂JOIN cursor.execute(""" SELECT p.subject, COUNT(r.id) as sample_count, AVG(r.correctness) as accuracy FROM responses r JOIN problems p ON r.unique_id = p.unique_id @@ -226,616 +161,1471 @@ class ModelDatabase: GROUP BY p.subject ORDER BY p.subject """, (model_name, dataset.lower())) subject_stats_rows = cursor.fetchall() + + stats_data = [] + if overall_stats and overall_stats['accuracy'] is not None: + stats_data.append(["Overall Acc.", f"{overall_stats['accuracy']:.2%}"]) + else: + stats_data.append(["Overall Acc.", "N/A"]) for subject_row in subject_stats_rows: acc_val = f"{subject_row['accuracy']:.2%}" if subject_row['accuracy'] is not None else "N/A" subject_name = subject_row['subject'] + # 使用翻译表翻译科目名称 translated_subject = SUBJECT_TRANS.get(subject_name, subject_name) stats_data.append([f"{translated_subject} Acc.", acc_val]) - + self._cache[cache_key] = stats_data return stats_data - except sqlite3.Error as e: - print(f"Database error in get_model_statistics({model_name}, {dataset}): {e}") - return [["Database Error", f"Query failed: {e}"]] - + except sqlite3.OperationalError: + return [["Database Error", "No data available"]] + def get_all_model_accuracies(self, dataset): - """Get all accuracies, using INDEXED BY hint.""" - if not self.conn: return [] + """获取所有模型在特定数据集上的准确率 (优化版本)""" if hasattr(dataset, 'value'): dataset = dataset.value - if not dataset: return [] - cache_key = f"all_accuracies_{dataset}" + if not hasattr(self, '_cache'): self._cache = {} if cache_key in self._cache: return self._cache[cache_key] - try: cursor = self.conn.cursor() - # Use index hint for potentially faster filtering/grouping + # 使用索引提示加速查询 cursor.execute(""" SELECT model_name, AVG(correctness) as accuracy FROM responses INDEXED BY idx_responses_model_dataset WHERE dataset = ? GROUP BY model_name ORDER BY accuracy DESC """, (dataset.lower(),)) - results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall() if row['accuracy'] is not None] + results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall()] self._cache[cache_key] = results return results - except sqlite3.Error as e: - print(f"Database error in get_all_model_accuracies({dataset}): {e}") + except sqlite3.OperationalError: return [] def get_problems_by_model_dataset(self, model_name, dataset): - """Get problems, using INDEXED BY hint for the primary table.""" - if not self.conn: return [] + """获取模型在特定数据集上的所有问题 (优化版本)""" if hasattr(model_name, 'value'): model_name = model_name.value if hasattr(dataset, 'value'): dataset = dataset.value - if not model_name or not dataset: return [] - cache_key = f"problems_{model_name}_{dataset}" + if not hasattr(self, '_cache'): self._cache = {} if cache_key in self._cache: return self._cache[cache_key] - + + cursor = self.conn.cursor() try: - cursor = self.conn.cursor() - # Add index hint to the 'responses' table scan + # 优化查询:使用索引提示和优化JOIN策略 cursor.execute(""" - SELECT r.unique_id, p.problem, COALESCE(AVG(r.correctness), 0.0) as accuracy + SELECT DISTINCT r.unique_id, p.problem, AVG(r.correctness) as accuracy FROM responses r INDEXED BY idx_responses_model_dataset - JOIN problems p ON r.unique_id = p.unique_id + JOIN problems p INDEXED BY idx_problems_unique_id ON r.unique_id = p.unique_id WHERE r.model_name = ? AND r.dataset = ? - GROUP BY r.unique_id, p.problem ORDER BY r.unique_id + GROUP BY r.unique_id ORDER BY r.unique_id """, (model_name, dataset.lower())) - results = [(row['unique_id'], row['accuracy'], row['problem']) for row in cursor.fetchall()] - - # Sorting in Python - id_extractor = re.compile(r'\d+') - def get_sort_key(problem_tuple): - match = id_extractor.search(problem_tuple[0]) - return int(match.group(0)) if match else 0 - sorted_results = sorted(results, key=get_sort_key) - + results = [(row['unique_id'], row['accuracy'] if row['accuracy'] is not None else 0.0, row['problem']) for row in cursor.fetchall()] + + # Sort by the integer part of unique_id + sorted_results = sorted(results, key=lambda x: int(re.search(r'\d+', x[0]).group(0)) if re.search(r'\d+', x[0]) else 0) self._cache[cache_key] = sorted_results return sorted_results - except sqlite3.Error as e: - print(f"Database error in get_problems_by_model_dataset({model_name}, {dataset}): {e}") + except sqlite3.OperationalError: return [] - except Exception as e: - print(f"Error processing/sorting problems for {model_name}, {dataset}: {e}") - return [] - def get_problem_data(self, model_name, dataset, problem_id): - """Get problem/responses, relying on automatic index usage (hints less common here).""" - # (This method's logic relies heavily on primary key lookups or specific filters, - # where SQLite is usually good at picking the right index (idx_problems_unique_id, idx_responses_unique_id). - # Adding hints here is less likely to be necessary unless performance proves otherwise.) - if not self.conn: return None, None + """获取问题和响应数据 (采用局部缓存策略)""" if hasattr(model_name, 'value'): model_name = model_name.value if hasattr(dataset, 'value'): dataset = dataset.value if hasattr(problem_id, 'value'): problem_id = problem_id.value - if not dataset or not problem_id: return None, None - + + # 问题数据缓存 - 问题数据通常不会变化,可长期缓存 problem_cache_key = f"problem_{problem_id}" - problem = self._problem_cache.get(problem_cache_key) - - if problem is None: + if problem_cache_key in self._problem_cache: + problem = self._problem_cache[problem_cache_key] + else: + if not self.conn: + return None, None + try: cursor = self.conn.cursor() - # Uses idx_problems_unique_id cursor.execute("SELECT * FROM problems WHERE unique_id = ?", (problem_id,)) - problem_row = cursor.fetchone() - if problem_row: - problem = dict(problem_row) - self._problem_cache[problem_cache_key] = problem - else: - print(f"Problem not found in DB: {problem_id}") - return None, None - except sqlite3.Error as e: - print(f"Database error fetching problem {problem_id}: {e}") - return None, None - - if problem is None: return None, None - - responses = None + problem = cursor.fetchone() + if problem: + # 转为字典存储,避免SQLite连接依赖 + self._problem_cache[problem_cache_key] = dict(problem) + problem = self._problem_cache[problem_cache_key] + except Exception: + return None, None + + if not problem: + return None, None + + # 响应数据缓存 - 更细粒度的缓存键 if model_name: resp_cache_key = f"responses_{model_name}_{dataset}_{problem_id}" if resp_cache_key in self._response_cache: - responses = self._response_cache[resp_cache_key] - else: - try: - cursor = self.conn.cursor() - # Uses idx_responses_model_dataset & idx_responses_unique_id composite/scan - cursor.execute(""" - SELECT * FROM responses -- INDEXED BY ??? (can be complex) - WHERE model_name = ? AND dataset = ? AND unique_id = ? - ORDER BY response_id - """, (model_name, dataset.lower(), problem_id)) - response_rows = cursor.fetchall() - responses = [dict(r) for r in response_rows] if response_rows else [] + return problem, self._response_cache[resp_cache_key] + + if not self.conn: + return problem, None + + # 获取特定模型的响应 + try: + cursor = self.conn.cursor() + cursor.execute(""" + SELECT * FROM responses + WHERE model_name = ? AND dataset = ? AND unique_id = ? + ORDER BY response_id + """, (model_name, dataset.lower(), problem_id)) + responses = cursor.fetchall() + + # 转换为字典列表存储 + if responses: + responses = [dict(r) for r in responses] self._response_cache[resp_cache_key] = responses - except sqlite3.Error as e: - print(f"DB error fetching responses for model {model_name}, dataset {dataset}, problem {problem_id}: {e}") - responses = None - else: # Fetch all responses for the problem + return problem, responses + except Exception: + return problem, None + else: + # 获取所有模型对此问题的响应 resp_cache_key = f"all_responses_{dataset}_{problem_id}" if resp_cache_key in self._response_cache: - responses = self._response_cache[resp_cache_key] - else: - try: - cursor = self.conn.cursor() - # Uses idx_responses_model_dataset or idx_responses_unique_id scan - cursor.execute(""" - SELECT * FROM responses -- INDEXED BY ??? - WHERE dataset = ? AND unique_id = ? - ORDER BY model_name, response_id - """, (dataset.lower(), problem_id)) - response_rows = cursor.fetchall() - responses = [dict(r) for r in response_rows] if response_rows else [] + return problem, self._response_cache[resp_cache_key] + + if not self.conn: + return problem, None + + try: + cursor = self.conn.cursor() + cursor.execute(""" + SELECT * FROM responses + WHERE dataset = ? AND unique_id = ? + ORDER BY model_name, response_id + """, (dataset.lower(), problem_id)) + responses = cursor.fetchall() + + # 转换为字典列表存储 + if responses: + responses = [dict(r) for r in responses] self._response_cache[resp_cache_key] = responses - except sqlite3.Error as e: - print(f"DB error fetching all responses for dataset {dataset}, problem {problem_id}: {e}") - responses = None - - return problem, responses - + return problem, responses + except Exception: + return problem, None def get_model_responses(self, selected_models, dataset, problem_id): - """Get responses for multiple models, bulk query preferred.""" - # (Keeping the bulk query logic using IN clause is still good for disk access) - if not self.conn: return None, {} + """获取多个模型对特定问题的响应(优化版本)""" if hasattr(dataset, 'value'): dataset = dataset.value if hasattr(problem_id, 'value'): problem_id = problem_id.value - if not selected_models or not dataset or not problem_id: + if not selected_models or not dataset or not problem_id: return None, {} + # 获取问题数据 - 可共享缓存 problem, _ = self.get_problem_data(None, dataset, problem_id) - if not problem: - print(f"Problem data not found for {problem_id} in get_model_responses") + if not problem: return None, {} - + model_responses_data = {} - real_model_names_map = {} - real_names_list = [] for model_display in selected_models: model_display_val = model_display.value if hasattr(model_display, 'value') else model_display - real_name = self.comp_model_display_to_real.get(model_display_val) or self.model_display_to_real.get(model_display_val) - if not real_name: - raw_name_part = model_display_val.split(" (")[0] - for db_name, display_lookup in MODEL_TRANS.items(): - if display_lookup == raw_name_part: real_name = db_name; break - if not real_name: real_name = raw_name_part - print(f"Warning: Using fallback lookup/parsing for model name: '{model_display_val}' -> '{real_name}'.") - - if real_name: - real_model_names_map[model_display_val] = real_name - if real_name not in real_names_list: real_names_list.append(real_name) - - if not real_names_list: - print("No valid real model names found to query.") - return problem, {} - - # Optimized: Fetch all relevant responses in a single query - try: - cursor = self.conn.cursor() - placeholders = ','.join('?' * len(real_names_list)) - # Rely on index idx_responses_model_dataset for the IN clause + other filters - query = f""" - SELECT * FROM responses -- INDEXED BY ??? (idx_responses_model_dataset likely used) - WHERE model_name IN ({placeholders}) AND dataset = ? AND unique_id = ? - ORDER BY model_name, correctness DESC, response_id - """ - params = real_names_list + [dataset.lower(), problem_id] - cursor.execute(query, params) - all_fetched_responses = cursor.fetchall() - - responses_by_real_model = {} - for resp_row in all_fetched_responses: - resp_dict = dict(resp_row) - model = resp_dict['model_name'] - if model not in responses_by_real_model: # Only store the first (best) one - responses_by_real_model[model] = resp_dict - - for display_name, real_name in real_model_names_map.items(): - model_responses_data[display_name] = responses_by_real_model.get(real_name) - - except sqlite3.Error as e: - print(f"Database error in bulk get_model_responses: {e}. Falling back.") - # Fallback to individual fetches (uses cache) - for display_name, real_name in real_model_names_map.items(): - _ , responses_for_model = self.get_problem_data(real_name, dataset, problem_id) - if responses_for_model: - correct_resp = next((r for r in responses_for_model if r.get('correctness') == 1), None) - model_responses_data[display_name] = correct_resp if correct_resp else responses_for_model[0] - else: model_responses_data[display_name] = None - + # 从显示名称中获取真实模型名称 + model = self.comp_model_display_to_real.get(model_display_val, model_display_val) + + _, responses_for_model = self.get_problem_data(model, dataset, problem_id) + if responses_for_model: + # 尝试找到正确的响应,否则使用第一个 + correct_resp = next((r for r in responses_for_model if r['correctness'] == 1), None) + model_responses_data[model_display_val] = correct_resp if correct_resp else responses_for_model[0] + else: + model_responses_data[model_display_val] = None + return problem, model_responses_data - def clear_cache(self, section=None): - # (Unchanged - Python cache clearing is still relevant) - print(f"Clearing cache section: {section if section else 'All'}") - cleared_something = False + """清除指定部分或全部缓存""" if section == 'main' or section is None: - if self._cache: - count = len(self._cache) - self._cache = {} - print(f"Cleared main cache ({count} items).") - cleared_something = True + self._cache = {} if section == 'problem' or section is None: - if self._problem_cache: - count = len(self._problem_cache) - self._problem_cache = {} - print(f"Cleared problem cache ({count} items).") - cleared_something = True + self._problem_cache = {} if section == 'response' or section is None: - if self._response_cache: - count = len(self._response_cache) - self._response_cache = {} - print(f"Cleared response cache ({count} items).") - cleared_something = True + self._response_cache = {} if section == 'models' or section is None: - if hasattr(self, '_models_cache') and self._models_cache is not None: + if hasattr(self, '_models_cache'): self._models_cache = None - print("Cleared models list cache.") - cleared_something = True - if hasattr(self, '_datasets_cache') and self._datasets_cache is not None: + if hasattr(self, '_datasets_cache'): self._datasets_cache = None - print("Cleared datasets list cache.") - cleared_something = True - if cleared_something: print("Running garbage collection..."); gc.collect() - else: print("Cache section(s) already empty or invalid section specified.") - + def close(self): - # (Unchanged - closing the disk connection) - print("Closing database connection...") + """关闭数据库连接并释放资源""" if hasattr(self, 'conn') and self.conn: try: - # Maybe run optimize before closing large WAL file? Might take time. - # print("Running PRAGMA optimize...") - # self.conn.execute("PRAGMA optimize;") self.conn.close() - self.conn = None - print("Database connection closed.") - except sqlite3.Error as e: - print(f"Error closing database connection: {e}") - else: print("Database connection already closed or never established.") + except Exception: + pass + + # 清理所有缓存 self.clear_cache() - -# --- Helper functions (format_*, get_color_*, etc.) --- -# (These remain unchanged as they don't depend on the DB access method) def format_latex(text): if text is None: return "" + # Process the text for proper LaTeX rendering with KaTeX + # KaTeX requires LaTeX backslashes to be preserved + # Only replace newlines with HTML breaks text = text.replace('\n', '
') + # Wrap in a span that KaTeX can detect and render return f'{text}' def format_markdown_with_math(text): if text is None: return "" + + # Don't add HTML tags or do special processing for LaTeX - let Gradio handle it + # Just clean up basic issues that might affect rendering + + # Convert newlines for markdown text = text.replace('\r\n', '\n').replace('\r', '\n') + + # Return the cleaned text for Gradio's markdown component to render return text def get_gradient_color(accuracy, color_map='RdYlGn'): - if accuracy is None or not isinstance(accuracy, (int, float)) or not (0.0 <= accuracy <= 1.0): - return "#808080" + if accuracy is None or not isinstance(accuracy, (int, float)): + return "#505050" # Default for missing or invalid accuracy try: + # 使用更深的颜色映射 cmap = plt.colormaps.get_cmap(color_map) - power_adjust = 0.7 - rgba = cmap(accuracy ** power_adjust) - hex_color = mpl.colors.rgb2hex(rgba) + rgba = cmap(float(accuracy)) + + # 确保颜色足够深以与白色文本形成对比 + r, g, b, a = rgba + # 降低颜色亮度,确保文本可读性 + r = r * 0.7 + g = g * 0.7 + b = b * 0.7 + + # 转回十六进制 + hex_color = mpl.colors.rgb2hex((r, g, b, a)) return hex_color - except Exception as e: - print(f"Error generating gradient color for accuracy {accuracy}: {e}") - return "#808080" - -def get_contrasting_text_color(bg_color_hex): - try: - if not bg_color_hex or not bg_color_hex.startswith('#') or len(bg_color_hex) != 7: - return "#000000" - r = int(bg_color_hex[1:3], 16); g = int(bg_color_hex[3:5], 16); b = int(bg_color_hex[5:7], 16) - rgb = [val / 255.0 for val in (r, g, b)] - rgb_corrected = [((val / 12.92) if val <= 0.03928 else ((val + 0.055) / 1.055) ** 2.4) for val in rgb] - luminance = 0.2126 * rgb_corrected[0] + 0.7152 * rgb_corrected[1] + 0.0722 * rgb_corrected[2] - return "#000000" if luminance > 0.22 else "#FFFFFF" - except Exception as e: - print(f"Error calculating contrasting color for {bg_color_hex}: {e}") - return "#000000" + except Exception: + return "#505050" + +def get_contrasting_text_color(bg_color): + """计算最佳对比文本颜色""" + # 如果背景是十六进制格式,转换为RGB + if bg_color.startswith('#'): + r = int(bg_color[1:3], 16) + g = int(bg_color[3:5], 16) + b = int(bg_color[5:7], 16) + else: + # 未知格式默认返回黑色 + return "#000" + + # 计算YIQ亮度值 - 更精确地表示人眼对亮度的感知 + yiq = (r * 299 + g * 587 + b * 114) / 1000 + + # 黄色检测 - 黄色通常R和G高,B低 + is_yellow = r > 200 and g > 200 and b < 150 + + # 浅绿色检测 - 通常G高,R中等,B低 + is_light_green = g > 200 and r > 100 and r < 180 and b < 150 + + # 米色/浅棕色检测 - R高,G中高,B低 + is_beige = r > 220 and g > 160 and g < 220 and b < 160 + + # 强制这些特定颜色使用黑色文本 + if is_yellow or is_light_green or is_beige: + return "#000" + + # 其他颜色根据亮度决定 + return "#000" if yiq > 160 else "#fff" def format_sample_metadata(sample, show_correctness=True): - if sample is None: return "
No sample data provided.
" - sample_dict = dict(sample) if hasattr(sample, 'keys') else {} - if not sample_dict: return "
Empty sample data.
" + """生成样本元数据的HTML格式显示""" + if sample is None: return "" + sample_dict = dict(sample) if hasattr(sample, 'keys') else sample if isinstance(sample, dict) else {} + if not sample_dict: return "No sample data" + + # 提取所需信息 extracted = sample_dict.get('extracted', '') - correctness = sample_dict.get('correctness', None) - output_tokens = sample_dict.get('output_tokens') - reasoning_tokens = sample_dict.get('reasoning_tokens') - if correctness == 1: correctness_label = "✓ Correct"; correctness_color = "var(--color-acc-green, #28a745)" - elif correctness == 0: correctness_label = "✗ Incorrect"; correctness_color = "var(--color-acc-red, #dc3545)" - else: correctness_label = "? Unknown"; correctness_color = "var(--color-acc-grey, #6c757d)" - html = f"
" - if show_correctness: html += f"" - if extracted: - extracted_safe = extracted.replace('<', '<').replace('>', '>') - extracted_display = f"${extracted_safe}$" if re.match(r'^-?\d+(\.\d+)?(/(-?\d+(\.\d+)?))?$', extracted_safe.strip()) else extracted_safe - html += f"" - if output_tokens is not None: html += f"" - if reasoning_tokens is not None: html += f"" - html += "
" + correctness = sample_dict.get('correctness', 0) + correctness_label = "✓ Correct" if correctness else "✗ Incorrect" + correctness_color = "var(--color-green)" if correctness else "var(--color-red)" + + # 获取token信息 + output_tokens = sample_dict.get('output_tokens', None) + reasoning_tokens = sample_dict.get('reasoning_tokens', None) + + # 创建元数据HTML + html = f"
" + + # 创建信息行 + if show_correctness: + html += f"
" + # 正确性指示器 + html += f"{correctness_label}" + + # 提取的答案 + if extracted: + html += f"Extracted: ${extracted}$" + + # 输出token数 + if output_tokens is not None: + html += f"Output Tokens: {output_tokens}" + + # 推理token数 - 仅在可用时 + if reasoning_tokens is not None: + html += f"Reasoning Tokens: {reasoning_tokens}" + + html += f"
" + + html += "
" return html def format_sample_response(sample): - if sample is None: return "No response data." - sample_dict = dict(sample) if hasattr(sample, 'keys') else {} - if not sample_dict: return "Empty response data." + """生成样本响应的Markdown格式显示""" + if sample is None: return "" + sample_dict = dict(sample) if hasattr(sample, 'keys') else sample if isinstance(sample, dict) else {} + if not sample_dict: return "No sample data" + + # 获取响应内容 response = sample_dict.get('response', '') - if not response: return "*(Empty Response)*" - response = response.replace('&', '&').replace('<', '<').replace('>', '>') + + # 转义特殊标签以防止被解析为HTML + # 替换标签 + response = response.replace("", "<think>") + response = response.replace("", "</think>") + + # 替换其他可能的特殊标签 + response = response.replace("", "<reasoning>") + response = response.replace("", "</reasoning>") + response = response.replace("", "<answer>") + response = response.replace("", "</answer>") + return response - -# --- Handler functions (handle_sample_select, handle_first_sample, etc.) --- -# (These remain unchanged as they interact with the DB class interface) -def handle_sample_select(sample_number_str, samples_data_state): - samples_list = samples_data_state if isinstance(samples_data_state, list) else (samples_data_state.value if hasattr(samples_data_state, 'value') else []) - if not samples_list or not isinstance(samples_list, list): - err_msg = "**Error:** No sample data available or invalid format." +def handle_sample_select(sample_number, samples_data): + # 确保从Gradio State对象中提取实际值 + if hasattr(samples_data, 'value'): + samples_list = samples_data.value + else: + samples_list = samples_data + + # 确保样本编号是整数 + try: + sample_idx = int(sample_number) + except ValueError: + return "Error: Sample number must be an integer.", "" + + # 确保样本数据存在且为非空列表 + if not samples_list or not isinstance(samples_list, list) or len(samples_list) == 0: + return "No sample data available. Please select a problem first.", "" + + # 检查索引是否在有效范围内,如果不在范围内,显示错误消息 + if sample_idx < 0: + err_msg = f"**Error:** Sample number {sample_idx} is out of range. Valid range is 0 to {len(samples_list) - 1}." + return err_msg, "" + + if sample_idx >= len(samples_list): + err_msg = f"**Error:** Sample number {sample_idx} is out of range. Valid range is 0 to {len(samples_list) - 1}." return err_msg, "" - try: sample_idx = int(sample_number_str) - except (ValueError, TypeError): return f"**Error:** Invalid sample number '{sample_number_str}'. Must be an integer.", "" - if not (0 <= sample_idx < len(samples_list)): return f"**Error:** Sample index {sample_idx} out of range (0 to {len(samples_list) - 1}).", "" + + # 获取所选样本的数据 try: - selected_sample = samples_list[sample_idx] - if not isinstance(selected_sample, dict): selected_sample = dict(selected_sample) if hasattr(selected_sample, 'keys') else {} - formatted_metadata = format_sample_metadata(selected_sample) - formatted_response = format_sample_response(selected_sample) + sample = samples_list[sample_idx] + formatted_metadata = format_sample_metadata(sample) + formatted_response = format_sample_response(sample) return formatted_metadata, formatted_response except Exception as e: - print(f"Error formatting sample {sample_idx}: {e}") err_msg = f"**Error displaying sample {sample_idx}:** {str(e)}" - return f"
{err_msg}
", "" + return err_msg, "" -def handle_first_sample(samples_data_state): - samples_list = samples_data_state if isinstance(samples_data_state, list) else (samples_data_state.value if hasattr(samples_data_state, 'value') else []) - if not samples_list: return format_sample_metadata(None), format_sample_response(None) - else: return handle_sample_select("0", samples_data_state) +def handle_first_sample(samples_data): + """处理并显示第一个样本(索引0)""" + # 确保从Gradio State对象中提取实际值 + if hasattr(samples_data, 'value'): + samples_list = samples_data.value + else: + samples_list = samples_data + + # 检查样本数据是否存在 + if not samples_list or not isinstance(samples_list, list) or len(samples_list) == 0: + return "No sample data available. Please select the problem and dataset first.", "" + + # 直接获取第一个样本,避免错误处理逻辑 + try: + sample = samples_list[0] + formatted_metadata = format_sample_metadata(sample) + formatted_response = format_sample_response(sample) + return formatted_metadata, formatted_response + except Exception as e: + err_msg = f"**Error displaying first sample:** {str(e)}" + return err_msg, "" -def handle_comparison_problem_update(problem_id_state, dataset_state): +def handle_comparison_problem_update(problem_id, dataset_state): + """处理比较页面的问题更新,仅更新问题和答案内容,不需要模型""" global db - if not db or not db.conn: return "Database not initialized.", "Error" + # 确保从Gradio State对象中提取实际值 dataset_name = dataset_state.value if hasattr(dataset_state, 'value') else dataset_state - problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state - original_problem_id = problem_id - if problem_id and problem_id.isdigit() and dataset_name: + problem_id_value = problem_id.value if hasattr(problem_id, 'value') else problem_id + + if not problem_id_value or not dataset_name: + return "Please select a dataset and enter a problem ID.", "No answer available." + + # 处理纯数字输入,构建完整unique_id + if problem_id_value and problem_id_value.isdigit(): + # 构建格式:OlymMATH-HARD-0-EN 或类似格式 parts = dataset_name.split('-') - if len(parts) == 2: language, difficulty = parts; problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" - else: print(f"Warning: Cannot reconstruct full ID from number '{problem_id}' and dataset '{dataset_name}'") - if not problem_id or not dataset_name: return "Please select dataset and enter problem ID.", "N/A" + if len(parts) == 2: # 确保格式正确 (例如 "EN-HARD") + language, difficulty = parts + # 构建完整ID + problem_id_value = f"OlymMATH-{difficulty}-{problem_id_value}-{language}" + try: - problem_data, _ = db.get_problem_data(None, dataset_name, problem_id) + # 只获取问题数据,不获取特定模型的响应 + problem_data, _ = db.get_problem_data(None, dataset_name, problem_id_value) + if not problem_data: - return f"Problem ID '{original_problem_id}' not found for {dataset_name}.", "N/A" - problem_dict = dict(problem_data) if problem_data else {} - problem_content = format_markdown_with_math(problem_dict.get('problem', '*(Problem text not available)*')) - answer_text = problem_dict.get('answer', '*(Answer not available)*') + return f"Problem not found: {problem_id_value}. Please check the ID and try again.", "No answer available." + + problem_dict = dict(problem_data) + # Use format_markdown_with_math for proper rendering + problem_content = format_markdown_with_math(problem_dict.get('problem', '')) + + # 将答案中的双美元符号替换为单美元符号 + answer_text = problem_dict.get('answer', '') + # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行 answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) - if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*('): answer_text = f"${answer_text}$" + + # 检查答案是否已经包含美元符号,如果没有则添加 + if '$' not in answer_text and answer_text.strip(): + answer_text = f"${answer_text}$" + answer_content = format_markdown_with_math(answer_text) + return problem_content, answer_content - except Exception as e: print(f"Error in handle_comparison_problem_update for {problem_id}, {dataset_name}: {e}"); return f"Error fetching problem details: {e}", "Error" + except Exception as e: + return f"Error: {str(e)}", "No answer available." -def handle_problem_select(problem_id_state, current_model_state, current_dataset_state, mode='default'): +def handle_problem_select(problem_id_from_js, current_model_state, current_dataset_state, mode='default'): global db - if not db or not db.conn: return "DB Error.", "DB Error.", gr.HTML("

DB Conn Error

"), gr.State([]) + # Ensure we're using the actual values from Gradio State objects model_name = current_model_state.value if hasattr(current_model_state, 'value') else current_model_state dataset_name = current_dataset_state.value if hasattr(current_dataset_state, 'value') else current_dataset_state - problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state - original_problem_id = problem_id - if not dataset_name: return "Select dataset.", "N/A", gr.HTML(""), gr.State([]) - if not problem_id: return "Enter problem ID.", "N/A", gr.HTML(""), gr.State([]) - if not model_name and mode == 'default': return "Select model.", "N/A", gr.HTML(""), gr.State([]) - if problem_id.isdigit(): - parts = dataset_name.split('-'); - if len(parts) == 2: language, difficulty = parts; problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}"; print(f"Reconstructed ID: {problem_id}") - else: print(f"Warning: Could not reconstruct ID from {original_problem_id} and {dataset_name}") + problem_id = problem_id_from_js.value if hasattr(problem_id_from_js, 'value') else problem_id_from_js + + # 处理纯数字输入,构建完整unique_id + if problem_id and problem_id.isdigit(): + # 构建格式:OlymMATH-HARD-0-EN 或类似格式 + # 从dataset_name (例如 "EN-HARD") 解析语言和难度 + parts = dataset_name.split('-') + if len(parts) == 2: # 确保格式正确 (例如 "EN-HARD") + language, difficulty = parts + # 构建完整ID + problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" + + if not problem_id or not dataset_name: + error_message = f"Missing data: problem_id='{problem_id}', dataset='{dataset_name}'" + return "Please fill in all the fields.", "No answer available.", "", gr.State([]) + + # For comparison mode, we might not have a model selected yet + if not model_name and mode == 'comparison': + try: + # Just get the problem data without model-specific responses + problem_data, _ = db.get_problem_data(None, dataset_name, problem_id) + + if not problem_data: + error_message = f"Problem data not found: problem_id='{problem_id}', dataset='{dataset_name}'" + return f"Problem not found: {problem_id}. Please check the ID and try again.", "No answer available.", "", gr.State([]) + + problem_dict = dict(problem_data) + # Process problem and answer text for Markdown rendering + problem_content = format_markdown_with_math(problem_dict.get('problem', '')) + + # 将答案中的双美元符号替换为单美元符号 + answer_text = problem_dict.get('answer', '') + # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行 + answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) + + # 检查答案是否已经包含美元符号,如果没有则添加 + if '$' not in answer_text and answer_text.strip(): + answer_text = f"${answer_text}$" + + answer_content = format_markdown_with_math(answer_text) + + # For comparison without model, we don't have samples to display + return problem_content, answer_content, "", gr.State([]) + except Exception as e: + error_message = f"Database error: {str(e)}" + return f"Database error occurred. Please try again.", "No answer available.", "", gr.State([]) + + # The regular flow for model-specific data + if not model_name: + error_message = f"Missing data: model='{model_name}'" + return "Please fill in all the fields.", "No answer available.", "", gr.State([]) + + # The problem_id from JS should be the full unique_id. No reconstruction needed normally. try: problem_data, responses_data = db.get_problem_data(model_name, dataset_name, problem_id) - if not problem_data: return f"Problem '{original_problem_id}' not found.", "N/A", gr.HTML(f"

ID '{original_problem_id}' not found.

"), gr.State([]) - problem_dict = dict(problem_data) if problem_data else {} - problem_content = format_markdown_with_math(problem_dict.get('problem', 'N/A')) - answer_text = problem_dict.get('answer', 'N/A') - answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) - if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*(') and answer_text != 'N/A': answer_text = f"${answer_text}$" - answer_content = format_markdown_with_math(answer_text) - if responses_data is None: samples_grid_html = gr.HTML("

Error fetching responses.

"); samples_data_for_state = gr.State([]) - elif not responses_data: samples_grid_html = gr.HTML("

No responses found.

"); samples_data_for_state = gr.State([]) - else: - samples_data = responses_data - samples_data_for_state = gr.State(samples_data) - correct_count = sum(1 for r in samples_data if r.get('correctness') == 1) - total_samples = len(samples_data); accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 - displayed_samples = samples_data[:64]; actual_display_count = len(displayed_samples) - samples_per_row = 16 if mode.startswith('comparison') else 32 - num_rows = math.ceil(actual_display_count / samples_per_row); grid_html_content = "" - js_mode = "'comparison_left'" if mode == 'comparison_left' else "'comparison_right'" if mode == 'comparison_right' else "'default'" - for row_idx in range(num_rows): - grid_html_content += f'
' - start_idx = row_idx * samples_per_row; end_idx = min(start_idx + samples_per_row, actual_display_count) - row_samples = displayed_samples[start_idx:end_idx] + + if not problem_data: + error_message = f"Problem data not found: problem_id='{problem_id}', model='{model_name}', dataset='{dataset_name}'" + return f"Problem not found: {problem_id}. Please check the ID and try again.", "No answer available.", "", gr.State([]) + except Exception as e: + error_message = f"Database error: {str(e)}" + return f"Database error occurred. Please try again.", "No answer available.", "", gr.State([]) + + problem_dict = dict(problem_data) + problem_display_num = re.search(r'\d+', problem_id).group(0) if re.search(r'\d+', problem_id) else problem_id + + # Process problem and answer text for Markdown rendering + problem_content = format_markdown_with_math(problem_dict.get('problem', '')) + + # 将答案中的双美元符号替换为单美元符号 + answer_text = problem_dict.get('answer', '') + # 先将$$...$$替换为单个$...$,使用re.DOTALL处理多行 + answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) + + # 检查答案是否已经包含美元符号,如果没有则添加 + if '$' not in answer_text and answer_text.strip(): + answer_text = f"${answer_text}$" + + answer_content = format_markdown_with_math(answer_text) + + # Rest of the function remains the same + if not responses_data: + samples_grid_html = "
No samples available for this problem.
" + # 返回空的样本数据状态 + return problem_content, answer_content, samples_grid_html, gr.State([]) + else: + # 准备所有样本数据,用���后续处理 + samples_data = [] + for i, resp in enumerate(responses_data): + resp_dict = dict(resp) + samples_data.append(resp_dict) + + # 计算正确率 + correct_count = sum(1 for r in samples_data if r['correctness']) + total_samples = len(samples_data) + accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 + + # 创建样本网格显示 (最多显示 64 个样本) + displayed_samples = samples_data[:64] + actual_display_count = len(displayed_samples) + + # 根据模式确定每行的样本数 + samples_per_row = 16 if mode == 'comparison' else 32 + + # 第一行: 样本 0-samples_per_row + samples_grid_html = f'
' + + for i, resp in enumerate(displayed_samples[:samples_per_row]): + correctness = resp.get('correctness', 0) + bg_color = get_gradient_color(1.0 if correctness else 0.0) + + # 移除点击事件和data属性,只保留纯显示 + samples_grid_html += f""" +
+ {i} +
+ """ + + # 如果少于samples_per_row个样本,填充剩余空间 + for i in range(min(actual_display_count, samples_per_row), samples_per_row): + samples_grid_html += f""" +
+ """ + + samples_grid_html += '
' + + # 如果有更多样本,显示第二行 + if actual_display_count > samples_per_row: + row_samples = displayed_samples[samples_per_row:2*samples_per_row] + samples_grid_html += f'
' + + for i, resp in enumerate(row_samples): + actual_idx = i + samples_per_row + correctness = resp.get('correctness', 0) + bg_color = get_gradient_color(1.0 if correctness else 0.0) + + samples_grid_html += f""" +
+ {actual_idx} +
+ """ + + # 填充剩余空间 + for i in range(len(row_samples), samples_per_row): + samples_grid_html += f""" +
+ """ + + samples_grid_html += '
' + + # 第三行和第四行 - 允许所有模式显示完整的64个样本 + if actual_display_count > 2*samples_per_row: + # 第三行 + row_samples = displayed_samples[2*samples_per_row:3*samples_per_row] + if row_samples: + samples_grid_html += f'
' + for i, resp in enumerate(row_samples): - actual_idx = start_idx + i; correctness = resp.get('correctness', None) - if correctness == 1: bg_color = get_gradient_color(1.0) - elif correctness == 0: bg_color = get_gradient_color(0.0) - else: bg_color = "#808080" - text_color = get_contrasting_text_color(bg_color) - grid_html_content += f"""""" - for _ in range(len(row_samples), samples_per_row): grid_html_content += "
" - grid_html_content += '
' - max_rows = 4 if mode.startswith('comparison') else 2 - for _ in range(num_rows, max_rows): - grid_html_content += f'
'; - for _ in range(samples_per_row): grid_html_content += "
"; grid_html_content += '
' - samples_grid_html = gr.HTML(f"""

Samples ({actual_display_count} shown) – Accuracy: {accuracy_on_problem:.1%}

{grid_html_content}
""") - return problem_content, answer_content, samples_grid_html, samples_data_for_state - except Exception as e: print(f"Error in handle_problem_select for {problem_id}, model {model_name}, dataset {dataset_name}: {e}"); import traceback; traceback.print_exc(); error_msg = f"**Internal Error:** {str(e)}"; return error_msg, "Error", gr.HTML(f"

{error_msg}

"), gr.State([]) - - -# --- UI Creation function (create_problem_grid_html, create_ui) --- -# (These remain unchanged as they interact with the DB class interface and handlers) + actual_idx = i + 2*samples_per_row + correctness = resp.get('correctness', 0) + bg_color = get_gradient_color(1.0 if correctness else 0.0) + + samples_grid_html += f""" +
+ {actual_idx} +
+ """ + + # 填充剩余空间 + for i in range(len(row_samples), samples_per_row): + samples_grid_html += f""" +
+ """ + + samples_grid_html += '
' + + # 第四行 + if actual_display_count > 3*samples_per_row: + row_samples = displayed_samples[3*samples_per_row:4*samples_per_row] + if row_samples: + samples_grid_html += f'
' + + for i, resp in enumerate(row_samples): + actual_idx = i + 3*samples_per_row + correctness = resp.get('correctness', 0) + bg_color = get_gradient_color(1.0 if correctness else 0.0) + + samples_grid_html += f""" +
+ {actual_idx} +
+ """ + + # 填充剩余空间 + for i in range(len(row_samples), samples_per_row): + samples_grid_html += f""" +
+ """ + + samples_grid_html += '
' + + # 组合HTML内容 + final_html = f""" +
+

Samples {actual_display_count} - Model Accuracy: {correct_count}/{actual_display_count} = {accuracy_on_problem:.1%}

+ {samples_grid_html} +
+ """ + + # 获取第一个样本作为初始样本 + if samples_data: + # 这样样本会在选择问题后立即显示 + return problem_content, answer_content, final_html, gr.State(samples_data) + else: + return problem_content, answer_content, final_html, gr.State([]) + def create_problem_grid_html(problems, mode='default'): - if not problems: return "
No problems found.
" + """Create HTML for problem grid buttons. The JS function will be defined globally.""" + if not problems: + return "
No problems found for this model/dataset. Please select a model and dataset.
" + html_buttons = "" try: - id_extractor = re.compile(r'\d+'); get_sort_key = lambda p: int(id_extractor.search(str(p[0])).group(0)) if id_extractor.search(str(p[0])) else 0 - processed_problems = [] - if isinstance(problems, list): - for p in problems: - try: pid = str(p[0]); acc = float(p[1]) if len(p)>1 and p[1] is not None else 0.0; processed_problems.append((pid, acc)) - except (IndexError, TypeError, ValueError) as conv_err: print(f"Skipping problem entry: {p} - {conv_err}") - sorted_problems = sorted(processed_problems, key=get_sort_key) - else: print(f"Problem data format unexpected: {type(problems)}"); return "
Error: Invalid problem data format.
" - except Exception as e: print(f"Error sorting/processing problems: {e}"); return f"
Error displaying problems: {e}
" - id_extractor = re.compile(r'\d+'); js_mode_arg = "'comparison'" if mode == 'comparison' else "'default'" - for pid, accuracy in sorted_problems: - match = id_extractor.search(pid); num_display = match.group(0) if match else pid[:6] - acc_pct = int(accuracy * 100); bg_color = get_gradient_color(accuracy); text_color = get_contrasting_text_color(bg_color) - escaped_pid = pid.replace("'", "\\'") - html_buttons += f"""""" + sorted_problems = sorted( + [(str(p[0]), float(p[1]) if p[1] is not None else 0.0, p[2]) for p in problems], + key=lambda x: int(re.search(r'\d+', x[0]).group(0)) if re.search(r'\d+', x[0]) else 0 + ) + except Exception as e: + return f"
Error displaying problems. Check logs. {e}
" + + for pid, accuracy, _ in sorted_problems: + match = re.search(r'\d+', pid) + num_display = match.group(0) if match else pid + acc_pct = int(accuracy * 100) + + # 获取背景颜色 + bg_color = get_gradient_color(accuracy) + # 统一使用白色文本,添加!important确保不被覆盖 + text_color = "#ffffff" + + html_buttons += f""" +
+
{num_display}
+
{acc_pct}%
+
+ """ + + # 添加自定义样式强制文本颜色为白色 + custom_style = "" + # 根据模式设置每行显示的列数 grid_cols = 20 if mode == 'comparison' else 10 - grid_html = f"""
{html_buttons}
""" + grid_html = f"{custom_style}
{html_buttons}
" return grid_html -def create_ui(db_instance): - global db; db = db_instance - if not db or not db.conn: - with gr.Blocks() as error_demo: gr.Markdown("# Error: DB Init Failed"); return error_demo - AVAILABLE_DATASETS = db.get_available_datasets(); - if not AVAILABLE_DATASETS: AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] - custom_css = """body, .gradio-container { font-family: sans-serif; font-size: 0.98em; line-height: 1.6; }.gradio-tabs > div[role='tablist'] button { font-size: 0.95em; padding: 8px 14px; }.gr-dropdown select, .gr-radio label span, .gr-checkboxgroup label span, .gr-button { font-size: 0.95em; }.gr-dataframe table { font-size:0.9em; }.gr-markdown { font-size: 1.0em; line-height: 1.6; }.dark .dark-mode-compatible { background-color: var(--neutral-800); color: var(--neutral-100); border-color: var(--neutral-700); }.dark .dark-mode-bg-secondary { background-color: var(--neutral-900); }.dark .dataframe-container table { color: var(--neutral-100); border-color: var(--neutral-600); }.dark .dataframe-container th { background-color: var(--neutral-700); }.math-inline, .math-display { font-size: 105%; }.compact-row { margin-bottom: 5px !important; padding: 0 !important; }.hidden-state > .svelte-kit-component { display: none !important; }""" - javascript = """function handleProblemClick(p,m){console.log(`Problem: ${p}, Mode: ${m}`);let i=(m==='comparison')?'comp-problem-id-state':'problem-id-state';let e=document.getElementById(i);if(e){let t=e.querySelector('textarea');if(t){console.log(`Updating ${i} value: ${p}`);t.value=p;t.dispatchEvent(new Event('input',{bubbles:!0}));t.dispatchEvent(new Event('change',{bubbles:!0}));console.log(`Events dispatched: ${i}`)}else{console.error(`No textarea in #${i}`)}}else{console.error(`Element #${i} not found`)}} function handleSampleClick(s,m){console.log(`Sample: ${s}, Mode: ${m}`);let i;if(m==='comparison_left'){i='comp-sample-index-state-left'}else if(m==='comparison_right'){i='comp-sample-index-state-right'}else{i='sample-index-state'}let e=document.getElementById(i);if(e){let t=e.querySelector('textarea');if(t){console.log(`Updating ${i} value: ${s}`);t.value=s;t.dispatchEvent(new Event('input',{bubbles:!0}));t.dispatchEvent(new Event('change',{bubbles:!0}));console.log(`Events dispatched: ${i}`)}else{console.error(`No textarea in #${i}`)}}else{console.error(`Element #${i} not found`)}} window.handleProblemClick=handleProblemClick;window.handleSampleClick=handleSampleClick;""" - with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky), head=f"", title="Model Performance Analyzer") as demo: +def create_ui(db_path): + global db + db = ModelDatabase(db_path) + + AVAILABLE_DATASETS = db.get_available_datasets() + if not AVAILABLE_DATASETS: + AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] # Fallback + + # Add MathJax support to the CSS + custom_css = """ + .padding.svelte-phx28p { padding: unset !important; } + body, .gradio-container { font-family: sans-serif; font-size: 0.95em; line-height: 1.6; } + .sample-btn { transition: all 0.15s ease-in-out; } + .sample-btn:hover { transform: translateY(-1px); box-shadow: 0 2px 5px rgba(0,0,0,0.1); } + .problem-grid-container { overflow-y: auto; } + .math-content { overflow-x: auto; padding: 5px; } + .sample-response { overflow-y: clip !important; max-height: none !important; height: auto !important; } + h1, h2, h3, h4, h5 { margin-top: 0.8em; margin-bottom: 0.4em; color: var(--color-text); } + .gradio-tabs > div[role='tablist'] button { font-size: 0.9em; padding: 8px 12px; } + .gr-dropdown select { font-size: 0.9em; } + .gr-radio label span { font-size: 0.9em; } + .gr-checkboxgroup label span { font-size: 0.9em; } + .gr-button { font-size: 0.9em; padding: 8px 12px; } + .gr-dataframe table { font-size:0.85em; } + .gr-markdown { font-size: 1em; } + + /* 适应深色模式的样式 */ + .dark-mode-compatible { + background-color: var(--background-fill-primary); + color: var(--color-text); + border-color: var(--border-color-primary); + } + .dark-mode-bg-secondary { + background-color: var(--background-fill-secondary); + } + + /* DataTable深色模式样式 */ + .dataframe-container { + //padding: 12px; + //border-radius: 8px; + //margin-top: 10px; + } + + /* MathJax Styles for Gradio's Built-in LaTeX */ + .math-inline, .math-display { + font-size: 110%; + } + .math-container p { + margin: 0.5em 0; + } + + /* Markdown content styles */ + .gr-markdown strong { + font-weight: bold; + } + .gr-markdown em { + font-style: italic; + } + .gr-markdown ul, .gr-markdown ol { + padding-left: 2em; + margin: 0.5em 0; + } + .gr-markdown blockquote { + border-left: 3px solid #ccc; + margin: 0.5em 0; + padding-left: 1em; + color: #666; + } + .gr-markdown pre, .gr-markdown code { + background-color: rgba(0,0,0,0.05); + padding: 2px 4px; + border-radius: 3px; + font-family: monospace; + } + .gr-markdown table { + border-collapse: collapse; + margin: 0.5em 0; + } + .gr-markdown th, .gr-markdown td { + border: 1px solid #ddd; + padding: 4px 8px; + } + """ + + with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo: + # Remove KaTeX loading script since we're using Gradio's native Markdown with LaTeX + current_dataset_state = gr.State(value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else "") current_model_state = gr.State(value=None) - problem_id_state = gr.Textbox(elem_id="problem-id-state", visible=False, label="Selected Problem ID") - sample_index_state = gr.Textbox(value="0", elem_id="sample-index-state", visible=False, label="Selected Sample Index") + comparison_data_state = gr.State(value={}) + # 添加当前样本状态 + current_sample_state = gr.State(value="0") + # 添加当前问题的样本数据状态 current_samples_data_state = gr.State(value=[]) + + # 为Comparison标签页添加独立状态 comp_dataset_state = gr.State(value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else "") - comp_problem_id_state = gr.Textbox(elem_id="comp-problem-id-state", visible=False, label="Selected Comparison Problem ID") - comp_model_state_left = gr.State(value=None); comp_sample_index_state_left = gr.Textbox(value="0", elem_id="comp-sample-index-state-left", visible=False, label="Selected Left Sample Index"); comp_samples_data_state_left = gr.State(value=[]) - comp_model_state_right = gr.State(value=None); comp_sample_index_state_right = gr.Textbox(value="0", elem_id="comp-sample-index-state-right", visible=False, label="Selected Right Sample Index"); comp_samples_data_state_right = gr.State(value=[]) + comp_model_state_left = gr.State(value=None) + comp_sample_state_left = gr.State(value="0") + comp_samples_data_state_left = gr.State(value=[]) + comp_model_state_right = gr.State(value=None) + comp_sample_state_right = gr.State(value="0") + comp_samples_data_state_right = gr.State(value=[]) + + # 创建占位符State组件替代None dummy_state = gr.State(value=None) + with gr.Tabs(): with gr.TabItem("Single Model Analysis"): - with gr.Row(): - with gr.Column(scale=1, min_width=300): - gr.Markdown("### Controls"); dataset_radio_single = gr.Radio(choices=AVAILABLE_DATASETS, value=current_dataset_state.value, label="Select Dataset", interactive=True); model_dropdown = gr.Dropdown(choices=[], label="Select Model (Name + Acc%)", interactive=True); problem_id_input_display = gr.Textbox(label="Enter Problem ID (e.g., 42) or click grid", placeholder="Enter number or full ID", interactive=True); gr.Markdown("### Problem Grid (Click to Select)"); problem_grid_html_output = gr.HTML("Select model and dataset."); gr.Markdown("### Model Statistics"); model_stats_df = gr.DataFrame(headers=["Metric", "Value"], wrap=True) - with gr.Column(scale=3, min_width=500): - gr.Markdown("### Problem Details"); with gr.Tabs(): with gr.TabItem("Problem Statement"): problem_markdown_output = gr.Markdown("Select a problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]); with gr.TabItem("Reference Answer"): answer_markdown_output = gr.Markdown("Select a problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]); gr.Markdown("### Model Responses (Click Grid Below to Select)"); samples_grid_output = gr.HTML(""); gr.Markdown("#### Selected Sample Details"); sample_metadata_output = gr.HTML("Select a sample from the grid above."); sample_response_output = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) + with gr.Row(variant='compact'): + with gr.Column(scale=1, min_width=280): + + dataset_radio_single = gr.Radio( + choices=AVAILABLE_DATASETS, + value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else None, + label="Select Dataset", + interactive=True + ) + + model_dropdown = gr.Dropdown( + choices=[], # Populated by callback + label="Select Model", + interactive=True + ) + + problem_state_input = gr.Textbox( + value="", + elem_id="problem-state-input", + visible=True, + label="Enter Problem ID (0 - 99, acc. below)", + container=True, + interactive=True, + every=0.5 + ) + + #gr.Markdown("#### Problem Grid") + problem_grid_html_output = gr.HTML( + value="
Select model and dataset to see problems.
", + elem_id="problem-grid-container" + ) + gr.Markdown("#### Model Statistics") + model_stats_df = gr.DataFrame( + headers=["Metric", "Value"], + wrap=True, + elem_classes="dataframe-container dark-mode-compatible dark-mode-bg-secondary" + ) + + with gr.Column(scale=3, min_width=400): + with gr.Tabs(): + with gr.TabItem("Problem Statement"): + problem_markdown_output = gr.Markdown( + "Please fill in all the fields.", + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": "\\(", "right": "\\)", "display": False}, + {"left": "\\[", "right": "\\]", "display": True} + ] + ) + with gr.TabItem("Reference Answer"): + answer_markdown_output = gr.Markdown( + "No answer available.", + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": "\\(", "right": "\\)", "display": False}, + {"left": "\\[", "right": "\\]", "display": True} + ] + ) + + # 样本网格 + samples_grid_output = gr.HTML("") + + # 在样本网格下方添加样本选择输入框 + with gr.Row(): + # 样本选择输入框 + sample_number_input = gr.Textbox( + value="0", + elem_id="sample-number-input", + visible=True, + label="Enter Sample Number (0 - 63)", + container=True, + interactive=True, + every=0.5 + ) + + # 样本内容显示区域 - 使用HTML和Markdown组件分别显示元数据和响应内容 + sample_metadata_output = gr.HTML( + value="
Select a problem first to view samples.
", + elem_classes="sample-metadata dark-mode-bg-secondary", + elem_id="sample-metadata-area" + ) + + sample_response_output = gr.Markdown( + value="Select a problem first to view samples.", + elem_classes="sample-response dark-mode-bg-secondary", + elem_id="sample-response-area", + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": "\\(", "right": "\\)", "display": False}, + {"left": "\\[", "right": "\\]", "display": True} + ] + ) + with gr.TabItem("Model Comparison"): - with gr.Row(variant='compact'): comp_dataset_radio = gr.Radio(choices=AVAILABLE_DATASETS, value=comp_dataset_state.value, label="Select Dataset", interactive=True, scale=1); comp_problem_id_input_display = gr.Textbox(label="Enter Problem ID (e.g., 42) or click grid", placeholder="Enter number or full ID", interactive=True, scale=1) - with gr.Row(variant='compact'): with gr.Column(scale=1): gr.Markdown("### Problem Details"); with gr.Tabs(): with gr.TabItem("Problem Statement"): comp_problem_markdown_output = gr.Markdown("Select models and problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]); with gr.TabItem("Reference Answer"): comp_answer_markdown_output = gr.Markdown("Select problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - with gr.Row(variant='compact', equal_height=False): - with gr.Column(scale=1, min_width=400): gr.Markdown("### Model 1"); comp_model_dropdown_left = gr.Dropdown(choices=[], label="Select Model 1", interactive=True); gr.Markdown("#### Problem Grid (Model 1 - Click)"); comp_problem_grid_html_output_left = gr.HTML("Select model 1."); gr.Markdown("#### Model 1 Responses (Click Grid Below)"); comp_samples_grid_output_left = gr.HTML(""); gr.Markdown("##### Selected Sample (Model 1)"); comp_sample_metadata_output_left = gr.HTML("Select sample."); comp_sample_response_output_left = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - with gr.Column(scale=1, min_width=400): gr.Markdown("### Model 2"); comp_model_dropdown_right = gr.Dropdown(choices=[], label="Select Model 2", interactive=True); gr.Markdown("#### Problem Grid (Model 2 - Click)"); comp_problem_grid_html_output_right = gr.HTML("Select model 2."); gr.Markdown("#### Model 2 Responses (Click Grid Below)"); comp_samples_grid_output_right = gr.HTML(""); gr.Markdown("##### Selected Sample (Model 2)"); comp_sample_metadata_output_right = gr.HTML("Select sample."); comp_sample_response_output_right = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - # --- Event Handlers (Remain the same, interact with DB interface) --- + # 共享部分 + with gr.Row(variant='compact'): + comp_dataset_radio = gr.Radio( + choices=AVAILABLE_DATASETS, + value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else None, + label="Select Dataset", + interactive=True + ) + + comp_problem_state_input = gr.Textbox( + value="", + elem_id="comp-problem-state-input", + visible=True, + label="Enter Problem ID (0 - 99, acc. below)", + container=True, + interactive=True, + every=0.5 + ) + + # 移动的共享问题和答案显示到这里 + with gr.Row(variant='compact'): + with gr.Column(scale=1): + with gr.Tabs(): + with gr.TabItem("Problem Statement"): + comp_problem_markdown_output = gr.Markdown( + "Please select models and problem.", + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": "\\(", "right": "\\)", "display": False}, + {"left": "\\[", "right": "\\]", "display": True} + ] + ) + with gr.TabItem("Reference Answer"): + comp_answer_markdown_output = gr.Markdown( + "No answer available.", + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": "\\(", "right": "\\)", "display": False}, + {"left": "\\[", "right": "\\]", "display": True} + ] + ) + + # 左右两部分模型比较 + with gr.Row(variant='compact'): + # 左侧模型 + with gr.Column(scale=1): + comp_model_dropdown_left = gr.Dropdown( + choices=[], # Populated by callback + label="Select Model 1", + interactive=True + ) + + gr.Markdown("#### Problem Grid") + comp_problem_grid_html_output_left = gr.HTML( + value="
Select model and dataset to see problems.
", + elem_id="comp-problem-grid-container-left" + ) + + # 样本网格和选择器 + comp_samples_grid_output_left = gr.HTML("") + + with gr.Row(): + comp_sample_number_input_left = gr.Textbox( + value="0", + elem_id="comp-sample-number-input-left", + visible=True, + label="Enter Sample Number (0 - 63)", + container=True, + interactive=True, + every=0.5 + ) + + # 样本内容显示区域 - 使用HTML和Markdown组件分别显示元数据和响应内容 + comp_sample_metadata_output_left = gr.HTML( + value="
Select a problem first to view samples.
", + elem_classes="sample-metadata dark-mode-bg-secondary", + elem_id="comp-sample-metadata-area-left" + ) + + comp_sample_response_output_left = gr.Markdown( + value="Select a problem first to view samples.", + elem_classes="sample-response dark-mode-bg-secondary", + elem_id="comp-sample-response-area-left", + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": "\\(", "right": "\\)", "display": False}, + {"left": "\\[", "right": "\\]", "display": True} + ] + ) + + # 右侧模型 + with gr.Column(scale=1): + comp_model_dropdown_right = gr.Dropdown( + choices=[], # Populated by callback + label="Select Model 2", + interactive=True + ) + + gr.Markdown("#### Problem Grid") + comp_problem_grid_html_output_right = gr.HTML( + value="
Select model and dataset to see problems.
", + elem_id="comp-problem-grid-container-right" + ) + + # 样本网格和选择器 + comp_samples_grid_output_right = gr.HTML("") + + with gr.Row(): + comp_sample_number_input_right = gr.Textbox( + value="0", + elem_id="comp-sample-number-input-right", + visible=True, + label="Enter Sample Number (0 - 63)", + container=True, + interactive=True, + every=0.5 + ) + + # 样本内容显示区域 - 使用HTML和Markdown组件分别显示元数据和响应内容 + comp_sample_metadata_output_right = gr.HTML( + value="
Select a problem first to view samples.
", + elem_classes="sample-metadata dark-mode-bg-secondary", + elem_id="comp-sample-metadata-area-right" + ) + + comp_sample_response_output_right = gr.Markdown( + value="Select a problem first to view samples.", + elem_classes="sample-response dark-mode-bg-secondary", + elem_id="comp-sample-response-area-right", + latex_delimiters=[ + {"left": "$", "right": "$", "display": False}, + {"left": "$$", "right": "$$", "display": True}, + {"left": "\\(", "right": "\\)", "display": False}, + {"left": "\\[", "right": "\\]", "display": True} + ] + ) + + # --- Event Handlers --- def update_available_models_for_dropdowns(selected_dataset): - if not db or not db.conn: print("Error: DB not available"); return gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[]) - all_models = db.get_available_models(); model_acc_map = {}; - if selected_dataset and all_models: model_accs = db.get_all_model_accuracies(selected_dataset); model_acc_map = {name: acc for name, acc in model_accs} - display_options = []; db.model_display_to_real = {}; db.comp_model_display_to_real = {} - sorted_models = sorted(all_models, key=lambda m: model_acc_map.get(m, -1), reverse=True) - for name in sorted_models: - display_name = MODEL_TRANS.get(name, name); acc = model_acc_map.get(name); acc_display = f" ({acc:.1%})" if acc is not None else " (N/A)"; display_text = f"{display_name}{acc_display}"; display_options.append(display_text); db.model_display_to_real[display_text] = name; db.comp_model_display_to_real[display_text] = name - return gr.Dropdown(choices=display_options, value=None, label="Select Model (Name + Acc%)", interactive=True), gr.Dropdown(choices=display_options, value=None, label="Select Model 1", interactive=True), gr.Dropdown(choices=display_options, value=None, label="Select Model 2", interactive=True) + # This function can be used to update model lists if they are dataset-dependent + # For now, assume get_available_models() gets all models irrespective of dataset for dropdown population + all_models = db.get_available_models() + # For single model tab, format with accuracy on the selected dataset + single_model_options = [] + model_to_display_map = {} # 映射用于存储真实模型名称到显示名称的映射 + + if selected_dataset and all_models: + model_accs = db.get_all_model_accuracies(selected_dataset) + model_acc_map = {name: acc for name, acc in model_accs} + single_model_options = [] + for name in all_models: + # 使用MODEL_TRANS映射模型名称 + display_name = MODEL_TRANS.get(name, name) + acc_display = f" ({model_acc_map.get(name, 0):.1%})" if model_acc_map.get(name) is not None else "" + display_text = f"{display_name}{acc_display}" + single_model_options.append(display_text) + model_to_display_map[display_text] = name # 存储映射关系 + else: + for name in all_models: + display_name = MODEL_TRANS.get(name, name) + single_model_options.append(display_name) + model_to_display_map[display_name] = name + + # 将映射存储到全局数据库对象中以便后续使用 + db.model_display_to_real = model_to_display_map + + # For comparison tab, also use formatted model names with accuracy + comp_model_choices = single_model_options # 使用和单模型相同的选项,包含准确率 + db.comp_model_display_to_real = model_to_display_map # 使用相同的映射 + + return gr.Dropdown(choices=single_model_options if single_model_options else [], value=None), \ + gr.Dropdown(choices=comp_model_choices if comp_model_choices else [], value=None) + def update_problem_grid_and_stats(selected_model_formatted, selected_dataset, mode='default'): - if not db or not db.conn: print("Error: DB not available"); return gr.DataFrame(value=[]), gr.HTML("

DB Error.

"), None - if not selected_model_formatted or not selected_dataset: return gr.DataFrame(value=[]), gr.HTML("Select model and dataset."), None - real_model_name = None; - if mode == 'comparison': real_model_name = db.comp_model_display_to_real.get(selected_model_formatted) - if not real_model_name: real_model_name = db.model_display_to_real.get(selected_model_formatted) - if not real_model_name: - raw_name_part = selected_model_formatted.split(" (")[0]; - for db_name, display_lookup in MODEL_TRANS.items(): - if display_lookup == raw_name_part: real_model_name = db_name; break - if not real_model_name: real_model_name = raw_name_part - print(f"Warning: Using fallback lookup for model name: '{selected_model_formatted}' -> '{real_model_name}'") - if not real_model_name: print(f"Error: Could not determine real model name for '{selected_model_formatted}'"); return gr.DataFrame(value=[]), gr.HTML("

Error.

"), None - stats_data = db.get_model_statistics(real_model_name, selected_dataset); problem_list = db.get_problems_by_model_dataset(real_model_name, selected_dataset); grid_html = create_problem_grid_html(problem_list, mode=mode) - return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), real_model_name - def clear_problem_outputs(): return "Select problem.", "N/A", gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" - def clear_comparison_side_outputs(): return gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" - # Single Model Event Connections - dataset_radio_single.change(fn=update_available_models_for_dropdowns, inputs=[dataset_radio_single], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]).then(lambda ds: (gr.DataFrame(value=[]), gr.HTML("Select model."), None, ds, "", *clear_problem_outputs(), ""), inputs=[dataset_radio_single], outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_id_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output, sample_index_state, problem_id_input_display]) - model_dropdown.change(fn=update_problem_grid_and_stats, inputs=[model_dropdown, current_dataset_state], outputs=[model_stats_df, problem_grid_html_output, current_model_state]).then(lambda: ("", *clear_problem_outputs(), ""), inputs=[], outputs=[problem_id_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output, sample_index_state, problem_id_input_display]) - problem_id_input_display.submit(fn=lambda x: x, inputs=[problem_id_input_display], outputs=[problem_id_state]) - problem_id_state.change(fn=handle_problem_select, inputs=[problem_id_state, current_model_state, current_dataset_state], outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state]).then(fn=handle_first_sample, inputs=[current_samples_data_state], outputs=[sample_metadata_output, sample_response_output]).then(lambda: "0", inputs=[], outputs=[sample_index_state]).then(fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[problem_id_state], outputs=[problem_id_input_display]) - sample_index_state.change(fn=handle_sample_select, inputs=[sample_index_state, current_samples_data_state], outputs=[sample_metadata_output, sample_response_output]) - # Comparison Tab Event Connections - comp_dataset_radio.change(fn=lambda ds: ds, inputs=[comp_dataset_radio], outputs=[comp_dataset_state]).then(fn=update_available_models_for_dropdowns, inputs=[comp_dataset_state], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]).then(lambda: (None, None, "Select models and problem.", "Select problem.", gr.HTML("Select model."), gr.HTML("Select model."), *clear_comparison_side_outputs(), *clear_comparison_side_outputs(), "", ""), inputs=[], outputs=[comp_model_state_left, comp_model_state_right, comp_problem_markdown_output, comp_answer_markdown_output, comp_problem_grid_html_output_left, comp_problem_grid_html_output_right, comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left, comp_sample_index_state_left, comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right, comp_sample_index_state_right, comp_problem_id_state, comp_problem_id_input_display]) - comp_model_dropdown_left.change(fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), inputs=[comp_model_dropdown_left, comp_dataset_state], outputs=[dummy_state, comp_problem_grid_html_output_left, comp_model_state_left]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_left]) - comp_model_dropdown_right.change(fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), inputs=[comp_model_dropdown_right, comp_dataset_state], outputs=[dummy_state, comp_problem_grid_html_output_right, comp_model_state_right]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_right]) - comp_problem_id_input_display.submit(fn=lambda x: x, inputs=[comp_problem_id_input_display], outputs=[comp_problem_id_state]) - comp_problem_id_state.change(fn=handle_comparison_problem_update, inputs=[comp_problem_id_state, comp_dataset_state], outputs=[comp_problem_markdown_output, comp_answer_markdown_output]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_left]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_right]).then(fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[comp_problem_id_state], outputs=[comp_problem_id_input_display]) - comp_sample_index_state_left.change(fn=handle_sample_select, inputs=[comp_sample_index_state_left, comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]) - comp_sample_index_state_right.change(fn=handle_sample_select, inputs=[comp_sample_index_state_right, comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]) - # Initial Load - demo.load(fn=update_available_models_for_dropdowns, inputs=[current_dataset_state], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]) - return demo + if not selected_model_formatted or not selected_dataset: + # Return empty/default values for all outputs, including the state + return gr.DataFrame(value=[]), gr.HTML("
Please select a model and dataset first.
"), None + + # 从映射中获取真实模型名称 + model_name = db.model_display_to_real.get(selected_model_formatted, selected_model_formatted) + # 如果找不到确切匹配,可能是因为准确率等动态内容导致,尝试前缀匹配 + if model_name == selected_model_formatted: + for display_name, real_name in db.model_display_to_real.items(): + if selected_model_formatted.startswith(display_name.split(" (")[0]): + model_name = real_name + break + + stats_data = db.get_model_statistics(model_name, selected_dataset) + problem_list = db.get_problems_by_model_dataset(model_name, selected_dataset) + grid_html = create_problem_grid_html(problem_list, mode=mode) + + # Correctly return the actual value for the current_model_state output + return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), model_name + + # Single Model Tab interactions + dataset_radio_single.change( + fn=update_available_models_for_dropdowns, + inputs=[dataset_radio_single], + outputs=[model_dropdown, comp_model_dropdown_left] + ).then( + lambda ds: (gr.DataFrame(value=[]), gr.HTML("
Select a model.
"), gr.State(value=None), ds, ""), # 清空所有输出,包括problem_state_input + inputs=[dataset_radio_single], + outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_state_input] + ).then( + # 重置Sample Number为0 + fn=lambda: "0", + inputs=[], + outputs=[sample_number_input] + ).then( + lambda: ("Please fill in all the fields.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", ""), + inputs=[], + outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output] + ) + + # Initial population of model dropdowns based on default dataset + demo.load( + fn=update_available_models_for_dropdowns, + inputs=[current_dataset_state], # Uses initial value of state + outputs=[model_dropdown, comp_model_dropdown_left] + ).then( + lambda ds_val: (gr.DataFrame(value=[]), gr.HTML("
Select a model.
"), ds_val), # Also update dataset state for single tab + inputs=[current_dataset_state], + outputs=[model_stats_df, problem_grid_html_output, current_dataset_state] + ).then( + lambda: ("Please fill in all the fields.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", ""), + inputs=[], + outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output] + ).then( + # 重置Sample Number为0 + fn=lambda: "0", + inputs=[], + outputs=[sample_number_input] + ) + + # ==== 比较页面事件处理 ==== + # 初始化两侧模型下拉列表 + demo.load( + fn=update_available_models_for_dropdowns, + inputs=[comp_dataset_state], + outputs=[model_dropdown, comp_model_dropdown_left] + ).then( + fn=update_available_models_for_dropdowns, + inputs=[comp_dataset_state], + outputs=[model_dropdown, comp_model_dropdown_right] + ) + + # 数据集改变事件 + comp_dataset_radio.change( + fn=lambda ds: ds, + inputs=[comp_dataset_radio], + outputs=[comp_dataset_state] + ).then( + fn=update_available_models_for_dropdowns, + inputs=[comp_dataset_state], + outputs=[model_dropdown, comp_model_dropdown_left] + ).then( + fn=update_available_models_for_dropdowns, + inputs=[comp_dataset_state], + outputs=[model_dropdown, comp_model_dropdown_right] + ).then( + lambda: ("Please select a dataset and enter a problem ID.", "No answer available."), + inputs=[], + outputs=[comp_problem_markdown_output, comp_answer_markdown_output] + ) + + # 为比较页面的问题ID添加单独的更新逻辑 + comp_problem_state_input.change( + fn=handle_comparison_problem_update, + inputs=[comp_problem_state_input, comp_dataset_state], + outputs=[comp_problem_markdown_output, comp_answer_markdown_output] + ) + + # 创建包装函数,预设模式参数 + def update_problem_grid_comparison(model, dataset): + return update_problem_grid_and_stats(model, dataset, mode='comparison') + + # 问题选择的包装函数 + def handle_problem_select_comparison(problem_id, model_state, dataset_state): + return handle_problem_select(problem_id, model_state, dataset_state, mode='comparison') + + # 修改model_dropdown的处理函数,以重新查询当前问题响应 - 比较页面左侧 + def update_model_and_requery_problem_left(model_dropdown_value, current_dataset, current_problem_id): + # 首先更新模型统计和问题网格 + _, grid_html, new_model_state = update_problem_grid_comparison(model_dropdown_value, current_dataset) + + # 如果有选择的问题ID,重新查询它的响应 + if current_problem_id: + problem_content, answer_content, samples_grid_html, new_samples_data = handle_problem_select_comparison(current_problem_id, new_model_state, current_dataset) + + # 获取第一个样本的内容 + first_metadata, first_response = handle_first_sample(new_samples_data) + + return grid_html, new_model_state, problem_content, answer_content, samples_grid_html, new_samples_data, first_metadata, first_response + else: + # 没有问题ID,只返回更新的模型状态 + return grid_html, new_model_state, "Please enter a problem ID.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", "" + + # 修改model_dropdown的处理函数,以重新查询当前问题响应 - 比较页面右侧 + def update_model_and_requery_problem_right(model_dropdown_value, current_dataset, current_problem_id): + # 首先更新模型统计和问题网格 + _, grid_html, new_model_state = update_problem_grid_comparison(model_dropdown_value, current_dataset) + + # 如果有选择的问题ID,重新查询它的响应 + if current_problem_id: + # 对于右侧,我们不需要更新问题和答案内容 + _, _, samples_grid_html, new_samples_data = handle_problem_select_comparison(current_problem_id, new_model_state, current_dataset) + + # 获取第一个样本的内容 + first_metadata, first_response = handle_first_sample(new_samples_data) + + return grid_html, new_model_state, samples_grid_html, new_samples_data, first_metadata, first_response + else: + # 没有问题ID,只返回更新的模型状态 + return grid_html, new_model_state, "", gr.State([]), "
Select a problem first to view samples.
", "" + + # 左侧模型选择事件 + comp_model_dropdown_left.change( + fn=update_model_and_requery_problem_left, + inputs=[comp_model_dropdown_left, comp_dataset_state, comp_problem_state_input], + outputs=[comp_problem_grid_html_output_left, comp_model_state_left, comp_problem_markdown_output, comp_answer_markdown_output, comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left] + ).then( + # 重置Sample Number为0 + fn=lambda: "0", + inputs=[], + outputs=[comp_sample_number_input_left] + ) + + # 右侧模型选择事件 + comp_model_dropdown_right.change( + fn=update_model_and_requery_problem_right, + inputs=[comp_model_dropdown_right, comp_dataset_state, comp_problem_state_input], + outputs=[comp_problem_grid_html_output_right, comp_model_state_right, comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right] + ).then( + # 重置Sample Number为0 + fn=lambda: "0", + inputs=[], + outputs=[comp_sample_number_input_right] + ) + + # 左侧样本选择 + comp_sample_number_input_left.change( + fn=handle_sample_select, + inputs=[comp_sample_number_input_left, comp_samples_data_state_left], + outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] + ) + + # 右侧样本选择 + comp_sample_number_input_right.change( + fn=handle_sample_select, + inputs=[comp_sample_number_input_right, comp_samples_data_state_right], + outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] + ) + + # 为比较页面问题选择事件添加处理 + comp_problem_state_input.change( + fn=handle_problem_select_comparison, + inputs=[comp_problem_state_input, comp_model_state_left, comp_dataset_state], + outputs=[comp_problem_markdown_output, comp_answer_markdown_output, comp_samples_grid_output_left, comp_samples_data_state_left] + ).then( + # 重置Sample Number为0 + fn=lambda: "0", + inputs=[], + outputs=[comp_sample_number_input_left] + ).then( + fn=handle_first_sample, + inputs=[comp_samples_data_state_left], + outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] + ) + + # 问题选择事件 - 右侧模型 + comp_problem_state_input.change( + fn=handle_problem_select_comparison, + inputs=[comp_problem_state_input, comp_model_state_right, comp_dataset_state], + outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right] + ).then( + # 重置Sample Number为0 + fn=lambda: "0", + inputs=[], + outputs=[comp_sample_number_input_right] + ).then( + fn=handle_first_sample, + inputs=[comp_samples_data_state_right], + outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] + ) + + # This is the crucial link: problem_state_input is changed by user, triggers this Python callback. + problem_state_input.change( + fn=handle_problem_select, + inputs=[problem_state_input, current_model_state, current_dataset_state], + outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] + ).then( + # 重置Sample Number为0 + fn=lambda: "0", + inputs=[], + outputs=[sample_number_input] + ).then( + fn=handle_first_sample, + inputs=[current_samples_data_state], + outputs=[sample_metadata_output, sample_response_output] + ) + + # Also listen for direct input event which may be more reliable than change + problem_state_input.input( + fn=handle_problem_select, + inputs=[problem_state_input, current_model_state, current_dataset_state], + outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] + ).then( + # 重置Sample Number为0 + fn=lambda: "0", + inputs=[], + outputs=[sample_number_input] + ).then( + fn=handle_first_sample, + inputs=[current_samples_data_state], + outputs=[sample_metadata_output, sample_response_output] + ) + + # 添加样本编号的事件处理 + sample_number_input.change( + fn=handle_sample_select, + inputs=[sample_number_input, current_samples_data_state], + outputs=[sample_metadata_output, sample_response_output] + ) + + sample_number_input.input( + fn=handle_sample_select, + inputs=[sample_number_input, current_samples_data_state], + outputs=[sample_metadata_output, sample_response_output] + ) + + # 修改model_dropdown.change处理函数,以重新查询当前问题响应 + def update_model_and_requery_problem(model_dropdown_value, current_dataset, current_problem_id): + # 首先更新模型统计和问题网格 + stats_df, grid_html, new_model_state = update_problem_grid_and_stats(model_dropdown_value, current_dataset) + + # 如果有选择的问题ID,重新查询它的响应 + if current_problem_id: + problem_content, answer_content, samples_grid_html, new_samples_data = handle_problem_select(current_problem_id, new_model_state, current_dataset) + + # 获取第一个样本的内容 + first_metadata, first_response = handle_first_sample(new_samples_data) + + return stats_df, grid_html, new_model_state, problem_content, answer_content, samples_grid_html, new_samples_data, first_metadata, first_response + else: + # 没有问题ID,只返回更新的模型状态 + return stats_df, grid_html, new_model_state, "Please fill in all the fields.", "No answer available.", "", gr.State([]), "
Select a problem first to view samples.
", "" + + model_dropdown.change( + fn=update_model_and_requery_problem, + inputs=[model_dropdown, current_dataset_state, problem_state_input], + outputs=[model_stats_df, problem_grid_html_output, current_model_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output] + ).then( + # 重置Sample Number为0 + fn=lambda: "0", + inputs=[], + outputs=[sample_number_input] + ) + return demo -# --- Memory Monitor (Adjust thresholds for 10GB limit) --- def monitor_memory_usage(): - """Monitors memory usage and clears caches if thresholds are exceeded (10GB Limit).""" + """监控内存使用情况并在必要时释放缓存""" global db + try: process = psutil.Process(os.getpid()) - memory_info = process.memory_info(); memory_usage_mb = memory_info.rss / (1024 * 1024) - total_memory_gb = 10.0 # Assumed total available - print(f"[Memory Monitor] Usage: {memory_usage_mb:.1f} MB / {total_memory_gb:.1f} GB limit.") - - # Define thresholds based on 10GB limit - threshold_1_mb = 10 * 1024 * 0.70 # Warn at 7GB - threshold_2_mb = 10 * 1024 * 0.85 # Critical at 8.5GB - - if db and db.conn: - if memory_usage_mb > threshold_2_mb: - print(f"[Memory Monitor] CRITICAL: Usage ({memory_usage_mb:.1f} MB) > {threshold_2_mb:.1f} MB. Clearing ALL caches.") - db.clear_cache() - elif memory_usage_mb > threshold_1_mb: - print(f"[Memory Monitor] WARNING: Usage ({memory_usage_mb:.1f} MB) > {threshold_1_mb:.1f} MB. Clearing response cache.") - db.clear_cache('response') - else: print("[Memory Monitor] DB not active.") - return f"Memory OK: {memory_usage_mb:.1f} MB" - except Exception as e: print(f"[Memory Monitor] Error: {e}"); return "Memory monitor error" - - -# --- Main execution block --- -# (Initialization connects to disk, UI launch remains the same) + memory_info = process.memory_info() + memory_usage_mb = memory_info.rss / 1024 / 1024 + + # 如果内存使用超过12GB (激进设置),清理缓存 + if memory_usage_mb > 12000: # 12GB + if db: + db.clear_cache('response') # 优先清理响应缓存 + gc.collect() + # 如果内存使用超过14GB,更激进地清理 + if memory_usage_mb > 14000: # 14GB + if db: + db.clear_cache() # 清理所有缓存 + gc.collect() + + return f"Memory: {memory_usage_mb:.1f} MB" + except Exception as e: + return "Memory monitor error" + +# 修改主函数以使用优化策略 if __name__ == "__main__": - DB_FILE_NAME = "data.db"; DB_PATH = os.path.abspath(DB_FILE_NAME) + DB_PATH = "data.db" + + # 检查数据库文件是否存在,如果不存在则从 Hugging Face 下载 if not os.path.exists(DB_PATH): - print(f"{DB_PATH} not found. Attempting download...") try: + # 从环境变量获取 HF_TOKEN hf_token = os.environ.get("HF_TOKEN") - if not hf_token: token_path = Path.home() / ".huggingface" / "token"; hf_token = token_path.read_text().strip() if token_path.exists() else None - if not hf_token: raise ValueError("HF token not found.") - print("Downloading data.db...") - downloaded_path = hf_hub_download(repo_id="CoderBak/OlymMATH-data", filename=DB_FILE_NAME, repo_type="dataset", token=hf_token, local_dir=os.path.dirname(DB_PATH), local_dir_use_symlinks=False) - DB_PATH = os.path.abspath(downloaded_path); print(f"Download complete: {DB_PATH}") + if not hf_token: + raise ValueError("HF_TOKEN environment variable is not set") + + # 从 Hugging Face 下载数据库文件 + DB_PATH = hf_hub_download( + repo_id="CoderBak/OlymMATH-data", + filename="data.db", + repo_type="dataset", + token=hf_token + ) except Exception as e: - print(f"Error downloading DB: {e}"); - with gr.Blocks() as error_demo: gr.Markdown(f"# Error: DB Download Failed\n`{str(e)}`\nEnsure `{DB_FILE_NAME}` exists or HF token is valid."); error_demo.launch(server_name="0.0.0.0"); exit(1) - - print(f"Initializing ModelDatabase from disk: {DB_PATH}...") - start_init = time.time() - try: - db = ModelDatabase(DB_PATH) # Connects to disk - except Exception as e: - print(f"Fatal Error during DB initialization: {e}"); import traceback; traceback.print_exc() - with gr.Blocks() as error_demo: gr.Markdown(f"# Error: DB Init Failed\n`{str(e)}`\nCheck file/permissions."); error_demo.launch(server_name="0.0.0.0"); exit(1) - end_init = time.time(); print(f"ModelDatabase initialized in {end_init - start_init:.2f} seconds.") - monitor_memory_usage() # Initial check - - def cleanup(): global db; print("\nRunning cleanup..."); db.close() if db else None; print("Cleanup finished.") - atexit.register(cleanup) - - print("Creating Gradio UI..."); main_demo = create_ui(db) - print("Launching Gradio application..."); - main_demo.queue().launch(server_name="0.0.0.0", share=os.environ.get("GRADIO_SHARE", False), inbrowser=False, show_error=True) - print("Application running. Press Ctrl+C to stop.") + # 创建一个显示错误信息的简单 Gradio 应用 + with gr.Blocks() as error_demo: + gr.Markdown(f"# Error: Database Download Failed\n{str(e)}\nPlease ensure HF_TOKEN is set correctly and try again.") + error_demo.launch(server_name="0.0.0.0") + exit(1) + + if os.path.exists(DB_PATH): + # 创建UI并启动 + db = ModelDatabase(DB_PATH) + + # 添加清理函数 + def cleanup(): + global db + if db: + db.close() + + # 注册清理函数 + import atexit + atexit.register(cleanup) + + # 创建UI + main_demo = create_ui(DB_PATH) + + # 使用兼容的启动参数 + main_demo.launch( + server_name="0.0.0.0", + share=False, + inbrowser=False + ) + else: + # 创建一个显示错误信息的简单 Gradio 应用 + with gr.Blocks() as error_demo: + gr.Markdown(f"# Error: Database Not Found\nCould not find `{DB_PATH}`. Please ensure the database file is correctly placed and accessible.") + error_demo.launch(server_name="0.0.0.0") \ No newline at end of file