diff --git "a/app.py" "b/app.py"
--- "a/app.py"
+++ "b/app.py"
@@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
-# <<< Keep all existing imports >>>
import os
import json
import pandas as pd
@@ -15,17 +14,16 @@ import time
from huggingface_hub import hf_hub_download
import psutil
import gc
-import atexit # Import atexit
+import atexit
-# <<< Keep SUBJECT_TRANS and MODEL_TRANS dictionaries >>>
-# 翻译表
+# 翻译表 (Unchanged)
SUBJECT_TRANS = {
"代数": "Algebra",
"数论": "Number Theory",
"几何": "Geometry",
"组合": "Combinatorics"
}
-# MODEL_TRANS
+# MODEL_TRANS (Unchanged)
MODEL_TRANS = {
"acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B",
"deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B",
@@ -57,21 +55,21 @@ MODEL_TRANS = {
"qwen3-0.6b": "Qwen3-0.6B"
}
-# <<< Keep Matplotlib configuration >>>
+# Matplotlib Config (Unchanged)
plt.style.use('ggplot')
mpl.rcParams['figure.figsize'] = (10, 6)
mpl.rcParams['font.size'] = 10
-# <<< Keep Constants >>>
+# Constants (Unchanged)
DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"]
-# Global database instance
+# Global DB Instance
db = None
class ModelDatabase:
- """Database access class - Optimized to use in-memory database"""
+ """Database access class - Optimized for disk-based access"""
def __init__(self, db_path):
- """Initialize database connection by copying disk DB to memory."""
+ """Initialize database connection directly to the disk file."""
self.db_path = db_path
self.conn = None
self._cache = {}
@@ -80,61 +78,46 @@ class ModelDatabase:
self.model_display_to_real = {}
self.comp_model_display_to_real = {}
- disk_conn = None
try:
- # 1. Connect to the source disk database in read-only mode
- print(f"Connecting to source database (read-only): {db_path}")
- # Ensure the file exists before trying to connect
+ print(f"Connecting to database file: {db_path}")
if not os.path.exists(db_path):
raise FileNotFoundError(f"Database file not found at {db_path}")
- # Use a longer timeout just in case connection takes time, e.g. on network drives
- disk_conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True, check_same_thread=False, timeout=120)
-
- # --- Adjusted PRAGMAs for disk_conn (read-only source) ---
- # Apply PRAGMAs that are safe for read-only and might speed up the backup read
- print("Applying safe PRAGMAs to source connection for potentially faster read...")
- # synchronous=OFF is generally safe for read-only access
- disk_conn.execute("PRAGMA synchronous = OFF")
- # Set a large cache size to potentially speed up reading from disk during backup
- disk_conn.execute("PRAGMA cache_size = -2097152") # ~2GB cache, adjust if needed
- # temp_store=MEMORY is safe and might help if temporary tables are used internally
- disk_conn.execute("PRAGMA temp_store = MEMORY")
- # Removed PRAGMA journal_mode = OFF as it caused the disk I/O error
- # Removed PRAGMA locking_mode = EXCLUSIVE as it's not needed for read-only source
- print("Read-safe PRAGMAs applied to source connection.")
- # --- End of Adjusted PRAGMAs ---
-
-
- # 2. Connect to the target in-memory database
- print("Creating in-memory database...")
- self.conn = sqlite3.connect(':memory:', check_same_thread=False, timeout=120)
- self.conn.row_factory = sqlite3.Row
- # 3. Backup data from disk to memory
- print("Starting database backup from disk to memory (this may take a while)...")
- start_backup = time.time()
- with self.conn:
- disk_conn.backup(self.conn)
- end_backup = time.time()
- print(f"Database backup completed in {end_backup - start_backup:.2f} seconds.")
+ # Connect directly to the database file
+ # Increased timeout for potentially slower disk operations
+ self.conn = sqlite3.connect(db_path, check_same_thread=False, timeout=120)
+ self.conn.row_factory = sqlite3.Row
- # 4. Apply PRAGMAs suitable for the in-memory database
- print("Applying PRAGMAs to in-memory database...")
+ # --- Apply PRAGMAs optimized for disk access ---
+ print("Applying PRAGMAs for disk-based access...")
+ # WAL mode generally provides better concurrency and performance
+ self.conn.execute("PRAGMA journal_mode = WAL")
+ # NORMAL synchronous is a good balance of safety and speed
+ self.conn.execute("PRAGMA synchronous = NORMAL")
+ # Allocate a cache size in KiB (e.g., 1GB = -1048576, 2GB = -2097152)
+ # Adjust based on available RAM (10GB total limit)
+ cache_size_kib = -1048576 # Start with 1GB cache
+ print(f"Setting cache_size to {cache_size_kib} KiB")
+ self.conn.execute(f"PRAGMA cache_size = {cache_size_kib}")
+ # Keep temporary storage in memory if possible
self.conn.execute("PRAGMA temp_store = MEMORY")
- # Optional: Set cache size for in-memory DB if desired, though less critical
- # self.conn.execute("PRAGMA cache_size = -4194304") # e.g., 4GB
+ # Avoid setting mmap_size explicitly when DB >> RAM initially
+ # self.conn.execute("PRAGMA mmap_size = XXXXXX") # Experiment later if needed
- # 5. Ensure indices exist on the in-memory database *after* data loading
- print("Creating indices on in-memory database...")
+ # Ensure indices exist (critical for disk performance)
+ print("Ensuring indices exist...")
start_index = time.time()
- self._ensure_indices() # Operates on self.conn (memory DB)
+ self._ensure_indices()
end_index = time.time()
- print(f"Index creation completed in {end_index - start_index:.2f} seconds.")
+ # Index check/creation might be very fast if they already exist
+ print(f"Index check/creation completed in {end_index - start_index:.2f} seconds.")
+
+ print("Database connection established successfully.")
except sqlite3.Error as e:
print(f"SQLite error during database initialization: {e}")
if self.conn: self.conn.close(); self.conn = None
- raise # Re-raise to signal failure
+ raise
except FileNotFoundError as e:
print(f"Error: {e}")
raise
@@ -142,96 +125,72 @@ class ModelDatabase:
print(f"Unexpected error during database initialization: {e}")
if self.conn: self.conn.close(); self.conn = None
raise
- finally:
- # 6. Close the disk connection
- if disk_conn:
- disk_conn.close()
- print("Closed connection to disk database.")
-
- if self.conn:
- print("In-memory database initialized successfully.")
- else:
- # This path should ideally not be reached if exceptions are raised properly
- print("Error: In-memory database connection failed.")
- raise RuntimeError("Failed to establish in-memory database connection.")
+
+ if not self.conn:
+ raise RuntimeError("Failed to establish database connection.")
def _ensure_indices(self):
- """Ensure necessary indices exist on the database connection (self.conn)."""
+ """Ensure necessary indices exist on the database connection."""
if not self.conn:
print("Error: Connection not established. Cannot ensure indices.")
return
try:
cursor = self.conn.cursor()
- print("Creating index: idx_responses_model_dataset")
+ print("Checking/Creating index: idx_responses_model_dataset")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_model_dataset ON responses(model_name, dataset)")
- print("Creating index: idx_responses_unique_id")
+ print("Checking/Creating index: idx_responses_unique_id")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_unique_id ON responses(unique_id)")
- print("Creating index: idx_problems_unique_id")
+ print("Checking/Creating index: idx_problems_unique_id")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_unique_id ON problems(unique_id)")
- print("Creating index: idx_problems_subject")
+ print("Checking/Creating index: idx_problems_subject")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_subject ON problems(subject)")
- # Analyze the tables after creating indices for optimal query plans
- print("Running ANALYZE...")
+ # Analyze is important for the query planner, especially on disk
+ print("Running ANALYZE (might take time on large DB)...")
cursor.execute("ANALYZE")
- self.conn.commit() # Commit index creation and analysis
- print("Indices created and table analyzed successfully.")
+ self.conn.commit()
+ print("Indices checked/created and table analyzed.")
except sqlite3.Error as e:
- # Log error but don't necessarily crash the app
print(f"Warning: Could not create or analyze indices: {e}")
- # Attempt rollback if something failed partially
- try:
- self.conn.rollback()
- except sqlite3.Error as rb_e:
- print(f"Rollback attempt failed after index error: {rb_e}")
- # Depending on severity, you might want to raise e here
-
- # <<< Methods get_available_models, get_available_datasets, get_model_statistics, >>>
- # <<< get_all_model_accuracies, get_problems_by_model_dataset, get_problem_data, >>>
- # <<< get_model_responses, clear_cache are modified to: >>>
- # <<< 1. Remove INDEXED BY hints >>>
- # <<< 2. Add checks for self.conn existence >>>
- # <<< 3. Improve error logging >>>
+ # Attempt rollback
+ try: self.conn.rollback()
+ except sqlite3.Error as rb_e: print(f"Rollback attempt failed: {rb_e}")
+
+ # --- Methods performing queries are adjusted to use INDEXED BY hints ---
def get_available_models(self):
- """Get list of all available models"""
+ # (No change needed here, simple query, uses cache)
if not self.conn: return []
- # Check cache first
if hasattr(self, '_models_cache') and self._models_cache is not None:
return self._models_cache
try:
cursor = self.conn.cursor()
- # Query without explicit index hints
cursor.execute("SELECT DISTINCT model_name FROM responses ORDER BY model_name")
models = [row['model_name'] for row in cursor.fetchall()]
- self._models_cache = models # Store in instance cache
+ self._models_cache = models
return models
except sqlite3.Error as e:
print(f"Database error in get_available_models: {e}")
- return [] # Return empty list on error
+ return []
def get_available_datasets(self):
- """Get list of all available datasets"""
- if not self.conn: return DATASETS # Fallback if connection failed
- # Check cache first
+ # (No change needed here, simple query, uses cache)
+ if not self.conn: return DATASETS
if hasattr(self, '_datasets_cache') and self._datasets_cache is not None:
return self._datasets_cache
try:
cursor = self.conn.cursor()
- # Query without explicit index hints
cursor.execute("SELECT DISTINCT dataset FROM responses ORDER BY dataset")
- # Ensure uppercase consistency
datasets = [row['dataset'].upper() for row in cursor.fetchall()]
- self._datasets_cache = datasets # Store in instance cache
+ self._datasets_cache = datasets
return datasets
except sqlite3.Error as e:
print(f"Database error in get_available_datasets: {e}")
- return DATASETS # Fallback on error
+ return DATASETS
def get_model_statistics(self, model_name, dataset):
- """Get statistics for a model on a specific dataset"""
+ """Get statistics, using INDEXED BY hints for disk access."""
if not self.conn: return [["Database Error", "No connection"]]
- # Sanitize inputs
if hasattr(model_name, 'value'): model_name = model_name.value
if hasattr(dataset, 'value'): dataset = dataset.value
if not model_name or not dataset: return [["Input Error", "Missing model or dataset"]]
@@ -242,10 +201,10 @@ class ModelDatabase:
stats_data = []
try:
cursor = self.conn.cursor()
- # Query 1: Overall accuracy - No INDEXED BY hint
+ # Query 1: Overall accuracy - Use index hint
cursor.execute("""
SELECT COUNT(*) as total_samples, AVG(correctness) as accuracy
- FROM responses
+ FROM responses INDEXED BY idx_responses_model_dataset
WHERE model_name = ? AND dataset = ?
""", (model_name, dataset.lower()))
overall_stats = cursor.fetchone()
@@ -257,7 +216,9 @@ class ModelDatabase:
else:
stats_data.append(["Overall Acc.", "N/A"])
- # Query 2: Per-subject statistics - No INDEXED BY hint
+ # Query 2: Per-subject statistics - Join still needed, rely on indices
+ # (Adding explicit index hints on joins can sometimes be complex/less effective)
+ # Rely on ANALYZE and standard indices (idx_responses_unique_id, idx_problems_unique_id, idx_problems_subject)
cursor.execute("""
SELECT p.subject, COUNT(r.id) as sample_count, AVG(r.correctness) as accuracy
FROM responses r JOIN problems p ON r.unique_id = p.unique_id
@@ -272,15 +233,14 @@ class ModelDatabase:
translated_subject = SUBJECT_TRANS.get(subject_name, subject_name)
stats_data.append([f"{translated_subject} Acc.", acc_val])
- self._cache[cache_key] = stats_data # Cache the result
+ self._cache[cache_key] = stats_data
return stats_data
except sqlite3.Error as e:
print(f"Database error in get_model_statistics({model_name}, {dataset}): {e}")
- # Return partial data if overall stats succeeded but subject failed? Or just error.
return [["Database Error", f"Query failed: {e}"]]
def get_all_model_accuracies(self, dataset):
- """获取所有模型在特定数据集上的准确率"""
+ """Get all accuracies, using INDEXED BY hint."""
if not self.conn: return []
if hasattr(dataset, 'value'): dataset = dataset.value
if not dataset: return []
@@ -290,22 +250,21 @@ class ModelDatabase:
try:
cursor = self.conn.cursor()
- # No INDEXED BY hint needed, rely on idx_responses_model_dataset
+ # Use index hint for potentially faster filtering/grouping
cursor.execute("""
SELECT model_name, AVG(correctness) as accuracy
- FROM responses
+ FROM responses INDEXED BY idx_responses_model_dataset
WHERE dataset = ? GROUP BY model_name ORDER BY accuracy DESC
""", (dataset.lower(),))
- # Fetchall directly into list comprehension
results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall() if row['accuracy'] is not None]
- self._cache[cache_key] = results # Cache result
+ self._cache[cache_key] = results
return results
except sqlite3.Error as e:
print(f"Database error in get_all_model_accuracies({dataset}): {e}")
return []
def get_problems_by_model_dataset(self, model_name, dataset):
- """获取模型在特定数据集上的所有问题"""
+ """Get problems, using INDEXED BY hint for the primary table."""
if not self.conn: return []
if hasattr(model_name, 'value'): model_name = model_name.value
if hasattr(dataset, 'value'): dataset = dataset.value
@@ -316,205 +275,182 @@ class ModelDatabase:
try:
cursor = self.conn.cursor()
- # No INDEXED BY hint, rely on indices on responses and problems tables
- # Ensure AVG returns 0 if no correct responses, not NULL -> COALESCE(AVG(r.correctness), 0.0)
+ # Add index hint to the 'responses' table scan
cursor.execute("""
SELECT r.unique_id, p.problem, COALESCE(AVG(r.correctness), 0.0) as accuracy
- FROM responses r
+ FROM responses r INDEXED BY idx_responses_model_dataset
JOIN problems p ON r.unique_id = p.unique_id
WHERE r.model_name = ? AND r.dataset = ?
GROUP BY r.unique_id, p.problem ORDER BY r.unique_id
""", (model_name, dataset.lower()))
- # Fetchall directly
results = [(row['unique_id'], row['accuracy'], row['problem']) for row in cursor.fetchall()]
- # Sort in Python - pre-compile regex for slight speedup
+ # Sorting in Python
id_extractor = re.compile(r'\d+')
def get_sort_key(problem_tuple):
- match = id_extractor.search(problem_tuple[0]) # problem_tuple[0] is unique_id
- # Handle cases where ID might not have numbers gracefully
+ match = id_extractor.search(problem_tuple[0])
return int(match.group(0)) if match else 0
-
- # Sort the results list using the defined key
sorted_results = sorted(results, key=get_sort_key)
- self._cache[cache_key] = sorted_results # Cache the sorted list
+ self._cache[cache_key] = sorted_results
return sorted_results
except sqlite3.Error as e:
print(f"Database error in get_problems_by_model_dataset({model_name}, {dataset}): {e}")
return []
- except Exception as e: # Catch potential errors during sorting
+ except Exception as e:
print(f"Error processing/sorting problems for {model_name}, {dataset}: {e}")
return []
def get_problem_data(self, model_name, dataset, problem_id):
- """获取问题和响应数据 (using in-memory DB and cache)"""
+ """Get problem/responses, relying on automatic index usage (hints less common here)."""
+ # (This method's logic relies heavily on primary key lookups or specific filters,
+ # where SQLite is usually good at picking the right index (idx_problems_unique_id, idx_responses_unique_id).
+ # Adding hints here is less likely to be necessary unless performance proves otherwise.)
if not self.conn: return None, None
- # Sanitize inputs
if hasattr(model_name, 'value'): model_name = model_name.value
if hasattr(dataset, 'value'): dataset = dataset.value
if hasattr(problem_id, 'value'): problem_id = problem_id.value
- if not dataset or not problem_id: return None, None # Need dataset and problem_id
+ if not dataset or not problem_id: return None, None
- # Problem data cache check
problem_cache_key = f"problem_{problem_id}"
problem = self._problem_cache.get(problem_cache_key)
- if problem is None: # Not in cache, fetch from DB
+ if problem is None:
try:
cursor = self.conn.cursor()
- # Query uses index idx_problems_unique_id automatically
+ # Uses idx_problems_unique_id
cursor.execute("SELECT * FROM problems WHERE unique_id = ?", (problem_id,))
problem_row = cursor.fetchone()
if problem_row:
- problem = dict(problem_row) # Convert to dict for caching
+ problem = dict(problem_row)
self._problem_cache[problem_cache_key] = problem
else:
print(f"Problem not found in DB: {problem_id}")
- return None, None # Problem ID does not exist in the database
+ return None, None
except sqlite3.Error as e:
print(f"Database error fetching problem {problem_id}: {e}")
- return None, None # Return None if problem fetch fails
+ return None, None
- # If problem is still None here, it wasn't found in DB or cache
- if problem is None:
- return None, None
+ if problem is None: return None, None
- # --- Response data fetching ---
responses = None
- if model_name: # Fetch for a specific model
+ if model_name:
resp_cache_key = f"responses_{model_name}_{dataset}_{problem_id}"
if resp_cache_key in self._response_cache:
responses = self._response_cache[resp_cache_key]
else:
try:
cursor = self.conn.cursor()
- # Query uses indices idx_responses_model_dataset and idx_responses_unique_id
+ # Uses idx_responses_model_dataset & idx_responses_unique_id composite/scan
cursor.execute("""
- SELECT * FROM responses
+ SELECT * FROM responses -- INDEXED BY ??? (can be complex)
WHERE model_name = ? AND dataset = ? AND unique_id = ?
ORDER BY response_id
""", (model_name, dataset.lower(), problem_id))
response_rows = cursor.fetchall()
- # Convert rows to list of dicts for easier handling and caching
responses = [dict(r) for r in response_rows] if response_rows else []
- self._response_cache[resp_cache_key] = responses # Cache the result (even if empty)
+ self._response_cache[resp_cache_key] = responses
except sqlite3.Error as e:
print(f"DB error fetching responses for model {model_name}, dataset {dataset}, problem {problem_id}: {e}")
- responses = None # Indicate error fetching responses
- else: # Fetch for all models for this problem
+ responses = None
+ else: # Fetch all responses for the problem
resp_cache_key = f"all_responses_{dataset}_{problem_id}"
if resp_cache_key in self._response_cache:
responses = self._response_cache[resp_cache_key]
else:
try:
cursor = self.conn.cursor()
- # Query uses indices idx_responses_dataset and idx_responses_unique_id
- # Need to create idx_responses_dataset if not exists, or rely on model_dataset index scan
- # Let's add CREATE INDEX IF NOT EXISTS idx_responses_dataset ON responses(dataset); in _ensure_indices
- # --> Added index idx_responses_model_dataset which covers (dataset, unique_id) lookups too.
+ # Uses idx_responses_model_dataset or idx_responses_unique_id scan
cursor.execute("""
- SELECT * FROM responses
+ SELECT * FROM responses -- INDEXED BY ???
WHERE dataset = ? AND unique_id = ?
ORDER BY model_name, response_id
""", (dataset.lower(), problem_id))
response_rows = cursor.fetchall()
responses = [dict(r) for r in response_rows] if response_rows else []
- self._response_cache[resp_cache_key] = responses # Cache the result
+ self._response_cache[resp_cache_key] = responses
except sqlite3.Error as e:
print(f"DB error fetching all responses for dataset {dataset}, problem {problem_id}: {e}")
- responses = None # Indicate error
+ responses = None
return problem, responses
def get_model_responses(self, selected_models, dataset, problem_id):
- """获取多个模型对特定问题的响应 (optimized for in-memory)"""
+ """Get responses for multiple models, bulk query preferred."""
+ # (Keeping the bulk query logic using IN clause is still good for disk access)
if not self.conn: return None, {}
- # Sanitize inputs
if hasattr(dataset, 'value'): dataset = dataset.value
if hasattr(problem_id, 'value'): problem_id = problem_id.value
if not selected_models or not dataset or not problem_id:
return None, {}
- # Get problem data first (uses cache/fast in-memory lookup)
problem, _ = self.get_problem_data(None, dataset, problem_id)
if not problem:
print(f"Problem data not found for {problem_id} in get_model_responses")
return None, {}
model_responses_data = {}
- # Get the *real* model names from the display names using the stored map
- real_model_names_map = {} # Map display name -> real name
+ real_model_names_map = {}
real_names_list = []
for model_display in selected_models:
model_display_val = model_display.value if hasattr(model_display, 'value') else model_display
- # Use comp_model_display_to_real if available, otherwise model_display_to_real
real_name = self.comp_model_display_to_real.get(model_display_val) or self.model_display_to_real.get(model_display_val)
- # Fallback if map lookup fails (try parsing)
if not real_name:
raw_name_part = model_display_val.split(" (")[0]
- # Reverse lookup MODEL_TRANS
for db_name, display_lookup in MODEL_TRANS.items():
- if display_lookup == raw_name_part:
- real_name = db_name
- break
- if not real_name: # If still not found, assume display name *is* real name (less accuracy suffix)
- real_name = raw_name_part
- print(f"Warning: Could not map display name '{model_display_val}' to real name via maps. Using inferred '{real_name}'.")
-
- if real_name: # Ensure we have a name to query
+ if display_lookup == raw_name_part: real_name = db_name; break
+ if not real_name: real_name = raw_name_part
+ print(f"Warning: Using fallback lookup/parsing for model name: '{model_display_val}' -> '{real_name}'.")
+
+ if real_name:
real_model_names_map[model_display_val] = real_name
- if real_name not in real_names_list: # Avoid duplicates in IN clause
- real_names_list.append(real_name)
+ if real_name not in real_names_list: real_names_list.append(real_name)
if not real_names_list:
print("No valid real model names found to query.")
- return problem, {} # Return problem data but empty responses
+ return problem, {}
# Optimized: Fetch all relevant responses in a single query
try:
cursor = self.conn.cursor()
placeholders = ','.join('?' * len(real_names_list))
+ # Rely on index idx_responses_model_dataset for the IN clause + other filters
query = f"""
- SELECT * FROM responses
+ SELECT * FROM responses -- INDEXED BY ??? (idx_responses_model_dataset likely used)
WHERE model_name IN ({placeholders}) AND dataset = ? AND unique_id = ?
- ORDER BY model_name, correctness DESC, response_id -- Prioritize correct responses
+ ORDER BY model_name, correctness DESC, response_id
"""
params = real_names_list + [dataset.lower(), problem_id]
cursor.execute(query, params)
all_fetched_responses = cursor.fetchall()
- # Group responses by *real* model name, keeping only the best (correct first, then by ID)
responses_by_real_model = {}
for resp_row in all_fetched_responses:
resp_dict = dict(resp_row)
model = resp_dict['model_name']
- if model not in responses_by_real_model: # Only store the first one encountered (due to ORDER BY)
+ if model not in responses_by_real_model: # Only store the first (best) one
responses_by_real_model[model] = resp_dict
- # Populate the result dictionary using display names as keys
for display_name, real_name in real_model_names_map.items():
- model_responses_data[display_name] = responses_by_real_model.get(real_name) # Will be None if no response found
+ model_responses_data[display_name] = responses_by_real_model.get(real_name)
except sqlite3.Error as e:
- print(f"Database error in bulk get_model_responses: {e}. Falling back to individual fetches.")
- # Fallback to individual fetching using get_problem_data (which uses cache)
+ print(f"Database error in bulk get_model_responses: {e}. Falling back.")
+ # Fallback to individual fetches (uses cache)
for display_name, real_name in real_model_names_map.items():
_ , responses_for_model = self.get_problem_data(real_name, dataset, problem_id)
if responses_for_model:
- # Find correct one first, otherwise take first response
correct_resp = next((r for r in responses_for_model if r.get('correctness') == 1), None)
model_responses_data[display_name] = correct_resp if correct_resp else responses_for_model[0]
- else:
- model_responses_data[display_name] = None
+ else: model_responses_data[display_name] = None
return problem, model_responses_data
def clear_cache(self, section=None):
- """Clear specified cache sections."""
+ # (Unchanged - Python cache clearing is still relevant)
print(f"Clearing cache section: {section if section else 'All'}")
cleared_something = False
if section == 'main' or section is None:
@@ -535,7 +471,6 @@ class ModelDatabase:
self._response_cache = {}
print(f"Cleared response cache ({count} items).")
cleared_something = True
- # Clear model/dataset list caches
if section == 'models' or section is None:
if hasattr(self, '_models_cache') and self._models_cache is not None:
self._models_cache = None
@@ -545,36 +480,28 @@ class ModelDatabase:
self._datasets_cache = None
print("Cleared datasets list cache.")
cleared_something = True
-
- if cleared_something:
- print("Running garbage collection...")
- gc.collect() # Explicitly trigger garbage collection
- else:
- print("Cache section(s) already empty or invalid section specified.")
-
+ if cleared_something: print("Running garbage collection..."); gc.collect()
+ else: print("Cache section(s) already empty or invalid section specified.")
def close(self):
- """Close the database connection."""
+ # (Unchanged - closing the disk connection)
print("Closing database connection...")
if hasattr(self, 'conn') and self.conn:
try:
- # Optional: Backup in-memory changes to disk if needed (not in this scenario)
- # Optional: Run final pragmas like optimize before closing if desired
+ # Maybe run optimize before closing large WAL file? Might take time.
+ # print("Running PRAGMA optimize...")
# self.conn.execute("PRAGMA optimize;")
self.conn.close()
- self.conn = None # Ensure the attribute is None after closing
- print("In-memory database connection closed.")
+ self.conn = None
+ print("Database connection closed.")
except sqlite3.Error as e:
print(f"Error closing database connection: {e}")
- else:
- print("Database connection already closed or never established.")
- # Clear caches on close as well
+ else: print("Database connection already closed or never established.")
self.clear_cache()
-# <<< Keep helper functions: format_latex, format_markdown_with_math, >>>
-# <<< get_gradient_color, get_contrasting_text_color, format_sample_metadata, >>>
-# <<< format_sample_response >>>
+# --- Helper functions (format_*, get_color_*, etc.) ---
+# (These remain unchanged as they don't depend on the DB access method)
def format_latex(text):
if text is None: return ""
text = text.replace('\n', '
')
@@ -583,1170 +510,332 @@ def format_latex(text):
def format_markdown_with_math(text):
if text is None: return ""
text = text.replace('\r\n', '\n').replace('\r', '\n')
- # Ensure math delimiters are properly handled by Gradio's Markdown component
- # No need for complex regex if Gradio handles $, $$, \(, \), \[, \]
return text
def get_gradient_color(accuracy, color_map='RdYlGn'):
if accuracy is None or not isinstance(accuracy, (int, float)) or not (0.0 <= accuracy <= 1.0):
- return "#808080" # Use gray for invalid/missing accuracy
+ return "#808080"
try:
- # Use the specified colormap
cmap = plt.colormaps.get_cmap(color_map)
- # Apply a power transform to make colors darker/more distinct, especially greens
- # Power < 1 darkens low values more, Power > 1 darkens high values more
power_adjust = 0.7
rgba = cmap(accuracy ** power_adjust)
- # Convert RGBA to Hex
hex_color = mpl.colors.rgb2hex(rgba)
return hex_color
except Exception as e:
print(f"Error generating gradient color for accuracy {accuracy}: {e}")
- return "#808080" # Fallback gray
+ return "#808080"
def get_contrasting_text_color(bg_color_hex):
- """Calculate contrasting text color (black or white) for a given hex background."""
try:
if not bg_color_hex or not bg_color_hex.startswith('#') or len(bg_color_hex) != 7:
- return "#000000" # Default to black for invalid input
-
- # Convert hex to RGB
- r = int(bg_color_hex[1:3], 16)
- g = int(bg_color_hex[3:5], 16)
- b = int(bg_color_hex[5:7], 16)
-
- # Calculate luminance using the WCAG formula (more accurate than YIQ for accessibility)
- # Normalize RGB values to 0-1 range
+ return "#000000"
+ r = int(bg_color_hex[1:3], 16); g = int(bg_color_hex[3:5], 16); b = int(bg_color_hex[5:7], 16)
rgb = [val / 255.0 for val in (r, g, b)]
- # Apply gamma correction approximation
rgb_corrected = [((val / 12.92) if val <= 0.03928 else ((val + 0.055) / 1.055) ** 2.4) for val in rgb]
- # Calculate relative luminance
luminance = 0.2126 * rgb_corrected[0] + 0.7152 * rgb_corrected[1] + 0.0722 * rgb_corrected[2]
-
- # WCAG contrast ratio threshold is complex, but a luminance threshold works well for black/white text
- # Threshold of 0.179 is often cited, but empirical testing might be needed
- # Let's use a slightly higher threshold towards white text for better readability on mid-tones
- return "#000000" if luminance > 0.22 else "#FFFFFF" # Black text on lighter backgrounds, White on darker
-
+ return "#000000" if luminance > 0.22 else "#FFFFFF"
except Exception as e:
print(f"Error calculating contrasting color for {bg_color_hex}: {e}")
- return "#000000" # Default to black on error
-
+ return "#000000"
def format_sample_metadata(sample, show_correctness=True):
- """Generates HTML for sample metadata display."""
if sample is None: return "
Database Connection Error
"), gr.State([]) - - # --- Get values from Gradio State objects --- + if not db or not db.conn: return "DB Error.", "DB Error.", gr.HTML("DB Conn Error
"), gr.State([]) model_name = current_model_state.value if hasattr(current_model_state, 'value') else current_model_state dataset_name = current_dataset_state.value if hasattr(current_dataset_state, 'value') else current_dataset_state problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state - - # --- Input Validation --- - if not dataset_name: - return "Please select a dataset first.", "N/A", gr.HTML(""), gr.State([]) - if not problem_id: - return "Problem ID is missing.", "N/A", gr.HTML(""), gr.State([]) - # Model name is required unless in comparison mode initial load - if not model_name and mode != 'comparison_initial_problem_load': - # In single mode, model is always required here - if mode == 'default': - return "Please select a model first.", "N/A", gr.HTML(""), gr.State([]) - # In comparison mode, if model state is None, means user selected problem before model - # We can still show problem/answer, but no samples for that side yet. - # Let's handle this by fetching problem/answer but returning empty samples for the specific call. - pass # Allow proceeding without model name in comparison mode to fetch problem/answer - - # --- Reconstruct full problem ID if only number is entered --- - original_problem_id = problem_id # Keep original for messages + original_problem_id = problem_id + if not dataset_name: return "Select dataset.", "N/A", gr.HTML(""), gr.State([]) + if not problem_id: return "Enter problem ID.", "N/A", gr.HTML(""), gr.State([]) + if not model_name and mode == 'default': return "Select model.", "N/A", gr.HTML(""), gr.State([]) if problem_id.isdigit(): - parts = dataset_name.split('-') - if len(parts) == 2: - language, difficulty = parts - problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" - print(f"Reconstructed problem ID: {problem_id}") - else: - # Cannot reconstruct, use the entered value but it might fail - print(f"Warning: Could not reconstruct full ID from number '{problem_id}' and dataset '{dataset_name}'") - - - # --- Fetch Data --- + parts = dataset_name.split('-'); + if len(parts) == 2: language, difficulty = parts; problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}"; print(f"Reconstructed ID: {problem_id}") + else: print(f"Warning: Could not reconstruct ID from {original_problem_id} and {dataset_name}") try: - # Fetch problem details and responses for the specific model (if provided) problem_data, responses_data = db.get_problem_data(model_name, dataset_name, problem_id) - - if not problem_data: - return f"Problem ID '{original_problem_id}' not found for dataset '{dataset_name}'.", "N/A", gr.HTML(f"Problem ID '{original_problem_id}' not found.
"), gr.State([]) - - # Ensure problem_data is a dict + if not problem_data: return f"Problem '{original_problem_id}' not found.", "N/A", gr.HTML(f"ID '{original_problem_id}' not found.
"), gr.State([]) problem_dict = dict(problem_data) if problem_data else {} - - # --- Format Problem and Answer --- - problem_content = format_markdown_with_math(problem_dict.get('problem', '*(Problem text not available)*')) - answer_text = problem_dict.get('answer', '*(Answer not available)*') + problem_content = format_markdown_with_math(problem_dict.get('problem', 'N/A')) + answer_text = problem_dict.get('answer', 'N/A') answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) - if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*('): - answer_text = f"${answer_text}$" + if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*(') and answer_text != 'N/A': answer_text = f"${answer_text}$" answer_content = format_markdown_with_math(answer_text) - - # --- Handle Responses and Generate Sample Grid --- - if responses_data is None: # Indicates a DB error occurred fetching responses - samples_grid_html = gr.HTML("Error fetching model responses.
") - samples_data_for_state = gr.State([]) # Empty state on error - elif not responses_data: # Empty list means no responses found - samples_grid_html = gr.HTML("No responses found for this model on this problem.
") - samples_data_for_state = gr.State([]) # Empty state + if responses_data is None: samples_grid_html = gr.HTML("Error fetching responses.
"); samples_data_for_state = gr.State([]) + elif not responses_data: samples_grid_html = gr.HTML("No responses found.
"); samples_data_for_state = gr.State([]) else: - # responses_data should already be a list of dicts from get_problem_data - samples_data = responses_data # Use directly - samples_data_for_state = gr.State(samples_data) # Store the list in state - + samples_data = responses_data + samples_data_for_state = gr.State(samples_data) correct_count = sum(1 for r in samples_data if r.get('correctness') == 1) - total_samples = len(samples_data) - accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 - - # --- Generate Sample Grid HTML (with onclick) --- - displayed_samples = samples_data[:64] # Limit display - actual_display_count = len(displayed_samples) - # Determine grid columns based on mode + total_samples = len(samples_data); accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 + displayed_samples = samples_data[:64]; actual_display_count = len(displayed_samples) samples_per_row = 16 if mode.startswith('comparison') else 32 - num_rows = math.ceil(actual_display_count / samples_per_row) - grid_html_content = "" - - # Determine the correct JS function call based on mode - js_mode = "'default'" # Default for single model tab - if mode == 'comparison_left': js_mode = "'comparison_left'" - elif mode == 'comparison_right': js_mode = "'comparison_right'" - + num_rows = math.ceil(actual_display_count / samples_per_row); grid_html_content = "" + js_mode = "'comparison_left'" if mode == 'comparison_left' else "'comparison_right'" if mode == 'comparison_right' else "'default'" for row_idx in range(num_rows): grid_html_content += f'{error_msg}
"), gr.State([]) + except Exception as e: print(f"Error in handle_problem_select for {problem_id}, model {model_name}, dataset {dataset_name}: {e}"); import traceback; traceback.print_exc(); error_msg = f"**Internal Error:** {str(e)}"; return error_msg, "Error", gr.HTML(f"{error_msg}
"), gr.State([]) -# <<< Keep create_problem_grid_html, modified to use onclick >>> +# --- UI Creation function (create_problem_grid_html, create_ui) --- +# (These remain unchanged as they interact with the DB class interface and handlers) def create_problem_grid_html(problems, mode='default'): - """Create HTML for problem grid buttons with onclick handlers.""" - if not problems: - return "DB Error.
"), None - if not selected_model_formatted or not selected_dataset: - return gr.DataFrame(value=[]), gr.HTML("Select model and dataset."), None - - # Use the appropriate map (comp or default) to get the real model name - real_model_name = None - if mode == 'comparison': - real_model_name = db.comp_model_display_to_real.get(selected_model_formatted) - if not real_model_name: # Fallback to default map or parsing if comp map missed - real_model_name = db.model_display_to_real.get(selected_model_formatted) - - # If still not found via map, try parsing (less reliable) + if not db or not db.conn: print("Error: DB not available"); return gr.DataFrame(value=[]), gr.HTML("DB Error.
"), None + if not selected_model_formatted or not selected_dataset: return gr.DataFrame(value=[]), gr.HTML("Select model and dataset."), None + real_model_name = None; + if mode == 'comparison': real_model_name = db.comp_model_display_to_real.get(selected_model_formatted) + if not real_model_name: real_model_name = db.model_display_to_real.get(selected_model_formatted) if not real_model_name: - raw_name_part = selected_model_formatted.split(" (")[0] + raw_name_part = selected_model_formatted.split(" (")[0]; for db_name, display_lookup in MODEL_TRANS.items(): - if display_lookup == raw_name_part: - real_model_name = db_name - break - if not real_model_name: real_model_name = raw_name_part # Assume it's the real name - print(f"Warning: Using fallback lookup/parsing for model name: '{selected_model_formatted}' -> '{real_model_name}'") - - if not real_model_name: # Final check if name resolution failed - print(f"Error: Could not determine real model name for '{selected_model_formatted}'") - return gr.DataFrame(value=[]), gr.HTML("Internal error resolving model name.
"), None - - # Fetch data using the resolved real model name - stats_data = db.get_model_statistics(real_model_name, selected_dataset) - problem_list = db.get_problems_by_model_dataset(real_model_name, selected_dataset) - grid_html = create_problem_grid_html(problem_list, mode=mode) - - # Return stats DF, grid HTML, and the *real* model name for state update + if display_lookup == raw_name_part: real_model_name = db_name; break + if not real_model_name: real_model_name = raw_name_part + print(f"Warning: Using fallback lookup for model name: '{selected_model_formatted}' -> '{real_model_name}'") + if not real_model_name: print(f"Error: Could not determine real model name for '{selected_model_formatted}'"); return gr.DataFrame(value=[]), gr.HTML("Error.
"), None + stats_data = db.get_model_statistics(real_model_name, selected_dataset); problem_list = db.get_problems_by_model_dataset(real_model_name, selected_dataset); grid_html = create_problem_grid_html(problem_list, mode=mode) return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), real_model_name - - - # Helper function to clear problem/sample displays - def clear_problem_outputs(): - return "Select a problem.", "N/A", gr.HTML(""), gr.State([]), "Select a sample.", "*(Response)*", "0" - - # Helper function to clear comparison problem/sample displays for one side - def clear_comparison_side_outputs(): - # Doesn't clear the main problem/answer display, only the side's samples - return gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" - - - # --- Single Model Tab Event Connections --- - - # Dataset selection changes model dropdown options and clears everything else - dataset_radio_single.change( - fn=update_available_models_for_dropdowns, - inputs=[dataset_radio_single], - outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] # Update all model lists - ).then( # Chain: Reset dependent states and UI components - lambda ds: (gr.DataFrame(value=[]), gr.HTML("Select model."), None, ds, "", # Clear stats, grid, model state, problem ID state - *clear_problem_outputs(), # Clear problem/answer/samples/state/metadata/response/index state - ""), # Clear problem input display - inputs=[dataset_radio_single], - outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_id_state, - problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, - sample_metadata_output, sample_response_output, sample_index_state, - problem_id_input_display] # Also clear the visible input box - ) - - # Model selection updates stats, grid, and clears problem/sample display - model_dropdown.change( - fn=update_problem_grid_and_stats, # Mode defaults to 'default' - inputs=[model_dropdown, current_dataset_state], - outputs=[model_stats_df, problem_grid_html_output, current_model_state] # Update stats, grid, REAL model state - ).then( # Chain: Clear problem-specific outputs - lambda: ("", # Clear problem ID state - *clear_problem_outputs(), # Clear problem/answer/samples/state/metadata/response/index state - ""), # Clear problem input display - inputs=[], - outputs=[problem_id_state, - problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, - sample_metadata_output, sample_response_output, sample_index_state, - problem_id_input_display] - ) - - # --- Problem Selection Handling (Single Tab) --- - # Option 1: User types into the visible input box - problem_id_input_display.submit( # Trigger on Enter press - # Copy value from visible input to hidden state input to trigger main handler - fn=lambda x: x, - inputs=[problem_id_input_display], - outputs=[problem_id_state] # This change triggers the next handler - ) - # Option 2: User clicks problem grid (JS updates hidden problem_id_state) - # Main handler triggered by change in hidden problem_id_state - problem_id_state.change( - fn=handle_problem_select, # Fetches problem, answer, sample grid, sample data state - inputs=[problem_id_state, current_model_state, current_dataset_state], # Mode defaults to 'default' - outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] - ).then( # Chain: Display the first sample after data is loaded - fn=handle_first_sample, - inputs=[current_samples_data_state], - outputs=[sample_metadata_output, sample_response_output] - ).then( # Chain: Reset sample index state to 0 - lambda: "0", inputs=[], outputs=[sample_index_state] - ).then( # Chain: Update the visible input box to reflect the selected ID (useful if clicked) - fn=lambda x: x.value if hasattr(x,'value') else x, # Get value from state object - inputs=[problem_id_state], - outputs=[problem_id_input_display] - ) - - - # --- Sample Selection Handling (Single Tab) --- - # Triggered by change in hidden sample_index_state (updated by JS sample click) - sample_index_state.change( - fn=handle_sample_select, - inputs=[sample_index_state, current_samples_data_state], - outputs=[sample_metadata_output, sample_response_output] - ) - - - # --- Comparison Tab Event Connections --- - - # Dataset change updates dropdowns and clears everything - comp_dataset_radio.change( - fn=lambda ds: ds, # Update comparison dataset state first - inputs=[comp_dataset_radio], - outputs=[comp_dataset_state] - ).then( - fn=update_available_models_for_dropdowns, - inputs=[comp_dataset_state], - outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] - ).then( # Clear everything dependent on dataset/models/problem - lambda: (None, None, # Clear model states - "Select models and problem.", "Select problem.", # Clear problem/answer display - gr.HTML("Select model."), gr.HTML("Select model."), # Clear grids - *clear_comparison_side_outputs(), # Clear left samples - *clear_comparison_side_outputs(), # Clear right samples - "", # Clear comp problem ID state - ""), # Clear comp problem input display - inputs=[], - outputs=[comp_model_state_left, comp_model_state_right, - comp_problem_markdown_output, comp_answer_markdown_output, - comp_problem_grid_html_output_left, comp_problem_grid_html_output_right, - # Left side sample outputs + state - comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left, comp_sample_index_state_left, - # Right side sample outputs + state - comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right, comp_sample_index_state_right, - comp_problem_id_state, comp_problem_id_input_display] - ) - - # Left model selection - comp_model_dropdown_left.change( - # 1. Update grid and model state for left side - fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), - inputs=[comp_model_dropdown_left, comp_dataset_state], - # Output to dummy state for stats DF, then grid HTML, then REAL model state - outputs=[dummy_state, comp_problem_grid_html_output_left, comp_model_state_left] - ).then( - # 2. If a problem is already selected, refetch left samples - # Use 'comparison_left' mode for handle_problem_select - fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), - inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], - # Outputs for handle_problem_select: problem, answer, sample_grid, sample_state - # We only care about sample_grid and sample_state here - outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left] - ).then( - # 3. Display first sample for left side (if data exists) - fn=handle_first_sample, - inputs=[comp_samples_data_state_left], - outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] - ).then( - # 4. Reset left sample index state - lambda: "0", inputs=[], outputs=[comp_sample_index_state_left] - ) - - # Right model selection (mirrors left side logic) - comp_model_dropdown_right.change( - fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), - inputs=[comp_model_dropdown_right, comp_dataset_state], - outputs=[dummy_state, comp_problem_grid_html_output_right, comp_model_state_right] - ).then( - # Use 'comparison_right' mode for handle_problem_select - fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), - inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], - outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right] - ).then( - fn=handle_first_sample, - inputs=[comp_samples_data_state_right], - outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] - ).then( - lambda: "0", inputs=[], outputs=[comp_sample_index_state_right] - ) - - # --- Problem Selection Handling (Comparison Tab) --- - # Option 1: User types into the visible input box - comp_problem_id_input_display.submit( - fn=lambda x: x, inputs=[comp_problem_id_input_display], outputs=[comp_problem_id_state] - ) - # Option 2: User clicks problem grid (JS updates hidden comp_problem_id_state) - # Main handler triggered by change in hidden comp_problem_id_state - comp_problem_id_state.change( - # 1. Update main problem/answer display (doesn't need model info) - fn=handle_comparison_problem_update, - inputs=[comp_problem_id_state, comp_dataset_state], - outputs=[comp_problem_markdown_output, comp_answer_markdown_output] - ).then( - # 2. Update left samples (if left model is selected) - fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if model_state.value else ("","",gr.HTML(""), gr.State([])), - inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], - outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left] - ).then( - # 3. Display first left sample - fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] - ).then( - # 4. Reset left sample index - lambda: "0", inputs=[], outputs=[comp_sample_index_state_left] - ).then( - # 5. Update right samples (if right model is selected) - fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if model_state.value else ("","",gr.HTML(""), gr.State([])), - inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], - outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right] - ).then( - # 6. Display first right sample - fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] - ).then( - # 7. Reset right sample index - lambda: "0", inputs=[], outputs=[comp_sample_index_state_right] - ).then( - # 8. Update the visible input box to reflect the selected ID - fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[comp_problem_id_state], outputs=[comp_problem_id_input_display] - ) - - - # --- Sample Selection Handling (Comparison Tab) --- - # Left sample selection (triggered by JS updating hidden state) - comp_sample_index_state_left.change( - fn=handle_sample_select, - inputs=[comp_sample_index_state_left, comp_samples_data_state_left], - outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] - ) - # Right sample selection (triggered by JS updating hidden state) - comp_sample_index_state_right.change( - fn=handle_sample_select, - inputs=[comp_sample_index_state_right, comp_samples_data_state_right], - outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] - ) - - # --- Initial Population on App Load --- - # Populate model dropdowns based on the initial dataset state - demo.load( - fn=update_available_models_for_dropdowns, - inputs=[current_dataset_state], # Use initial state for single tab dataset - outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] - ) - + def clear_problem_outputs(): return "Select problem.", "N/A", gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" + def clear_comparison_side_outputs(): return gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" + # Single Model Event Connections + dataset_radio_single.change(fn=update_available_models_for_dropdowns, inputs=[dataset_radio_single], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]).then(lambda ds: (gr.DataFrame(value=[]), gr.HTML("Select model."), None, ds, "", *clear_problem_outputs(), ""), inputs=[dataset_radio_single], outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_id_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output, sample_index_state, problem_id_input_display]) + model_dropdown.change(fn=update_problem_grid_and_stats, inputs=[model_dropdown, current_dataset_state], outputs=[model_stats_df, problem_grid_html_output, current_model_state]).then(lambda: ("", *clear_problem_outputs(), ""), inputs=[], outputs=[problem_id_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output, sample_index_state, problem_id_input_display]) + problem_id_input_display.submit(fn=lambda x: x, inputs=[problem_id_input_display], outputs=[problem_id_state]) + problem_id_state.change(fn=handle_problem_select, inputs=[problem_id_state, current_model_state, current_dataset_state], outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state]).then(fn=handle_first_sample, inputs=[current_samples_data_state], outputs=[sample_metadata_output, sample_response_output]).then(lambda: "0", inputs=[], outputs=[sample_index_state]).then(fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[problem_id_state], outputs=[problem_id_input_display]) + sample_index_state.change(fn=handle_sample_select, inputs=[sample_index_state, current_samples_data_state], outputs=[sample_metadata_output, sample_response_output]) + # Comparison Tab Event Connections + comp_dataset_radio.change(fn=lambda ds: ds, inputs=[comp_dataset_radio], outputs=[comp_dataset_state]).then(fn=update_available_models_for_dropdowns, inputs=[comp_dataset_state], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]).then(lambda: (None, None, "Select models and problem.", "Select problem.", gr.HTML("Select model."), gr.HTML("Select model."), *clear_comparison_side_outputs(), *clear_comparison_side_outputs(), "", ""), inputs=[], outputs=[comp_model_state_left, comp_model_state_right, comp_problem_markdown_output, comp_answer_markdown_output, comp_problem_grid_html_output_left, comp_problem_grid_html_output_right, comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left, comp_sample_index_state_left, comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right, comp_sample_index_state_right, comp_problem_id_state, comp_problem_id_input_display]) + comp_model_dropdown_left.change(fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), inputs=[comp_model_dropdown_left, comp_dataset_state], outputs=[dummy_state, comp_problem_grid_html_output_left, comp_model_state_left]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_left]) + comp_model_dropdown_right.change(fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), inputs=[comp_model_dropdown_right, comp_dataset_state], outputs=[dummy_state, comp_problem_grid_html_output_right, comp_model_state_right]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_right]) + comp_problem_id_input_display.submit(fn=lambda x: x, inputs=[comp_problem_id_input_display], outputs=[comp_problem_id_state]) + comp_problem_id_state.change(fn=handle_comparison_problem_update, inputs=[comp_problem_id_state, comp_dataset_state], outputs=[comp_problem_markdown_output, comp_answer_markdown_output]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_left]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_right]).then(fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[comp_problem_id_state], outputs=[comp_problem_id_input_display]) + comp_sample_index_state_left.change(fn=handle_sample_select, inputs=[comp_sample_index_state_left, comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]) + comp_sample_index_state_right.change(fn=handle_sample_select, inputs=[comp_sample_index_state_right, comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]) + # Initial Load + demo.load(fn=update_available_models_for_dropdowns, inputs=[current_dataset_state], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]) return demo -# <<< Keep monitor_memory_usage function >>> + +# --- Memory Monitor (Adjust thresholds for 10GB limit) --- def monitor_memory_usage(): - """Monitors memory usage and clears caches if thresholds are exceeded.""" + """Monitors memory usage and clears caches if thresholds are exceeded (10GB Limit).""" global db try: process = psutil.Process(os.getpid()) - # Use resident set size (RSS) as a measure of physical RAM usage - memory_info = process.memory_info() - memory_usage_mb = memory_info.rss / (1024 * 1024) - # Get total system memory for context - total_memory_gb = psutil.virtual_memory().total / (1024**3) - print(f"[Memory Monitor] Usage: {memory_usage_mb:.1f} MB / {total_memory_gb:.1f} GB available.") - - # Define thresholds based on 18GB total available RAM - # Threshold 1: ~70% = 12.6 GB - threshold_1_mb = 18 * 1024 * 0.70 - # Threshold 2: ~85% = 15.3 GB - threshold_2_mb = 18 * 1024 * 0.85 - - if db and db.conn: # Only clear caches if DB is active - if memory_usage_mb > threshold_2_mb: - print(f"[Memory Monitor] CRITICAL: Usage ({memory_usage_mb:.1f} MB) exceeds {threshold_2_mb:.1f} MB. Clearing ALL caches.") - db.clear_cache() # Clear all Python-level caches - elif memory_usage_mb > threshold_1_mb: - print(f"[Memory Monitor] WARNING: Usage ({memory_usage_mb:.1f} MB) exceeds {threshold_1_mb:.1f} MB. Clearing response cache.") - db.clear_cache('response') # Clear less critical response cache first - else: - print("[Memory Monitor] DB not active, skipping cache check.") + memory_info = process.memory_info(); memory_usage_mb = memory_info.rss / (1024 * 1024) + total_memory_gb = 10.0 # Assumed total available + print(f"[Memory Monitor] Usage: {memory_usage_mb:.1f} MB / {total_memory_gb:.1f} GB limit.") + # Define thresholds based on 10GB limit + threshold_1_mb = 10 * 1024 * 0.70 # Warn at 7GB + threshold_2_mb = 10 * 1024 * 0.85 # Critical at 8.5GB - # Return status string (optional, could be used in UI if needed) + if db and db.conn: + if memory_usage_mb > threshold_2_mb: + print(f"[Memory Monitor] CRITICAL: Usage ({memory_usage_mb:.1f} MB) > {threshold_2_mb:.1f} MB. Clearing ALL caches.") + db.clear_cache() + elif memory_usage_mb > threshold_1_mb: + print(f"[Memory Monitor] WARNING: Usage ({memory_usage_mb:.1f} MB) > {threshold_1_mb:.1f} MB. Clearing response cache.") + db.clear_cache('response') + else: print("[Memory Monitor] DB not active.") return f"Memory OK: {memory_usage_mb:.1f} MB" + except Exception as e: print(f"[Memory Monitor] Error: {e}"); return "Memory monitor error" - except Exception as e: - print(f"[Memory Monitor] Error: {e}") - return "Memory monitor error" -# <<< Keep __main__ block, ensuring DB initialization happens before UI creation >>> +# --- Main execution block --- +# (Initialization connects to disk, UI launch remains the same) if __name__ == "__main__": - DB_FILE_NAME = "data.db" - # Determine expected path (e.g., in the current directory) - DB_PATH = os.path.abspath(DB_FILE_NAME) - - # --- Database Download/Check --- + DB_FILE_NAME = "data.db"; DB_PATH = os.path.abspath(DB_FILE_NAME) if not os.path.exists(DB_PATH): - print(f"{DB_PATH} not found. Attempting to download from Hugging Face Hub...") + print(f"{DB_PATH} not found. Attempting download...") try: - # Attempt to get token from environment or local file hf_token = os.environ.get("HF_TOKEN") - if not hf_token: - token_path = Path.home() / ".huggingface" / "token" - if token_path.exists(): - print("Using token from ~/.huggingface/token") - hf_token = token_path.read_text().strip() - # Optional: Add interactive prompt for token if needed - # else: hf_token = input("Enter your Hugging Face token: ").strip() - - if not hf_token: - raise ValueError("Hugging Face token not found (set HF_TOKEN env var or place in ~/.huggingface/token).") - - print("Downloading data.db (this might take time)...") - # Download directly to the expected DB_PATH location's directory - downloaded_path = hf_hub_download( - repo_id="CoderBak/OlymMATH-data", - filename=DB_FILE_NAME, - repo_type="dataset", - token=hf_token, - local_dir=os.path.dirname(DB_PATH), # Download to target directory - local_dir_use_symlinks=False # Force copy, avoid symlink issues - ) - # Verify download path matches expected path - if os.path.abspath(downloaded_path) != DB_PATH: - print(f"Warning: Downloaded path '{downloaded_path}' differs from expected path '{DB_PATH}'.") - # Optionally, attempt to move/rename or update DB_PATH - # For simplicity, assume download worked if no exception - DB_PATH = os.path.abspath(downloaded_path) - - print(f"Database downloaded successfully to: {DB_PATH}") + if not hf_token: token_path = Path.home() / ".huggingface" / "token"; hf_token = token_path.read_text().strip() if token_path.exists() else None + if not hf_token: raise ValueError("HF token not found.") + print("Downloading data.db...") + downloaded_path = hf_hub_download(repo_id="CoderBak/OlymMATH-data", filename=DB_FILE_NAME, repo_type="dataset", token=hf_token, local_dir=os.path.dirname(DB_PATH), local_dir_use_symlinks=False) + DB_PATH = os.path.abspath(downloaded_path); print(f"Download complete: {DB_PATH}") except Exception as e: - print(f"Error downloading database: {e}") - # Display error UI and exit - with gr.Blocks() as error_demo: - gr.Markdown(f"# Error: Database Download Failed\n`{str(e)}`\nEnsure `{DB_FILE_NAME}` exists in the script's directory, or provide a valid Hugging Face token and network access.") - error_demo.launch(server_name="0.0.0.0", show_error=True) - exit(1) - - # --- Database Initialization (Loads into Memory) --- - print(f"Initializing ModelDatabase from: {DB_PATH} (loading data into memory)...") + print(f"Error downloading DB: {e}"); + with gr.Blocks() as error_demo: gr.Markdown(f"# Error: DB Download Failed\n`{str(e)}`\nEnsure `{DB_FILE_NAME}` exists or HF token is valid."); error_demo.launch(server_name="0.0.0.0"); exit(1) + + print(f"Initializing ModelDatabase from disk: {DB_PATH}...") start_init = time.time() try: - # Instantiate the database class - this performs the load - db = ModelDatabase(DB_PATH) # db becomes the global instance + db = ModelDatabase(DB_PATH) # Connects to disk except Exception as e: - print(f"Fatal Error during ModelDatabase initialization: {e}") - import traceback - traceback.print_exc() # Print full traceback for debugging - # Display error UI and exit - with gr.Blocks() as error_demo: - gr.Markdown(f"# Error: Database Initialization Failed\n`{str(e)}`\nCheck database file integrity, permissions, and available memory ({psutil.virtual_memory().available / (1024**3):.1f} GB free).") - error_demo.launch(server_name="0.0.0.0", show_error=True) - exit(1) - end_init = time.time() - print(f"ModelDatabase initialized in {end_init - start_init:.2f} seconds.") - # Initial memory check after load - monitor_memory_usage() - - # --- Cleanup Registration --- - def cleanup(): - global db - print("\nRunning cleanup before exit...") - if db: - db.close() # Ensures connection is closed properly - print("Cleanup finished.") - atexit.register(cleanup) # Register cleanup to run on script exit - - # --- Create and Launch UI --- - print("Creating Gradio UI...") - # Pass the initialized (in-memory) db instance to the UI creation function - main_demo = create_ui(db) - - # --- Optional: Periodic Memory Monitor --- - # Use Gradio's `every` parameter on a dummy component if you want checks tied to UI activity, - # or run a separate thread (carefully). For simplicity, we'll rely on checks during requests (if implemented) - # or manual monitoring based on the initial check. - - print("Launching Gradio application...") - # Use queue() for better performance under load - main_demo.queue().launch( - server_name="0.0.0.0", # Listen on all interfaces - share=os.environ.get("GRADIO_SHARE", False), # Allow sharing via env var if needed - inbrowser=False, # Don't automatically open browser - show_error=True, # Show Python errors in browser console - # Optional: Increase workers if needed, but be mindful of memory/CPU - # max_threads=4 - ) - - # Server runs here until interrupted (Ctrl+C) - print("Application is running. Press Ctrl+C to stop.") - # Keep the main thread alive if needed (though launch() usually blocks) - # while True: time.sleep(1) \ No newline at end of file + print(f"Fatal Error during DB initialization: {e}"); import traceback; traceback.print_exc() + with gr.Blocks() as error_demo: gr.Markdown(f"# Error: DB Init Failed\n`{str(e)}`\nCheck file/permissions."); error_demo.launch(server_name="0.0.0.0"); exit(1) + end_init = time.time(); print(f"ModelDatabase initialized in {end_init - start_init:.2f} seconds.") + monitor_memory_usage() # Initial check + + def cleanup(): global db; print("\nRunning cleanup..."); db.close() if db else None; print("Cleanup finished.") + atexit.register(cleanup) + + print("Creating Gradio UI..."); main_demo = create_ui(db) + print("Launching Gradio application..."); + main_demo.queue().launch(server_name="0.0.0.0", share=os.environ.get("GRADIO_SHARE", False), inbrowser=False, show_error=True) + print("Application running. Press Ctrl+C to stop.")