Austing Dong commited on
Commit
b019cc7
·
1 Parent(s): 63b5fc2
app.py CHANGED
@@ -9,6 +9,7 @@ from demo.modified_attn import ModifiedLlamaAttention, ModifiedGemmaAttention
9
  from questions.mini_VLAT import mini_VLAT_questions
10
  from questions.VLAT_old import VLAT_old_questions
11
  from questions.VLAT import VLAT_questions
 
12
  import numpy as np
13
  import matplotlib.pyplot as plt
14
  import gc
@@ -244,6 +245,10 @@ def test_change(test_selector):
244
  return gr.Dataset(
245
  samples=VLAT_questions,
246
  )
 
 
 
 
247
  else:
248
  return gr.Dataset(
249
  samples=VLAT_old_questions,
@@ -265,7 +270,7 @@ with gr.Blocks() as demo:
265
 
266
  with gr.Column():
267
  model_selector = gr.Dropdown(choices=["ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="ChartGemma-3B", label="model")
268
- test_selector = gr.Dropdown(choices=["mini-VLAT", "VLAT", "VLAT-old"], value="mini-VLAT", label="test")
269
  chart_type = gr.Textbox(label="Chart Type", value="Any")
270
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
271
  top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
 
9
  from questions.mini_VLAT import mini_VLAT_questions
10
  from questions.VLAT_old import VLAT_old_questions
11
  from questions.VLAT import VLAT_questions
12
+ from questions.New_test import new_test_questions
13
  import numpy as np
14
  import matplotlib.pyplot as plt
15
  import gc
 
245
  return gr.Dataset(
246
  samples=VLAT_questions,
247
  )
248
+ elif test_selector == "New_test":
249
+ return gr.Dataset(
250
+ samples=new_test_questions,
251
+ )
252
  else:
253
  return gr.Dataset(
254
  samples=VLAT_old_questions,
 
270
 
271
  with gr.Column():
272
  model_selector = gr.Dropdown(choices=["ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="ChartGemma-3B", label="model")
273
+ test_selector = gr.Dropdown(choices=["mini-VLAT", "VLAT", "VLAT-old", "New_test"], value="mini-VLAT", label="test")
274
  chart_type = gr.Textbox(label="Chart Type", value="Any")
275
  und_seed_input = gr.Number(label="Seed", precision=0, value=42)
276
  top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
evaluate/evaluate.py CHANGED
@@ -1,15 +1,22 @@
1
  import os
2
  import torch
3
  import base64
 
4
  import numpy as np
5
  from PIL import Image
6
  from openai import OpenAI
7
  from demo.model_utils import *
8
- from evaluate.questions import questions
 
 
 
 
 
 
 
9
 
10
  def set_seed(model_seed = 70):
11
  torch.manual_seed(model_seed)
12
- # np.random.seed(model_seed)
13
  torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
14
 
15
  def clean():
@@ -26,7 +33,23 @@ def encode_image(image_path):
26
  with open(image_path, "rb") as image_file:
27
  return base64.b64encode(image_file.read()).decode("utf-8")
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def evaluate(model_type, num_eval = 10):
 
 
30
  for eval_idx in range(num_eval):
31
  clean()
32
  set_seed(np.random.randint(0, 1000))
@@ -53,13 +76,16 @@ def evaluate(model_type, num_eval = 10):
53
  base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
54
 
55
  for question_idx, question in enumerate(questions):
56
- chart_type = question[0]
57
- q = question[1]
58
- img_path = question[2]
 
 
 
59
  image = np.array(Image.open(img_path).convert("RGB"))
60
 
61
 
62
-
63
  if model_type.split('-')[0] == "GPT":
64
  base64_image = encode_image(img_path)
65
  completion = client.chat.completions.create(
@@ -68,7 +94,7 @@ def evaluate(model_type, num_eval = 10):
68
  {
69
  "role": "user",
70
  "content": [
71
- { "type": "text", "text": f"{q}" },
72
  {
73
  "type": "image_url",
74
  "image_url": {
@@ -89,7 +115,7 @@ def evaluate(model_type, num_eval = 10):
89
  {
90
  "role": "user",
91
  "content": [
92
- { "type": "text", "text": f"{q}" },
93
  {
94
  "type": "image_url",
95
  "image_url": {
@@ -103,7 +129,8 @@ def evaluate(model_type, num_eval = 10):
103
  answer = completion.choices[0].message.content
104
 
105
  else:
106
- prepare_inputs = model_utils.prepare_inputs(q, image)
 
107
  temperature = 0.1
108
  top_p = 0.95
109
 
@@ -116,19 +143,30 @@ def evaluate(model_type, num_eval = 10):
116
  sequences = outputs.sequences.cpu().tolist()
117
  answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
118
 
119
- RESULTS_ROOT = "./evaluate/results"
 
 
 
 
 
120
  FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
121
  os.makedirs(FILES_ROOT, exist_ok=True)
122
 
123
  with open(f"{FILES_ROOT}/Q{question_idx + 1}-{chart_type}.txt", "w") as f:
124
  f.write(answer)
125
  f.close()
 
 
 
 
 
 
 
126
 
 
127
 
 
 
128
 
129
- if __name__ == '__main__':
130
-
131
- # models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B", "GPT-4o", "Gemini-2.0-flash"]
132
- models = ["Janus-Pro-7B"]
133
  for model_type in models:
134
  evaluate(model_type=model_type, num_eval=10)
 
1
  import os
2
  import torch
3
  import base64
4
+ import json
5
  import numpy as np
6
  from PIL import Image
7
  from openai import OpenAI
8
  from demo.model_utils import *
9
+ from pydantic import BaseModel
10
+
11
+ questions = json.load(open("evaluate/new_test.json", "r"))
12
+ judge_client = OpenAI(api_key=os.environ["GEMINI_HCI_API_KEY"],
13
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
14
+
15
+ class Judge_Result(BaseModel):
16
+ result: int
17
 
18
  def set_seed(model_seed = 70):
19
  torch.manual_seed(model_seed)
 
20
  torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
21
 
22
  def clean():
 
33
  with open(image_path, "rb") as image_file:
34
  return base64.b64encode(image_file.read()).decode("utf-8")
35
 
36
+ def llm_judge(answer, options, correct_answer):
37
+ completion = judge_client.beta.chat.completions.parse(
38
+ model="gemini-2.5-pro-preview-03-25",
39
+ messages=[
40
+ { "role": "system", "content": "You are a judge that evaluates the correctness of answers to questions. The answer might not be the letter of correct option. You need to judge correctness of the answer, comparing to the options and correct option. Return the correctness in 1: Correct or 0: Incorrect." },
41
+ { "role": "user", "content": f":Options: {options}\nAnswer:{answer},Correct Option: {correct_answer}" },
42
+ ],
43
+ response_format=Judge_Result
44
+ )
45
+ answer = completion.choices[0].message.content
46
+ print(f"Judge Answer: {answer}")
47
+ return json.loads(answer)["result"]
48
+
49
+
50
  def evaluate(model_type, num_eval = 10):
51
+ sum_correct = np.zeros(len(questions))
52
+ RESULTS_ROOT = "./evaluate/results"
53
  for eval_idx in range(num_eval):
54
  clean()
55
  set_seed(np.random.randint(0, 1000))
 
76
  base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
77
 
78
  for question_idx, question in enumerate(questions):
79
+ chart_type = question["type"]
80
+ q = question["question"]
81
+ img_path = question["img_path"]
82
+ options = question.get("options", None)
83
+ correct_answer = question.get("correct_answer", None)
84
+
85
  image = np.array(Image.open(img_path).convert("RGB"))
86
 
87
 
88
+ input_text = f"Options: {options}\nQuestion: {q}\n"
89
  if model_type.split('-')[0] == "GPT":
90
  base64_image = encode_image(img_path)
91
  completion = client.chat.completions.create(
 
94
  {
95
  "role": "user",
96
  "content": [
97
+ { "type": "text", "text": f"{input_text}" },
98
  {
99
  "type": "image_url",
100
  "image_url": {
 
115
  {
116
  "role": "user",
117
  "content": [
118
+ { "type": "text", "text": f"{input_text}" },
119
  {
120
  "type": "image_url",
121
  "image_url": {
 
129
  answer = completion.choices[0].message.content
130
 
131
  else:
132
+
133
+ prepare_inputs = model_utils.prepare_inputs(input_text, image)
134
  temperature = 0.1
135
  top_p = 0.95
136
 
 
143
  sequences = outputs.sequences.cpu().tolist()
144
  answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
145
 
146
+ # Judge the answer
147
+ result_judge = llm_judge(answer, options, correct_answer)
148
+ sum_correct[question_idx] += 1 if result_judge else 0
149
+ print(f"Model: {model_type}, Question: {question_idx + 1}, Answer: {answer}, Correct: {result_judge}")
150
+
151
+ # Save the results
152
  FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
153
  os.makedirs(FILES_ROOT, exist_ok=True)
154
 
155
  with open(f"{FILES_ROOT}/Q{question_idx + 1}-{chart_type}.txt", "w") as f:
156
  f.write(answer)
157
  f.close()
158
+ accuracy = sum_correct / num_eval
159
+ print(f"Model: {model_type}, Accuracy: {accuracy}")
160
+ with open(f"{RESULTS_ROOT}/{model_type}/accuracy.txt", "w") as f:
161
+ for question_idx, question in enumerate(questions):
162
+ chart_type = question["type"]
163
+ f.write(f"Chart Type: {chart_type}, Accuracy: {accuracy[question_idx]}\n")
164
+ f.close()
165
 
166
+ if __name__ == '__main__':
167
 
168
+ # models = ["Janus-Pro-1B", "ChartGemma", "GPT-4o", "Gemini-2.0-flash", "Janus-Pro-7B", "LLaVA-1.5-7B"]
169
+ models = ["LLaVA-1.5-7B", "Gemini-2.0-flash", "Janus-Pro-7B", ]
170
 
 
 
 
 
171
  for model_type in models:
172
  evaluate(model_type=model_type, num_eval=10)
evaluate/new_test.json ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "type": "Area Chart",
4
+ "img_path": "images/New_test/AreaChart.png",
5
+ "question": "At which month the price of teAreaa is the highest?",
6
+ "options": [
7
+ {
8
+ "A": "1"
9
+ },
10
+ {
11
+ "B": "3"
12
+ },
13
+ {
14
+ "C": "4"
15
+ },
16
+ {
17
+ "D": "7"
18
+ }
19
+ ],
20
+ "correct_answer": "C"
21
+ },
22
+ {
23
+ "type": "Bar Chart",
24
+ "img_path": "images/New_test/BarChart.png",
25
+ "question": "Which student has the highest score in midterm?",
26
+ "options": [
27
+ {
28
+ "A": "Bob"
29
+ },
30
+ {
31
+ "B": "Elora"
32
+ },
33
+ {
34
+ "C": "Kevin"
35
+ },
36
+ {
37
+ "D": "David"
38
+ }
39
+ ],
40
+ "correct_answer": "D"
41
+ },
42
+ {
43
+ "type": "Bubble Chart",
44
+ "img_path": "images/New_test/BubbleChart.png",
45
+ "question": "What is the number of employees of the company that has lowest annual income?",
46
+ "options": [
47
+ {
48
+ "A": "5"
49
+ },
50
+ {
51
+ "B": "20"
52
+ },
53
+ {
54
+ "C": "85"
55
+ },
56
+ {
57
+ "D": "60"
58
+ }
59
+ ],
60
+ "correct_answer": "C"
61
+ },
62
+ {
63
+ "type": "Choropleth Map",
64
+ "img_path": "images/New_test/Choropleth.png",
65
+ "question": "Is the GDP of CA higher than NV?",
66
+ "options": [
67
+ {
68
+ "A": "True"
69
+ },
70
+ {
71
+ "B": "False"
72
+ }
73
+ ],
74
+ "correct_answer": "A"
75
+ },
76
+ {
77
+ "type": "Histogram",
78
+ "img_path": "images/New_test/Histogram.png",
79
+ "question": "Which range of distance of trip people prefers the most?",
80
+ "options": [
81
+ {
82
+ "A": "10-20km"
83
+ },
84
+ {
85
+ "B": "30-40km"
86
+ },
87
+ {
88
+ "C": "50-60km"
89
+ },
90
+ {
91
+ "D": "70-80km"
92
+ }
93
+ ],
94
+ "correct_answer": "C"
95
+ },
96
+ {
97
+ "type": "Line Chart",
98
+ "img_path": "images/New_test/LineChart.png",
99
+ "question": "What is the blood sugar level three hours after a meal",
100
+ "options": [
101
+ {
102
+ "A": "18"
103
+ },
104
+ {
105
+ "B": "70"
106
+ },
107
+ {
108
+ "C": "90"
109
+ },
110
+ {
111
+ "D": "50"
112
+ }
113
+ ],
114
+ "correct_answer": "3"
115
+ },
116
+ {
117
+ "type": "Pie Chart",
118
+ "img_path": "images/New_test/PieChart.png",
119
+ "question": "Which stock has the smallest holdings in this portfolio?",
120
+ "options": [
121
+ {
122
+ "A": "AAPL"
123
+ },
124
+ {
125
+ "B": "NVDA"
126
+ },
127
+ {
128
+ "C": "TSLA"
129
+ },
130
+ {
131
+ "D": "GOOG"
132
+ }
133
+ ],
134
+ "correct_answer": "B"
135
+ },
136
+ {
137
+ "type": "Scatter Chart",
138
+ "img_path": "images/New_test/Scatterplot.png",
139
+ "question": "What is the weight of the individual that has highest height?",
140
+ "options": [
141
+ {
142
+ "A": "96"
143
+ },
144
+ {
145
+ "B": "72"
146
+ },
147
+ {
148
+ "C": "55"
149
+ },
150
+ {
151
+ "D": "25"
152
+ }
153
+ ],
154
+ "correct_answer": "A"
155
+ },
156
+ {
157
+ "type": "Stacked Area Chart",
158
+ "img_path": "images/New_test/StackedArea.png",
159
+ "question": "What was the ratio of boys named 'Justin' to boys named 'Kevin' in the 3rd month in the USA?",
160
+ "options": [
161
+ {
162
+ "A": "1:1"
163
+ },
164
+ {
165
+ "B": "1:2"
166
+ },
167
+ {
168
+ "C": "2:1"
169
+ },
170
+ {
171
+ "D": "4:1"
172
+ }
173
+ ],
174
+ "correct_answer": "C"
175
+ },
176
+ {
177
+ "type": "Stacked Bar Chart",
178
+ "img_path": "images/New_test/StackedBar.png",
179
+ "question": "What is the price of headphone in Japan?",
180
+ "options": [
181
+ {
182
+ "A": "10"
183
+ },
184
+ {
185
+ "B": "30"
186
+ },
187
+ {
188
+ "C": "50"
189
+ },
190
+ {
191
+ "D": "70"
192
+ }
193
+ ],
194
+ "correct_answer": "B"
195
+ },
196
+ {
197
+ "type": "100% Stacked Bar Chart",
198
+ "img_path": "images/New_test/Stacked100.png",
199
+ "question": "Which country has the largest proportion of mouse price?",
200
+ "options": [
201
+ {
202
+ "A": "Canada"
203
+ },
204
+ {
205
+ "B": "China"
206
+ },
207
+ {
208
+ "C": "Japan"
209
+ },
210
+ {
211
+ "D": "Korea"
212
+ }
213
+ ],
214
+ "correct_answer": "A"
215
+ },
216
+ {
217
+ "type": "Treemap",
218
+ "img_path": "images/New_test/TreeMap.png",
219
+ "question": "Which country has the largest population?",
220
+ "options": [
221
+ {
222
+ "A": "China"
223
+ },
224
+ {
225
+ "B": " USA"
226
+ },
227
+ {
228
+ "C": " Canada"
229
+ },
230
+ {
231
+ "D": " UK"
232
+ }
233
+ ],
234
+ "correct_answer": "A"
235
+ }
236
+ ]
questions/New_test.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ new_test_questions=[
2
+ [
3
+ "LineChart",
4
+ "What is the blood sugar level three hours after a meal",
5
+ "images/New_test/LineChart.png"
6
+ ],
7
+
8
+ [
9
+ "BarChart",
10
+ "Which student has the highest score in midterm?",
11
+ "images/New_test/BarChart.png"
12
+ ],
13
+
14
+ [
15
+ "StackedBar",
16
+ "What is the price of headphone in Japan?",
17
+ "images/New_test/StackedBar.png"
18
+ ],
19
+
20
+ [
21
+ "100%StackedBar",
22
+ "Which country has the largest proportion of mouse price?",
23
+ "images/New_test/Stacked100.png"
24
+ ],
25
+
26
+ [
27
+ "PieChart",
28
+ "Which stock has the smallest holdings in this portfolio?",
29
+ "images/New_test/PieChart.png"
30
+ ],
31
+
32
+ [
33
+ "Histogram",
34
+ "Which range of distance of trip people prefers the most?",
35
+ "images/New_test/Histogram.png"
36
+ ],
37
+
38
+ [
39
+ "Scatterplot",
40
+ "What is the weight of the individual that has highest height?",
41
+ "images/New_test/Scatterplot.png"
42
+ ],
43
+
44
+ [
45
+ "AreaChart",
46
+ "At which month the price of tea is the highest?",
47
+ "images/New_test/AreaChart.png"
48
+ ],
49
+
50
+ [
51
+ "StackedArea",
52
+ "What was the ratio of boys named 'Justin' to boys named 'Kevin' in the 3rd month in the USA?",
53
+ "images/New_test/StackedArea.png"
54
+ ],
55
+
56
+ [
57
+ "BubbleChart",
58
+ "What is the number of employees of the company that has lowest annual income?",
59
+ "images/New_test/BubbleChart.png"
60
+ ],
61
+
62
+ [
63
+ "Choropleth",
64
+ "Is the GDP of CA higher than NV?",
65
+ "images/New_test/Choropleth.png"
66
+ ],
67
+
68
+ [
69
+ "TreeMap",
70
+ "Which country has the largest population?",
71
+ "images/New_test/TreeMap.png"
72
+ ]
73
+ ]
saliency_map/__init__.py DELETED
@@ -1,12 +0,0 @@
1
-
2
- import sys
3
-
4
- if sys.version_info >= (3, 10):
5
- print("Python version is above 3.10, patching the collections module.")
6
- # Monkey patch collections
7
- import collections
8
- import collections.abc
9
-
10
- for type_name in collections.abc.__all__:
11
- setattr(collections, type_name, getattr(collections.abc, type_name))
12
-
 
 
 
 
 
 
 
 
 
 
 
 
 
saliency_map/cam.py DELETED
@@ -1,75 +0,0 @@
1
- import torch
2
- import cv2
3
- from PIL import Image
4
- import numpy as np
5
-
6
- class MultimodalGradCAM:
7
- def __init__(self, model, processor):
8
- self.model = model
9
- self.processor = processor
10
- self.activations = {}
11
- self.gradients = {}
12
-
13
- # Register hooks
14
- self._register_hooks()
15
-
16
- def _register_hooks(self):
17
- # Hook the last vision transformer layer
18
- def forward_hook(module, input, output):
19
- self.activations['vision'] = output.last_hidden_state
20
-
21
- def backward_hook(module, grad_input, grad_output):
22
- self.gradients['vision'] = grad_output[0]
23
-
24
- vision_encoder = self.model.get_vision_encoder()
25
- vision_encoder.layers[-1].register_forward_hook(forward_hook)
26
- vision_encoder.layers[-1].register_backward_hook(backward_hook)
27
-
28
- def generate_saliency(self, image, question):
29
- # Preprocess inputs
30
- inputs = self.processor(
31
- text=question,
32
- images=image,
33
- return_tensors="pt",
34
- padding=True
35
- )
36
-
37
- # Forward pass
38
- outputs = self.model(**inputs)
39
- answer_ids = outputs.logits.argmax(dim=-1)
40
-
41
- # Get target token (use last token for answer)
42
- target_token_id = answer_ids[0, -1].item()
43
- target = outputs.logits[0, -1, target_token_id]
44
-
45
- # Backward pass
46
- self.model.zero_grad()
47
- target.backward()
48
-
49
- # Process activations and gradients
50
- activations = self.activations['vision'].detach()
51
- gradients = self.gradients['vision'].detach()
52
-
53
- # Grad-CAM calculation
54
- weights = gradients.mean(dim=[1, 2], keepdim=True) # Global average pooling
55
- cam = (weights * activations).sum(dim=-1, keepdims=True)
56
- cam = torch.relu(cam)
57
-
58
- # Reshape and normalize
59
- cam = cam.squeeze().cpu().numpy()
60
- cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
61
-
62
- return cam
63
-
64
- def visualize(self, image, cam):
65
- # Resize CAM to original image size
66
- img_size = image.size[::-1] # (width, height) -> (height, width)
67
- cam = cv2.resize(cam, img_size)
68
-
69
- # Convert to heatmap
70
- heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
71
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
72
-
73
- # Superimpose on original image
74
- superimposed = np.array(image) * 0.4 + heatmap * 0.6
75
- return Image.fromarray(np.uint8(superimposed))