diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -# <<< Keep all existing imports >>> import os import json import pandas as pd @@ -15,17 +14,16 @@ import time from huggingface_hub import hf_hub_download import psutil import gc -import atexit # Import atexit +import atexit -# <<< Keep SUBJECT_TRANS and MODEL_TRANS dictionaries >>> -# 翻译表 +# 翻译表 (Unchanged) SUBJECT_TRANS = { "代数": "Algebra", "数论": "Number Theory", "几何": "Geometry", "组合": "Combinatorics" } -# MODEL_TRANS +# MODEL_TRANS (Unchanged) MODEL_TRANS = { "acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B", "deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B", @@ -57,21 +55,21 @@ MODEL_TRANS = { "qwen3-0.6b": "Qwen3-0.6B" } -# <<< Keep Matplotlib configuration >>> +# Matplotlib Config (Unchanged) plt.style.use('ggplot') mpl.rcParams['figure.figsize'] = (10, 6) mpl.rcParams['font.size'] = 10 -# <<< Keep Constants >>> +# Constants (Unchanged) DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] -# Global database instance +# Global DB Instance db = None class ModelDatabase: - """Database access class - Optimized to use in-memory database""" + """Database access class - Optimized for disk-based access""" def __init__(self, db_path): - """Initialize database connection by copying disk DB to memory.""" + """Initialize database connection directly to the disk file.""" self.db_path = db_path self.conn = None self._cache = {} @@ -80,61 +78,46 @@ class ModelDatabase: self.model_display_to_real = {} self.comp_model_display_to_real = {} - disk_conn = None try: - # 1. Connect to the source disk database in read-only mode - print(f"Connecting to source database (read-only): {db_path}") - # Ensure the file exists before trying to connect + print(f"Connecting to database file: {db_path}") if not os.path.exists(db_path): raise FileNotFoundError(f"Database file not found at {db_path}") - # Use a longer timeout just in case connection takes time, e.g. on network drives - disk_conn = sqlite3.connect(f'file:{db_path}?mode=ro', uri=True, check_same_thread=False, timeout=120) - - # --- Adjusted PRAGMAs for disk_conn (read-only source) --- - # Apply PRAGMAs that are safe for read-only and might speed up the backup read - print("Applying safe PRAGMAs to source connection for potentially faster read...") - # synchronous=OFF is generally safe for read-only access - disk_conn.execute("PRAGMA synchronous = OFF") - # Set a large cache size to potentially speed up reading from disk during backup - disk_conn.execute("PRAGMA cache_size = -2097152") # ~2GB cache, adjust if needed - # temp_store=MEMORY is safe and might help if temporary tables are used internally - disk_conn.execute("PRAGMA temp_store = MEMORY") - # Removed PRAGMA journal_mode = OFF as it caused the disk I/O error - # Removed PRAGMA locking_mode = EXCLUSIVE as it's not needed for read-only source - print("Read-safe PRAGMAs applied to source connection.") - # --- End of Adjusted PRAGMAs --- - - - # 2. Connect to the target in-memory database - print("Creating in-memory database...") - self.conn = sqlite3.connect(':memory:', check_same_thread=False, timeout=120) - self.conn.row_factory = sqlite3.Row - # 3. Backup data from disk to memory - print("Starting database backup from disk to memory (this may take a while)...") - start_backup = time.time() - with self.conn: - disk_conn.backup(self.conn) - end_backup = time.time() - print(f"Database backup completed in {end_backup - start_backup:.2f} seconds.") + # Connect directly to the database file + # Increased timeout for potentially slower disk operations + self.conn = sqlite3.connect(db_path, check_same_thread=False, timeout=120) + self.conn.row_factory = sqlite3.Row - # 4. Apply PRAGMAs suitable for the in-memory database - print("Applying PRAGMAs to in-memory database...") + # --- Apply PRAGMAs optimized for disk access --- + print("Applying PRAGMAs for disk-based access...") + # WAL mode generally provides better concurrency and performance + self.conn.execute("PRAGMA journal_mode = WAL") + # NORMAL synchronous is a good balance of safety and speed + self.conn.execute("PRAGMA synchronous = NORMAL") + # Allocate a cache size in KiB (e.g., 1GB = -1048576, 2GB = -2097152) + # Adjust based on available RAM (10GB total limit) + cache_size_kib = -1048576 # Start with 1GB cache + print(f"Setting cache_size to {cache_size_kib} KiB") + self.conn.execute(f"PRAGMA cache_size = {cache_size_kib}") + # Keep temporary storage in memory if possible self.conn.execute("PRAGMA temp_store = MEMORY") - # Optional: Set cache size for in-memory DB if desired, though less critical - # self.conn.execute("PRAGMA cache_size = -4194304") # e.g., 4GB + # Avoid setting mmap_size explicitly when DB >> RAM initially + # self.conn.execute("PRAGMA mmap_size = XXXXXX") # Experiment later if needed - # 5. Ensure indices exist on the in-memory database *after* data loading - print("Creating indices on in-memory database...") + # Ensure indices exist (critical for disk performance) + print("Ensuring indices exist...") start_index = time.time() - self._ensure_indices() # Operates on self.conn (memory DB) + self._ensure_indices() end_index = time.time() - print(f"Index creation completed in {end_index - start_index:.2f} seconds.") + # Index check/creation might be very fast if they already exist + print(f"Index check/creation completed in {end_index - start_index:.2f} seconds.") + + print("Database connection established successfully.") except sqlite3.Error as e: print(f"SQLite error during database initialization: {e}") if self.conn: self.conn.close(); self.conn = None - raise # Re-raise to signal failure + raise except FileNotFoundError as e: print(f"Error: {e}") raise @@ -142,96 +125,72 @@ class ModelDatabase: print(f"Unexpected error during database initialization: {e}") if self.conn: self.conn.close(); self.conn = None raise - finally: - # 6. Close the disk connection - if disk_conn: - disk_conn.close() - print("Closed connection to disk database.") - - if self.conn: - print("In-memory database initialized successfully.") - else: - # This path should ideally not be reached if exceptions are raised properly - print("Error: In-memory database connection failed.") - raise RuntimeError("Failed to establish in-memory database connection.") + + if not self.conn: + raise RuntimeError("Failed to establish database connection.") def _ensure_indices(self): - """Ensure necessary indices exist on the database connection (self.conn).""" + """Ensure necessary indices exist on the database connection.""" if not self.conn: print("Error: Connection not established. Cannot ensure indices.") return try: cursor = self.conn.cursor() - print("Creating index: idx_responses_model_dataset") + print("Checking/Creating index: idx_responses_model_dataset") cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_model_dataset ON responses(model_name, dataset)") - print("Creating index: idx_responses_unique_id") + print("Checking/Creating index: idx_responses_unique_id") cursor.execute("CREATE INDEX IF NOT EXISTS idx_responses_unique_id ON responses(unique_id)") - print("Creating index: idx_problems_unique_id") + print("Checking/Creating index: idx_problems_unique_id") cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_unique_id ON problems(unique_id)") - print("Creating index: idx_problems_subject") + print("Checking/Creating index: idx_problems_subject") cursor.execute("CREATE INDEX IF NOT EXISTS idx_problems_subject ON problems(subject)") - # Analyze the tables after creating indices for optimal query plans - print("Running ANALYZE...") + # Analyze is important for the query planner, especially on disk + print("Running ANALYZE (might take time on large DB)...") cursor.execute("ANALYZE") - self.conn.commit() # Commit index creation and analysis - print("Indices created and table analyzed successfully.") + self.conn.commit() + print("Indices checked/created and table analyzed.") except sqlite3.Error as e: - # Log error but don't necessarily crash the app print(f"Warning: Could not create or analyze indices: {e}") - # Attempt rollback if something failed partially - try: - self.conn.rollback() - except sqlite3.Error as rb_e: - print(f"Rollback attempt failed after index error: {rb_e}") - # Depending on severity, you might want to raise e here - - # <<< Methods get_available_models, get_available_datasets, get_model_statistics, >>> - # <<< get_all_model_accuracies, get_problems_by_model_dataset, get_problem_data, >>> - # <<< get_model_responses, clear_cache are modified to: >>> - # <<< 1. Remove INDEXED BY hints >>> - # <<< 2. Add checks for self.conn existence >>> - # <<< 3. Improve error logging >>> + # Attempt rollback + try: self.conn.rollback() + except sqlite3.Error as rb_e: print(f"Rollback attempt failed: {rb_e}") + + # --- Methods performing queries are adjusted to use INDEXED BY hints --- def get_available_models(self): - """Get list of all available models""" + # (No change needed here, simple query, uses cache) if not self.conn: return [] - # Check cache first if hasattr(self, '_models_cache') and self._models_cache is not None: return self._models_cache try: cursor = self.conn.cursor() - # Query without explicit index hints cursor.execute("SELECT DISTINCT model_name FROM responses ORDER BY model_name") models = [row['model_name'] for row in cursor.fetchall()] - self._models_cache = models # Store in instance cache + self._models_cache = models return models except sqlite3.Error as e: print(f"Database error in get_available_models: {e}") - return [] # Return empty list on error + return [] def get_available_datasets(self): - """Get list of all available datasets""" - if not self.conn: return DATASETS # Fallback if connection failed - # Check cache first + # (No change needed here, simple query, uses cache) + if not self.conn: return DATASETS if hasattr(self, '_datasets_cache') and self._datasets_cache is not None: return self._datasets_cache try: cursor = self.conn.cursor() - # Query without explicit index hints cursor.execute("SELECT DISTINCT dataset FROM responses ORDER BY dataset") - # Ensure uppercase consistency datasets = [row['dataset'].upper() for row in cursor.fetchall()] - self._datasets_cache = datasets # Store in instance cache + self._datasets_cache = datasets return datasets except sqlite3.Error as e: print(f"Database error in get_available_datasets: {e}") - return DATASETS # Fallback on error + return DATASETS def get_model_statistics(self, model_name, dataset): - """Get statistics for a model on a specific dataset""" + """Get statistics, using INDEXED BY hints for disk access.""" if not self.conn: return [["Database Error", "No connection"]] - # Sanitize inputs if hasattr(model_name, 'value'): model_name = model_name.value if hasattr(dataset, 'value'): dataset = dataset.value if not model_name or not dataset: return [["Input Error", "Missing model or dataset"]] @@ -242,10 +201,10 @@ class ModelDatabase: stats_data = [] try: cursor = self.conn.cursor() - # Query 1: Overall accuracy - No INDEXED BY hint + # Query 1: Overall accuracy - Use index hint cursor.execute(""" SELECT COUNT(*) as total_samples, AVG(correctness) as accuracy - FROM responses + FROM responses INDEXED BY idx_responses_model_dataset WHERE model_name = ? AND dataset = ? """, (model_name, dataset.lower())) overall_stats = cursor.fetchone() @@ -257,7 +216,9 @@ class ModelDatabase: else: stats_data.append(["Overall Acc.", "N/A"]) - # Query 2: Per-subject statistics - No INDEXED BY hint + # Query 2: Per-subject statistics - Join still needed, rely on indices + # (Adding explicit index hints on joins can sometimes be complex/less effective) + # Rely on ANALYZE and standard indices (idx_responses_unique_id, idx_problems_unique_id, idx_problems_subject) cursor.execute(""" SELECT p.subject, COUNT(r.id) as sample_count, AVG(r.correctness) as accuracy FROM responses r JOIN problems p ON r.unique_id = p.unique_id @@ -272,15 +233,14 @@ class ModelDatabase: translated_subject = SUBJECT_TRANS.get(subject_name, subject_name) stats_data.append([f"{translated_subject} Acc.", acc_val]) - self._cache[cache_key] = stats_data # Cache the result + self._cache[cache_key] = stats_data return stats_data except sqlite3.Error as e: print(f"Database error in get_model_statistics({model_name}, {dataset}): {e}") - # Return partial data if overall stats succeeded but subject failed? Or just error. return [["Database Error", f"Query failed: {e}"]] def get_all_model_accuracies(self, dataset): - """获取所有模型在特定数据集上的准确率""" + """Get all accuracies, using INDEXED BY hint.""" if not self.conn: return [] if hasattr(dataset, 'value'): dataset = dataset.value if not dataset: return [] @@ -290,22 +250,21 @@ class ModelDatabase: try: cursor = self.conn.cursor() - # No INDEXED BY hint needed, rely on idx_responses_model_dataset + # Use index hint for potentially faster filtering/grouping cursor.execute(""" SELECT model_name, AVG(correctness) as accuracy - FROM responses + FROM responses INDEXED BY idx_responses_model_dataset WHERE dataset = ? GROUP BY model_name ORDER BY accuracy DESC """, (dataset.lower(),)) - # Fetchall directly into list comprehension results = [(row['model_name'], row['accuracy']) for row in cursor.fetchall() if row['accuracy'] is not None] - self._cache[cache_key] = results # Cache result + self._cache[cache_key] = results return results except sqlite3.Error as e: print(f"Database error in get_all_model_accuracies({dataset}): {e}") return [] def get_problems_by_model_dataset(self, model_name, dataset): - """获取模型在特定数据集上的所有问题""" + """Get problems, using INDEXED BY hint for the primary table.""" if not self.conn: return [] if hasattr(model_name, 'value'): model_name = model_name.value if hasattr(dataset, 'value'): dataset = dataset.value @@ -316,205 +275,182 @@ class ModelDatabase: try: cursor = self.conn.cursor() - # No INDEXED BY hint, rely on indices on responses and problems tables - # Ensure AVG returns 0 if no correct responses, not NULL -> COALESCE(AVG(r.correctness), 0.0) + # Add index hint to the 'responses' table scan cursor.execute(""" SELECT r.unique_id, p.problem, COALESCE(AVG(r.correctness), 0.0) as accuracy - FROM responses r + FROM responses r INDEXED BY idx_responses_model_dataset JOIN problems p ON r.unique_id = p.unique_id WHERE r.model_name = ? AND r.dataset = ? GROUP BY r.unique_id, p.problem ORDER BY r.unique_id """, (model_name, dataset.lower())) - # Fetchall directly results = [(row['unique_id'], row['accuracy'], row['problem']) for row in cursor.fetchall()] - # Sort in Python - pre-compile regex for slight speedup + # Sorting in Python id_extractor = re.compile(r'\d+') def get_sort_key(problem_tuple): - match = id_extractor.search(problem_tuple[0]) # problem_tuple[0] is unique_id - # Handle cases where ID might not have numbers gracefully + match = id_extractor.search(problem_tuple[0]) return int(match.group(0)) if match else 0 - - # Sort the results list using the defined key sorted_results = sorted(results, key=get_sort_key) - self._cache[cache_key] = sorted_results # Cache the sorted list + self._cache[cache_key] = sorted_results return sorted_results except sqlite3.Error as e: print(f"Database error in get_problems_by_model_dataset({model_name}, {dataset}): {e}") return [] - except Exception as e: # Catch potential errors during sorting + except Exception as e: print(f"Error processing/sorting problems for {model_name}, {dataset}: {e}") return [] def get_problem_data(self, model_name, dataset, problem_id): - """获取问题和响应数据 (using in-memory DB and cache)""" + """Get problem/responses, relying on automatic index usage (hints less common here).""" + # (This method's logic relies heavily on primary key lookups or specific filters, + # where SQLite is usually good at picking the right index (idx_problems_unique_id, idx_responses_unique_id). + # Adding hints here is less likely to be necessary unless performance proves otherwise.) if not self.conn: return None, None - # Sanitize inputs if hasattr(model_name, 'value'): model_name = model_name.value if hasattr(dataset, 'value'): dataset = dataset.value if hasattr(problem_id, 'value'): problem_id = problem_id.value - if not dataset or not problem_id: return None, None # Need dataset and problem_id + if not dataset or not problem_id: return None, None - # Problem data cache check problem_cache_key = f"problem_{problem_id}" problem = self._problem_cache.get(problem_cache_key) - if problem is None: # Not in cache, fetch from DB + if problem is None: try: cursor = self.conn.cursor() - # Query uses index idx_problems_unique_id automatically + # Uses idx_problems_unique_id cursor.execute("SELECT * FROM problems WHERE unique_id = ?", (problem_id,)) problem_row = cursor.fetchone() if problem_row: - problem = dict(problem_row) # Convert to dict for caching + problem = dict(problem_row) self._problem_cache[problem_cache_key] = problem else: print(f"Problem not found in DB: {problem_id}") - return None, None # Problem ID does not exist in the database + return None, None except sqlite3.Error as e: print(f"Database error fetching problem {problem_id}: {e}") - return None, None # Return None if problem fetch fails + return None, None - # If problem is still None here, it wasn't found in DB or cache - if problem is None: - return None, None + if problem is None: return None, None - # --- Response data fetching --- responses = None - if model_name: # Fetch for a specific model + if model_name: resp_cache_key = f"responses_{model_name}_{dataset}_{problem_id}" if resp_cache_key in self._response_cache: responses = self._response_cache[resp_cache_key] else: try: cursor = self.conn.cursor() - # Query uses indices idx_responses_model_dataset and idx_responses_unique_id + # Uses idx_responses_model_dataset & idx_responses_unique_id composite/scan cursor.execute(""" - SELECT * FROM responses + SELECT * FROM responses -- INDEXED BY ??? (can be complex) WHERE model_name = ? AND dataset = ? AND unique_id = ? ORDER BY response_id """, (model_name, dataset.lower(), problem_id)) response_rows = cursor.fetchall() - # Convert rows to list of dicts for easier handling and caching responses = [dict(r) for r in response_rows] if response_rows else [] - self._response_cache[resp_cache_key] = responses # Cache the result (even if empty) + self._response_cache[resp_cache_key] = responses except sqlite3.Error as e: print(f"DB error fetching responses for model {model_name}, dataset {dataset}, problem {problem_id}: {e}") - responses = None # Indicate error fetching responses - else: # Fetch for all models for this problem + responses = None + else: # Fetch all responses for the problem resp_cache_key = f"all_responses_{dataset}_{problem_id}" if resp_cache_key in self._response_cache: responses = self._response_cache[resp_cache_key] else: try: cursor = self.conn.cursor() - # Query uses indices idx_responses_dataset and idx_responses_unique_id - # Need to create idx_responses_dataset if not exists, or rely on model_dataset index scan - # Let's add CREATE INDEX IF NOT EXISTS idx_responses_dataset ON responses(dataset); in _ensure_indices - # --> Added index idx_responses_model_dataset which covers (dataset, unique_id) lookups too. + # Uses idx_responses_model_dataset or idx_responses_unique_id scan cursor.execute(""" - SELECT * FROM responses + SELECT * FROM responses -- INDEXED BY ??? WHERE dataset = ? AND unique_id = ? ORDER BY model_name, response_id """, (dataset.lower(), problem_id)) response_rows = cursor.fetchall() responses = [dict(r) for r in response_rows] if response_rows else [] - self._response_cache[resp_cache_key] = responses # Cache the result + self._response_cache[resp_cache_key] = responses except sqlite3.Error as e: print(f"DB error fetching all responses for dataset {dataset}, problem {problem_id}: {e}") - responses = None # Indicate error + responses = None return problem, responses def get_model_responses(self, selected_models, dataset, problem_id): - """获取多个模型对特定问题的响应 (optimized for in-memory)""" + """Get responses for multiple models, bulk query preferred.""" + # (Keeping the bulk query logic using IN clause is still good for disk access) if not self.conn: return None, {} - # Sanitize inputs if hasattr(dataset, 'value'): dataset = dataset.value if hasattr(problem_id, 'value'): problem_id = problem_id.value if not selected_models or not dataset or not problem_id: return None, {} - # Get problem data first (uses cache/fast in-memory lookup) problem, _ = self.get_problem_data(None, dataset, problem_id) if not problem: print(f"Problem data not found for {problem_id} in get_model_responses") return None, {} model_responses_data = {} - # Get the *real* model names from the display names using the stored map - real_model_names_map = {} # Map display name -> real name + real_model_names_map = {} real_names_list = [] for model_display in selected_models: model_display_val = model_display.value if hasattr(model_display, 'value') else model_display - # Use comp_model_display_to_real if available, otherwise model_display_to_real real_name = self.comp_model_display_to_real.get(model_display_val) or self.model_display_to_real.get(model_display_val) - # Fallback if map lookup fails (try parsing) if not real_name: raw_name_part = model_display_val.split(" (")[0] - # Reverse lookup MODEL_TRANS for db_name, display_lookup in MODEL_TRANS.items(): - if display_lookup == raw_name_part: - real_name = db_name - break - if not real_name: # If still not found, assume display name *is* real name (less accuracy suffix) - real_name = raw_name_part - print(f"Warning: Could not map display name '{model_display_val}' to real name via maps. Using inferred '{real_name}'.") - - if real_name: # Ensure we have a name to query + if display_lookup == raw_name_part: real_name = db_name; break + if not real_name: real_name = raw_name_part + print(f"Warning: Using fallback lookup/parsing for model name: '{model_display_val}' -> '{real_name}'.") + + if real_name: real_model_names_map[model_display_val] = real_name - if real_name not in real_names_list: # Avoid duplicates in IN clause - real_names_list.append(real_name) + if real_name not in real_names_list: real_names_list.append(real_name) if not real_names_list: print("No valid real model names found to query.") - return problem, {} # Return problem data but empty responses + return problem, {} # Optimized: Fetch all relevant responses in a single query try: cursor = self.conn.cursor() placeholders = ','.join('?' * len(real_names_list)) + # Rely on index idx_responses_model_dataset for the IN clause + other filters query = f""" - SELECT * FROM responses + SELECT * FROM responses -- INDEXED BY ??? (idx_responses_model_dataset likely used) WHERE model_name IN ({placeholders}) AND dataset = ? AND unique_id = ? - ORDER BY model_name, correctness DESC, response_id -- Prioritize correct responses + ORDER BY model_name, correctness DESC, response_id """ params = real_names_list + [dataset.lower(), problem_id] cursor.execute(query, params) all_fetched_responses = cursor.fetchall() - # Group responses by *real* model name, keeping only the best (correct first, then by ID) responses_by_real_model = {} for resp_row in all_fetched_responses: resp_dict = dict(resp_row) model = resp_dict['model_name'] - if model not in responses_by_real_model: # Only store the first one encountered (due to ORDER BY) + if model not in responses_by_real_model: # Only store the first (best) one responses_by_real_model[model] = resp_dict - # Populate the result dictionary using display names as keys for display_name, real_name in real_model_names_map.items(): - model_responses_data[display_name] = responses_by_real_model.get(real_name) # Will be None if no response found + model_responses_data[display_name] = responses_by_real_model.get(real_name) except sqlite3.Error as e: - print(f"Database error in bulk get_model_responses: {e}. Falling back to individual fetches.") - # Fallback to individual fetching using get_problem_data (which uses cache) + print(f"Database error in bulk get_model_responses: {e}. Falling back.") + # Fallback to individual fetches (uses cache) for display_name, real_name in real_model_names_map.items(): _ , responses_for_model = self.get_problem_data(real_name, dataset, problem_id) if responses_for_model: - # Find correct one first, otherwise take first response correct_resp = next((r for r in responses_for_model if r.get('correctness') == 1), None) model_responses_data[display_name] = correct_resp if correct_resp else responses_for_model[0] - else: - model_responses_data[display_name] = None + else: model_responses_data[display_name] = None return problem, model_responses_data def clear_cache(self, section=None): - """Clear specified cache sections.""" + # (Unchanged - Python cache clearing is still relevant) print(f"Clearing cache section: {section if section else 'All'}") cleared_something = False if section == 'main' or section is None: @@ -535,7 +471,6 @@ class ModelDatabase: self._response_cache = {} print(f"Cleared response cache ({count} items).") cleared_something = True - # Clear model/dataset list caches if section == 'models' or section is None: if hasattr(self, '_models_cache') and self._models_cache is not None: self._models_cache = None @@ -545,36 +480,28 @@ class ModelDatabase: self._datasets_cache = None print("Cleared datasets list cache.") cleared_something = True - - if cleared_something: - print("Running garbage collection...") - gc.collect() # Explicitly trigger garbage collection - else: - print("Cache section(s) already empty or invalid section specified.") - + if cleared_something: print("Running garbage collection..."); gc.collect() + else: print("Cache section(s) already empty or invalid section specified.") def close(self): - """Close the database connection.""" + # (Unchanged - closing the disk connection) print("Closing database connection...") if hasattr(self, 'conn') and self.conn: try: - # Optional: Backup in-memory changes to disk if needed (not in this scenario) - # Optional: Run final pragmas like optimize before closing if desired + # Maybe run optimize before closing large WAL file? Might take time. + # print("Running PRAGMA optimize...") # self.conn.execute("PRAGMA optimize;") self.conn.close() - self.conn = None # Ensure the attribute is None after closing - print("In-memory database connection closed.") + self.conn = None + print("Database connection closed.") except sqlite3.Error as e: print(f"Error closing database connection: {e}") - else: - print("Database connection already closed or never established.") - # Clear caches on close as well + else: print("Database connection already closed or never established.") self.clear_cache() -# <<< Keep helper functions: format_latex, format_markdown_with_math, >>> -# <<< get_gradient_color, get_contrasting_text_color, format_sample_metadata, >>> -# <<< format_sample_response >>> +# --- Helper functions (format_*, get_color_*, etc.) --- +# (These remain unchanged as they don't depend on the DB access method) def format_latex(text): if text is None: return "" text = text.replace('\n', '
') @@ -583,1170 +510,332 @@ def format_latex(text): def format_markdown_with_math(text): if text is None: return "" text = text.replace('\r\n', '\n').replace('\r', '\n') - # Ensure math delimiters are properly handled by Gradio's Markdown component - # No need for complex regex if Gradio handles $, $$, \(, \), \[, \] return text def get_gradient_color(accuracy, color_map='RdYlGn'): if accuracy is None or not isinstance(accuracy, (int, float)) or not (0.0 <= accuracy <= 1.0): - return "#808080" # Use gray for invalid/missing accuracy + return "#808080" try: - # Use the specified colormap cmap = plt.colormaps.get_cmap(color_map) - # Apply a power transform to make colors darker/more distinct, especially greens - # Power < 1 darkens low values more, Power > 1 darkens high values more power_adjust = 0.7 rgba = cmap(accuracy ** power_adjust) - # Convert RGBA to Hex hex_color = mpl.colors.rgb2hex(rgba) return hex_color except Exception as e: print(f"Error generating gradient color for accuracy {accuracy}: {e}") - return "#808080" # Fallback gray + return "#808080" def get_contrasting_text_color(bg_color_hex): - """Calculate contrasting text color (black or white) for a given hex background.""" try: if not bg_color_hex or not bg_color_hex.startswith('#') or len(bg_color_hex) != 7: - return "#000000" # Default to black for invalid input - - # Convert hex to RGB - r = int(bg_color_hex[1:3], 16) - g = int(bg_color_hex[3:5], 16) - b = int(bg_color_hex[5:7], 16) - - # Calculate luminance using the WCAG formula (more accurate than YIQ for accessibility) - # Normalize RGB values to 0-1 range + return "#000000" + r = int(bg_color_hex[1:3], 16); g = int(bg_color_hex[3:5], 16); b = int(bg_color_hex[5:7], 16) rgb = [val / 255.0 for val in (r, g, b)] - # Apply gamma correction approximation rgb_corrected = [((val / 12.92) if val <= 0.03928 else ((val + 0.055) / 1.055) ** 2.4) for val in rgb] - # Calculate relative luminance luminance = 0.2126 * rgb_corrected[0] + 0.7152 * rgb_corrected[1] + 0.0722 * rgb_corrected[2] - - # WCAG contrast ratio threshold is complex, but a luminance threshold works well for black/white text - # Threshold of 0.179 is often cited, but empirical testing might be needed - # Let's use a slightly higher threshold towards white text for better readability on mid-tones - return "#000000" if luminance > 0.22 else "#FFFFFF" # Black text on lighter backgrounds, White on darker - + return "#000000" if luminance > 0.22 else "#FFFFFF" except Exception as e: print(f"Error calculating contrasting color for {bg_color_hex}: {e}") - return "#000000" # Default to black on error - + return "#000000" def format_sample_metadata(sample, show_correctness=True): - """Generates HTML for sample metadata display.""" if sample is None: return "
No sample data provided.
" - # Ensure sample is a dictionary sample_dict = dict(sample) if hasattr(sample, 'keys') else {} if not sample_dict: return "
Empty sample data.
" - - # Use .get() with defaults for safety extracted = sample_dict.get('extracted', '') - correctness = sample_dict.get('correctness', None) # Keep None distinct from False/0 + correctness = sample_dict.get('correctness', None) output_tokens = sample_dict.get('output_tokens') reasoning_tokens = sample_dict.get('reasoning_tokens') - - # Correctness display - if correctness == 1: - correctness_label = "✓ Correct" - correctness_color = "var(--color-acc-green, #28a745)" # Use CSS variable with fallback - elif correctness == 0: - correctness_label = "✗ Incorrect" - correctness_color = "var(--color-acc-red, #dc3545)" - else: # None or other values - correctness_label = "? Unknown" - correctness_color = "var(--color-acc-grey, #6c757d)" - - # Build HTML using f-string for clarity - html = f""" -
-
""" # Use flexbox for alignment - - if show_correctness: - html += f"" - + if correctness == 1: correctness_label = "✓ Correct"; correctness_color = "var(--color-acc-green, #28a745)" + elif correctness == 0: correctness_label = "✗ Incorrect"; correctness_color = "var(--color-acc-red, #dc3545)" + else: correctness_label = "? Unknown"; correctness_color = "var(--color-acc-grey, #6c757d)" + html = f" - - """ + if output_tokens is not None: html += f"" + if reasoning_tokens is not None: html += f"" + html += "
" return html - def format_sample_response(sample): - """Generates Markdown-compatible string for the sample response.""" if sample is None: return "No response data." sample_dict = dict(sample) if hasattr(sample, 'keys') else {} if not sample_dict: return "Empty response data." - response = sample_dict.get('response', '') if not response: return "*(Empty Response)*" - - # Escape HTML tags that might interfere with Markdown rendering - # Focus on < > & characters. - response = response.replace('&', '&') - response = response.replace('<', '<') - response = response.replace('>', '>') - - # Gradio's Markdown component should handle LaTeX delimiters like $...$, $$...$$ - # No need for manual replacement here if delimiters are set correctly in gr.Markdown + response = response.replace('&', '&').replace('<', '<').replace('>', '>') return response -# <<< Keep handler functions: handle_sample_select, handle_first_sample, >>> -# <<< handle_comparison_problem_update, handle_problem_select >>> -# <<< Ensure they use the global db instance and handle potential None values >>> -# <<< And use the updated format_ functions >>> - +# --- Handler functions (handle_sample_select, handle_first_sample, etc.) --- +# (These remain unchanged as they interact with the DB class interface) def handle_sample_select(sample_number_str, samples_data_state): - """Handles selection of a specific sample index.""" - # Extract list from state samples_list = samples_data_state if isinstance(samples_data_state, list) else (samples_data_state.value if hasattr(samples_data_state, 'value') else []) - if not samples_list or not isinstance(samples_list, list): err_msg = "**Error:** No sample data available or invalid format." - return err_msg, "" # Return error for metadata, empty for response - - try: - sample_idx = int(sample_number_str) # Convert input string to int - except (ValueError, TypeError): - err_msg = f"**Error:** Invalid sample number '{sample_number_str}'. Must be an integer." - return err_msg, "" - - if not (0 <= sample_idx < len(samples_list)): - err_msg = f"**Error:** Sample index {sample_idx} out of range (0 to {len(samples_list) - 1})." return err_msg, "" - + try: sample_idx = int(sample_number_str) + except (ValueError, TypeError): return f"**Error:** Invalid sample number '{sample_number_str}'. Must be an integer.", "" + if not (0 <= sample_idx < len(samples_list)): return f"**Error:** Sample index {sample_idx} out of range (0 to {len(samples_list) - 1}).", "" try: selected_sample = samples_list[sample_idx] - # Ensure the selected sample is a dict before formatting - if not isinstance(selected_sample, dict): - selected_sample = dict(selected_sample) if hasattr(selected_sample, 'keys') else {} - + if not isinstance(selected_sample, dict): selected_sample = dict(selected_sample) if hasattr(selected_sample, 'keys') else {} formatted_metadata = format_sample_metadata(selected_sample) formatted_response = format_sample_response(selected_sample) return formatted_metadata, formatted_response except Exception as e: print(f"Error formatting sample {sample_idx}: {e}") err_msg = f"**Error displaying sample {sample_idx}:** {str(e)}" - # Return error message in metadata, keep response empty return f"
{err_msg}
", "" def handle_first_sample(samples_data_state): - """Handles displaying the first sample (index 0) from the state.""" - # Delegate to handle_sample_select with index 0 - # Provide default empty display if no samples samples_list = samples_data_state if isinstance(samples_data_state, list) else (samples_data_state.value if hasattr(samples_data_state, 'value') else []) - if not samples_list: - return format_sample_metadata(None), format_sample_response(None) # Display "No data" messages - else: - # Use the main handler to display sample 0 - return handle_sample_select("0", samples_data_state) + if not samples_list: return format_sample_metadata(None), format_sample_response(None) + else: return handle_sample_select("0", samples_data_state) def handle_comparison_problem_update(problem_id_state, dataset_state): - """Updates only the Problem/Answer display in the comparison tab.""" global db if not db or not db.conn: return "Database not initialized.", "Error" - dataset_name = dataset_state.value if hasattr(dataset_state, 'value') else dataset_state problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state - - # Allow entering just the number part of the ID + original_problem_id = problem_id if problem_id and problem_id.isdigit() and dataset_name: parts = dataset_name.split('-') - if len(parts) == 2: - language, difficulty = parts - problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" - else: - print(f"Warning: Cannot reconstruct full ID from number '{problem_id}' and dataset '{dataset_name}'") - # Proceed with the entered value, might fail if not a full ID - - if not problem_id or not dataset_name: - return "Please select dataset and enter problem ID.", "N/A" - + if len(parts) == 2: language, difficulty = parts; problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" + else: print(f"Warning: Cannot reconstruct full ID from number '{problem_id}' and dataset '{dataset_name}'") + if not problem_id or not dataset_name: return "Please select dataset and enter problem ID.", "N/A" try: - # Fetch only problem data, no responses needed here problem_data, _ = db.get_problem_data(None, dataset_name, problem_id) - if not problem_data: - # Check if just the number was entered and failed reconstruction - if problem_id.isdigit(): - return f"Problem number {problem_id} not found for {dataset_name}. Enter full ID or check dataset.", "N/A" - else: - return f"Problem ID '{problem_id}' not found for {dataset_name}.", "N/A" - - # Ensure problem_data is a dictionary + return f"Problem ID '{original_problem_id}' not found for {dataset_name}.", "N/A" problem_dict = dict(problem_data) if problem_data else {} - - # Format problem statement for Markdown rendering problem_content = format_markdown_with_math(problem_dict.get('problem', '*(Problem text not available)*')) - - # Format answer, handling LaTeX and ensuring $...$ answer_text = problem_dict.get('answer', '*(Answer not available)*') - # Simplify $$...$$ to $...$ answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) - # Add $...$ if missing (basic check) - if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*('): - answer_text = f"${answer_text}$" + if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*('): answer_text = f"${answer_text}$" answer_content = format_markdown_with_math(answer_text) - return problem_content, answer_content - except Exception as e: - print(f"Error in handle_comparison_problem_update for {problem_id}, {dataset_name}: {e}") - return f"Error fetching problem details: {e}", "Error" - + except Exception as e: print(f"Error in handle_comparison_problem_update for {problem_id}, {dataset_name}: {e}"); return f"Error fetching problem details: {e}", "Error" def handle_problem_select(problem_id_state, current_model_state, current_dataset_state, mode='default'): - """Handles problem selection, fetching details, responses, and generating sample grid.""" global db - if not db or not db.conn: - return "DB Error.", "DB Error.", gr.HTML("

Database Connection Error

"), gr.State([]) - - # --- Get values from Gradio State objects --- + if not db or not db.conn: return "DB Error.", "DB Error.", gr.HTML("

DB Conn Error

"), gr.State([]) model_name = current_model_state.value if hasattr(current_model_state, 'value') else current_model_state dataset_name = current_dataset_state.value if hasattr(current_dataset_state, 'value') else current_dataset_state problem_id = problem_id_state.value if hasattr(problem_id_state, 'value') else problem_id_state - - # --- Input Validation --- - if not dataset_name: - return "Please select a dataset first.", "N/A", gr.HTML(""), gr.State([]) - if not problem_id: - return "Problem ID is missing.", "N/A", gr.HTML(""), gr.State([]) - # Model name is required unless in comparison mode initial load - if not model_name and mode != 'comparison_initial_problem_load': - # In single mode, model is always required here - if mode == 'default': - return "Please select a model first.", "N/A", gr.HTML(""), gr.State([]) - # In comparison mode, if model state is None, means user selected problem before model - # We can still show problem/answer, but no samples for that side yet. - # Let's handle this by fetching problem/answer but returning empty samples for the specific call. - pass # Allow proceeding without model name in comparison mode to fetch problem/answer - - # --- Reconstruct full problem ID if only number is entered --- - original_problem_id = problem_id # Keep original for messages + original_problem_id = problem_id + if not dataset_name: return "Select dataset.", "N/A", gr.HTML(""), gr.State([]) + if not problem_id: return "Enter problem ID.", "N/A", gr.HTML(""), gr.State([]) + if not model_name and mode == 'default': return "Select model.", "N/A", gr.HTML(""), gr.State([]) if problem_id.isdigit(): - parts = dataset_name.split('-') - if len(parts) == 2: - language, difficulty = parts - problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}" - print(f"Reconstructed problem ID: {problem_id}") - else: - # Cannot reconstruct, use the entered value but it might fail - print(f"Warning: Could not reconstruct full ID from number '{problem_id}' and dataset '{dataset_name}'") - - - # --- Fetch Data --- + parts = dataset_name.split('-'); + if len(parts) == 2: language, difficulty = parts; problem_id = f"OlymMATH-{difficulty}-{problem_id}-{language}"; print(f"Reconstructed ID: {problem_id}") + else: print(f"Warning: Could not reconstruct ID from {original_problem_id} and {dataset_name}") try: - # Fetch problem details and responses for the specific model (if provided) problem_data, responses_data = db.get_problem_data(model_name, dataset_name, problem_id) - - if not problem_data: - return f"Problem ID '{original_problem_id}' not found for dataset '{dataset_name}'.", "N/A", gr.HTML(f"

Problem ID '{original_problem_id}' not found.

"), gr.State([]) - - # Ensure problem_data is a dict + if not problem_data: return f"Problem '{original_problem_id}' not found.", "N/A", gr.HTML(f"

ID '{original_problem_id}' not found.

"), gr.State([]) problem_dict = dict(problem_data) if problem_data else {} - - # --- Format Problem and Answer --- - problem_content = format_markdown_with_math(problem_dict.get('problem', '*(Problem text not available)*')) - answer_text = problem_dict.get('answer', '*(Answer not available)*') + problem_content = format_markdown_with_math(problem_dict.get('problem', 'N/A')) + answer_text = problem_dict.get('answer', 'N/A') answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL) - if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*('): - answer_text = f"${answer_text}$" + if '$' not in answer_text and answer_text.strip() and not answer_text.startswith('*(') and answer_text != 'N/A': answer_text = f"${answer_text}$" answer_content = format_markdown_with_math(answer_text) - - # --- Handle Responses and Generate Sample Grid --- - if responses_data is None: # Indicates a DB error occurred fetching responses - samples_grid_html = gr.HTML("

Error fetching model responses.

") - samples_data_for_state = gr.State([]) # Empty state on error - elif not responses_data: # Empty list means no responses found - samples_grid_html = gr.HTML("

No responses found for this model on this problem.

") - samples_data_for_state = gr.State([]) # Empty state + if responses_data is None: samples_grid_html = gr.HTML("

Error fetching responses.

"); samples_data_for_state = gr.State([]) + elif not responses_data: samples_grid_html = gr.HTML("

No responses found.

"); samples_data_for_state = gr.State([]) else: - # responses_data should already be a list of dicts from get_problem_data - samples_data = responses_data # Use directly - samples_data_for_state = gr.State(samples_data) # Store the list in state - + samples_data = responses_data + samples_data_for_state = gr.State(samples_data) correct_count = sum(1 for r in samples_data if r.get('correctness') == 1) - total_samples = len(samples_data) - accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 - - # --- Generate Sample Grid HTML (with onclick) --- - displayed_samples = samples_data[:64] # Limit display - actual_display_count = len(displayed_samples) - # Determine grid columns based on mode + total_samples = len(samples_data); accuracy_on_problem = correct_count / total_samples if total_samples > 0 else 0 + displayed_samples = samples_data[:64]; actual_display_count = len(displayed_samples) samples_per_row = 16 if mode.startswith('comparison') else 32 - num_rows = math.ceil(actual_display_count / samples_per_row) - grid_html_content = "" - - # Determine the correct JS function call based on mode - js_mode = "'default'" # Default for single model tab - if mode == 'comparison_left': js_mode = "'comparison_left'" - elif mode == 'comparison_right': js_mode = "'comparison_right'" - + num_rows = math.ceil(actual_display_count / samples_per_row); grid_html_content = "" + js_mode = "'comparison_left'" if mode == 'comparison_left' else "'comparison_right'" if mode == 'comparison_right' else "'default'" for row_idx in range(num_rows): grid_html_content += f'
' - start_idx = row_idx * samples_per_row - end_idx = min(start_idx + samples_per_row, actual_display_count) + start_idx = row_idx * samples_per_row; end_idx = min(start_idx + samples_per_row, actual_display_count) row_samples = displayed_samples[start_idx:end_idx] - for i, resp in enumerate(row_samples): - actual_idx = start_idx + i - correctness = resp.get('correctness', None) # Handle None correctness - # Get background color based on correctness (1=green, 0=red, None=grey) + actual_idx = start_idx + i; correctness = resp.get('correctness', None) if correctness == 1: bg_color = get_gradient_color(1.0) elif correctness == 0: bg_color = get_gradient_color(0.0) - else: bg_color = "#808080" # Grey for unknown + else: bg_color = "#808080" text_color = get_contrasting_text_color(bg_color) - - # Add onclick event to call the JavaScript handler - grid_html_content += f""" - - """ - # Fill remaining columns in the row with gray placeholders if needed - for _ in range(len(row_samples), samples_per_row): - grid_html_content += "
" + grid_html_content += f"""""" + for _ in range(len(row_samples), samples_per_row): grid_html_content += "
" grid_html_content += '
' - - # Add filler rows if less than the max were generated (e.g., 4 for comparison, 2 for single) max_rows = 4 if mode.startswith('comparison') else 2 for _ in range(num_rows, max_rows): - grid_html_content += f'
' - for _ in range(samples_per_row): - grid_html_content += "
" - grid_html_content += '
' - - - # Assemble the final HTML for the samples section - samples_grid_html = gr.HTML(f""" -
-

Samples ({actual_display_count} shown) – Model Accuracy: {correct_count}/{total_samples} ({accuracy_on_problem:.1%})

-
{grid_html_content}
-
- - """) - - - # Return all results + grid_html_content += f'
'; + for _ in range(samples_per_row): grid_html_content += "
"; grid_html_content += '
' + samples_grid_html = gr.HTML(f"""

Samples ({actual_display_count} shown) – Accuracy: {accuracy_on_problem:.1%}

{grid_html_content}
""") return problem_content, answer_content, samples_grid_html, samples_data_for_state - - except Exception as e: - print(f"Unexpected error in handle_problem_select for {problem_id}, model {model_name}, dataset {dataset_name}: {e}") - # Log the full traceback for debugging if possible - import traceback - traceback.print_exc() - error_msg = f"**Internal Error processing problem {original_problem_id}:** {str(e)}" - return error_msg, "Error", gr.HTML(f"

{error_msg}

"), gr.State([]) + except Exception as e: print(f"Error in handle_problem_select for {problem_id}, model {model_name}, dataset {dataset_name}: {e}"); import traceback; traceback.print_exc(); error_msg = f"**Internal Error:** {str(e)}"; return error_msg, "Error", gr.HTML(f"

{error_msg}

"), gr.State([]) -# <<< Keep create_problem_grid_html, modified to use onclick >>> +# --- UI Creation function (create_problem_grid_html, create_ui) --- +# (These remain unchanged as they interact with the DB class interface and handlers) def create_problem_grid_html(problems, mode='default'): - """Create HTML for problem grid buttons with onclick handlers.""" - if not problems: - return "
No problems found for this model/dataset.
" - + if not problems: return "
No problems found.
" html_buttons = "" - # Sort problems based on the numeric part of the ID try: - id_extractor = re.compile(r'\d+') - def get_sort_key(p): - # p is expected to be a tuple/list like (unique_id, accuracy, problem_text) - match = id_extractor.search(str(p[0])) - return int(match.group(0)) if match else 0 - - # Ensure problems is a list of tuples/lists before sorting - if isinstance(problems, list) and all(isinstance(p, (list, tuple)) and len(p) >= 2 for p in problems): - # Convert accuracy to float, handle None or potential errors - processed_problems = [] - for p in problems: - try: - pid = str(p[0]) - acc = float(p[1]) if p[1] is not None else 0.0 - processed_problems.append((pid, acc)) - except (IndexError, TypeError, ValueError) as conv_err: - print(f"Skipping problem entry due to conversion error: {p} - {conv_err}") - sorted_problems = sorted(processed_problems, key=get_sort_key) - else: - print(f"Warning: Problem data format unexpected in create_problem_grid_html (mode={mode}). Skipping sort.") - # Attempt to process anyway if possible, otherwise return error message - if isinstance(problems, list): - processed_problems = [] - for p in problems: - try: - pid = str(p[0]) - # Try to get accuracy, default to 0.0 if fails - acc = 0.0 - if len(p) > 1: - try: - acc = float(p[1]) if p[1] is not None else 0.0 - except (TypeError, ValueError): pass # Keep acc as 0.0 - processed_problems.append((pid, acc)) - except (IndexError, TypeError, ValueError): - print(f"Skipping invalid problem entry: {p}") - sorted_problems = processed_problems # No sort if format is wrong initially - else: - return "
Error: Invalid problem data format.
" - - except Exception as e: - print(f"Error sorting/processing problems in create_problem_grid_html (mode={mode}): {e}") - return f"
Error displaying problems: {e}
" - - id_extractor = re.compile(r'\d+') # Re-use compiled regex - # Determine JS mode argument based on Python mode - js_mode_arg = "'comparison'" if mode == 'comparison' else "'default'" - + id_extractor = re.compile(r'\d+'); get_sort_key = lambda p: int(id_extractor.search(str(p[0])).group(0)) if id_extractor.search(str(p[0])) else 0 + processed_problems = [] + if isinstance(problems, list): + for p in problems: + try: pid = str(p[0]); acc = float(p[1]) if len(p)>1 and p[1] is not None else 0.0; processed_problems.append((pid, acc)) + except (IndexError, TypeError, ValueError) as conv_err: print(f"Skipping problem entry: {p} - {conv_err}") + sorted_problems = sorted(processed_problems, key=get_sort_key) + else: print(f"Problem data format unexpected: {type(problems)}"); return "
Error: Invalid problem data format.
" + except Exception as e: print(f"Error sorting/processing problems: {e}"); return f"
Error displaying problems: {e}
" + id_extractor = re.compile(r'\d+'); js_mode_arg = "'comparison'" if mode == 'comparison' else "'default'" for pid, accuracy in sorted_problems: - match = id_extractor.search(pid) - num_display = match.group(0) if match else pid[:6] # Fallback display if no number found - acc_pct = int(accuracy * 100) - bg_color = get_gradient_color(accuracy) - text_color = get_contrasting_text_color(bg_color) # Calculate contrast - - # Add onclick event calling the global JavaScript function - # Escape pid if it could contain quotes or special chars (unlikely but safer) + match = id_extractor.search(pid); num_display = match.group(0) if match else pid[:6] + acc_pct = int(accuracy * 100); bg_color = get_gradient_color(accuracy); text_color = get_contrasting_text_color(bg_color) escaped_pid = pid.replace("'", "\\'") - html_buttons += f""" - - """ - + html_buttons += f"""""" grid_cols = 20 if mode == 'comparison' else 10 - # Add CSS for the button layout within the grid item - grid_html = f""" -
- {html_buttons} -
- - """ + grid_html = f"""
{html_buttons}
""" return grid_html - -# <<< Keep create_ui, modified for hidden state and JS >>> def create_ui(db_instance): - """Creates the Gradio UI application.""" - global db - db = db_instance # Use the passed-in, initialized DB instance - + global db; db = db_instance if not db or not db.conn: - print("Error: Database instance is not valid in create_ui.") - with gr.Blocks() as error_demo: - gr.Markdown("# Error: Database Initialization Failed\nThe application cannot start. Check logs for details.") - return error_demo - - AVAILABLE_DATASETS = db.get_available_datasets() # Fetch datasets from the initialized DB - if not AVAILABLE_DATASETS: - print("Warning: No datasets found in the database. Using fallback list.") - AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] # Fallback - - # --- CSS --- (Add styles for metadata, sample grid buttons if not already present) - custom_css = """ - body, .gradio-container { font-family: sans-serif; font-size: 0.98em; line-height: 1.6; } /* Slightly larger base font */ - .gradio-tabs > div[role='tablist'] button { font-size: 0.95em; padding: 8px 14px; } - .gr-dropdown select { font-size: 0.95em; } - .gr-radio label span { font-size: 0.95em; } - .gr-checkboxgroup label span { font-size: 0.95em; } - .gr-button { font-size: 0.95em; padding: 8px 14px; } - .gr-dataframe table { font-size:0.9em; } /* Slightly larger dataframe font */ - .gr-markdown { font-size: 1.0em; line-height: 1.6; } /* Ensure markdown line height is good */ - /* Dark mode adjustments */ - .dark .dark-mode-compatible { background-color: var(--neutral-800); color: var(--neutral-100); border-color: var(--neutral-700); } - .dark .dark-mode-bg-secondary { background-color: var(--neutral-900); } - .dark .dataframe-container table { color: var(--neutral-100); border-color: var(--neutral-600); } - .dark .dataframe-container th { background-color: var(--neutral-700); } - /* Math rendering adjustments */ - .math-inline, .math-display { font-size: 105%; } /* Slightly smaller math */ - /* Custom classes for layout */ - .compact-row { margin-bottom: 5px !important; padding: 0 !important; } - /* Ensure hidden inputs don't take up space */ - .hidden-state > .svelte-kit-component { display: none !important; } - """ - # --- JavaScript for Button Clicks --- - javascript = """ - function handleProblemClick(problemId, mode) { - console.log(`Problem clicked: ${problemId}, Mode: ${mode}`); - // Determine the target hidden input element ID based on the mode - let targetInputId = (mode === 'comparison') ? 'comp-problem-id-state' : 'problem-id-state'; - let targetInputElement = document.getElementById(targetInputId); - - if (targetInputElement) { - // Gradio Textbox often wraps the actual input (textarea) - let actualInput = targetInputElement.querySelector('textarea'); - if (actualInput) { - console.log(`Updating ${targetInputId} value to: ${problemId}`); - // Set the value of the hidden input - actualInput.value = problemId; - // Dispatch 'input' and 'change' events to notify Gradio backend - actualInput.dispatchEvent(new Event('input', { bubbles: true })); - actualInput.dispatchEvent(new Event('change', { bubbles: true })); - console.log(`Events dispatched for ${targetInputId}`); - } else { - console.error(`Could not find textarea within #${targetInputId}`); - } - } else { - console.error(`Target input element #${targetInputId} not found.`); - } - } - - function handleSampleClick(sampleIndex, mode) { - console.log(`Sample clicked: ${sampleIndex}, Mode: ${mode}`); - let targetInputId; - // Determine target based on which sample grid was clicked - if (mode === 'comparison_left') { - targetInputId = 'comp-sample-index-state-left'; - } else if (mode === 'comparison_right') { - targetInputId = 'comp-sample-index-state-right'; - } else { // Default single model mode - targetInputId = 'sample-index-state'; - } - - let targetInputElement = document.getElementById(targetInputId); - if (targetInputElement) { - let actualInput = targetInputElement.querySelector('textarea'); // Assuming Textbox uses textarea - if (actualInput) { - console.log(`Updating ${targetInputId} value to: ${sampleIndex}`); - actualInput.value = sampleIndex; // Set the index - // Dispatch events - actualInput.dispatchEvent(new Event('input', { bubbles: true })); - actualInput.dispatchEvent(new Event('change', { bubbles: true })); - console.log(`Events dispatched for ${targetInputId}`); - } else { - console.error(`Could not find textarea within #${targetInputId}`); - } - } else { - console.error(`Target input element #${targetInputId} not found.`); - } - } - - // Make functions globally available for onclick handlers - window.handleProblemClick = handleProblemClick; - window.handleSampleClick = handleSampleClick; - """ - + with gr.Blocks() as error_demo: gr.Markdown("# Error: DB Init Failed"); return error_demo + AVAILABLE_DATASETS = db.get_available_datasets(); + if not AVAILABLE_DATASETS: AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] + custom_css = """body, .gradio-container { font-family: sans-serif; font-size: 0.98em; line-height: 1.6; }.gradio-tabs > div[role='tablist'] button { font-size: 0.95em; padding: 8px 14px; }.gr-dropdown select, .gr-radio label span, .gr-checkboxgroup label span, .gr-button { font-size: 0.95em; }.gr-dataframe table { font-size:0.9em; }.gr-markdown { font-size: 1.0em; line-height: 1.6; }.dark .dark-mode-compatible { background-color: var(--neutral-800); color: var(--neutral-100); border-color: var(--neutral-700); }.dark .dark-mode-bg-secondary { background-color: var(--neutral-900); }.dark .dataframe-container table { color: var(--neutral-100); border-color: var(--neutral-600); }.dark .dataframe-container th { background-color: var(--neutral-700); }.math-inline, .math-display { font-size: 105%; }.compact-row { margin-bottom: 5px !important; padding: 0 !important; }.hidden-state > .svelte-kit-component { display: none !important; }""" + javascript = """function handleProblemClick(p,m){console.log(`Problem: ${p}, Mode: ${m}`);let i=(m==='comparison')?'comp-problem-id-state':'problem-id-state';let e=document.getElementById(i);if(e){let t=e.querySelector('textarea');if(t){console.log(`Updating ${i} value: ${p}`);t.value=p;t.dispatchEvent(new Event('input',{bubbles:!0}));t.dispatchEvent(new Event('change',{bubbles:!0}));console.log(`Events dispatched: ${i}`)}else{console.error(`No textarea in #${i}`)}}else{console.error(`Element #${i} not found`)}} function handleSampleClick(s,m){console.log(`Sample: ${s}, Mode: ${m}`);let i;if(m==='comparison_left'){i='comp-sample-index-state-left'}else if(m==='comparison_right'){i='comp-sample-index-state-right'}else{i='sample-index-state'}let e=document.getElementById(i);if(e){let t=e.querySelector('textarea');if(t){console.log(`Updating ${i} value: ${s}`);t.value=s;t.dispatchEvent(new Event('input',{bubbles:!0}));t.dispatchEvent(new Event('change',{bubbles:!0}));console.log(`Events dispatched: ${i}`)}else{console.error(`No textarea in #${i}`)}}else{console.error(`Element #${i} not found`)}} window.handleProblemClick=handleProblemClick;window.handleSampleClick=handleSampleClick;""" with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky), head=f"", title="Model Performance Analyzer") as demo: - - # --- Define Hidden State Components --- - # These hold the actual state triggered by UI interactions (clicks, dropdowns) - # Single Model Tab States current_dataset_state = gr.State(value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else "") - current_model_state = gr.State(value=None) # Holds the *real* model name - problem_id_state = gr.Textbox(elem_id="problem-id-state", visible=False, label="Selected Problem ID") # Updated by JS problem click - sample_index_state = gr.Textbox(value="0", elem_id="sample-index-state", visible=False, label="Selected Sample Index") # Updated by JS sample click - current_samples_data_state = gr.State(value=[]) # Holds the list of dicts for the current problem's samples - - # Comparison Tab States + current_model_state = gr.State(value=None) + problem_id_state = gr.Textbox(elem_id="problem-id-state", visible=False, label="Selected Problem ID") + sample_index_state = gr.Textbox(value="0", elem_id="sample-index-state", visible=False, label="Selected Sample Index") + current_samples_data_state = gr.State(value=[]) comp_dataset_state = gr.State(value=AVAILABLE_DATASETS[0] if AVAILABLE_DATASETS else "") - comp_problem_id_state = gr.Textbox(elem_id="comp-problem-id-state", visible=False, label="Selected Comparison Problem ID") # Updated by JS problem click - # Left Side - comp_model_state_left = gr.State(value=None) - comp_sample_index_state_left = gr.Textbox(value="0", elem_id="comp-sample-index-state-left", visible=False, label="Selected Left Sample Index") - comp_samples_data_state_left = gr.State(value=[]) - # Right Side - comp_model_state_right = gr.State(value=None) - comp_sample_index_state_right = gr.Textbox(value="0", elem_id="comp-sample-index-state-right", visible=False, label="Selected Right Sample Index") - comp_samples_data_state_right = gr.State(value=[]) - # Dummy state for outputs we need to provide but don't use + comp_problem_id_state = gr.Textbox(elem_id="comp-problem-id-state", visible=False, label="Selected Comparison Problem ID") + comp_model_state_left = gr.State(value=None); comp_sample_index_state_left = gr.Textbox(value="0", elem_id="comp-sample-index-state-left", visible=False, label="Selected Left Sample Index"); comp_samples_data_state_left = gr.State(value=[]) + comp_model_state_right = gr.State(value=None); comp_sample_index_state_right = gr.Textbox(value="0", elem_id="comp-sample-index-state-right", visible=False, label="Selected Right Sample Index"); comp_samples_data_state_right = gr.State(value=[]) dummy_state = gr.State(value=None) - - # --- UI Layout --- with gr.Tabs(): - # == Single Model Analysis Tab == with gr.TabItem("Single Model Analysis"): with gr.Row(): - # Column 1: Controls & Stats with gr.Column(scale=1, min_width=300): - gr.Markdown("### Controls") - dataset_radio_single = gr.Radio( - choices=AVAILABLE_DATASETS, value=current_dataset_state.value, # Use initial state value - label="Select Dataset", interactive=True - ) - model_dropdown = gr.Dropdown( - choices=[], label="Select Model (Name + Acc%)", interactive=True - ) - # Problem ID Input (for manual entry or display - could be combined with state?) - # Let's keep it visible for manual entry as an alternative to clicking grid - problem_id_input_display = gr.Textbox( - label="Enter Problem ID (e.g., 42) or click grid", - placeholder="Enter number or full ID", - interactive=True, - # elem_id="problem_id_input_display" # Use different ID than state - ) - - gr.Markdown("### Problem Grid (Click to Select)") - problem_grid_html_output = gr.HTML("Select model and dataset.") # Populated by callback - - gr.Markdown("### Model Statistics") - model_stats_df = gr.DataFrame(headers=["Metric", "Value"], wrap=True) - - # Column 2: Problem Details & Samples + gr.Markdown("### Controls"); dataset_radio_single = gr.Radio(choices=AVAILABLE_DATASETS, value=current_dataset_state.value, label="Select Dataset", interactive=True); model_dropdown = gr.Dropdown(choices=[], label="Select Model (Name + Acc%)", interactive=True); problem_id_input_display = gr.Textbox(label="Enter Problem ID (e.g., 42) or click grid", placeholder="Enter number or full ID", interactive=True); gr.Markdown("### Problem Grid (Click to Select)"); problem_grid_html_output = gr.HTML("Select model and dataset."); gr.Markdown("### Model Statistics"); model_stats_df = gr.DataFrame(headers=["Metric", "Value"], wrap=True) with gr.Column(scale=3, min_width=500): - gr.Markdown("### Problem Details") - with gr.Tabs(): - with gr.TabItem("Problem Statement"): - problem_markdown_output = gr.Markdown("Select a problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - with gr.TabItem("Reference Answer"): - answer_markdown_output = gr.Markdown("Select a problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - - gr.Markdown("### Model Responses (Click Grid Below to Select)") - # Sample Grid HTML (generated by handle_problem_select) - samples_grid_output = gr.HTML("") - - # Sample Display Area - gr.Markdown("#### Selected Sample Details") - sample_metadata_output = gr.HTML("Select a sample from the grid above.") - sample_response_output = gr.Markdown("*(Response will appear here)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - - - # == Model Comparison Tab == + gr.Markdown("### Problem Details"); with gr.Tabs(): with gr.TabItem("Problem Statement"): problem_markdown_output = gr.Markdown("Select a problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]); with gr.TabItem("Reference Answer"): answer_markdown_output = gr.Markdown("Select a problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]); gr.Markdown("### Model Responses (Click Grid Below to Select)"); samples_grid_output = gr.HTML(""); gr.Markdown("#### Selected Sample Details"); sample_metadata_output = gr.HTML("Select a sample from the grid above."); sample_response_output = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) with gr.TabItem("Model Comparison"): - # Row 1: Shared Controls - with gr.Row(variant='compact'): - comp_dataset_radio = gr.Radio( - choices=AVAILABLE_DATASETS, value=comp_dataset_state.value, - label="Select Dataset", interactive=True, scale=1 - ) - # Visible Problem ID input for comparison tab - comp_problem_id_input_display = gr.Textbox( - label="Enter Problem ID (e.g., 42) or click grid", - placeholder="Enter number or full ID", - interactive=True, scale=1 - ) - - # Row 2: Shared Problem/Answer Display - with gr.Row(variant='compact'): - with gr.Column(scale=1): - gr.Markdown("### Problem Details") - with gr.Tabs(): - with gr.TabItem("Problem Statement"): - comp_problem_markdown_output = gr.Markdown("Select models and problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - with gr.TabItem("Reference Answer"): - comp_answer_markdown_output = gr.Markdown("Select problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - - # Row 3: Left vs Right Columns - with gr.Row(variant='compact', equal_height=False): # Allow columns to size independently - # Left Column - with gr.Column(scale=1, min_width=400): - gr.Markdown("### Model 1") - comp_model_dropdown_left = gr.Dropdown(choices=[], label="Select Model 1", interactive=True) - gr.Markdown("#### Problem Grid (Model 1 - Click to Select Problem)") - comp_problem_grid_html_output_left = gr.HTML("Select model 1.") - gr.Markdown("#### Model 1 Responses (Click Grid Below)") - comp_samples_grid_output_left = gr.HTML("") - gr.Markdown("##### Selected Sample (Model 1)") - comp_sample_metadata_output_left = gr.HTML("Select sample.") - comp_sample_response_output_left = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - - # Right Column - with gr.Column(scale=1, min_width=400): - gr.Markdown("### Model 2") - comp_model_dropdown_right = gr.Dropdown(choices=[], label="Select Model 2", interactive=True) - gr.Markdown("#### Problem Grid (Model 2 - Click to Select Problem)") - comp_problem_grid_html_output_right = gr.HTML("Select model 2.") - gr.Markdown("#### Model 2 Responses (Click Grid Below)") - comp_samples_grid_output_right = gr.HTML("") - gr.Markdown("##### Selected Sample (Model 2)") - comp_sample_metadata_output_right = gr.HTML("Select sample.") - comp_sample_response_output_right = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) - - - # --- Event Handlers --- - + with gr.Row(variant='compact'): comp_dataset_radio = gr.Radio(choices=AVAILABLE_DATASETS, value=comp_dataset_state.value, label="Select Dataset", interactive=True, scale=1); comp_problem_id_input_display = gr.Textbox(label="Enter Problem ID (e.g., 42) or click grid", placeholder="Enter number or full ID", interactive=True, scale=1) + with gr.Row(variant='compact'): with gr.Column(scale=1): gr.Markdown("### Problem Details"); with gr.Tabs(): with gr.TabItem("Problem Statement"): comp_problem_markdown_output = gr.Markdown("Select models and problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]); with gr.TabItem("Reference Answer"): comp_answer_markdown_output = gr.Markdown("Select problem.", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) + with gr.Row(variant='compact', equal_height=False): + with gr.Column(scale=1, min_width=400): gr.Markdown("### Model 1"); comp_model_dropdown_left = gr.Dropdown(choices=[], label="Select Model 1", interactive=True); gr.Markdown("#### Problem Grid (Model 1 - Click)"); comp_problem_grid_html_output_left = gr.HTML("Select model 1."); gr.Markdown("#### Model 1 Responses (Click Grid Below)"); comp_samples_grid_output_left = gr.HTML(""); gr.Markdown("##### Selected Sample (Model 1)"); comp_sample_metadata_output_left = gr.HTML("Select sample."); comp_sample_response_output_left = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) + with gr.Column(scale=1, min_width=400): gr.Markdown("### Model 2"); comp_model_dropdown_right = gr.Dropdown(choices=[], label="Select Model 2", interactive=True); gr.Markdown("#### Problem Grid (Model 2 - Click)"); comp_problem_grid_html_output_right = gr.HTML("Select model 2."); gr.Markdown("#### Model 2 Responses (Click Grid Below)"); comp_samples_grid_output_right = gr.HTML(""); gr.Markdown("##### Selected Sample (Model 2)"); comp_sample_metadata_output_right = gr.HTML("Select sample."); comp_sample_response_output_right = gr.Markdown("*(Response)*", latex_delimiters=[{"left": "$", "right": "$", "display": False}, {"left": "$$", "right": "$$", "display": True}]) + # --- Event Handlers (Remain the same, interact with DB interface) --- def update_available_models_for_dropdowns(selected_dataset): - # Fetches all models, gets accuracies for the selected dataset, formats choices - if not db or not db.conn: - print("Error: DB not available in update_available_models_for_dropdowns") - # Return empty updates for all dropdowns - return gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[]) - - all_models = db.get_available_models() # Uses cache/in-memory DB - model_acc_map = {} - if selected_dataset and all_models: - # Fetch accuracies (uses cache/in-memory DB) - model_accs = db.get_all_model_accuracies(selected_dataset) - model_acc_map = {name: acc for name, acc in model_accs} - - display_options = [] - # Clear previous maps before repopulating - db.model_display_to_real = {} - db.comp_model_display_to_real = {} - - # Sort models by accuracy (descending), handle missing accuracy (treat as -1) + if not db or not db.conn: print("Error: DB not available"); return gr.Dropdown(choices=[]), gr.Dropdown(choices=[]), gr.Dropdown(choices=[]) + all_models = db.get_available_models(); model_acc_map = {}; + if selected_dataset and all_models: model_accs = db.get_all_model_accuracies(selected_dataset); model_acc_map = {name: acc for name, acc in model_accs} + display_options = []; db.model_display_to_real = {}; db.comp_model_display_to_real = {} sorted_models = sorted(all_models, key=lambda m: model_acc_map.get(m, -1), reverse=True) - for name in sorted_models: - display_name = MODEL_TRANS.get(name, name) # Use translation map - acc = model_acc_map.get(name) - # Format accuracy nicely, handle None - acc_display = f" ({acc:.1%})" if acc is not None else " (N/A)" - display_text = f"{display_name}{acc_display}" - display_options.append(display_text) - # Store mapping from formatted name back to real name for both contexts - db.model_display_to_real[display_text] = name - db.comp_model_display_to_real[display_text] = name - - # Return updates for all three dropdowns - # Use label argument to set labels dynamically if needed, or keep static - return gr.Dropdown(choices=display_options, value=None, label="Select Model (Name + Acc%)", interactive=True), \ - gr.Dropdown(choices=display_options, value=None, label="Select Model 1", interactive=True), \ - gr.Dropdown(choices=display_options, value=None, label="Select Model 2", interactive=True) - - + display_name = MODEL_TRANS.get(name, name); acc = model_acc_map.get(name); acc_display = f" ({acc:.1%})" if acc is not None else " (N/A)"; display_text = f"{display_name}{acc_display}"; display_options.append(display_text); db.model_display_to_real[display_text] = name; db.comp_model_display_to_real[display_text] = name + return gr.Dropdown(choices=display_options, value=None, label="Select Model (Name + Acc%)", interactive=True), gr.Dropdown(choices=display_options, value=None, label="Select Model 1", interactive=True), gr.Dropdown(choices=display_options, value=None, label="Select Model 2", interactive=True) def update_problem_grid_and_stats(selected_model_formatted, selected_dataset, mode='default'): - # Fetches stats/problems, returns stats DF, grid HTML, and *real* model name state - if not db or not db.conn: - print("Error: DB not available in update_problem_grid_and_stats") - return gr.DataFrame(value=[]), gr.HTML("

DB Error.

"), None - if not selected_model_formatted or not selected_dataset: - return gr.DataFrame(value=[]), gr.HTML("Select model and dataset."), None - - # Use the appropriate map (comp or default) to get the real model name - real_model_name = None - if mode == 'comparison': - real_model_name = db.comp_model_display_to_real.get(selected_model_formatted) - if not real_model_name: # Fallback to default map or parsing if comp map missed - real_model_name = db.model_display_to_real.get(selected_model_formatted) - - # If still not found via map, try parsing (less reliable) + if not db or not db.conn: print("Error: DB not available"); return gr.DataFrame(value=[]), gr.HTML("

DB Error.

"), None + if not selected_model_formatted or not selected_dataset: return gr.DataFrame(value=[]), gr.HTML("Select model and dataset."), None + real_model_name = None; + if mode == 'comparison': real_model_name = db.comp_model_display_to_real.get(selected_model_formatted) + if not real_model_name: real_model_name = db.model_display_to_real.get(selected_model_formatted) if not real_model_name: - raw_name_part = selected_model_formatted.split(" (")[0] + raw_name_part = selected_model_formatted.split(" (")[0]; for db_name, display_lookup in MODEL_TRANS.items(): - if display_lookup == raw_name_part: - real_model_name = db_name - break - if not real_model_name: real_model_name = raw_name_part # Assume it's the real name - print(f"Warning: Using fallback lookup/parsing for model name: '{selected_model_formatted}' -> '{real_model_name}'") - - if not real_model_name: # Final check if name resolution failed - print(f"Error: Could not determine real model name for '{selected_model_formatted}'") - return gr.DataFrame(value=[]), gr.HTML("

Internal error resolving model name.

"), None - - # Fetch data using the resolved real model name - stats_data = db.get_model_statistics(real_model_name, selected_dataset) - problem_list = db.get_problems_by_model_dataset(real_model_name, selected_dataset) - grid_html = create_problem_grid_html(problem_list, mode=mode) - - # Return stats DF, grid HTML, and the *real* model name for state update + if display_lookup == raw_name_part: real_model_name = db_name; break + if not real_model_name: real_model_name = raw_name_part + print(f"Warning: Using fallback lookup for model name: '{selected_model_formatted}' -> '{real_model_name}'") + if not real_model_name: print(f"Error: Could not determine real model name for '{selected_model_formatted}'"); return gr.DataFrame(value=[]), gr.HTML("

Error.

"), None + stats_data = db.get_model_statistics(real_model_name, selected_dataset); problem_list = db.get_problems_by_model_dataset(real_model_name, selected_dataset); grid_html = create_problem_grid_html(problem_list, mode=mode) return gr.DataFrame(value=stats_data), gr.HTML(value=grid_html), real_model_name - - - # Helper function to clear problem/sample displays - def clear_problem_outputs(): - return "Select a problem.", "N/A", gr.HTML(""), gr.State([]), "Select a sample.", "*(Response)*", "0" - - # Helper function to clear comparison problem/sample displays for one side - def clear_comparison_side_outputs(): - # Doesn't clear the main problem/answer display, only the side's samples - return gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" - - - # --- Single Model Tab Event Connections --- - - # Dataset selection changes model dropdown options and clears everything else - dataset_radio_single.change( - fn=update_available_models_for_dropdowns, - inputs=[dataset_radio_single], - outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] # Update all model lists - ).then( # Chain: Reset dependent states and UI components - lambda ds: (gr.DataFrame(value=[]), gr.HTML("Select model."), None, ds, "", # Clear stats, grid, model state, problem ID state - *clear_problem_outputs(), # Clear problem/answer/samples/state/metadata/response/index state - ""), # Clear problem input display - inputs=[dataset_radio_single], - outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_id_state, - problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, - sample_metadata_output, sample_response_output, sample_index_state, - problem_id_input_display] # Also clear the visible input box - ) - - # Model selection updates stats, grid, and clears problem/sample display - model_dropdown.change( - fn=update_problem_grid_and_stats, # Mode defaults to 'default' - inputs=[model_dropdown, current_dataset_state], - outputs=[model_stats_df, problem_grid_html_output, current_model_state] # Update stats, grid, REAL model state - ).then( # Chain: Clear problem-specific outputs - lambda: ("", # Clear problem ID state - *clear_problem_outputs(), # Clear problem/answer/samples/state/metadata/response/index state - ""), # Clear problem input display - inputs=[], - outputs=[problem_id_state, - problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, - sample_metadata_output, sample_response_output, sample_index_state, - problem_id_input_display] - ) - - # --- Problem Selection Handling (Single Tab) --- - # Option 1: User types into the visible input box - problem_id_input_display.submit( # Trigger on Enter press - # Copy value from visible input to hidden state input to trigger main handler - fn=lambda x: x, - inputs=[problem_id_input_display], - outputs=[problem_id_state] # This change triggers the next handler - ) - # Option 2: User clicks problem grid (JS updates hidden problem_id_state) - # Main handler triggered by change in hidden problem_id_state - problem_id_state.change( - fn=handle_problem_select, # Fetches problem, answer, sample grid, sample data state - inputs=[problem_id_state, current_model_state, current_dataset_state], # Mode defaults to 'default' - outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state] - ).then( # Chain: Display the first sample after data is loaded - fn=handle_first_sample, - inputs=[current_samples_data_state], - outputs=[sample_metadata_output, sample_response_output] - ).then( # Chain: Reset sample index state to 0 - lambda: "0", inputs=[], outputs=[sample_index_state] - ).then( # Chain: Update the visible input box to reflect the selected ID (useful if clicked) - fn=lambda x: x.value if hasattr(x,'value') else x, # Get value from state object - inputs=[problem_id_state], - outputs=[problem_id_input_display] - ) - - - # --- Sample Selection Handling (Single Tab) --- - # Triggered by change in hidden sample_index_state (updated by JS sample click) - sample_index_state.change( - fn=handle_sample_select, - inputs=[sample_index_state, current_samples_data_state], - outputs=[sample_metadata_output, sample_response_output] - ) - - - # --- Comparison Tab Event Connections --- - - # Dataset change updates dropdowns and clears everything - comp_dataset_radio.change( - fn=lambda ds: ds, # Update comparison dataset state first - inputs=[comp_dataset_radio], - outputs=[comp_dataset_state] - ).then( - fn=update_available_models_for_dropdowns, - inputs=[comp_dataset_state], - outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] - ).then( # Clear everything dependent on dataset/models/problem - lambda: (None, None, # Clear model states - "Select models and problem.", "Select problem.", # Clear problem/answer display - gr.HTML("Select model."), gr.HTML("Select model."), # Clear grids - *clear_comparison_side_outputs(), # Clear left samples - *clear_comparison_side_outputs(), # Clear right samples - "", # Clear comp problem ID state - ""), # Clear comp problem input display - inputs=[], - outputs=[comp_model_state_left, comp_model_state_right, - comp_problem_markdown_output, comp_answer_markdown_output, - comp_problem_grid_html_output_left, comp_problem_grid_html_output_right, - # Left side sample outputs + state - comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left, comp_sample_index_state_left, - # Right side sample outputs + state - comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right, comp_sample_index_state_right, - comp_problem_id_state, comp_problem_id_input_display] - ) - - # Left model selection - comp_model_dropdown_left.change( - # 1. Update grid and model state for left side - fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), - inputs=[comp_model_dropdown_left, comp_dataset_state], - # Output to dummy state for stats DF, then grid HTML, then REAL model state - outputs=[dummy_state, comp_problem_grid_html_output_left, comp_model_state_left] - ).then( - # 2. If a problem is already selected, refetch left samples - # Use 'comparison_left' mode for handle_problem_select - fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), - inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], - # Outputs for handle_problem_select: problem, answer, sample_grid, sample_state - # We only care about sample_grid and sample_state here - outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left] - ).then( - # 3. Display first sample for left side (if data exists) - fn=handle_first_sample, - inputs=[comp_samples_data_state_left], - outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] - ).then( - # 4. Reset left sample index state - lambda: "0", inputs=[], outputs=[comp_sample_index_state_left] - ) - - # Right model selection (mirrors left side logic) - comp_model_dropdown_right.change( - fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), - inputs=[comp_model_dropdown_right, comp_dataset_state], - outputs=[dummy_state, comp_problem_grid_html_output_right, comp_model_state_right] - ).then( - # Use 'comparison_right' mode for handle_problem_select - fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), - inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], - outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right] - ).then( - fn=handle_first_sample, - inputs=[comp_samples_data_state_right], - outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] - ).then( - lambda: "0", inputs=[], outputs=[comp_sample_index_state_right] - ) - - # --- Problem Selection Handling (Comparison Tab) --- - # Option 1: User types into the visible input box - comp_problem_id_input_display.submit( - fn=lambda x: x, inputs=[comp_problem_id_input_display], outputs=[comp_problem_id_state] - ) - # Option 2: User clicks problem grid (JS updates hidden comp_problem_id_state) - # Main handler triggered by change in hidden comp_problem_id_state - comp_problem_id_state.change( - # 1. Update main problem/answer display (doesn't need model info) - fn=handle_comparison_problem_update, - inputs=[comp_problem_id_state, comp_dataset_state], - outputs=[comp_problem_markdown_output, comp_answer_markdown_output] - ).then( - # 2. Update left samples (if left model is selected) - fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if model_state.value else ("","",gr.HTML(""), gr.State([])), - inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], - outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left] - ).then( - # 3. Display first left sample - fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] - ).then( - # 4. Reset left sample index - lambda: "0", inputs=[], outputs=[comp_sample_index_state_left] - ).then( - # 5. Update right samples (if right model is selected) - fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if model_state.value else ("","",gr.HTML(""), gr.State([])), - inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], - outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right] - ).then( - # 6. Display first right sample - fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] - ).then( - # 7. Reset right sample index - lambda: "0", inputs=[], outputs=[comp_sample_index_state_right] - ).then( - # 8. Update the visible input box to reflect the selected ID - fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[comp_problem_id_state], outputs=[comp_problem_id_input_display] - ) - - - # --- Sample Selection Handling (Comparison Tab) --- - # Left sample selection (triggered by JS updating hidden state) - comp_sample_index_state_left.change( - fn=handle_sample_select, - inputs=[comp_sample_index_state_left, comp_samples_data_state_left], - outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left] - ) - # Right sample selection (triggered by JS updating hidden state) - comp_sample_index_state_right.change( - fn=handle_sample_select, - inputs=[comp_sample_index_state_right, comp_samples_data_state_right], - outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right] - ) - - # --- Initial Population on App Load --- - # Populate model dropdowns based on the initial dataset state - demo.load( - fn=update_available_models_for_dropdowns, - inputs=[current_dataset_state], # Use initial state for single tab dataset - outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right] - ) - + def clear_problem_outputs(): return "Select problem.", "N/A", gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" + def clear_comparison_side_outputs(): return gr.HTML(""), gr.State([]), "Select sample.", "*(Response)*", "0" + # Single Model Event Connections + dataset_radio_single.change(fn=update_available_models_for_dropdowns, inputs=[dataset_radio_single], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]).then(lambda ds: (gr.DataFrame(value=[]), gr.HTML("Select model."), None, ds, "", *clear_problem_outputs(), ""), inputs=[dataset_radio_single], outputs=[model_stats_df, problem_grid_html_output, current_model_state, current_dataset_state, problem_id_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output, sample_index_state, problem_id_input_display]) + model_dropdown.change(fn=update_problem_grid_and_stats, inputs=[model_dropdown, current_dataset_state], outputs=[model_stats_df, problem_grid_html_output, current_model_state]).then(lambda: ("", *clear_problem_outputs(), ""), inputs=[], outputs=[problem_id_state, problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state, sample_metadata_output, sample_response_output, sample_index_state, problem_id_input_display]) + problem_id_input_display.submit(fn=lambda x: x, inputs=[problem_id_input_display], outputs=[problem_id_state]) + problem_id_state.change(fn=handle_problem_select, inputs=[problem_id_state, current_model_state, current_dataset_state], outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state]).then(fn=handle_first_sample, inputs=[current_samples_data_state], outputs=[sample_metadata_output, sample_response_output]).then(lambda: "0", inputs=[], outputs=[sample_index_state]).then(fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[problem_id_state], outputs=[problem_id_input_display]) + sample_index_state.change(fn=handle_sample_select, inputs=[sample_index_state, current_samples_data_state], outputs=[sample_metadata_output, sample_response_output]) + # Comparison Tab Event Connections + comp_dataset_radio.change(fn=lambda ds: ds, inputs=[comp_dataset_radio], outputs=[comp_dataset_state]).then(fn=update_available_models_for_dropdowns, inputs=[comp_dataset_state], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]).then(lambda: (None, None, "Select models and problem.", "Select problem.", gr.HTML("Select model."), gr.HTML("Select model."), *clear_comparison_side_outputs(), *clear_comparison_side_outputs(), "", ""), inputs=[], outputs=[comp_model_state_left, comp_model_state_right, comp_problem_markdown_output, comp_answer_markdown_output, comp_problem_grid_html_output_left, comp_problem_grid_html_output_right, comp_samples_grid_output_left, comp_samples_data_state_left, comp_sample_metadata_output_left, comp_sample_response_output_left, comp_sample_index_state_left, comp_samples_grid_output_right, comp_samples_data_state_right, comp_sample_metadata_output_right, comp_sample_response_output_right, comp_sample_index_state_right, comp_problem_id_state, comp_problem_id_input_display]) + comp_model_dropdown_left.change(fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), inputs=[comp_model_dropdown_left, comp_dataset_state], outputs=[dummy_state, comp_problem_grid_html_output_left, comp_model_state_left]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_left]) + comp_model_dropdown_right.change(fn=lambda model, ds: update_problem_grid_and_stats(model, ds, mode='comparison'), inputs=[comp_model_dropdown_right, comp_dataset_state], outputs=[dummy_state, comp_problem_grid_html_output_right, comp_model_state_right]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if prob_id.value and model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_right]) + comp_problem_id_input_display.submit(fn=lambda x: x, inputs=[comp_problem_id_input_display], outputs=[comp_problem_id_state]) + comp_problem_id_state.change(fn=handle_comparison_problem_update, inputs=[comp_problem_id_state, comp_dataset_state], outputs=[comp_problem_markdown_output, comp_answer_markdown_output]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_left') if model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_left, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_left, comp_samples_data_state_left]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_left]).then(fn=lambda prob_id, model_state, ds_state: handle_problem_select(prob_id, model_state, ds_state, mode='comparison_right') if model_state.value else ("","",gr.HTML(""), gr.State([])), inputs=[comp_problem_id_state, comp_model_state_right, comp_dataset_state], outputs=[dummy_state, dummy_state, comp_samples_grid_output_right, comp_samples_data_state_right]).then(fn=handle_first_sample, inputs=[comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]).then(lambda: "0", inputs=[], outputs=[comp_sample_index_state_right]).then(fn=lambda x: x.value if hasattr(x,'value') else x, inputs=[comp_problem_id_state], outputs=[comp_problem_id_input_display]) + comp_sample_index_state_left.change(fn=handle_sample_select, inputs=[comp_sample_index_state_left, comp_samples_data_state_left], outputs=[comp_sample_metadata_output_left, comp_sample_response_output_left]) + comp_sample_index_state_right.change(fn=handle_sample_select, inputs=[comp_sample_index_state_right, comp_samples_data_state_right], outputs=[comp_sample_metadata_output_right, comp_sample_response_output_right]) + # Initial Load + demo.load(fn=update_available_models_for_dropdowns, inputs=[current_dataset_state], outputs=[model_dropdown, comp_model_dropdown_left, comp_model_dropdown_right]) return demo -# <<< Keep monitor_memory_usage function >>> + +# --- Memory Monitor (Adjust thresholds for 10GB limit) --- def monitor_memory_usage(): - """Monitors memory usage and clears caches if thresholds are exceeded.""" + """Monitors memory usage and clears caches if thresholds are exceeded (10GB Limit).""" global db try: process = psutil.Process(os.getpid()) - # Use resident set size (RSS) as a measure of physical RAM usage - memory_info = process.memory_info() - memory_usage_mb = memory_info.rss / (1024 * 1024) - # Get total system memory for context - total_memory_gb = psutil.virtual_memory().total / (1024**3) - print(f"[Memory Monitor] Usage: {memory_usage_mb:.1f} MB / {total_memory_gb:.1f} GB available.") - - # Define thresholds based on 18GB total available RAM - # Threshold 1: ~70% = 12.6 GB - threshold_1_mb = 18 * 1024 * 0.70 - # Threshold 2: ~85% = 15.3 GB - threshold_2_mb = 18 * 1024 * 0.85 - - if db and db.conn: # Only clear caches if DB is active - if memory_usage_mb > threshold_2_mb: - print(f"[Memory Monitor] CRITICAL: Usage ({memory_usage_mb:.1f} MB) exceeds {threshold_2_mb:.1f} MB. Clearing ALL caches.") - db.clear_cache() # Clear all Python-level caches - elif memory_usage_mb > threshold_1_mb: - print(f"[Memory Monitor] WARNING: Usage ({memory_usage_mb:.1f} MB) exceeds {threshold_1_mb:.1f} MB. Clearing response cache.") - db.clear_cache('response') # Clear less critical response cache first - else: - print("[Memory Monitor] DB not active, skipping cache check.") + memory_info = process.memory_info(); memory_usage_mb = memory_info.rss / (1024 * 1024) + total_memory_gb = 10.0 # Assumed total available + print(f"[Memory Monitor] Usage: {memory_usage_mb:.1f} MB / {total_memory_gb:.1f} GB limit.") + # Define thresholds based on 10GB limit + threshold_1_mb = 10 * 1024 * 0.70 # Warn at 7GB + threshold_2_mb = 10 * 1024 * 0.85 # Critical at 8.5GB - # Return status string (optional, could be used in UI if needed) + if db and db.conn: + if memory_usage_mb > threshold_2_mb: + print(f"[Memory Monitor] CRITICAL: Usage ({memory_usage_mb:.1f} MB) > {threshold_2_mb:.1f} MB. Clearing ALL caches.") + db.clear_cache() + elif memory_usage_mb > threshold_1_mb: + print(f"[Memory Monitor] WARNING: Usage ({memory_usage_mb:.1f} MB) > {threshold_1_mb:.1f} MB. Clearing response cache.") + db.clear_cache('response') + else: print("[Memory Monitor] DB not active.") return f"Memory OK: {memory_usage_mb:.1f} MB" + except Exception as e: print(f"[Memory Monitor] Error: {e}"); return "Memory monitor error" - except Exception as e: - print(f"[Memory Monitor] Error: {e}") - return "Memory monitor error" -# <<< Keep __main__ block, ensuring DB initialization happens before UI creation >>> +# --- Main execution block --- +# (Initialization connects to disk, UI launch remains the same) if __name__ == "__main__": - DB_FILE_NAME = "data.db" - # Determine expected path (e.g., in the current directory) - DB_PATH = os.path.abspath(DB_FILE_NAME) - - # --- Database Download/Check --- + DB_FILE_NAME = "data.db"; DB_PATH = os.path.abspath(DB_FILE_NAME) if not os.path.exists(DB_PATH): - print(f"{DB_PATH} not found. Attempting to download from Hugging Face Hub...") + print(f"{DB_PATH} not found. Attempting download...") try: - # Attempt to get token from environment or local file hf_token = os.environ.get("HF_TOKEN") - if not hf_token: - token_path = Path.home() / ".huggingface" / "token" - if token_path.exists(): - print("Using token from ~/.huggingface/token") - hf_token = token_path.read_text().strip() - # Optional: Add interactive prompt for token if needed - # else: hf_token = input("Enter your Hugging Face token: ").strip() - - if not hf_token: - raise ValueError("Hugging Face token not found (set HF_TOKEN env var or place in ~/.huggingface/token).") - - print("Downloading data.db (this might take time)...") - # Download directly to the expected DB_PATH location's directory - downloaded_path = hf_hub_download( - repo_id="CoderBak/OlymMATH-data", - filename=DB_FILE_NAME, - repo_type="dataset", - token=hf_token, - local_dir=os.path.dirname(DB_PATH), # Download to target directory - local_dir_use_symlinks=False # Force copy, avoid symlink issues - ) - # Verify download path matches expected path - if os.path.abspath(downloaded_path) != DB_PATH: - print(f"Warning: Downloaded path '{downloaded_path}' differs from expected path '{DB_PATH}'.") - # Optionally, attempt to move/rename or update DB_PATH - # For simplicity, assume download worked if no exception - DB_PATH = os.path.abspath(downloaded_path) - - print(f"Database downloaded successfully to: {DB_PATH}") + if not hf_token: token_path = Path.home() / ".huggingface" / "token"; hf_token = token_path.read_text().strip() if token_path.exists() else None + if not hf_token: raise ValueError("HF token not found.") + print("Downloading data.db...") + downloaded_path = hf_hub_download(repo_id="CoderBak/OlymMATH-data", filename=DB_FILE_NAME, repo_type="dataset", token=hf_token, local_dir=os.path.dirname(DB_PATH), local_dir_use_symlinks=False) + DB_PATH = os.path.abspath(downloaded_path); print(f"Download complete: {DB_PATH}") except Exception as e: - print(f"Error downloading database: {e}") - # Display error UI and exit - with gr.Blocks() as error_demo: - gr.Markdown(f"# Error: Database Download Failed\n`{str(e)}`\nEnsure `{DB_FILE_NAME}` exists in the script's directory, or provide a valid Hugging Face token and network access.") - error_demo.launch(server_name="0.0.0.0", show_error=True) - exit(1) - - # --- Database Initialization (Loads into Memory) --- - print(f"Initializing ModelDatabase from: {DB_PATH} (loading data into memory)...") + print(f"Error downloading DB: {e}"); + with gr.Blocks() as error_demo: gr.Markdown(f"# Error: DB Download Failed\n`{str(e)}`\nEnsure `{DB_FILE_NAME}` exists or HF token is valid."); error_demo.launch(server_name="0.0.0.0"); exit(1) + + print(f"Initializing ModelDatabase from disk: {DB_PATH}...") start_init = time.time() try: - # Instantiate the database class - this performs the load - db = ModelDatabase(DB_PATH) # db becomes the global instance + db = ModelDatabase(DB_PATH) # Connects to disk except Exception as e: - print(f"Fatal Error during ModelDatabase initialization: {e}") - import traceback - traceback.print_exc() # Print full traceback for debugging - # Display error UI and exit - with gr.Blocks() as error_demo: - gr.Markdown(f"# Error: Database Initialization Failed\n`{str(e)}`\nCheck database file integrity, permissions, and available memory ({psutil.virtual_memory().available / (1024**3):.1f} GB free).") - error_demo.launch(server_name="0.0.0.0", show_error=True) - exit(1) - end_init = time.time() - print(f"ModelDatabase initialized in {end_init - start_init:.2f} seconds.") - # Initial memory check after load - monitor_memory_usage() - - # --- Cleanup Registration --- - def cleanup(): - global db - print("\nRunning cleanup before exit...") - if db: - db.close() # Ensures connection is closed properly - print("Cleanup finished.") - atexit.register(cleanup) # Register cleanup to run on script exit - - # --- Create and Launch UI --- - print("Creating Gradio UI...") - # Pass the initialized (in-memory) db instance to the UI creation function - main_demo = create_ui(db) - - # --- Optional: Periodic Memory Monitor --- - # Use Gradio's `every` parameter on a dummy component if you want checks tied to UI activity, - # or run a separate thread (carefully). For simplicity, we'll rely on checks during requests (if implemented) - # or manual monitoring based on the initial check. - - print("Launching Gradio application...") - # Use queue() for better performance under load - main_demo.queue().launch( - server_name="0.0.0.0", # Listen on all interfaces - share=os.environ.get("GRADIO_SHARE", False), # Allow sharing via env var if needed - inbrowser=False, # Don't automatically open browser - show_error=True, # Show Python errors in browser console - # Optional: Increase workers if needed, but be mindful of memory/CPU - # max_threads=4 - ) - - # Server runs here until interrupted (Ctrl+C) - print("Application is running. Press Ctrl+C to stop.") - # Keep the main thread alive if needed (though launch() usually blocks) - # while True: time.sleep(1) \ No newline at end of file + print(f"Fatal Error during DB initialization: {e}"); import traceback; traceback.print_exc() + with gr.Blocks() as error_demo: gr.Markdown(f"# Error: DB Init Failed\n`{str(e)}`\nCheck file/permissions."); error_demo.launch(server_name="0.0.0.0"); exit(1) + end_init = time.time(); print(f"ModelDatabase initialized in {end_init - start_init:.2f} seconds.") + monitor_memory_usage() # Initial check + + def cleanup(): global db; print("\nRunning cleanup..."); db.close() if db else None; print("Cleanup finished.") + atexit.register(cleanup) + + print("Creating Gradio UI..."); main_demo = create_ui(db) + print("Launching Gradio application..."); + main_demo.queue().launch(server_name="0.0.0.0", share=os.environ.get("GRADIO_SHARE", False), inbrowser=False, show_error=True) + print("Application running. Press Ctrl+C to stop.")