Allanatrix commited on
Commit
08d1f1b
·
verified ·
1 Parent(s): fb054fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -123
app.py CHANGED
@@ -1,59 +1,39 @@
1
  import gradio as gr
2
- import plotly.graph_objects as go
3
- import json
4
 
5
- # Data for tabular models
6
  TABULAR_MODEL_EVALS = {
7
  "Proteins": {
8
- "Nexa Bio1 (Secondary)": 0.71,
9
- "Porter6 (Secondary)": 0.8456,
10
- "DeepCNF (Secondary)": 0.85,
11
- "AlphaFold2 (Tertiary GDT-TS)": 0.924,
12
- "Nexa Bio2 (Tertiary)": 0.90,
13
  },
14
  "Astro": {
15
- "Nexa Astro": 0.97,
16
- "Baseline CNN": 0.89,
17
  },
18
  "Materials": {
19
- "Nexa Materials": 0.9999,
20
- "Random Forest Baseline": 0.92,
21
  },
22
  "QST": {
23
- "Nexa PIN Model": 0.80,
24
- "Quantum TomoNet": 0.85,
25
  },
26
  "HEP": {
27
- "Nexa HEP Model": 0.91,
28
- "CMSNet": 0.94,
29
  },
30
  "CFD": {
31
- "Nexa CFD Model": 0.92,
32
- "FlowNet": 0.89,
33
  },
34
  }
35
 
