Spaces:
Running
Running
import gradio as gr | |
import plotly.graph_objects as go | |
import json | |
# Data for tabular models | |
TABULAR_MODEL_EVALS = { | |
"Proteins": { | |
"Nexa Bio1 (Secondary)": 0.71, | |
"Porter6 (Secondary)": 0.8456, | |
"DeepCNF (Secondary)": 0.85, | |
"AlphaFold2 (Tertiary GDT-TS)": 0.924, | |
"Nexa Bio2 (Tertiary)": 0.90, | |
}, | |
"Astro": { | |
"Nexa Astro": 0.97, | |
"Baseline CNN": 0.89, | |
}, | |
"Materials": { | |
"Nexa Materials": 0.9999, | |
"Random Forest Baseline": 0.92, | |
}, | |
"QST": { | |
"Nexa PIN Model": 0.80, | |
"Quantum TomoNet": 0.85, | |
}, | |
"HEP": { | |
"Nexa HEP Model": 0.91, | |
"CMSNet": 0.94, | |
}, | |
"CFD": { | |
"Nexa CFD Model": 0.92, | |
"FlowNet": 0.89, | |
}, | |
} | |
# Data for LLMs | |
LLM_MODEL_EVALS = { | |
"LLM (General OSIR)": { | |
"Nexa Mistral Sci-7B": 0.61, | |
"Llama-3-8B-Instruct": 0.39, | |
"Mixtral-8x7B-Instruct-v0.1": 0.41, | |
"Claude-3-Sonnet": 0.64, | |
"GPT-4-Turbo": 0.68, | |
"GPT-4o": 0.71, | |
}, | |
"LLM (Field-Specific OSIR)": { | |
"Nexa Bio Adapter": 0.66, | |
"Nexa Astro Adapter": 0.70, | |
"GPT-4o (Biomed)": 0.69, | |
"Claude-3-Opus (Bio)": 0.67, | |
"Llama-3-8B-Bio": 0.42, | |
"Mixtral-8x7B-BioTune": 0.43, | |
}, | |
} | |
# Data for Nexa Mistral Sci-7B Evaluation (based on the provided image) | |
NEXA_MISTRAL_EVALS = { | |
"Nexa Mistral Sci-7B": { | |
"Scientific Utility": {"OSIR (General)": 7.0, "OSIR-Field (Physics)": 8.5}, | |
"Symbolism & Math Logic": {"OSIR (General)": 6.0, "OSIR-Field (Physics)": 7.5}, | |
"Citation & Structure": {"OSIR (General)": 5.5, "OSIR-Field (Physics)": 6.0}, | |
"Thematic Grounding": {"OSIR (General)": 7.0, "OSIR-Field (Physics)": 8.0}, | |
"Hypothesis Framing": {"OSIR (General)": 6.0, "OSIR-Field (Physics)": 7.0}, | |
"Internal Consistency": {"OSIR (General)": 9.0, "OSIR-Field (Physics)": 9.5}, | |
"Entropy / Novelty": {"OSIR (General)": 6.5, "OSIR-Field (Physics)": 6.0}, | |
} | |
} | |
# Universal plotting function with highlighted Nexa models | |
def plot_horizontal_bar(domain, data, highlight_keyword="Nexa", highlight_color='indigo', default_color='lightgray'): | |
sorted_items = sorted(data.items(), key=lambda x: x[1], reverse=True) | |
models, scores = zip(*sorted_items) | |
colors = [highlight_color if highlight_keyword in model else default_color for model in models] | |
fig = go.Figure() | |
fig.add_trace(go.Bar( | |
x=scores, | |
y=models, | |
orientation='h', | |
marker_color=colors, | |
)) | |
fig.update_layout( | |
title=f"Model Benchmark Scores β {domain}", | |
xaxis_title="Score", | |
yaxis_title="Model", | |
xaxis_range=[0, 1.0], | |
template="plotly_white", | |
height=500, | |
margin=dict(l=120, r=20, t=40, b=40), | |
yaxis=dict(automargin=True), | |
) | |
return fig | |
# Plotting function for Nexa Mistral Sci-7B Evaluation | |
def plot_mistral_eval(metric): | |
if metric not in NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"]: | |
return None, "Invalid metric selected" | |
data = NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"][metric] | |
models = list(data.keys()) | |
scores = list(data.values()) | |
fig = go.Figure() | |
fig.add_trace(go.Bar( | |
x=scores, | |
y=models, | |
orientation='h', | |
marker_color=['yellow', 'orange'] # Matching the provided image colors | |
)) | |
fig.update_layout( | |
title=f"Nexa Mistral Sci-7B Evaluation: {metric}", | |
xaxis_title="Score (1-10)", | |
yaxis_title="Model", | |
xaxis_range=[0, 10], | |
template="plotly_white", | |
height=400, | |
margin=dict(l=120, r=20, t=40, b=40), | |
yaxis=dict(automargin=True), | |
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1) | |
) | |
return fig | |
# Display functions for each section | |
def display_tabular_eval(domain): | |
if domain not in TABULAR_MODEL_EVALS: | |
return None, "Invalid domain selected" | |
plot = plot_horizontal_bar(domain, TABULAR_MODEL_EVALS[domain], highlight_color='indigo', default_color='lightgray') | |
details = json.dumps(TABULAR_MODEL_EVALS[domain], indent=2) | |
return plot, details | |
def display_llm_eval(domain): | |
if domain not in LLM_MODEL_EVALS: | |
return None, "Invalid domain selected" | |
plot = plot_horizontal_bar(domain, LLM_MODEL_EVALS[domain], highlight_color='lightblue', default_color='gray') | |
details = json.dumps(LLM_MODEL_EVALS[domain], indent=2) | |
return plot, details | |
def display_mistral_eval(metric): | |
plot = plot_mistral_eval(metric) | |
details = json.dumps(NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"][metric], indent=2) | |
return plot, details | |
# Gradio interface with improved styling | |
with gr.Blocks(css="body {font-family: 'Inter', sans-serif; background-color: #f0f0f0; color: #333;}") as demo: | |
gr.Markdown(""" | |
# π¬ Nexa Evals β Scientific ML Benchmark Suite | |
A comprehensive benchmarking suite comparing Nexa models against state-of-the-art models. | |
""") | |
with gr.Tabs(): | |
with gr.TabItem("Tabular Models"): | |
with gr.Row(): | |
tabular_domain = gr.Dropdown( | |
choices=list(TABULAR_MODEL_EVALS.keys()), | |
label="Select Domain", | |
value="Proteins" | |
) | |
show_tabular_btn = gr.Button("Show Evaluation") | |
tabular_plot = gr.Plot(label="Benchmark Plot") | |
tabular_details = gr.Code(label="Raw Scores (JSON)", language="json") | |
show_tabular_btn.click( | |
fn=display_tabular_eval, | |
inputs=tabular_domain, | |
outputs=[tabular_plot, tabular_details] | |
) | |
with gr.TabItem("LLMs"): | |
with gr.Row(): | |
llm_domain = gr.Dropdown( | |
choices=list(LLM_MODEL_EVALS.keys()), | |
label="Select Domain", | |
value="LLM (General OSIR)" | |
) | |
show_llm_btn = gr.Button("Show Evaluation") | |
llm_plot = gr.Plot(label="Benchmark Plot") | |
llm_details = gr.Code(label="Raw Scores (JSON)", language="json") | |
show_llm_btn.click( | |
fn=display_llm_eval, | |
inputs=llm_domain, | |
outputs=[llm_plot, llm_details] | |
) | |
with gr.TabItem("Nexa Mistral Sci-7B"): | |
with gr.Row(): | |
mistral_metric = gr.Dropdown( | |
choices=list(NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"].keys()), | |
label="Select Metric", | |
value="Scientific Utility" | |
) | |
show_mistral_btn = gr.Button("Show Evaluation") | |
mistral_plot = gr.Plot(label="Benchmark Plot") | |
mistral_details = gr.Code(label="Raw Scores (JSON)", language="json") | |
show_mistral_btn.click( | |
fn=display_mistral_eval, | |
inputs=mistral_metric, | |
outputs=[mistral_plot, mistral_details] | |
) | |
gr.Markdown(""" | |
--- | |
### βΉοΈ About | |
Nexa Evals provides benchmarks for tabular models, language models, and specific evaluations like Nexa Mistral Sci-7B: | |
- **Tabular Models**: Evaluated on domain-specific metrics across fields like Proteins and Astro. | |
- **LLMs**: Assessed using the SciEval benchmark under the OSIR initiative. | |
- **Nexa Mistral Sci-7B**: Compares general (OSIR) and physics-specific (OSIR-Field) performance across multiple metrics. | |
Scores are normalized where applicable (0-1 for tabular/LLMs, 1-10 for Mistral). | |
""") | |
demo.launch() | |