NexaEvals / app.py
Allanatrix's picture
Update app.py
b0ad3dc verified
raw
history blame
4.9 kB
import gradio as gr
import plotly.graph_objects as go
import json
# Data for tabular models
TABULAR_MODEL_EVALS = {
"Proteins": {
"Nexa Bio1 (Secondary)": 0.71,
"Porter6 (Secondary)": 0.8456,
"DeepCNF (Secondary)": 0.85,
"AlphaFold2 (Tertiary GDT-TS)": 0.924,
"Nexa Bio2 (Tertiary)": 0.90,
},
"Astro": {
"Nexa Astro": 0.97,
"Baseline CNN": 0.89,
},
"Materials": {
"Nexa Materials": 0.9999,
"Random Forest Baseline": 0.92,
},
"QST": {
"Nexa PIN Model": 0.80,
"Quantum TomoNet": 0.85,
},
"HEP": {
"Nexa HEP Model": 0.91,
"CMSNet": 0.94,
},
"CFD": {
"Nexa CFD Model": 0.92,
"FlowNet": 0.89,
},
}
# Data for LLMs
LLM_MODEL_EVALS = {
"LLM (General OSIR)": {
"Nexa Mistral Sci-7B": 0.61,
"Llama-3-8B-Instruct": 0.39,
"Mixtral-8x7B-Instruct-v0.1": 0.41,
"Claude-3-Sonnet": 0.64,
"GPT-4-Turbo": 0.68,
"GPT-4o": 0.71,
},
"LLM (Field-Specific OSIR)": {
"Nexa Bio Adapter": 0.66,
"Nexa Astro Adapter": 0.70,
"GPT-4o (Biomed)": 0.69,
"Claude-3-Opus (Bio)": 0.67,
"Llama-3-8B-Bio": 0.42,
"Mixtral-8x7B-BioTune": 0.43,
},
}
# Universal plotting function for horizontal bar charts
def plot_horizontal_bar(domain, data, color):
sorted_items = sorted(data.items(), key=lambda x: x[1], reverse=True)
models, scores = zip(*sorted_items)
fig = go.Figure()
fig.add_trace(go.Bar(
x=scores,
y=models,
orientation='h',
marker_color=color,
))
fig.update_layout(
title=f"Model Benchmark Scores β€” {domain}",
xaxis_title="Score",
yaxis_title="Model",
xaxis_range=[0, 1.0],
template="plotly_white",
height=500,
margin=dict(l=120, r=20, t=40, b=40),
)
return fig
# Display functions for each section
def display_tabular_eval(domain):
if domain not in TABULAR_MODEL_EVALS:
return None, "Invalid domain selected"
plot = plot_horizontal_bar(domain, TABULAR_MODEL_EVALS[domain], 'indigo')
details = json.dumps(TABULAR_MODEL_EVALS[domain], indent=2)
return plot, details
def display_llm_eval(domain):
if domain not in LLM_MODEL_EVALS:
return None, "Invalid domain selected"
plot = plot_horizontal_bar(domain, LLM_MODEL_EVALS[domain], 'lightblue')
details = json.dumps(LLM_MODEL_EVALS[domain], indent=2)
return plot, details
# Gradio interface
with gr.Blocks(css="body {font-family: 'Inter', sans-serif; background-color: #fafafa;}") as demo:
gr.Markdown("""
# πŸ”¬ Nexa Evals β€” Scientific ML Benchmark Suite
A comprehensive benchmarking suite comparing Nexa models against state-of-the-art models across scientific domains and language models.
""")
with gr.Tabs():
with gr.TabItem("Tabular Models"):
with gr.Row():
tabular_domain = gr.Dropdown(
choices=list(TABULAR_MODEL_EVALS.keys()),
label="Select Domain",
value="Proteins"
)
show_tabular_btn = gr.Button("Show Evaluation")
tabular_plot = gr.Plot(label="Benchmark Plot")
tabular_details = gr.Code(label="Raw Scores (JSON)", language="json")
show_tabular_btn.click(
fn=display_tabular_eval,
inputs=tabular_domain,
outputs=[tabular_plot, tabular_details]
)
with gr.TabItem("LLMs"):
with gr.Row():
llm_domain = gr.Dropdown(
choices=list(LLM_MODEL_EVALS.keys()),
label="Select Domain",
value="LLM (General OSIR)"
)
show_llm_btn = gr.Button("Show Evaluation")
llm_plot = gr.Plot(label="Benchmark Plot")
llm_details = gr.Code(label="Raw Scores (JSON)", language="json")
show_llm_btn.click(
fn=display_llm_eval,
inputs=llm_domain,
outputs=[llm_plot, llm_details]
)
gr.Markdown("""
---
### ℹ️ About
Nexa Evals provides benchmarks for both tabular models and language models in scientific domains:
- **Tabular Models**: Evaluated on domain-specific metrics (e.g., accuracy, GDT-TS) across fields like Proteins, Astro, Materials, QST, HEP, and CFD.
- **Language Models**: Assessed using the SciEval benchmark under the OSIR initiative, focusing on scientific utility, information entropy, internal consistency, hypothesis framing, domain grounding, and math logic.
Scores range from 0 to 1, with higher values indicating better performance. Models are sorted by score in descending order for easy comparison.
""")
demo.launch()