alielfilali01 commited on
Commit
692e84c
·
verified ·
1 Parent(s): 7ce9f0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -5
app.py CHANGED
@@ -50,7 +50,8 @@ def generate_heatmap_image(model_entry):
50
  # Create a mask for the upper triangle (keeping the diagonal visible).
51
  mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
52
 
53
- plt.figure(figsize=(6, 5))
 
54
  sns.heatmap(matrix,
55
  mask=mask,
56
  annot=True,
@@ -66,11 +67,18 @@ def generate_heatmap_image(model_entry):
66
 
67
  # Save the plot to a bytes buffer.
68
  buf = BytesIO()
69
- plt.savefig(buf, format="png")
70
  plt.close()
71
  buf.seek(0)
 
72
  # Convert the buffer into a PIL Image.
73
  image = Image.open(buf).convert("RGB")
 
 
 
 
 
 
74
  return image
75
 
76
  def generate_heatmaps(selected_model_names):
@@ -88,7 +96,13 @@ def generate_heatmaps(selected_model_names):
88
  # -------------------------------
89
  # 3. Build the Gradio Interface
90
  # -------------------------------
91
- with gr.Blocks() as demo:
 
 
 
 
 
 
92
  gr.Markdown("## 3C3H Heatmap Generator")
93
  gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
94
 
@@ -96,10 +110,123 @@ with gr.Blocks() as demo:
96
  model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
97
 
98
  generate_btn = gr.Button("Generate Heatmaps")
99
- # Use the 'columns' parameter to set a grid layout in the gallery.
100
- gallery = gr.Gallery(label="Heatmaps", columns=2)
 
 
 
 
 
 
101
 
102
  generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
103
 
104
  # Launch the Gradio app
105
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Create a mask for the upper triangle (keeping the diagonal visible).
51
  mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
52
 
53
+ # Set a consistent figure size that will work well in the gallery
54
+ plt.figure(figsize=(6, 5), dpi=100)
55
  sns.heatmap(matrix,
56
  mask=mask,
57
  annot=True,
 
67
 
68
  # Save the plot to a bytes buffer.
69
  buf = BytesIO()
70
+ plt.savefig(buf, format="png", bbox_inches="tight")
71
  plt.close()
72
  buf.seek(0)
73
+
74
  # Convert the buffer into a PIL Image.
75
  image = Image.open(buf).convert("RGB")
76
+
77
+ # Resize the image to a reasonable fixed size for the gallery
78
+ # This helps maintain consistency and prevent oversized images
79
+ max_size = (800, 600)
80
+ image.thumbnail(max_size, Image.Resampling.LANCZOS)
81
+
82
  return image
83
 
84
  def generate_heatmaps(selected_model_names):
 
96
  # -------------------------------
97
  # 3. Build the Gradio Interface
98
  # -------------------------------
99
+ with gr.Blocks(css="""
100
+ .gallery-item img {
101
+ max-width: 100% !important;
102
+ max-height: 100% !important;
103
+ object-fit: contain !important;
104
+ }
105
+ """) as demo:
106
  gr.Markdown("## 3C3H Heatmap Generator")
107
  gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
108
 
 
110
  model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
111
 
112
  generate_btn = gr.Button("Generate Heatmaps")
113
+
114
+ # Set height and columns for better display
115
+ gallery = gr.Gallery(
116
+ label="Heatmaps",
117
+ columns=2,
118
+ height="auto",
119
+ object_fit="contain"
120
+ )
121
 
122
  generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
123
 
124
  # Launch the Gradio app
125
  demo.launch()
126
+
127
+
128
+ # import gradio as gr
129
+ # import json
130
+ # import os
131
+ # import numpy as np
132
+ # import matplotlib.pyplot as plt
133
+ # import seaborn as sns
134
+ # from io import BytesIO
135
+ # from PIL import Image
136
+
137
+ # # -------------------------------
138
+ # # 1. Load Results from Local File
139
+ # # -------------------------------
140
+ # def load_results():
141
+ # # Get the directory of the current file
142
+ # current_dir = os.path.dirname(os.path.abspath(__file__))
143
+ # # Construct the path to the JSON file (assumes file is stored in "files/aragen_v1_results.json")
144
+ # results_file = os.path.join(current_dir, "files", "aragen_v1_results.json")
145
+ # with open(results_file, "r") as f:
146
+ # data = json.load(f)
147
+ # # Filter out any non-model entries (e.g., timestamp entries)
148
+ # model_data = [entry for entry in data if "Meta" in entry]
149
+ # return model_data
150
+
151
+ # # Load the JSON data once when the app starts
152
+ # DATA = load_results()
153
+
154
+ # # Extract model names for the dropdown from the JSON "Meta" field
155
+ # def get_model_names(data):
156
+ # model_names = [entry["Meta"]["Model Name"] for entry in data]
157
+ # return model_names
158
+
159
+ # MODEL_NAMES = get_model_names(DATA)
160
+
161
+ # # -------------------------------
162
+ # # 2. Define Metrics and Heatmap Generation Functions
163
+ # # -------------------------------
164
+ # # Define the six metrics in the desired order.
165
+ # METRICS = ["Correctness", "Completeness", "Conciseness", "Helpfulness", "Honesty", "Harmlessness"]
166
+
167
+ # def generate_heatmap_image(model_entry):
168
+ # """
169
+ # For a given model entry, extract the six metrics and compute a 6x6 similarity matrix
170
+ # using the definition: similarity = 1 - |v_i - v_j|, then return the heatmap as a PIL image.
171
+ # """
172
+ # scores = model_entry["claude-3.5-sonnet Scores"]["3C3H Scores"]
173
+ # # Create a vector with the metrics in the defined order.
174
+ # v = np.array([scores[m] for m in METRICS])
175
+ # # Compute the 6x6 similarity matrix.
176
+ # matrix = 1 - np.abs(np.subtract.outer(v, v))
177
+ # # Create a mask for the upper triangle (keeping the diagonal visible).
178
+ # mask = np.triu(np.ones_like(matrix, dtype=bool), k=1)
179
+
180
+ # plt.figure(figsize=(6, 5))
181
+ # sns.heatmap(matrix,
182
+ # mask=mask,
183
+ # annot=True,
184
+ # fmt=".2f",
185
+ # cmap="viridis",
186
+ # xticklabels=METRICS,
187
+ # yticklabels=METRICS,
188
+ # cbar_kws={"label": "Similarity"})
189
+ # plt.title(f"Confusion Matrix for Model: {model_entry['Meta']['Model Name']}")
190
+ # plt.xlabel("Metrics")
191
+ # plt.ylabel("Metrics")
192
+ # plt.tight_layout()
193
+
194
+ # # Save the plot to a bytes buffer.
195
+ # buf = BytesIO()
196
+ # plt.savefig(buf, format="png")
197
+ # plt.close()
198
+ # buf.seek(0)
199
+ # # Convert the buffer into a PIL Image.
200
+ # image = Image.open(buf).convert("RGB")
201
+ # return image
202
+
203
+ # def generate_heatmaps(selected_model_names):
204
+ # """
205
+ # Filter the global DATA for entries matching the selected model names,
206
+ # generate a heatmap for each, and return a list of PIL images.
207
+ # """
208
+ # filtered_entries = [entry for entry in DATA if entry["Meta"]["Model Name"] in selected_model_names]
209
+ # images = []
210
+ # for entry in filtered_entries:
211
+ # img = generate_heatmap_image(entry)
212
+ # images.append(img)
213
+ # return images
214
+
215
+ # # -------------------------------
216
+ # # 3. Build the Gradio Interface
217
+ # # -------------------------------
218
+ # with gr.Blocks() as demo:
219
+ # gr.Markdown("## 3C3H Heatmap Generator")
220
+ # gr.Markdown("Select the models you want to compare and generate their heatmaps below.")
221
+
222
+ # with gr.Row():
223
+ # model_dropdown = gr.Dropdown(choices=MODEL_NAMES, label="Select Model(s)", multiselect=True, value=MODEL_NAMES[:3])
224
+
225
+ # generate_btn = gr.Button("Generate Heatmaps")
226
+ # # Use the 'columns' parameter to set a grid layout in the gallery.
227
+ # gallery = gr.Gallery(label="Heatmaps", columns=2)
228
+
229
+ # generate_btn.click(fn=generate_heatmaps, inputs=model_dropdown, outputs=gallery)
230
+
231
+ # # Launch the Gradio app
232
+ # demo.launch()