36
- # Data for LLMs
37
- LLM_MODEL_EVALS = {
38
- "LLM (General OSIR)": {
39
- "Nexa Mistral Sci-7B": 0.61,
40
- "Llama-3-8B-Instruct": 0.39,
41
- "Mixtral-8x7B-Instruct-v0.1": 0.41,
42
- "Claude-3-Sonnet": 0.64,
43
- "GPT-4-Turbo": 0.68,
44
- "GPT-4o": 0.71,
45
- },
46
- "LLM (Field-Specific OSIR)": {
47
- "Nexa Bio Adapter": 0.66,
48
- "Nexa Astro Adapter": 0.70,
49
- "GPT-4o (Biomed)": 0.69,
50
- "Claude-3-Opus (Bio)": 0.67,
51
- "Llama-3-8B-Bio": 0.42,
52
- "Mixtral-8x7B-BioTune": 0.43,
53
- },
54
- }
55
-
56
- # Data for Nexa Mistral Sci-7B Evaluation (based on the provided image)
57
  NEXA_MISTRAL_EVALS = {
58
  "Nexa Mistral Sci-7B": {
59
  "Scientific Utility": {"OSIR (General)": 7.0, "OSIR-Field (Physics)": 8.5},
@@ -66,86 +46,55 @@ NEXA_MISTRAL_EVALS = {
66
  }
67
  }
68
 
69
- # Universal plotting function with highlighted Nexa models
70
- def plot_horizontal_bar(domain, data, highlight_keyword="Nexa", highlight_color='indigo', default_color='lightgray'):
71
- sorted_items = sorted(data.items(), key=lambda x: x[1], reverse=True)
72
- models, scores = zip(*sorted_items)
73
- colors = [highlight_color if highlight_keyword in model else default_color for model in models]
74
-
75
- fig = go.Figure()
76
- fig.add_trace(go.Bar(
77
- x=scores,
78
- y=models,
79
- orientation='h',
80
- marker_color=colors,
81
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- fig.update_layout(
84
- title=f"Model Benchmark Scores — {domain}",
85
- xaxis_title="Score",
86
- yaxis_title="Model",
87
- xaxis_range=[0, 1.0],
88
- template="plotly_white",
89
- height=500,
90
- margin=dict(l=120, r=20, t=40, b=40),
91
- yaxis=dict(automargin=True),
92
- )
93
  return fig
94
 
95
- # Plotting function for Nexa Mistral Sci-7B Evaluation
96
- def plot_mistral_eval(metric):
97
- if metric not in NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"]:
98
- return None, "Invalid metric selected"
99
- data = NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"][metric]
100
- models = list(data.keys())
101
- scores = list(data.values())
102
-
103
- fig = go.Figure()
104
- fig.add_trace(go.Bar(
105
- x=scores,
106
- y=models,
107
- orientation='h',
108
- marker_color=['yellow', 'orange'] # Matching the provided image colors
109
- ))
110
-
111
- fig.update_layout(
112
- title=f"Nexa Mistral Sci-7B Evaluation: {metric}",
113
- xaxis_title="Score (1-10)",
114
- yaxis_title="Model",
115
- xaxis_range=[0, 10],
116
- template="plotly_black",
117
- height=400,
118
- margin=dict(l=120, r=20, t=40, b=40),
119
- yaxis=dict(automargin=True),
120
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
121
- )
122
- return fig
123
-
124
- # Display functions for each section
125
  def display_tabular_eval(domain):
126
- if domain not in TABULAR_MODEL_EVALS:
127
- return None, "Invalid domain selected"
128
- plot = plot_horizontal_bar(domain, TABULAR_MODEL_EVALS[domain], highlight_color='indigo', default_color='lightgray')
129
- details = json.dumps(TABULAR_MODEL_EVALS[domain], indent=2)
130
- return plot, details
131
 
132
  def display_llm_eval(domain):
133
- if domain not in LLM_MODEL_EVALS:
134
- return None, "Invalid domain selected"
135
- plot = plot_horizontal_bar(domain, LLM_MODEL_EVALS[domain], highlight_color='lightblue', default_color='gray')
136
- details = json.dumps(LLM_MODEL_EVALS[domain], indent=2)
137
- return plot, details
138
 
139
  def display_mistral_eval(metric):
140
- plot = plot_mistral_eval(metric)
141
- details = json.dumps(NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"][metric], indent=2)
142
- return plot, details
143
 
144
- # Gradio interface with improved styling
145
- with gr.Blocks(css="body {font-family: 'Inter', sans-serif; background-color: #f0f0f0; color: #333;}") as demo:
146
  gr.Markdown("""
147
  # 🔬 Nexa Evals — Scientific ML Benchmark Suite
148
- A comprehensive benchmarking suite comparing Nexa models against state-of-the-art models.
149
  """)
150
 
151
  with gr.Tabs():
@@ -158,11 +107,10 @@ with gr.Blocks(css="body {font-family: 'Inter', sans-serif; background-color: #f
158
  )
159
  show_tabular_btn = gr.Button("Show Evaluation")
160
  tabular_plot = gr.Plot(label="Benchmark Plot")
161
- tabular_details = gr.Code(label="Raw Scores (JSON)", language="json")
162
  show_tabular_btn.click(
163
  fn=display_tabular_eval,
164
  inputs=tabular_domain,
165
- outputs=[tabular_plot, tabular_details]
166
  )
167
 
168
  with gr.TabItem("LLMs"):
@@ -174,11 +122,10 @@ with gr.Blocks(css="body {font-family: 'Inter', sans-serif; background-color: #f
174
  )
175
  show_llm_btn = gr.Button("Show Evaluation")
176
  llm_plot = gr.Plot(label="Benchmark Plot")
177
- llm_details = gr.Code(label="Raw Scores (JSON)", language="json")
178
  show_llm_btn.click(
179
  fn=display_llm_eval,
180
  inputs=llm_domain,
181
- outputs=[llm_plot, llm_details]
182
  )
183
 
184
  with gr.TabItem("Nexa Mistral Sci-7B"):
@@ -190,21 +137,20 @@ with gr.Blocks(css="body {font-family: 'Inter', sans-serif; background-color: #f
190
  )
191
  show_mistral_btn = gr.Button("Show Evaluation")
192
  mistral_plot = gr.Plot(label="Benchmark Plot")
193
- mistral_details = gr.Code(label="Raw Scores (JSON)", language="json")
194
  show_mistral_btn.click(
195
  fn=display_mistral_eval,
196
  inputs=mistral_metric,
197
- outputs=[mistral_plot, mistral_details]
198
  )
199
 
200
- gr.Markdown("""
201
- ---
202
- ### ℹ️ About
203
- Nexa Evals provides benchmarks for tabular models, language models, and specific evaluations like Nexa Mistral Sci-7B:
204
- - **Tabular Models**: Evaluated on domain-specific metrics across fields like Proteins and Astro.
205
- - **LLMs**: Assessed using the SciEval benchmark under the OSIR initiative.
206
- - **Nexa Mistral Sci-7B**: Compares general (OSIR) and physics-specific (OSIR-Field) performance across multiple metrics.
207
- Scores are normalized where applicable (0-1 for tabular/LLMs, 1-10 for Mistral).
208
- """)
209
 
210
  demo.launch()
 
1
  import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
 
5
+ # Data for Tabular Models (normalized to 0-10 from original 0-1 data)
6
  TABULAR_MODEL_EVALS = {
7
  "Proteins": {
8
+ "Nexa Bio1 (Secondary)": 7.1,
9
+ "Porter6 (Secondary)": 8.5,
10
+ "DeepCNF (Secondary)": 8.5,
11
+ "AlphaFold2 (Tertiary GDT-TS)": 9.2,
12
+ "Nexa Bio2 (Tertiary)": 9.0,
13
  },
14
  "Astro": {
15
+ "Nexa Astro": 9.7,
16
+ "Baseline CNN": 8.9,
17
  },
18
  "Materials": {
19
+ "Nexa Materials": 10.0,
20
+ "Random Forest Baseline": 9.2,
21
  },
22
  "QST": {
23
+ "Nexa PIN Model": 8.0,
24
+ "Quantum TomoNet": 8.5,
25
  },
26
  "HEP": {
27
+ "Nexa HEP Model": 9.1,
28
+ "CMSNet": 9.4,
29
  },
30
  "CFD": {
31
+ "Nexa CFD Model": 9.2,
32
+ "FlowNet": 8.9,
33
  },
34
  }
35
 
36
+ # Data for Nexa Mistral Sci-7B Evaluation (from your image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  NEXA_MISTRAL_EVALS = {
38
  "Nexa Mistral Sci-7B": {
39
  "Scientific Utility": {"OSIR (General)": 7.0, "OSIR-Field (Physics)": 8.5},
 
46
  }
47
  }
48
 
49
+ # Plotting function using Matplotlib
50
+ def plot_comparison(domain, data_type):
51
+ if data_type == "mistral":
52
+ metric = domain
53
+ data = NEXA_MISTRAL_EVALS["Nexa Mistral Sci-7B"][metric]
54
+ models = list(data.keys())
55
+ scores = list(data.values())
56
+ fig, ax = plt.subplots(figsize=(8, 6), facecolor='#e0e0e0')
57
+ y_pos = np.arange(len(models))
58
+ width = 0.35
59
+ ax.barh(y_pos - width/2, scores[:1], width, label=models[0], color='yellow')
60
+ ax.barh(y_pos + width/2, scores[1:], width, label=models[1], color='orange')
61
+ else:
62
+ data = TABULAR_MODEL_EVALS[domain] if data_type == "tabular" else LLM_MODEL_EVALS[domain]
63
+ models = list(data.keys())
64
+ scores = list(data.values())
65
+ fig, ax = plt.subplots(figsize=(8, 6), facecolor='#e0e0e0')
66
+ y_pos = np.arange(len(models))
67
+ width = 0.8
68
+ colors = ['indigo' if 'Nexa' in model else 'lightgray' if data_type == "tabular" else 'gray' for model in models]
69
+ ax.barh(y_pos, scores, width, color=colors)
70
+
71
+ ax.set_yticks(y_pos)
72
+ ax.set_yticklabels(models)
73
+ ax.set_xlabel('Score (1-10)')
74
+ ax.set_title(f"{('Nexa Mistral Sci-7B Evaluation: ' if data_type == 'mistral' else '')}{domain}")
75
+ ax.set_xlim(0, 10)
76
+ if data_type == "mistral":
77
+ ax.legend()
78
+ ax.grid(True, axis='x', linestyle='--', alpha=0.7)
79
+ plt.tight_layout()
80
 
 
 
 
 
 
 
 
 
 
 
81
  return fig
82
 
83
+ # Display functions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def display_tabular_eval(domain):
85
+ return plot_comparison(domain, "tabular")
 
 
 
 
86
 
87
  def display_llm_eval(domain):
88
+ return plot_comparison(domain, "llm")
 
 
 
 
89
 
90
  def display_mistral_eval(metric):
91
+ return plot_comparison(metric, "mistral")
 
 
92
 
93
+ # Gradio interface
94
+ with gr.Blocks(css="body {font-family: 'Inter', sans-serif; background-color: #e0e0e0; color: #333;}") as demo:
95
  gr.Markdown("""
96
  # 🔬 Nexa Evals — Scientific ML Benchmark Suite
97
+ A benchmarking suite for Nexa models across various domains.
98
  """)
99
 
100
  with gr.Tabs():
 
107
  )
108
  show_tabular_btn = gr.Button("Show Evaluation")
109
  tabular_plot = gr.Plot(label="Benchmark Plot")
 
110
  show_tabular_btn.click(
111
  fn=display_tabular_eval,
112
  inputs=tabular_domain,
113
+ outputs=tabular_plot
114
  )
115
 
116
  with gr.TabItem("LLMs"):
 
122
  )
123
  show_llm_btn = gr.Button("Show Evaluation")
124
  llm_plot = gr.Plot(label="Benchmark Plot")
 
125
  show_llm_btn.click(
126
  fn=display_llm_eval,
127
  inputs=llm_domain,
128
+ outputs=llm_plot
129
  )
130
 
131
  with gr.TabItem("Nexa Mistral Sci-7B"):
 
137
  )
138
  show_mistral_btn = gr.Button("Show Evaluation")
139
  mistral_plot = gr.Plot(label="Benchmark Plot")
 
140
  show_mistral_btn.click(
141
  fn=display_mistral_eval,
142
  inputs=mistral_metric,
143
+ outputs=mistral_plot
144
  )
145
 
146
+ with gr.TabItem("About"):
147
+ gr.Markdown("""
148
+ # ℹ️ About Nexa Evals
149
+ Nexa Evals benchmarks Nexa models across scientific domains:
150
+ - **Tabular Models**: Compares Nexa models against baselines.
151
+ - **LLMs**: Evaluates Nexa language models against competitors.
152
+ - **Nexa Mistral Sci-7B**: Compares general and physics-specific performance.
153
+ Scores are on a 1-10 scale.
154
+ """)
155
 
156
  demo.launch()