Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -22,6 +22,14 @@ SUBJECT_TRANS = {
|
|
22 |
"组合": "Combinatorics"
|
23 |
}
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
MODEL_TRANS = {
|
26 |
"acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B",
|
27 |
"deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B",
|
@@ -65,6 +73,70 @@ DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"]
|
|
65 |
# 全局数据库实例
|
66 |
db = None
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
class ModelDatabase:
|
69 |
"""Database access class"""
|
70 |
|
@@ -360,6 +432,82 @@ class ModelDatabase:
|
|
360 |
# 清理所有缓存
|
361 |
self.clear_cache()
|
362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
def format_latex(text):
|
364 |
if text is None: return ""
|
365 |
# Process the text for proper LaTeX rendering with KaTeX
|
@@ -372,12 +520,24 @@ def format_latex(text):
|
|
372 |
def format_markdown_with_math(text):
|
373 |
if text is None: return ""
|
374 |
|
375 |
-
#
|
376 |
-
#
|
|
|
|
|
|
|
|
|
|
|
377 |
|
378 |
# Convert newlines for markdown
|
379 |
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
# Return the cleaned text for Gradio's markdown component to render
|
382 |
return text
|
383 |
|
@@ -584,16 +744,9 @@ def handle_comparison_problem_update(problem_id, dataset_state):
|
|
584 |
# Use format_markdown_with_math for proper rendering
|
585 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
586 |
|
587 |
-
#
|
588 |
answer_text = problem_dict.get('answer', '')
|
589 |
-
|
590 |
-
answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
|
591 |
-
|
592 |
-
# 检查答案是否已经包含美元符号,如果没有则添加
|
593 |
-
if '$' not in answer_text and answer_text.strip():
|
594 |
-
answer_text = f"${answer_text}$"
|
595 |
-
|
596 |
-
answer_content = format_markdown_with_math(answer_text)
|
597 |
|
598 |
return problem_content, answer_content
|
599 |
except Exception as e:
|
@@ -634,16 +787,9 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
634 |
# Process problem and answer text for Markdown rendering
|
635 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
636 |
|
637 |
-
#
|
638 |
answer_text = problem_dict.get('answer', '')
|
639 |
-
|
640 |
-
answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
|
641 |
-
|
642 |
-
# 检查答案是否已经包含美元符号,如果没有则添加
|
643 |
-
if '$' not in answer_text and answer_text.strip():
|
644 |
-
answer_text = f"${answer_text}$"
|
645 |
-
|
646 |
-
answer_content = format_markdown_with_math(answer_text)
|
647 |
|
648 |
# For comparison without model, we don't have samples to display
|
649 |
return problem_content, answer_content, "", gr.State([])
|
@@ -673,16 +819,9 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
673 |
# Process problem and answer text for Markdown rendering
|
674 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
675 |
|
676 |
-
#
|
677 |
answer_text = problem_dict.get('answer', '')
|
678 |
-
|
679 |
-
answer_text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', answer_text, flags=re.DOTALL)
|
680 |
-
|
681 |
-
# 检查答案是否已经包含美元符号,如果没有则添加
|
682 |
-
if '$' not in answer_text and answer_text.strip():
|
683 |
-
answer_text = f"${answer_text}$"
|
684 |
-
|
685 |
-
answer_content = format_markdown_with_math(answer_text)
|
686 |
|
687 |
# Rest of the function remains the same
|
688 |
if not responses_data:
|
@@ -709,7 +848,7 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
709 |
samples_per_row = 16 if mode == 'comparison' else 32
|
710 |
|
711 |
# 第一行: 样本 0-samples_per_row
|
712 |
-
samples_grid_html = f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;
|
713 |
|
714 |
for i, resp in enumerate(displayed_samples[:samples_per_row]):
|
715 |
correctness = resp.get('correctness', 0)
|
@@ -737,7 +876,7 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
737 |
# 如果有更多样本,显示第二行
|
738 |
if actual_display_count > samples_per_row:
|
739 |
row_samples = displayed_samples[samples_per_row:2*samples_per_row]
|
740 |
-
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;
|
741 |
|
742 |
for i, resp in enumerate(row_samples):
|
743 |
actual_idx = i + samples_per_row
|
@@ -767,7 +906,7 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
767 |
# 第三行
|
768 |
row_samples = displayed_samples[2*samples_per_row:3*samples_per_row]
|
769 |
if row_samples:
|
770 |
-
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;
|
771 |
|
772 |
for i, resp in enumerate(row_samples):
|
773 |
actual_idx = i + 2*samples_per_row
|
@@ -796,7 +935,7 @@ def handle_problem_select(problem_id_from_js, current_model_state, current_datas
|
|
796 |
if actual_display_count > 3*samples_per_row:
|
797 |
row_samples = displayed_samples[3*samples_per_row:4*samples_per_row]
|
798 |
if row_samples:
|
799 |
-
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;
|
800 |
|
801 |
for i, resp in enumerate(row_samples):
|
802 |
actual_idx = i + 3*samples_per_row
|
@@ -886,6 +1025,54 @@ def create_ui(db_path):
|
|
886 |
global db
|
887 |
db = ModelDatabase(db_path)
|
888 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
889 |
AVAILABLE_DATASETS = db.get_available_datasets()
|
890 |
if not AVAILABLE_DATASETS:
|
891 |
AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] # Fallback
|
@@ -896,9 +1083,9 @@ def create_ui(db_path):
|
|
896 |
body, .gradio-container { font-family: sans-serif; font-size: 0.95em; line-height: 1.6; }
|
897 |
.sample-btn { transition: all 0.15s ease-in-out; }
|
898 |
.sample-btn:hover { transform: translateY(-1px); box-shadow: 0 2px 5px rgba(0,0,0,0.1); }
|
899 |
-
.problem-grid-container { overflow
|
900 |
-
.math-content { overflow
|
901 |
-
.sample-response { overflow
|
902 |
h1, h2, h3, h4, h5 { margin-top: 0.8em; margin-bottom: 0.4em; color: var(--color-text); }
|
903 |
.gradio-tabs > div[role='tablist'] button { font-size: 0.9em; padding: 8px 12px; }
|
904 |
.gr-dropdown select { font-size: 0.9em; }
|
@@ -964,6 +1151,68 @@ def create_ui(db_path):
|
|
964 |
border: 1px solid #ddd;
|
965 |
padding: 4px 8px;
|
966 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
967 |
"""
|
968 |
|
969 |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
|
@@ -989,6 +1238,64 @@ def create_ui(db_path):
|
|
989 |
# 创建占位符State组件替代None
|
990 |
dummy_state = gr.State(value=None)
|
991 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
992 |
with gr.Tabs():
|
993 |
with gr.TabItem("Single Model Analysis"):
|
994 |
with gr.Row(variant='compact'):
|
@@ -1228,6 +1535,83 @@ def create_ui(db_path):
|
|
1228 |
]
|
1229 |
)
|
1230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1231 |
# --- Event Handlers ---
|
1232 |
def update_available_models_for_dropdowns(selected_dataset):
|
1233 |
# This function can be used to update model lists if they are dataset-dependent
|
@@ -1549,6 +1933,37 @@ def create_ui(db_path):
|
|
1549 |
outputs=[sample_number_input]
|
1550 |
)
|
1551 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1552 |
return demo
|
1553 |
|
1554 |
def monitor_memory_usage():
|
@@ -1575,6 +1990,273 @@ def monitor_memory_usage():
|
|
1575 |
except Exception as e:
|
1576 |
return "Memory monitor error"
|
1577 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1578 |
# 修改主函数以使用优化策略
|
1579 |
if __name__ == "__main__":
|
1580 |
DB_PATH = "data.db"
|
@@ -1582,22 +2264,15 @@ if __name__ == "__main__":
|
|
1582 |
# 检查数据库文件是否存在,如果不存在则从 Hugging Face 下载
|
1583 |
if not os.path.exists(DB_PATH):
|
1584 |
try:
|
1585 |
-
# 从环境变量获取 HF_TOKEN
|
1586 |
-
hf_token = os.environ.get("HF_TOKEN")
|
1587 |
-
if not hf_token:
|
1588 |
-
raise ValueError("HF_TOKEN environment variable is not set")
|
1589 |
-
|
1590 |
-
# 从 Hugging Face 下载数据库文件
|
1591 |
DB_PATH = hf_hub_download(
|
1592 |
repo_id="CoderBak/OlymMATH-data",
|
1593 |
filename="data.db",
|
1594 |
-
repo_type="dataset"
|
1595 |
-
token=hf_token
|
1596 |
)
|
1597 |
except Exception as e:
|
1598 |
# 创建一个显示错误信息的简单 Gradio 应用
|
1599 |
with gr.Blocks() as error_demo:
|
1600 |
-
gr.Markdown(f"# Error: Database Download Failed\n{str(e)}
|
1601 |
error_demo.launch(server_name="0.0.0.0")
|
1602 |
exit(1)
|
1603 |
|
|
|
22 |
"组合": "Combinatorics"
|
23 |
}
|
24 |
|
25 |
+
# 英文到中文的翻译表
|
26 |
+
SUBJECT_TRANS_EN_TO_ZH = {
|
27 |
+
"Algebra": "代数",
|
28 |
+
"Number Theory": "数论",
|
29 |
+
"Geometry": "几何",
|
30 |
+
"Combinatorics": "组合"
|
31 |
+
}
|
32 |
+
|
33 |
MODEL_TRANS = {
|
34 |
"acemath-rl-nemotron-7b": "AceMath-RL-Nemotron-7B",
|
35 |
"deepseek-r1-distill-qwen-1.5b": "DeepSeek-R1-Distill-Qwen-1.5B",
|
|
|
73 |
# 全局数据库实例
|
74 |
db = None
|
75 |
|
76 |
+
# 全局缓存for Reference Solutions
|
77 |
+
reference_accuracy_cache = {}
|
78 |
+
|
79 |
+
def precompute_reference_accuracies(db, reference_loader):
|
80 |
+
"""Pre-compute all reference problem accuracies for fast loading"""
|
81 |
+
global reference_accuracy_cache
|
82 |
+
|
83 |
+
if not db or not reference_loader:
|
84 |
+
return
|
85 |
+
|
86 |
+
print("Pre-computing reference problem accuracies...")
|
87 |
+
start_time = time.time()
|
88 |
+
|
89 |
+
problem_ids = reference_loader.get_all_problem_ids()
|
90 |
+
reference_accuracy_cache = {}
|
91 |
+
|
92 |
+
# 获取所有模型一次性
|
93 |
+
all_models = db.get_available_models()
|
94 |
+
print(f"Computing accuracies for {len(problem_ids)} problems across {len(all_models)} models...")
|
95 |
+
|
96 |
+
for i, pid in enumerate(problem_ids):
|
97 |
+
if i % 5 == 0: # 每5个问题打印一次进度
|
98 |
+
print(f"Processing problem {i+1}/{len(problem_ids)}: {pid}")
|
99 |
+
|
100 |
+
try:
|
101 |
+
en_unique_id = f"OlymMATH-HARD-{pid}-EN"
|
102 |
+
zh_unique_id = f"OlymMATH-HARD-{pid}-ZH"
|
103 |
+
|
104 |
+
en_accuracies = []
|
105 |
+
zh_accuracies = []
|
106 |
+
|
107 |
+
for model in all_models:
|
108 |
+
# 英文版本
|
109 |
+
try:
|
110 |
+
_, responses_en = db.get_problem_data(model, "EN-HARD", en_unique_id)
|
111 |
+
if responses_en and len(responses_en) > 0:
|
112 |
+
avg_accuracy_en = sum(r['correctness'] for r in responses_en) / len(responses_en)
|
113 |
+
en_accuracies.append(avg_accuracy_en)
|
114 |
+
except Exception:
|
115 |
+
pass
|
116 |
+
|
117 |
+
# 中文版本
|
118 |
+
try:
|
119 |
+
_, responses_zh = db.get_problem_data(model, "ZH-HARD", zh_unique_id)
|
120 |
+
if responses_zh and len(responses_zh) > 0:
|
121 |
+
avg_accuracy_zh = sum(r['correctness'] for r in responses_zh) / len(responses_zh)
|
122 |
+
zh_accuracies.append(avg_accuracy_zh)
|
123 |
+
except Exception:
|
124 |
+
pass
|
125 |
+
|
126 |
+
# 计算平均值并存储到缓存
|
127 |
+
en_avg = sum(en_accuracies) / len(en_accuracies) if en_accuracies else 0.0
|
128 |
+
zh_avg = sum(zh_accuracies) / len(zh_accuracies) if zh_accuracies else 0.0
|
129 |
+
|
130 |
+
reference_accuracy_cache[pid] = {"EN": en_avg, "ZH": zh_avg}
|
131 |
+
|
132 |
+
except Exception as e:
|
133 |
+
print(f"Error computing accuracy for problem {pid}: {e}")
|
134 |
+
reference_accuracy_cache[pid] = {"EN": 0.0, "ZH": 0.0}
|
135 |
+
|
136 |
+
elapsed_time = time.time() - start_time
|
137 |
+
print(f"✅ Pre-computation completed in {elapsed_time:.2f} seconds")
|
138 |
+
print(f"✅ Cached accuracies for {len(reference_accuracy_cache)} problems")
|
139 |
+
|
140 |
class ModelDatabase:
|
141 |
"""Database access class"""
|
142 |
|
|
|
432 |
# 清理所有缓存
|
433 |
self.clear_cache()
|
434 |
|
435 |
+
class ReferenceDataLoader:
|
436 |
+
"""Load and manage reference solutions data"""
|
437 |
+
|
438 |
+
def __init__(self, jsonl_path):
|
439 |
+
self.jsonl_path = jsonl_path
|
440 |
+
self.reference_data = {}
|
441 |
+
self._load_data()
|
442 |
+
|
443 |
+
def _load_data(self):
|
444 |
+
"""Load data from extra.jsonl"""
|
445 |
+
try:
|
446 |
+
with open(self.jsonl_path, 'r', encoding='utf-8') as f:
|
447 |
+
for line in f:
|
448 |
+
data = json.loads(line.strip())
|
449 |
+
unique_id = data['unique_id']
|
450 |
+
self.reference_data[unique_id] = data
|
451 |
+
except Exception as e:
|
452 |
+
print(f"Error loading reference data: {e}")
|
453 |
+
|
454 |
+
def get_problem_data(self, unique_id):
|
455 |
+
"""Get reference data for a specific problem ID"""
|
456 |
+
return self.reference_data.get(unique_id)
|
457 |
+
|
458 |
+
def get_all_problem_ids(self):
|
459 |
+
"""Get all available problem IDs"""
|
460 |
+
return sorted(self.reference_data.keys())
|
461 |
+
|
462 |
+
def calculate_reference_problem_accuracy(db, unique_id):
|
463 |
+
"""Calculate average accuracy for a reference problem across all models for both EN and ZH versions"""
|
464 |
+
try:
|
465 |
+
# 构建英文和中文版本的unique_id
|
466 |
+
en_unique_id = f"OlymMATH-HARD-{unique_id}-EN"
|
467 |
+
zh_unique_id = f"OlymMATH-HARD-{unique_id}-ZH"
|
468 |
+
|
469 |
+
print(f"Calculating accuracy for problem {unique_id}: EN={en_unique_id}, ZH={zh_unique_id}")
|
470 |
+
|
471 |
+
accuracies = {"EN": [], "ZH": []}
|
472 |
+
|
473 |
+
# 获取所有模型
|
474 |
+
all_models = db.get_available_models()
|
475 |
+
print(f"Found {len(all_models)} models in database")
|
476 |
+
|
477 |
+
for model in all_models:
|
478 |
+
# 英文版本
|
479 |
+
try:
|
480 |
+
_, responses_en = db.get_problem_data(model, "EN-HARD", en_unique_id)
|
481 |
+
if responses_en and len(responses_en) > 0:
|
482 |
+
avg_accuracy_en = sum(r['correctness'] for r in responses_en) / len(responses_en)
|
483 |
+
accuracies["EN"].append(avg_accuracy_en)
|
484 |
+
print(f" Model {model} EN: {avg_accuracy_en:.2%} ({len(responses_en)} responses)")
|
485 |
+
except Exception as e:
|
486 |
+
print(f" Error getting EN data for model {model}: {e}")
|
487 |
+
pass
|
488 |
+
|
489 |
+
# 中文版本
|
490 |
+
try:
|
491 |
+
_, responses_zh = db.get_problem_data(model, "ZH-HARD", zh_unique_id)
|
492 |
+
if responses_zh and len(responses_zh) > 0:
|
493 |
+
avg_accuracy_zh = sum(r['correctness'] for r in responses_zh) / len(responses_zh)
|
494 |
+
accuracies["ZH"].append(avg_accuracy_zh)
|
495 |
+
print(f" Model {model} ZH: {avg_accuracy_zh:.2%} ({len(responses_zh)} responses)")
|
496 |
+
except Exception as e:
|
497 |
+
print(f" Error getting ZH data for model {model}: {e}")
|
498 |
+
pass
|
499 |
+
|
500 |
+
# 计算平均值
|
501 |
+
en_avg = sum(accuracies["EN"]) / len(accuracies["EN"]) if accuracies["EN"] else 0.0
|
502 |
+
zh_avg = sum(accuracies["ZH"]) / len(accuracies["ZH"]) if accuracies["ZH"] else 0.0
|
503 |
+
|
504 |
+
print(f"Final averages for problem {unique_id}: EN={en_avg:.2%} (from {len(accuracies['EN'])} models), ZH={zh_avg:.2%} (from {len(accuracies['ZH'])} models)")
|
505 |
+
|
506 |
+
return en_avg, zh_avg
|
507 |
+
except Exception as e:
|
508 |
+
print(f"Error calculating accuracy for problem {unique_id}: {e}")
|
509 |
+
return 0.0, 0.0
|
510 |
+
|
511 |
def format_latex(text):
|
512 |
if text is None: return ""
|
513 |
# Process the text for proper LaTeX rendering with KaTeX
|
|
|
520 |
def format_markdown_with_math(text):
|
521 |
if text is None: return ""
|
522 |
|
523 |
+
# Convert LaTeX delimiters first - same logic as format_solution_latex
|
524 |
+
# Convert $$xxx$$ to \[xxx\] (display math)
|
525 |
+
text = re.sub(r'\$\$(.*?)\$\$', r'\\[\1\\]', text, flags=re.DOTALL)
|
526 |
+
|
527 |
+
# Convert $xxx$ to \(xxx\) (inline math)
|
528 |
+
# Be careful not to match already converted \[...\] content
|
529 |
+
text = re.sub(r'(?<!\\)\$([^$\n]+?)\$(?!\])', r'\\(\1\\)', text)
|
530 |
|
531 |
# Convert newlines for markdown
|
532 |
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
533 |
|
534 |
+
# Clean up excessive newlines
|
535 |
+
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
536 |
+
|
537 |
+
# Debug: Print if aligned environment detected
|
538 |
+
if '\\begin{aligned}' in text:
|
539 |
+
print(f"LaTeX aligned environment detected in text (first 200 chars): {text[:200]}...")
|
540 |
+
|
541 |
# Return the cleaned text for Gradio's markdown component to render
|
542 |
return text
|
543 |
|
|
|
744 |
# Use format_markdown_with_math for proper rendering
|
745 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
746 |
|
747 |
+
# Use special answer formatting
|
748 |
answer_text = problem_dict.get('answer', '')
|
749 |
+
answer_content = format_answer_with_math(answer_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
|
751 |
return problem_content, answer_content
|
752 |
except Exception as e:
|
|
|
787 |
# Process problem and answer text for Markdown rendering
|
788 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
789 |
|
790 |
+
# Use special answer formatting
|
791 |
answer_text = problem_dict.get('answer', '')
|
792 |
+
answer_content = format_answer_with_math(answer_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
|
794 |
# For comparison without model, we don't have samples to display
|
795 |
return problem_content, answer_content, "", gr.State([])
|
|
|
819 |
# Process problem and answer text for Markdown rendering
|
820 |
problem_content = format_markdown_with_math(problem_dict.get('problem', ''))
|
821 |
|
822 |
+
# Use special answer formatting
|
823 |
answer_text = problem_dict.get('answer', '')
|
824 |
+
answer_content = format_answer_with_math(answer_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
825 |
|
826 |
# Rest of the function remains the same
|
827 |
if not responses_data:
|
|
|
848 |
samples_per_row = 16 if mode == 'comparison' else 32
|
849 |
|
850 |
# 第一行: 样本 0-samples_per_row
|
851 |
+
samples_grid_html = f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">'
|
852 |
|
853 |
for i, resp in enumerate(displayed_samples[:samples_per_row]):
|
854 |
correctness = resp.get('correctness', 0)
|
|
|
876 |
# 如果有更多样本,显示第二行
|
877 |
if actual_display_count > samples_per_row:
|
878 |
row_samples = displayed_samples[samples_per_row:2*samples_per_row]
|
879 |
+
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">'
|
880 |
|
881 |
for i, resp in enumerate(row_samples):
|
882 |
actual_idx = i + samples_per_row
|
|
|
906 |
# 第三行
|
907 |
row_samples = displayed_samples[2*samples_per_row:3*samples_per_row]
|
908 |
if row_samples:
|
909 |
+
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">'
|
910 |
|
911 |
for i, resp in enumerate(row_samples):
|
912 |
actual_idx = i + 2*samples_per_row
|
|
|
935 |
if actual_display_count > 3*samples_per_row:
|
936 |
row_samples = displayed_samples[3*samples_per_row:4*samples_per_row]
|
937 |
if row_samples:
|
938 |
+
samples_grid_html += f'<div style="display: grid; grid-template-columns: repeat({samples_per_row}, 1fr); gap: 2px; margin-bottom: 4px;">'
|
939 |
|
940 |
for i, resp in enumerate(row_samples):
|
941 |
actual_idx = i + 3*samples_per_row
|
|
|
1025 |
global db
|
1026 |
db = ModelDatabase(db_path)
|
1027 |
|
1028 |
+
# Initialize reference data loader with better path handling
|
1029 |
+
reference_loader = None
|
1030 |
+
# Try multiple possible paths for extra.jsonl
|
1031 |
+
possible_paths = [
|
1032 |
+
os.path.join(os.path.dirname(db_path), "extra.jsonl"),
|
1033 |
+
os.path.join(os.getcwd(), "extra.jsonl"),
|
1034 |
+
"extra.jsonl"
|
1035 |
+
]
|
1036 |
+
|
1037 |
+
for extra_jsonl_path in possible_paths:
|
1038 |
+
if os.path.exists(extra_jsonl_path):
|
1039 |
+
try:
|
1040 |
+
reference_loader = ReferenceDataLoader(extra_jsonl_path)
|
1041 |
+
print(f"Successfully loaded reference data from: {extra_jsonl_path}")
|
1042 |
+
break
|
1043 |
+
except Exception as e:
|
1044 |
+
print(f"Error loading reference data from {extra_jsonl_path}: {e}")
|
1045 |
+
continue
|
1046 |
+
|
1047 |
+
# If not found locally, try to download from Hugging Face
|
1048 |
+
if not reference_loader:
|
1049 |
+
try:
|
1050 |
+
print("Attempting to download extra.jsonl from Hugging Face...")
|
1051 |
+
extra_jsonl_path = hf_hub_download(
|
1052 |
+
repo_id="CoderBak/OlymMATH-data",
|
1053 |
+
filename="extra.jsonl",
|
1054 |
+
repo_type="dataset"
|
1055 |
+
)
|
1056 |
+
reference_loader = ReferenceDataLoader(extra_jsonl_path)
|
1057 |
+
print(f"Successfully downloaded and loaded reference data from: {extra_jsonl_path}")
|
1058 |
+
except Exception as e:
|
1059 |
+
print(f"Failed to download extra.jsonl from Hugging Face: {e}")
|
1060 |
+
|
1061 |
+
if not reference_loader:
|
1062 |
+
print("Warning: extra.jsonl not found in any of the expected locations:")
|
1063 |
+
for path in possible_paths:
|
1064 |
+
print(f" - {path}")
|
1065 |
+
print("Reference Solutions tab will not be available.")
|
1066 |
+
else:
|
1067 |
+
# Test the reference data availability
|
1068 |
+
test_reference_data_availability(db, reference_loader)
|
1069 |
+
|
1070 |
+
# Pre-compute reference problem accuracies for fast loading
|
1071 |
+
precompute_reference_accuracies(db, reference_loader)
|
1072 |
+
|
1073 |
+
# Test LaTeX formatting
|
1074 |
+
test_latex_formatting()
|
1075 |
+
|
1076 |
AVAILABLE_DATASETS = db.get_available_datasets()
|
1077 |
if not AVAILABLE_DATASETS:
|
1078 |
AVAILABLE_DATASETS = ["EN-HARD", "EN-EASY", "ZH-HARD", "ZH-EASY"] # Fallback
|
|
|
1083 |
body, .gradio-container { font-family: sans-serif; font-size: 0.95em; line-height: 1.6; }
|
1084 |
.sample-btn { transition: all 0.15s ease-in-out; }
|
1085 |
.sample-btn:hover { transform: translateY(-1px); box-shadow: 0 2px 5px rgba(0,0,0,0.1); }
|
1086 |
+
.problem-grid-container { overflow: visible !important; }
|
1087 |
+
.math-content { overflow: visible !important; padding: 5px; }
|
1088 |
+
.sample-response { overflow: visible !important; max-height: none !important; height: auto !important; }
|
1089 |
h1, h2, h3, h4, h5 { margin-top: 0.8em; margin-bottom: 0.4em; color: var(--color-text); }
|
1090 |
.gradio-tabs > div[role='tablist'] button { font-size: 0.9em; padding: 8px 12px; }
|
1091 |
.gr-dropdown select { font-size: 0.9em; }
|
|
|
1151 |
border: 1px solid #ddd;
|
1152 |
padding: 4px 8px;
|
1153 |
}
|
1154 |
+
|
1155 |
+
/* 隐藏滚动条但保留功能 */
|
1156 |
+
::-webkit-scrollbar {
|
1157 |
+
display: none !important;
|
1158 |
+
width: 0px !important;
|
1159 |
+
height: 0px !important;
|
1160 |
+
}
|
1161 |
+
|
1162 |
+
/* 主容器禁用滚动 */
|
1163 |
+
.gradio-container {
|
1164 |
+
overflow-x: hidden !important;
|
1165 |
+
}
|
1166 |
+
|
1167 |
+
/* Gradio组件容器 */
|
1168 |
+
.gradio-row, .gradio-column {
|
1169 |
+
overflow: visible !important;
|
1170 |
+
max-height: none !important;
|
1171 |
+
}
|
1172 |
+
|
1173 |
+
/* HTML组件 */
|
1174 |
+
.gr-html {
|
1175 |
+
overflow: visible !important;
|
1176 |
+
max-height: none !important;
|
1177 |
+
}
|
1178 |
+
|
1179 |
+
/* Markdown组件保持可见 */
|
1180 |
+
.gr-markdown {
|
1181 |
+
overflow: visible !important;
|
1182 |
+
max-height: none !important;
|
1183 |
+
}
|
1184 |
+
|
1185 |
+
/* 特定的问题网格容器 */
|
1186 |
+
#ref-problem-grid-container, #problem-grid-container, #comp-problem-grid-container-left, #comp-problem-grid-container-right {
|
1187 |
+
overflow: visible !important;
|
1188 |
+
max-height: none !important;
|
1189 |
+
height: auto !important;
|
1190 |
+
}
|
1191 |
+
|
1192 |
+
/* 样本网格 */
|
1193 |
+
.sample-grid-btn {
|
1194 |
+
overflow: visible !important;
|
1195 |
+
}
|
1196 |
+
|
1197 |
+
/* 确保内容区域不会产生滚动条 */
|
1198 |
+
.gr-form, .gr-box {
|
1199 |
+
overflow: visible !important;
|
1200 |
+
max-height: none !important;
|
1201 |
+
}
|
1202 |
+
|
1203 |
+
/* Reference Solutions - 禁止Solution部分的滚动 */
|
1204 |
+
#ref-solution {
|
1205 |
+
overflow: hidden !important;
|
1206 |
+
max-height: none !important;
|
1207 |
+
height: auto !important;
|
1208 |
+
}
|
1209 |
+
|
1210 |
+
/* 确保Solution内容容器也禁止滚动 */
|
1211 |
+
#ref-solution .gr-markdown {
|
1212 |
+
overflow: hidden !important;
|
1213 |
+
max-height: none !important;
|
1214 |
+
height: auto !important;
|
1215 |
+
}
|
1216 |
"""
|
1217 |
|
1218 |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky)) as demo:
|
|
|
1238 |
# 创建占位符State组件替代None
|
1239 |
dummy_state = gr.State(value=None)
|
1240 |
|
1241 |
+
# Add JavaScript for handling problem grid clicks
|
1242 |
+
demo.load(lambda: None, js="""
|
1243 |
+
() => {
|
1244 |
+
// Handle problem button clicks for single model tab
|
1245 |
+
function setupProblemGridListeners() {
|
1246 |
+
document.addEventListener('click', function(e) {
|
1247 |
+
if (e.target.closest('.problem-btn')) {
|
1248 |
+
const problemBtn = e.target.closest('.problem-btn');
|
1249 |
+
const problemId = problemBtn.getAttribute('data-problem-id');
|
1250 |
+
if (problemId) {
|
1251 |
+
const problemInput = document.getElementById('problem-state-input');
|
1252 |
+
if (problemInput) {
|
1253 |
+
problemInput.querySelector('input').value = problemId;
|
1254 |
+
problemInput.querySelector('input').dispatchEvent(new Event('input', {bubbles: true}));
|
1255 |
+
}
|
1256 |
+
}
|
1257 |
+
}
|
1258 |
+
|
1259 |
+
// Handle comparison problem button clicks
|
1260 |
+
if (e.target.closest('#comp-problem-grid-container-left .problem-btn') ||
|
1261 |
+
e.target.closest('#comp-problem-grid-container-right .problem-btn')) {
|
1262 |
+
const problemBtn = e.target.closest('.problem-btn');
|
1263 |
+
const problemId = problemBtn.getAttribute('data-problem-id');
|
1264 |
+
if (problemId) {
|
1265 |
+
const problemInput = document.getElementById('comp-problem-state-input');
|
1266 |
+
if (problemInput) {
|
1267 |
+
problemInput.querySelector('input').value = problemId;
|
1268 |
+
problemInput.querySelector('input').dispatchEvent(new Event('input', {bubbles: true}));
|
1269 |
+
}
|
1270 |
+
}
|
1271 |
+
}
|
1272 |
+
|
1273 |
+
// Handle reference problem button clicks
|
1274 |
+
if (e.target.closest('#ref-problem-grid-container .ref-problem-btn')) {
|
1275 |
+
const problemBtn = e.target.closest('.ref-problem-btn');
|
1276 |
+
const problemId = problemBtn.getAttribute('data-problem-id');
|
1277 |
+
if (problemId) {
|
1278 |
+
const problemInput = document.getElementById('ref-problem-state-input');
|
1279 |
+
if (problemInput) {
|
1280 |
+
problemInput.querySelector('input').value = problemId;
|
1281 |
+
problemInput.querySelector('input').dispatchEvent(new Event('input', {bubbles: true}));
|
1282 |
+
}
|
1283 |
+
}
|
1284 |
+
}
|
1285 |
+
});
|
1286 |
+
}
|
1287 |
+
|
1288 |
+
// Set up listeners initially and after any DOM changes
|
1289 |
+
setupProblemGridListeners();
|
1290 |
+
|
1291 |
+
// Re-setup listeners whenever the DOM changes (for dynamic content)
|
1292 |
+
const observer = new MutationObserver(function(mutations) {
|
1293 |
+
setupProblemGridListeners();
|
1294 |
+
});
|
1295 |
+
observer.observe(document.body, {childList: true, subtree: true});
|
1296 |
+
}
|
1297 |
+
""")
|
1298 |
+
|
1299 |
with gr.Tabs():
|
1300 |
with gr.TabItem("Single Model Analysis"):
|
1301 |
with gr.Row(variant='compact'):
|
|
|
1535 |
]
|
1536 |
)
|
1537 |
|
1538 |
+
with gr.TabItem("Reference Solutions"):
|
1539 |
+
with gr.Row(variant='compact'):
|
1540 |
+
with gr.Column(scale=1, min_width=280):
|
1541 |
+
ref_problem_state_input = gr.Textbox(
|
1542 |
+
value="",
|
1543 |
+
elem_id="ref-problem-state-input",
|
1544 |
+
visible=True,
|
1545 |
+
label="Enter Problem ID",
|
1546 |
+
container=True,
|
1547 |
+
interactive=True,
|
1548 |
+
every=0.5
|
1549 |
+
)
|
1550 |
+
|
1551 |
+
with gr.Column(scale=3, min_width=400):
|
1552 |
+
gr.Markdown("#### Problem Grid (OlymMATH-HARD: All models avg. acc. - Top: EN, Bottom: ZH)")
|
1553 |
+
ref_problem_grid_html_output = gr.HTML(
|
1554 |
+
value="<div>Loading reference data...</div>",
|
1555 |
+
elem_id="ref-problem-grid-container"
|
1556 |
+
)
|
1557 |
+
|
1558 |
+
# 问题内容显示区域 - 左右分布
|
1559 |
+
with gr.Row(variant='compact'):
|
1560 |
+
# 左侧:问题信息
|
1561 |
+
with gr.Column(scale=1):
|
1562 |
+
gr.Markdown("#### Problem (EN)")
|
1563 |
+
ref_problem_en_output = gr.Markdown(
|
1564 |
+
"Please select a problem.",
|
1565 |
+
latex_delimiters=[
|
1566 |
+
{"left": "$", "right": "$", "display": False},
|
1567 |
+
{"left": "$$", "right": "$$", "display": True},
|
1568 |
+
{"left": "\\(", "right": "\\)", "display": False},
|
1569 |
+
{"left": "\\[", "right": "\\]", "display": True}
|
1570 |
+
]
|
1571 |
+
)
|
1572 |
+
|
1573 |
+
gr.Markdown("#### Problem (ZH)")
|
1574 |
+
ref_problem_zh_output = gr.Markdown(
|
1575 |
+
"Please select a problem.",
|
1576 |
+
latex_delimiters=[
|
1577 |
+
{"left": "$", "right": "$", "display": False},
|
1578 |
+
{"left": "$$", "right": "$$", "display": True},
|
1579 |
+
{"left": "\\(", "right": "\\)", "display": False},
|
1580 |
+
{"left": "\\[", "right": "\\]", "display": True}
|
1581 |
+
]
|
1582 |
+
)
|
1583 |
+
|
1584 |
+
gr.Markdown("#### Subject")
|
1585 |
+
ref_subject_output = gr.Markdown("Please select a problem.")
|
1586 |
+
|
1587 |
+
gr.Markdown("#### Answer")
|
1588 |
+
ref_answer_output = gr.Markdown(
|
1589 |
+
"Please select a problem.",
|
1590 |
+
latex_delimiters=[
|
1591 |
+
{"left": "$", "right": "$", "display": False},
|
1592 |
+
{"left": "$$", "right": "$$", "display": True},
|
1593 |
+
{"left": "\\(", "right": "\\)", "display": False},
|
1594 |
+
{"left": "\\[", "right": "\\]", "display": True}
|
1595 |
+
]
|
1596 |
+
)
|
1597 |
+
|
1598 |
+
# 右侧:解答
|
1599 |
+
with gr.Column(scale=1):
|
1600 |
+
gr.Markdown("#### Solution")
|
1601 |
+
ref_solution_output = gr.Markdown(
|
1602 |
+
"Please select a problem.",
|
1603 |
+
elem_id="ref-solution",
|
1604 |
+
latex_delimiters=[
|
1605 |
+
{"left": "$", "right": "$", "display": False},
|
1606 |
+
{"left": "$$", "right": "$$", "display": True},
|
1607 |
+
{"left": "\\(", "right": "\\)", "display": False},
|
1608 |
+
{"left": "\\[", "right": "\\]", "display": True},
|
1609 |
+
{"left": "\\begin{align}", "right": "\\end{align}", "display": True},
|
1610 |
+
{"left": "\\begin{aligned}", "right": "\\end{aligned}", "display": True},
|
1611 |
+
{"left": "\\begin{equation}", "right": "\\end{equation}", "display": True}
|
1612 |
+
]
|
1613 |
+
)
|
1614 |
+
|
1615 |
# --- Event Handlers ---
|
1616 |
def update_available_models_for_dropdowns(selected_dataset):
|
1617 |
# This function can be used to update model lists if they are dataset-dependent
|
|
|
1933 |
outputs=[sample_number_input]
|
1934 |
)
|
1935 |
|
1936 |
+
# 为引用解决方案标签页添加处理器
|
1937 |
+
# 初始化引用问题网格
|
1938 |
+
demo.load(
|
1939 |
+
fn=lambda: create_reference_problem_grid_html(reference_loader, db),
|
1940 |
+
inputs=[],
|
1941 |
+
outputs=[ref_problem_grid_html_output]
|
1942 |
+
)
|
1943 |
+
|
1944 |
+
# 引用问题选择事件
|
1945 |
+
ref_problem_state_input.change(
|
1946 |
+
fn=handle_reference_problem_select,
|
1947 |
+
inputs=[ref_problem_state_input, gr.State(reference_loader)],
|
1948 |
+
outputs=[ref_problem_en_output, ref_problem_zh_output, ref_subject_output, ref_answer_output, ref_solution_output]
|
1949 |
+
)
|
1950 |
+
|
1951 |
+
# This is the crucial link: problem_state_input is changed by user, triggers this Python callback.
|
1952 |
+
problem_state_input.change(
|
1953 |
+
fn=handle_problem_select,
|
1954 |
+
inputs=[problem_state_input, current_model_state, current_dataset_state],
|
1955 |
+
outputs=[problem_markdown_output, answer_markdown_output, samples_grid_output, current_samples_data_state]
|
1956 |
+
).then(
|
1957 |
+
# 重置Sample Number为0
|
1958 |
+
fn=lambda: "0",
|
1959 |
+
inputs=[],
|
1960 |
+
outputs=[sample_number_input]
|
1961 |
+
).then(
|
1962 |
+
fn=handle_first_sample,
|
1963 |
+
inputs=[current_samples_data_state],
|
1964 |
+
outputs=[sample_metadata_output, sample_response_output]
|
1965 |
+
)
|
1966 |
+
|
1967 |
return demo
|
1968 |
|
1969 |
def monitor_memory_usage():
|
|
|
1990 |
except Exception as e:
|
1991 |
return "Memory monitor error"
|
1992 |
|
1993 |
+
def create_reference_problem_grid_html(reference_loader, db):
|
1994 |
+
"""Create HTML for reference problem grid with average accuracies (using cache)"""
|
1995 |
+
global reference_accuracy_cache
|
1996 |
+
|
1997 |
+
if not db:
|
1998 |
+
return "<div>Database not available.</div>"
|
1999 |
+
|
2000 |
+
if not reference_loader:
|
2001 |
+
return "<div><strong>No reference data available.</strong><br>Please ensure <code>extra.jsonl</code> file is in the same directory as the database file or in the current working directory.</div>"
|
2002 |
+
|
2003 |
+
problem_ids = reference_loader.get_all_problem_ids()
|
2004 |
+
if not problem_ids:
|
2005 |
+
return "<div>No reference problems found in extra.jsonl file.</div>"
|
2006 |
+
|
2007 |
+
# 如果缓存为空,返回加载提示
|
2008 |
+
if not reference_accuracy_cache:
|
2009 |
+
return "<div><strong>Computing problem accuracies...</strong><br>This may take a moment on first load.</div>"
|
2010 |
+
|
2011 |
+
print(f"Using cached accuracies for {len(problem_ids)} reference problems")
|
2012 |
+
|
2013 |
+
# 创建两行网格:第一行英文,第二行中文
|
2014 |
+
custom_style = "<style>.ref-problem-btn, .ref-problem-btn div { color: white !important; }</style>"
|
2015 |
+
|
2016 |
+
html_en = ""
|
2017 |
+
html_zh = ""
|
2018 |
+
|
2019 |
+
# 按数字顺序排序
|
2020 |
+
sorted_problem_ids = sorted(problem_ids, key=int)
|
2021 |
+
|
2022 |
+
for pid in sorted_problem_ids:
|
2023 |
+
# 从缓存获取准确率
|
2024 |
+
accuracy_data = reference_accuracy_cache.get(pid, {"EN": 0.0, "ZH": 0.0})
|
2025 |
+
en_acc = accuracy_data["EN"]
|
2026 |
+
zh_acc = accuracy_data["ZH"]
|
2027 |
+
|
2028 |
+
# 英文版本按钮
|
2029 |
+
en_bg_color = get_gradient_color(en_acc)
|
2030 |
+
en_acc_pct = int(en_acc * 100)
|
2031 |
+
html_en += f"""
|
2032 |
+
<div
|
2033 |
+
data-problem-id="{pid}"
|
2034 |
+
class="ref-problem-btn"
|
2035 |
+
title="ID: {pid} (EN) - Avg Acc: {en_acc_pct}%"
|
2036 |
+
style='background-color: {en_bg_color}; color: white !important;
|
2037 |
+
border-radius: 4px; padding: 5px; text-align: center; font-size: 0.7em;
|
2038 |
+
min-height: 36px; user-select: none; width: 100%;
|
2039 |
+
display: flex; flex-direction: column; justify-content: center;
|
2040 |
+
overflow: hidden; text-overflow: ellipsis; white-space: nowrap; cursor: pointer;'>
|
2041 |
+
<div style="font-weight: bold; color: white !important;">{pid}</div>
|
2042 |
+
<div style="color: white !important;">{en_acc_pct}%</div>
|
2043 |
+
</div>
|
2044 |
+
"""
|
2045 |
+
|
2046 |
+
# 中文版本按钮
|
2047 |
+
zh_bg_color = get_gradient_color(zh_acc)
|
2048 |
+
zh_acc_pct = int(zh_acc * 100)
|
2049 |
+
html_zh += f"""
|
2050 |
+
<div
|
2051 |
+
data-problem-id="{pid}"
|
2052 |
+
class="ref-problem-btn"
|
2053 |
+
title="ID: {pid} (ZH) - Avg Acc: {zh_acc_pct}%"
|
2054 |
+
style='background-color: {zh_bg_color}; color: white !important;
|
2055 |
+
border-radius: 4px; padding: 5px; text-align: center; font-size: 0.7em;
|
2056 |
+
min-height: 36px; user-select: none; width: 100%;
|
2057 |
+
display: flex; flex-direction: column; justify-content: center;
|
2058 |
+
overflow: hidden; text-overflow: ellipsis; white-space: nowrap; cursor: pointer;'>
|
2059 |
+
<div style="font-weight: bold; color: white !important;">{pid}</div>
|
2060 |
+
<div style="color: white !important;">{zh_acc_pct}%</div>
|
2061 |
+
</div>
|
2062 |
+
"""
|
2063 |
+
|
2064 |
+
# 计算网格列数(根据问题数量)
|
2065 |
+
grid_cols = len(sorted_problem_ids) if len(sorted_problem_ids) <= 30 else 30
|
2066 |
+
|
2067 |
+
# 组合成完整的HTML
|
2068 |
+
grid_html = f"""
|
2069 |
+
{custom_style}
|
2070 |
+
<div style='margin-bottom: 10px;'>
|
2071 |
+
<div style='display: grid; grid-template-columns: repeat({grid_cols}, 1fr); gap: 2px;'>{html_en}</div>
|
2072 |
+
</div>
|
2073 |
+
<div>
|
2074 |
+
<div style='display: grid; grid-template-columns: repeat({grid_cols}, 1fr); gap: 2px;'>{html_zh}</div>
|
2075 |
+
</div>
|
2076 |
+
"""
|
2077 |
+
return grid_html
|
2078 |
+
|
2079 |
+
def handle_reference_problem_select(problem_id, reference_loader):
|
2080 |
+
"""Handle reference problem selection and display all information"""
|
2081 |
+
if not problem_id or not reference_loader:
|
2082 |
+
return ("Please select a problem.", "Please select a problem.",
|
2083 |
+
"Please select a problem.", "Please select a problem.", "Please select a problem.")
|
2084 |
+
|
2085 |
+
try:
|
2086 |
+
problem_id_int = int(problem_id)
|
2087 |
+
except ValueError:
|
2088 |
+
return ("Please enter a valid problem ID.", "Please enter a valid problem ID.",
|
2089 |
+
"Please enter a valid problem ID.", "Please enter a valid problem ID.", "Please enter a valid problem ID.")
|
2090 |
+
|
2091 |
+
reference_data = reference_loader.get_problem_data(problem_id_int)
|
2092 |
+
if not reference_data:
|
2093 |
+
error_msg = f"Problem {problem_id_int} not found in reference data."
|
2094 |
+
return (error_msg, error_msg, "No subject available.", "No answer available.", "Solution not available.")
|
2095 |
+
|
2096 |
+
# 格式化各个部分
|
2097 |
+
en_problem = format_markdown_with_math(reference_data.get('en_problem', 'Problem (EN) not available.'))
|
2098 |
+
zh_problem = format_markdown_with_math(reference_data.get('zh_problem', 'Problem (ZH) not available.'))
|
2099 |
+
|
2100 |
+
# 处理答案格式 - 使用特殊的答案格式处理
|
2101 |
+
answer_text = reference_data.get('answer', 'No answer available.')
|
2102 |
+
answer = format_answer_with_math(answer_text)
|
2103 |
+
|
2104 |
+
# 科目显示
|
2105 |
+
subject_en = reference_data.get('subject', 'Unknown')
|
2106 |
+
subject_zh = SUBJECT_TRANS_EN_TO_ZH.get(subject_en, subject_en)
|
2107 |
+
subject_display = f"**{subject_en}** / **{subject_zh}**"
|
2108 |
+
|
2109 |
+
# Solution - 使用solution字段,通常是中文解答
|
2110 |
+
solution_text = reference_data.get('solution', 'Solution not available.')
|
2111 |
+
if solution_text != 'Solution not available.':
|
2112 |
+
solution = format_solution_latex(solution_text)
|
2113 |
+
else:
|
2114 |
+
solution = solution_text
|
2115 |
+
|
2116 |
+
return (en_problem, zh_problem, subject_display, answer, solution)
|
2117 |
+
|
2118 |
+
def test_reference_data_availability(db, reference_loader):
|
2119 |
+
"""Test function to check if reference data is available"""
|
2120 |
+
print("=== Reference Data Availability Test ===")
|
2121 |
+
|
2122 |
+
# Test database
|
2123 |
+
if not db:
|
2124 |
+
print("❌ Database is not available")
|
2125 |
+
return False
|
2126 |
+
|
2127 |
+
# Check database schema
|
2128 |
+
try:
|
2129 |
+
cursor = db.conn.cursor()
|
2130 |
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
2131 |
+
tables = [row[0] for row in cursor.fetchall()]
|
2132 |
+
print(f"✅ Database tables: {tables}")
|
2133 |
+
|
2134 |
+
# Check problems table
|
2135 |
+
cursor.execute("SELECT COUNT(*) FROM problems")
|
2136 |
+
problem_count = cursor.fetchone()[0]
|
2137 |
+
print(f"✅ Problems table: {problem_count} problems")
|
2138 |
+
|
2139 |
+
# Check responses table
|
2140 |
+
cursor.execute("SELECT COUNT(*) FROM responses")
|
2141 |
+
response_count = cursor.fetchone()[0]
|
2142 |
+
print(f"✅ Responses table: {response_count} responses")
|
2143 |
+
|
2144 |
+
# Check unique datasets
|
2145 |
+
cursor.execute("SELECT DISTINCT dataset FROM responses")
|
2146 |
+
datasets = [row[0] for row in cursor.fetchall()]
|
2147 |
+
print(f"✅ Available datasets: {datasets}")
|
2148 |
+
|
2149 |
+
# Check some sample unique_ids from problems
|
2150 |
+
cursor.execute("SELECT unique_id FROM problems LIMIT 10")
|
2151 |
+
sample_ids = [row[0] for row in cursor.fetchall()]
|
2152 |
+
print(f"✅ Sample problem unique_ids: {sample_ids}")
|
2153 |
+
|
2154 |
+
except Exception as e:
|
2155 |
+
print(f"❌ Error checking database schema: {e}")
|
2156 |
+
|
2157 |
+
models = db.get_available_models()
|
2158 |
+
print(f"✅ Database connected: {len(models)} models available")
|
2159 |
+
|
2160 |
+
# Test reference loader
|
2161 |
+
if not reference_loader:
|
2162 |
+
print("❌ Reference loader is not available (extra.jsonl not found)")
|
2163 |
+
return False
|
2164 |
+
|
2165 |
+
problem_ids = reference_loader.get_all_problem_ids()
|
2166 |
+
print(f"✅ Reference loader: {len(problem_ids)} problems available: {problem_ids}")
|
2167 |
+
|
2168 |
+
# Test a specific problem (simplified test)
|
2169 |
+
if problem_ids:
|
2170 |
+
test_id = problem_ids[0]
|
2171 |
+
en_unique_id = f"OlymMATH-HARD-{test_id}-EN"
|
2172 |
+
zh_unique_id = f"OlymMATH-HARD-{test_id}-ZH"
|
2173 |
+
|
2174 |
+
print(f"Testing with constructed IDs: {en_unique_id}, {zh_unique_id}")
|
2175 |
+
|
2176 |
+
# Check if problems exist in database
|
2177 |
+
problem_en, responses_en = db.get_problem_data(None, "EN-HARD", en_unique_id)
|
2178 |
+
problem_zh, responses_zh = db.get_problem_data(None, "ZH-HARD", zh_unique_id)
|
2179 |
+
|
2180 |
+
print(f"Test problem {test_id}:")
|
2181 |
+
print(f" EN problem exists: {problem_en is not None}")
|
2182 |
+
print(f" ZH problem exists: {problem_zh is not None}")
|
2183 |
+
if responses_en:
|
2184 |
+
print(f" EN responses: {len(responses_en)} found")
|
2185 |
+
if responses_zh:
|
2186 |
+
print(f" ZH responses: {len(responses_zh)} found")
|
2187 |
+
|
2188 |
+
print("=== End Test ===")
|
2189 |
+
return True
|
2190 |
+
|
2191 |
+
def test_latex_formatting():
|
2192 |
+
"""Test function to verify LaTeX environment processing"""
|
2193 |
+
test_text = """
|
2194 |
+
易知,1, 4, 6, 7, 9 这五个数中的任意两个数之差均不为 4 或 7.
|
2195 |
+
|
2196 |
+
$$
|
2197 |
+
\\begin{aligned}
|
2198 |
+
\\sum_{n=1}^{2023}f_{n} &= \\sum_{k=0}^{183}\\sum_{i=0}^{10}f_{11k+i} \\\\
|
2199 |
+
&= \\sum_{k=0}^{183}(11 \\times 5k+1+2+3+5 \\times 4+2 \\times 5) \\\\
|
2200 |
+
&= 55 \\times \\frac{183 \\times 184}{2}+184 \\times 36 \\\\
|
2201 |
+
&= 932604.
|
2202 |
+
\\end{aligned}
|
2203 |
+
$$
|
2204 |
+
|
2205 |
+
故答案为:$\\boxed{932604}$.
|
2206 |
+
"""
|
2207 |
+
|
2208 |
+
formatted = format_markdown_with_math(test_text)
|
2209 |
+
print("=== LaTeX Formatting Test ===")
|
2210 |
+
print("Original text contains \\begin{aligned}:", "\\begin{aligned}" in test_text)
|
2211 |
+
print("Formatted text contains \\begin{aligned}:", "\\begin{aligned}" in formatted)
|
2212 |
+
print("Formatted text (first 300 chars):", formatted[:300])
|
2213 |
+
print("=== End Test ===")
|
2214 |
+
return formatted
|
2215 |
+
|
2216 |
+
def format_solution_latex(text):
|
2217 |
+
"""Preprocess solution text by converting LaTeX delimiters from MathJax to KaTeX format"""
|
2218 |
+
if text is None:
|
2219 |
+
return ""
|
2220 |
+
|
2221 |
+
# Convert $$xxx$$ to \[xxx\] (display math)
|
2222 |
+
# Use non-greedy matching and handle multiple lines
|
2223 |
+
text = re.sub(r'\$\$(.*?)\$\$', r'\\[\1\\]', text, flags=re.DOTALL)
|
2224 |
+
|
2225 |
+
# Convert $xxx$ to \(xxx\) (inline math)
|
2226 |
+
# Be careful not to match already converted \[...\] content
|
2227 |
+
text = re.sub(r'(?<!\\)\$([^$\n]+?)\$(?!\])', r'\\(\1\\)', text)
|
2228 |
+
|
2229 |
+
# Convert newlines for markdown
|
2230 |
+
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
2231 |
+
|
2232 |
+
# Clean up excessive newlines
|
2233 |
+
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
2234 |
+
|
2235 |
+
return text
|
2236 |
+
|
2237 |
+
def format_answer_with_math(text):
|
2238 |
+
"""Special formatting for answer fields - manually wrap with \(\) delimiters"""
|
2239 |
+
if text is None or text.strip() == "" or text == "No answer available.":
|
2240 |
+
return text
|
2241 |
+
|
2242 |
+
# Convert newlines for markdown
|
2243 |
+
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
2244 |
+
|
2245 |
+
# Convert $$xxx$$ to $xxx$ first (same as before)
|
2246 |
+
text = re.sub(r'\$\$(.*?)\$\$', r'$\1$', text, flags=re.DOTALL)
|
2247 |
+
|
2248 |
+
# Check if answer already contains dollar signs, if not add them
|
2249 |
+
if '$' not in text and text.strip():
|
2250 |
+
text = f"${text}$"
|
2251 |
+
|
2252 |
+
# Now convert $xxx$ to \(xxx\) for proper rendering
|
2253 |
+
text = re.sub(r'(?<!\\)\$([^$\n]+?)\$', r'\\(\1\\)', text)
|
2254 |
+
|
2255 |
+
# Clean up excessive newlines
|
2256 |
+
text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text)
|
2257 |
+
|
2258 |
+
return text
|
2259 |
+
|
2260 |
# 修改主函数以使用优化策略
|
2261 |
if __name__ == "__main__":
|
2262 |
DB_PATH = "data.db"
|
|
|
2264 |
# 检查数据库文件是否存在,如果不存在则从 Hugging Face 下载
|
2265 |
if not os.path.exists(DB_PATH):
|
2266 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
2267 |
DB_PATH = hf_hub_download(
|
2268 |
repo_id="CoderBak/OlymMATH-data",
|
2269 |
filename="data.db",
|
2270 |
+
repo_type="dataset"
|
|
|
2271 |
)
|
2272 |
except Exception as e:
|
2273 |
# 创建一个显示错误信息的简单 Gradio 应用
|
2274 |
with gr.Blocks() as error_demo:
|
2275 |
+
gr.Markdown(f"# Error: Database Download Failed\n{str(e)}")
|
2276 |
error_demo.launch(server_name="0.0.0.0")
|
2277 |
exit(1)
|
2278 |
|