Spaces:
Sleeping
Sleeping
File size: 7,634 Bytes
0bbd367 b0ad3dc 0bbd367 b0ad3dc 0bbd367 94c2f22 b0ad3dc 0bbd367 94c2f22 0bbd367 94c2f22 0bbd367 b0ad3dc 0bbd367 94c2f22 0bbd367 b0ad3dc 94c2f22 b0ad3dc 94c2f22 b0ad3dc 0bbd367 94c2f22 0bbd367 b0ad3dc 94c2f22 0bbd367 b0ad3dc 0bbd367 b0ad3dc 0bbd367 94c2f22 0bbd367 94c2f22 0bbd367 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import gradio as gr
import plotly.graph_objects as go
import json
# Data for tabular models
TABULAR_MODEL_EVALS = {
"Proteins": {
"Nexa Bio1 (Secondary)": 0.71,
"Porter6 (Secondary)": 0.8456,
"DeepCNF (Secondary)": 0.85,
"AlphaFold2 (Tertiary GDT-TS)": 0.924,
"Nexa Bio2 (Tertiary)": 0.90,
},
"Astro": {
"Nexa Astro": 0.97,
"Baseline CNN": 0.89,
},
"Materials": {
"Nexa Materials": 0.9999,
"Random Forest Baseline": 0.92,
},
"QST": {
"Nexa PIN Model": 0.80,
"Quantum TomoNet": 0.85,
},
"HEP": {
"Nexa HEP Model": 0.91,
"CMSNet": 0.94,
},
"CFD": {
"Nexa CFD Model": 0.92,
"FlowNet": 0.89,
},
}
# Data for LLMs
LLM_MODEL_EVALS = {
"LLM (General OSIR)": {
"Nexa Mistral Sci-7B": 0.61,
"Llama-3-8B-Instruct": 0.39,
"Mixtral-8x7B-Instruct-v0.1": 0.41,
"Claude-3-Sonnet": 0.64,
"GPT-4-Turbo": 0.68,
"GPT-4o": 0.71,
},
"LLM (Field-Specific OSIR)": {
"Nexa Bio Adapter": 0.66,
"Nexa Astro Adapter": 0.70,
"GPT-4o (Biomed)": 0.69,
"Claude-3-Opus (Bio)": 0.67,
"Llama-3-8B-Bio": 0.42,
"Mixtral-8x7B-BioTune": 0.43,
},
}
# Data for Nexa Mistral Sci-7B Evaluation (based on the provided image)
NEXA_MISTRAL_EVALS = {
"Nexa Mistral Sci-7B": {
"Scientific Utility": {"OSIR (General)": 7.0, "OSIR-Field (Physics)": 8.5},
"Symbolism & Math Logic": {"OSIR (General)": 6.0, "OSIR-Field (Physics)": 7.5},
"Citation & Structure": {"OSIR (General)": 5.5, "OSIR-Field (Physics)": 6.0},
"Thematic Grounding": {"OSIR (General)": 7.0, "OSIR-Field (Physics)": 8.0},
"Hypothesis Framing": {"OSIR (General)": 6.0, "OSIR-Field (Physics)": 7.0},
"Internal Consistency": {"OSIR (General)": 9.0, "OSIR-Field (Physics)": 9.5},
"Entropy / Novelty": {"OSIR (General)": 6.5, "OSIR-Field (Physics)": 6.0},
}
}
# Universal plotting function with highlighted Nexa models
def plot_horizontal_bar(domain, data, highlight_keyword="Nexa", highlight_color='indigo', default_color='lightgray'):
sorted_items = sorted(data.items(), key=lambda x: x[1], reverse=True)
models, scores = zip(*sorted_items)
colors = [highlight_color if highlight_keyword in model else default_color for model in models]
fig = go.Figure()
fig.add_trace(go.Bar(
x=scores,
y=models,
orientation='h',
marker_color=colors,
))
fig.update_layout(
title=f"Model Benchmark Scores — {domain}",
xaxis_title="Score",
yaxis_title="Model",
xaxis_range=[0, 1.0],
template="plotly_white",
height=500,
margin=dict(l=120, r=20, t=40, b=40),
yaxis=dict(automargin=True),
)
return fig
# Plotting function for Nexa Mistral Sci-7B Evaluation
def plot_mistral_eval(metric):
if metric not in NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"]:
return None, "Invalid metric selected"
data = NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"][metric]
models = list(data.keys())
scores = list(data.values())
fig = go.Figure()
fig.add_trace(go.Bar(
x=scores,
y=models,
orientation='h',
marker_color=['yellow', 'orange'] # Matching the provided image colors
))
fig.update_layout(
title=f"Nexa Mistral Sci-7B Evaluation: {metric}",
xaxis_title="Score (1-10)",
yaxis_title="Model",
xaxis_range=[0, 10],
template="plotly_white",
height=400,
margin=dict(l=120, r=20, t=40, b=40),
yaxis=dict(automargin=True),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
)
return fig
# Display functions for each section
def display_tabular_eval(domain):
if domain not in TABULAR_MODEL_EVALS:
return None, "Invalid domain selected"
plot = plot_horizontal_bar(domain, TABULAR_MODEL_EVALS[domain], highlight_color='indigo', default_color='lightgray')
details = json.dumps(TABULAR_MODEL_EVALS[domain], indent=2)
return plot, details
def display_llm_eval(domain):
if domain not in LLM_MODEL_EVALS:
return None, "Invalid domain selected"
plot = plot_horizontal_bar(domain, LLM_MODEL_EVALS[domain], highlight_color='lightblue', default_color='gray')
details = json.dumps(LLM_MODEL_EVALS[domain], indent=2)
return plot, details
def display_mistral_eval(metric):
plot = plot_mistral_eval(metric)
details = json.dumps(NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"][metric], indent=2)
return plot, details
# Gradio interface with improved styling
with gr.Blocks(css="body {font-family: 'Inter', sans-serif; background-color: #f0f0f0; color: #333;}") as demo:
gr.Markdown("""
# 🔬 Nexa Evals — Scientific ML Benchmark Suite
A comprehensive benchmarking suite comparing Nexa models against state-of-the-art models.
""")
with gr.Tabs():
with gr.TabItem("Tabular Models"):
with gr.Row():
tabular_domain = gr.Dropdown(
choices=list(TABULAR_MODEL_EVALS.keys()),
label="Select Domain",
value="Proteins"
)
show_tabular_btn = gr.Button("Show Evaluation")
tabular_plot = gr.Plot(label="Benchmark Plot")
tabular_details = gr.Code(label="Raw Scores (JSON)", language="json")
show_tabular_btn.click(
fn=display_tabular_eval,
inputs=tabular_domain,
outputs=[tabular_plot, tabular_details]
)
with gr.TabItem("LLMs"):
with gr.Row():
llm_domain = gr.Dropdown(
choices=list(LLM_MODEL_EVALS.keys()),
label="Select Domain",
value="LLM (General OSIR)"
)
show_llm_btn = gr.Button("Show Evaluation")
llm_plot = gr.Plot(label="Benchmark Plot")
llm_details = gr.Code(label="Raw Scores (JSON)", language="json")
show_llm_btn.click(
fn=display_llm_eval,
inputs=llm_domain,
outputs=[llm_plot, llm_details]
)
with gr.TabItem("Nexa Mistral Sci-7B"):
with gr.Row():
mistral_metric = gr.Dropdown(
choices=list(NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"].keys()),
label="Select Metric",
value="Scientific Utility"
)
show_mistral_btn = gr.Button("Show Evaluation")
mistral_plot = gr.Plot(label="Benchmark Plot")
mistral_details = gr.Code(label="Raw Scores (JSON)", language="json")
show_mistral_btn.click(
fn=display_mistral_eval,
inputs=mistral_metric,
outputs=[mistral_plot, mistral_details]
)
gr.Markdown("""
---
### ℹ️ About
Nexa Evals provides benchmarks for tabular models, language models, and specific evaluations like Nexa Mistral Sci-7B:
- **Tabular Models**: Evaluated on domain-specific metrics across fields like Proteins and Astro.
- **LLMs**: Assessed using the SciEval benchmark under the OSIR initiative.
- **Nexa Mistral Sci-7B**: Compares general (OSIR) and physics-specific (OSIR-Field) performance across multiple metrics.
Scores are normalized where applicable (0-1 for tabular/LLMs, 1-10 for Mistral).
""")
demo.launch()
|