NexaEvals / app.py
Allanatrix's picture
Update app.py
94c2f22 verified
raw
history blame
7.63 kB
import gradio as gr
import plotly.graph_objects as go
import json
# Data for tabular models
TABULAR_MODEL_EVALS = {
"Proteins": {
"Nexa Bio1 (Secondary)": 0.71,
"Porter6 (Secondary)": 0.8456,
"DeepCNF (Secondary)": 0.85,
"AlphaFold2 (Tertiary GDT-TS)": 0.924,
"Nexa Bio2 (Tertiary)": 0.90,
},
"Astro": {
"Nexa Astro": 0.97,
"Baseline CNN": 0.89,
},
"Materials": {
"Nexa Materials": 0.9999,
"Random Forest Baseline": 0.92,
},
"QST": {
"Nexa PIN Model": 0.80,
"Quantum TomoNet": 0.85,
},
"HEP": {
"Nexa HEP Model": 0.91,
"CMSNet": 0.94,
},
"CFD": {
"Nexa CFD Model": 0.92,
"FlowNet": 0.89,
},
}
# Data for LLMs
LLM_MODEL_EVALS = {
"LLM (General OSIR)": {
"Nexa Mistral Sci-7B": 0.61,
"Llama-3-8B-Instruct": 0.39,
"Mixtral-8x7B-Instruct-v0.1": 0.41,
"Claude-3-Sonnet": 0.64,
"GPT-4-Turbo": 0.68,
"GPT-4o": 0.71,
},
"LLM (Field-Specific OSIR)": {
"Nexa Bio Adapter": 0.66,
"Nexa Astro Adapter": 0.70,
"GPT-4o (Biomed)": 0.69,
"Claude-3-Opus (Bio)": 0.67,
"Llama-3-8B-Bio": 0.42,
"Mixtral-8x7B-BioTune": 0.43,
},
}
# Data for Nexa Mistral Sci-7B Evaluation (based on the provided image)
NEXA_MISTRAL_EVALS = {
"Nexa Mistral Sci-7B": {
"Scientific Utility": {"OSIR (General)": 7.0, "OSIR-Field (Physics)": 8.5},
"Symbolism & Math Logic": {"OSIR (General)": 6.0, "OSIR-Field (Physics)": 7.5},
"Citation & Structure": {"OSIR (General)": 5.5, "OSIR-Field (Physics)": 6.0},
"Thematic Grounding": {"OSIR (General)": 7.0, "OSIR-Field (Physics)": 8.0},
"Hypothesis Framing": {"OSIR (General)": 6.0, "OSIR-Field (Physics)": 7.0},
"Internal Consistency": {"OSIR (General)": 9.0, "OSIR-Field (Physics)": 9.5},
"Entropy / Novelty": {"OSIR (General)": 6.5, "OSIR-Field (Physics)": 6.0},
}
}
# Universal plotting function with highlighted Nexa models
def plot_horizontal_bar(domain, data, highlight_keyword="Nexa", highlight_color='indigo', default_color='lightgray'):
sorted_items = sorted(data.items(), key=lambda x: x[1], reverse=True)
models, scores = zip(*sorted_items)
colors = [highlight_color if highlight_keyword in model else default_color for model in models]
fig = go.Figure()
fig.add_trace(go.Bar(
x=scores,
y=models,
orientation='h',
marker_color=colors,
))
fig.update_layout(
title=f"Model Benchmark Scores β€” {domain}",
xaxis_title="Score",
yaxis_title="Model",
xaxis_range=[0, 1.0],
template="plotly_white",
height=500,
margin=dict(l=120, r=20, t=40, b=40),
yaxis=dict(automargin=True),
)
return fig
# Plotting function for Nexa Mistral Sci-7B Evaluation
def plot_mistral_eval(metric):
if metric not in NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"]:
return None, "Invalid metric selected"
data = NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"][metric]
models = list(data.keys())
scores = list(data.values())
fig = go.Figure()
fig.add_trace(go.Bar(
x=scores,
y=models,
orientation='h',
marker_color=['yellow', 'orange'] # Matching the provided image colors
))
fig.update_layout(
title=f"Nexa Mistral Sci-7B Evaluation: {metric}",
xaxis_title="Score (1-10)",
yaxis_title="Model",
xaxis_range=[0, 10],
template="plotly_white",
height=400,
margin=dict(l=120, r=20, t=40, b=40),
yaxis=dict(automargin=True),
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
)
return fig
# Display functions for each section
def display_tabular_eval(domain):
if domain not in TABULAR_MODEL_EVALS:
return None, "Invalid domain selected"
plot = plot_horizontal_bar(domain, TABULAR_MODEL_EVALS[domain], highlight_color='indigo', default_color='lightgray')
details = json.dumps(TABULAR_MODEL_EVALS[domain], indent=2)
return plot, details
def display_llm_eval(domain):
if domain not in LLM_MODEL_EVALS:
return None, "Invalid domain selected"
plot = plot_horizontal_bar(domain, LLM_MODEL_EVALS[domain], highlight_color='lightblue', default_color='gray')
details = json.dumps(LLM_MODEL_EVALS[domain], indent=2)
return plot, details
def display_mistral_eval(metric):
plot = plot_mistral_eval(metric)
details = json.dumps(NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"][metric], indent=2)
return plot, details
# Gradio interface with improved styling
with gr.Blocks(css="body {font-family: 'Inter', sans-serif; background-color: #f0f0f0; color: #333;}") as demo:
gr.Markdown("""
# πŸ”¬ Nexa Evals β€” Scientific ML Benchmark Suite
A comprehensive benchmarking suite comparing Nexa models against state-of-the-art models.
""")
with gr.Tabs():
with gr.TabItem("Tabular Models"):
with gr.Row():
tabular_domain = gr.Dropdown(
choices=list(TABULAR_MODEL_EVALS.keys()),
label="Select Domain",
value="Proteins"
)
show_tabular_btn = gr.Button("Show Evaluation")
tabular_plot = gr.Plot(label="Benchmark Plot")
tabular_details = gr.Code(label="Raw Scores (JSON)", language="json")
show_tabular_btn.click(
fn=display_tabular_eval,
inputs=tabular_domain,
outputs=[tabular_plot, tabular_details]
)
with gr.TabItem("LLMs"):
with gr.Row():
llm_domain = gr.Dropdown(
choices=list(LLM_MODEL_EVALS.keys()),
label="Select Domain",
value="LLM (General OSIR)"
)
show_llm_btn = gr.Button("Show Evaluation")
llm_plot = gr.Plot(label="Benchmark Plot")
llm_details = gr.Code(label="Raw Scores (JSON)", language="json")
show_llm_btn.click(
fn=display_llm_eval,
inputs=llm_domain,
outputs=[llm_plot, llm_details]
)
with gr.TabItem("Nexa Mistral Sci-7B"):
with gr.Row():
mistral_metric = gr.Dropdown(
choices=list(NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"].keys()),
label="Select Metric",
value="Scientific Utility"
)
show_mistral_btn = gr.Button("Show Evaluation")
mistral_plot = gr.Plot(label="Benchmark Plot")
mistral_details = gr.Code(label="Raw Scores (JSON)", language="json")
show_mistral_btn.click(
fn=display_mistral_eval,
inputs=mistral_metric,
outputs=[mistral_plot, mistral_details]
)
gr.Markdown("""
---
### ℹ️ About
Nexa Evals provides benchmarks for tabular models, language models, and specific evaluations like Nexa Mistral Sci-7B:
- **Tabular Models**: Evaluated on domain-specific metrics across fields like Proteins and Astro.
- **LLMs**: Assessed using the SciEval benchmark under the OSIR initiative.
- **Nexa Mistral Sci-7B**: Compares general (OSIR) and physics-specific (OSIR-Field) performance across multiple metrics.
Scores are normalized where applicable (0-1 for tabular/LLMs, 1-10 for Mistral).
""")
demo.launch()