Austing Dong
commited on
Commit
·
b019cc7
1
Parent(s):
63b5fc2
- app.py +6 -1
- evaluate/evaluate.py +52 -14
- evaluate/new_test.json +236 -0
- questions/New_test.py +73 -0
- saliency_map/__init__.py +0 -12
- saliency_map/cam.py +0 -75
app.py
CHANGED
@@ -9,6 +9,7 @@ from demo.modified_attn import ModifiedLlamaAttention, ModifiedGemmaAttention
|
|
9 |
from questions.mini_VLAT import mini_VLAT_questions
|
10 |
from questions.VLAT_old import VLAT_old_questions
|
11 |
from questions.VLAT import VLAT_questions
|
|
|
12 |
import numpy as np
|
13 |
import matplotlib.pyplot as plt
|
14 |
import gc
|
@@ -244,6 +245,10 @@ def test_change(test_selector):
|
|
244 |
return gr.Dataset(
|
245 |
samples=VLAT_questions,
|
246 |
)
|
|
|
|
|
|
|
|
|
247 |
else:
|
248 |
return gr.Dataset(
|
249 |
samples=VLAT_old_questions,
|
@@ -265,7 +270,7 @@ with gr.Blocks() as demo:
|
|
265 |
|
266 |
with gr.Column():
|
267 |
model_selector = gr.Dropdown(choices=["ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="ChartGemma-3B", label="model")
|
268 |
-
test_selector = gr.Dropdown(choices=["mini-VLAT", "VLAT", "VLAT-old"], value="mini-VLAT", label="test")
|
269 |
chart_type = gr.Textbox(label="Chart Type", value="Any")
|
270 |
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
271 |
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
|
|
9 |
from questions.mini_VLAT import mini_VLAT_questions
|
10 |
from questions.VLAT_old import VLAT_old_questions
|
11 |
from questions.VLAT import VLAT_questions
|
12 |
+
from questions.New_test import new_test_questions
|
13 |
import numpy as np
|
14 |
import matplotlib.pyplot as plt
|
15 |
import gc
|
|
|
245 |
return gr.Dataset(
|
246 |
samples=VLAT_questions,
|
247 |
)
|
248 |
+
elif test_selector == "New_test":
|
249 |
+
return gr.Dataset(
|
250 |
+
samples=new_test_questions,
|
251 |
+
)
|
252 |
else:
|
253 |
return gr.Dataset(
|
254 |
samples=VLAT_old_questions,
|
|
|
270 |
|
271 |
with gr.Column():
|
272 |
model_selector = gr.Dropdown(choices=["ChartGemma-3B", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B"], value="ChartGemma-3B", label="model")
|
273 |
+
test_selector = gr.Dropdown(choices=["mini-VLAT", "VLAT", "VLAT-old", "New_test"], value="mini-VLAT", label="test")
|
274 |
chart_type = gr.Textbox(label="Chart Type", value="Any")
|
275 |
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
276 |
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
evaluate/evaluate.py
CHANGED
@@ -1,15 +1,22 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
import base64
|
|
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
from openai import OpenAI
|
7 |
from demo.model_utils import *
|
8 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def set_seed(model_seed = 70):
|
11 |
torch.manual_seed(model_seed)
|
12 |
-
# np.random.seed(model_seed)
|
13 |
torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
|
14 |
|
15 |
def clean():
|
@@ -26,7 +33,23 @@ def encode_image(image_path):
|
|
26 |
with open(image_path, "rb") as image_file:
|
27 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
def evaluate(model_type, num_eval = 10):
|
|
|
|
|
30 |
for eval_idx in range(num_eval):
|
31 |
clean()
|
32 |
set_seed(np.random.randint(0, 1000))
|
@@ -53,13 +76,16 @@ def evaluate(model_type, num_eval = 10):
|
|
53 |
base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
|
54 |
|
55 |
for question_idx, question in enumerate(questions):
|
56 |
-
chart_type = question[
|
57 |
-
q = question[
|
58 |
-
img_path = question[
|
|
|
|
|
|
|
59 |
image = np.array(Image.open(img_path).convert("RGB"))
|
60 |
|
61 |
|
62 |
-
|
63 |
if model_type.split('-')[0] == "GPT":
|
64 |
base64_image = encode_image(img_path)
|
65 |
completion = client.chat.completions.create(
|
@@ -68,7 +94,7 @@ def evaluate(model_type, num_eval = 10):
|
|
68 |
{
|
69 |
"role": "user",
|
70 |
"content": [
|
71 |
-
{ "type": "text", "text": f"{
|
72 |
{
|
73 |
"type": "image_url",
|
74 |
"image_url": {
|
@@ -89,7 +115,7 @@ def evaluate(model_type, num_eval = 10):
|
|
89 |
{
|
90 |
"role": "user",
|
91 |
"content": [
|
92 |
-
{ "type": "text", "text": f"{
|
93 |
{
|
94 |
"type": "image_url",
|
95 |
"image_url": {
|
@@ -103,7 +129,8 @@ def evaluate(model_type, num_eval = 10):
|
|
103 |
answer = completion.choices[0].message.content
|
104 |
|
105 |
else:
|
106 |
-
|
|
|
107 |
temperature = 0.1
|
108 |
top_p = 0.95
|
109 |
|
@@ -116,19 +143,30 @@ def evaluate(model_type, num_eval = 10):
|
|
116 |
sequences = outputs.sequences.cpu().tolist()
|
117 |
answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
|
118 |
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
120 |
FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
|
121 |
os.makedirs(FILES_ROOT, exist_ok=True)
|
122 |
|
123 |
with open(f"{FILES_ROOT}/Q{question_idx + 1}-{chart_type}.txt", "w") as f:
|
124 |
f.write(answer)
|
125 |
f.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
|
|
127 |
|
|
|
|
|
128 |
|
129 |
-
if __name__ == '__main__':
|
130 |
-
|
131 |
-
# models = ["ChartGemma", "Janus-Pro-1B", "Janus-Pro-7B", "LLaVA-1.5-7B", "GPT-4o", "Gemini-2.0-flash"]
|
132 |
-
models = ["Janus-Pro-7B"]
|
133 |
for model_type in models:
|
134 |
evaluate(model_type=model_type, num_eval=10)
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
import base64
|
4 |
+
import json
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
from openai import OpenAI
|
8 |
from demo.model_utils import *
|
9 |
+
from pydantic import BaseModel
|
10 |
+
|
11 |
+
questions = json.load(open("evaluate/new_test.json", "r"))
|
12 |
+
judge_client = OpenAI(api_key=os.environ["GEMINI_HCI_API_KEY"],
|
13 |
+
base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
|
14 |
+
|
15 |
+
class Judge_Result(BaseModel):
|
16 |
+
result: int
|
17 |
|
18 |
def set_seed(model_seed = 70):
|
19 |
torch.manual_seed(model_seed)
|
|
|
20 |
torch.cuda.manual_seed(model_seed) if torch.cuda.is_available() else None
|
21 |
|
22 |
def clean():
|
|
|
33 |
with open(image_path, "rb") as image_file:
|
34 |
return base64.b64encode(image_file.read()).decode("utf-8")
|
35 |
|
36 |
+
def llm_judge(answer, options, correct_answer):
|
37 |
+
completion = judge_client.beta.chat.completions.parse(
|
38 |
+
model="gemini-2.5-pro-preview-03-25",
|
39 |
+
messages=[
|
40 |
+
{ "role": "system", "content": "You are a judge that evaluates the correctness of answers to questions. The answer might not be the letter of correct option. You need to judge correctness of the answer, comparing to the options and correct option. Return the correctness in 1: Correct or 0: Incorrect." },
|
41 |
+
{ "role": "user", "content": f":Options: {options}\nAnswer:{answer},Correct Option: {correct_answer}" },
|
42 |
+
],
|
43 |
+
response_format=Judge_Result
|
44 |
+
)
|
45 |
+
answer = completion.choices[0].message.content
|
46 |
+
print(f"Judge Answer: {answer}")
|
47 |
+
return json.loads(answer)["result"]
|
48 |
+
|
49 |
+
|
50 |
def evaluate(model_type, num_eval = 10):
|
51 |
+
sum_correct = np.zeros(len(questions))
|
52 |
+
RESULTS_ROOT = "./evaluate/results"
|
53 |
for eval_idx in range(num_eval):
|
54 |
clean()
|
55 |
set_seed(np.random.randint(0, 1000))
|
|
|
76 |
base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
|
77 |
|
78 |
for question_idx, question in enumerate(questions):
|
79 |
+
chart_type = question["type"]
|
80 |
+
q = question["question"]
|
81 |
+
img_path = question["img_path"]
|
82 |
+
options = question.get("options", None)
|
83 |
+
correct_answer = question.get("correct_answer", None)
|
84 |
+
|
85 |
image = np.array(Image.open(img_path).convert("RGB"))
|
86 |
|
87 |
|
88 |
+
input_text = f"Options: {options}\nQuestion: {q}\n"
|
89 |
if model_type.split('-')[0] == "GPT":
|
90 |
base64_image = encode_image(img_path)
|
91 |
completion = client.chat.completions.create(
|
|
|
94 |
{
|
95 |
"role": "user",
|
96 |
"content": [
|
97 |
+
{ "type": "text", "text": f"{input_text}" },
|
98 |
{
|
99 |
"type": "image_url",
|
100 |
"image_url": {
|
|
|
115 |
{
|
116 |
"role": "user",
|
117 |
"content": [
|
118 |
+
{ "type": "text", "text": f"{input_text}" },
|
119 |
{
|
120 |
"type": "image_url",
|
121 |
"image_url": {
|
|
|
129 |
answer = completion.choices[0].message.content
|
130 |
|
131 |
else:
|
132 |
+
|
133 |
+
prepare_inputs = model_utils.prepare_inputs(input_text, image)
|
134 |
temperature = 0.1
|
135 |
top_p = 0.95
|
136 |
|
|
|
143 |
sequences = outputs.sequences.cpu().tolist()
|
144 |
answer = tokenizer.decode(sequences[0], skip_special_tokens=True)
|
145 |
|
146 |
+
# Judge the answer
|
147 |
+
result_judge = llm_judge(answer, options, correct_answer)
|
148 |
+
sum_correct[question_idx] += 1 if result_judge else 0
|
149 |
+
print(f"Model: {model_type}, Question: {question_idx + 1}, Answer: {answer}, Correct: {result_judge}")
|
150 |
+
|
151 |
+
# Save the results
|
152 |
FILES_ROOT = f"{RESULTS_ROOT}/{model_type}/{eval_idx}"
|
153 |
os.makedirs(FILES_ROOT, exist_ok=True)
|
154 |
|
155 |
with open(f"{FILES_ROOT}/Q{question_idx + 1}-{chart_type}.txt", "w") as f:
|
156 |
f.write(answer)
|
157 |
f.close()
|
158 |
+
accuracy = sum_correct / num_eval
|
159 |
+
print(f"Model: {model_type}, Accuracy: {accuracy}")
|
160 |
+
with open(f"{RESULTS_ROOT}/{model_type}/accuracy.txt", "w") as f:
|
161 |
+
for question_idx, question in enumerate(questions):
|
162 |
+
chart_type = question["type"]
|
163 |
+
f.write(f"Chart Type: {chart_type}, Accuracy: {accuracy[question_idx]}\n")
|
164 |
+
f.close()
|
165 |
|
166 |
+
if __name__ == '__main__':
|
167 |
|
168 |
+
# models = ["Janus-Pro-1B", "ChartGemma", "GPT-4o", "Gemini-2.0-flash", "Janus-Pro-7B", "LLaVA-1.5-7B"]
|
169 |
+
models = ["LLaVA-1.5-7B", "Gemini-2.0-flash", "Janus-Pro-7B", ]
|
170 |
|
|
|
|
|
|
|
|
|
171 |
for model_type in models:
|
172 |
evaluate(model_type=model_type, num_eval=10)
|
evaluate/new_test.json
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"type": "Area Chart",
|
4 |
+
"img_path": "images/New_test/AreaChart.png",
|
5 |
+
"question": "At which month the price of teAreaa is the highest?",
|
6 |
+
"options": [
|
7 |
+
{
|
8 |
+
"A": "1"
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"B": "3"
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"C": "4"
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"D": "7"
|
18 |
+
}
|
19 |
+
],
|
20 |
+
"correct_answer": "C"
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"type": "Bar Chart",
|
24 |
+
"img_path": "images/New_test/BarChart.png",
|
25 |
+
"question": "Which student has the highest score in midterm?",
|
26 |
+
"options": [
|
27 |
+
{
|
28 |
+
"A": "Bob"
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"B": "Elora"
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"C": "Kevin"
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"D": "David"
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"correct_answer": "D"
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"type": "Bubble Chart",
|
44 |
+
"img_path": "images/New_test/BubbleChart.png",
|
45 |
+
"question": "What is the number of employees of the company that has lowest annual income?",
|
46 |
+
"options": [
|
47 |
+
{
|
48 |
+
"A": "5"
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"B": "20"
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"C": "85"
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"D": "60"
|
58 |
+
}
|
59 |
+
],
|
60 |
+
"correct_answer": "C"
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"type": "Choropleth Map",
|
64 |
+
"img_path": "images/New_test/Choropleth.png",
|
65 |
+
"question": "Is the GDP of CA higher than NV?",
|
66 |
+
"options": [
|
67 |
+
{
|
68 |
+
"A": "True"
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"B": "False"
|
72 |
+
}
|
73 |
+
],
|
74 |
+
"correct_answer": "A"
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"type": "Histogram",
|
78 |
+
"img_path": "images/New_test/Histogram.png",
|
79 |
+
"question": "Which range of distance of trip people prefers the most?",
|
80 |
+
"options": [
|
81 |
+
{
|
82 |
+
"A": "10-20km"
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"B": "30-40km"
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"C": "50-60km"
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"D": "70-80km"
|
92 |
+
}
|
93 |
+
],
|
94 |
+
"correct_answer": "C"
|
95 |
+
},
|
96 |
+
{
|
97 |
+
"type": "Line Chart",
|
98 |
+
"img_path": "images/New_test/LineChart.png",
|
99 |
+
"question": "What is the blood sugar level three hours after a meal",
|
100 |
+
"options": [
|
101 |
+
{
|
102 |
+
"A": "18"
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"B": "70"
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"C": "90"
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"D": "50"
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"correct_answer": "3"
|
115 |
+
},
|
116 |
+
{
|
117 |
+
"type": "Pie Chart",
|
118 |
+
"img_path": "images/New_test/PieChart.png",
|
119 |
+
"question": "Which stock has the smallest holdings in this portfolio?",
|
120 |
+
"options": [
|
121 |
+
{
|
122 |
+
"A": "AAPL"
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"B": "NVDA"
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"C": "TSLA"
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"D": "GOOG"
|
132 |
+
}
|
133 |
+
],
|
134 |
+
"correct_answer": "B"
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"type": "Scatter Chart",
|
138 |
+
"img_path": "images/New_test/Scatterplot.png",
|
139 |
+
"question": "What is the weight of the individual that has highest height?",
|
140 |
+
"options": [
|
141 |
+
{
|
142 |
+
"A": "96"
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"B": "72"
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"C": "55"
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"D": "25"
|
152 |
+
}
|
153 |
+
],
|
154 |
+
"correct_answer": "A"
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"type": "Stacked Area Chart",
|
158 |
+
"img_path": "images/New_test/StackedArea.png",
|
159 |
+
"question": "What was the ratio of boys named 'Justin' to boys named 'Kevin' in the 3rd month in the USA?",
|
160 |
+
"options": [
|
161 |
+
{
|
162 |
+
"A": "1:1"
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"B": "1:2"
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"C": "2:1"
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"D": "4:1"
|
172 |
+
}
|
173 |
+
],
|
174 |
+
"correct_answer": "C"
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"type": "Stacked Bar Chart",
|
178 |
+
"img_path": "images/New_test/StackedBar.png",
|
179 |
+
"question": "What is the price of headphone in Japan?",
|
180 |
+
"options": [
|
181 |
+
{
|
182 |
+
"A": "10"
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"B": "30"
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"C": "50"
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"D": "70"
|
192 |
+
}
|
193 |
+
],
|
194 |
+
"correct_answer": "B"
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"type": "100% Stacked Bar Chart",
|
198 |
+
"img_path": "images/New_test/Stacked100.png",
|
199 |
+
"question": "Which country has the largest proportion of mouse price?",
|
200 |
+
"options": [
|
201 |
+
{
|
202 |
+
"A": "Canada"
|
203 |
+
},
|
204 |
+
{
|
205 |
+
"B": "China"
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"C": "Japan"
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"D": "Korea"
|
212 |
+
}
|
213 |
+
],
|
214 |
+
"correct_answer": "A"
|
215 |
+
},
|
216 |
+
{
|
217 |
+
"type": "Treemap",
|
218 |
+
"img_path": "images/New_test/TreeMap.png",
|
219 |
+
"question": "Which country has the largest population?",
|
220 |
+
"options": [
|
221 |
+
{
|
222 |
+
"A": "China"
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"B": " USA"
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"C": " Canada"
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"D": " UK"
|
232 |
+
}
|
233 |
+
],
|
234 |
+
"correct_answer": "A"
|
235 |
+
}
|
236 |
+
]
|
questions/New_test.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
new_test_questions=[
|
2 |
+
[
|
3 |
+
"LineChart",
|
4 |
+
"What is the blood sugar level three hours after a meal",
|
5 |
+
"images/New_test/LineChart.png"
|
6 |
+
],
|
7 |
+
|
8 |
+
[
|
9 |
+
"BarChart",
|
10 |
+
"Which student has the highest score in midterm?",
|
11 |
+
"images/New_test/BarChart.png"
|
12 |
+
],
|
13 |
+
|
14 |
+
[
|
15 |
+
"StackedBar",
|
16 |
+
"What is the price of headphone in Japan?",
|
17 |
+
"images/New_test/StackedBar.png"
|
18 |
+
],
|
19 |
+
|
20 |
+
[
|
21 |
+
"100%StackedBar",
|
22 |
+
"Which country has the largest proportion of mouse price?",
|
23 |
+
"images/New_test/Stacked100.png"
|
24 |
+
],
|
25 |
+
|
26 |
+
[
|
27 |
+
"PieChart",
|
28 |
+
"Which stock has the smallest holdings in this portfolio?",
|
29 |
+
"images/New_test/PieChart.png"
|
30 |
+
],
|
31 |
+
|
32 |
+
[
|
33 |
+
"Histogram",
|
34 |
+
"Which range of distance of trip people prefers the most?",
|
35 |
+
"images/New_test/Histogram.png"
|
36 |
+
],
|
37 |
+
|
38 |
+
[
|
39 |
+
"Scatterplot",
|
40 |
+
"What is the weight of the individual that has highest height?",
|
41 |
+
"images/New_test/Scatterplot.png"
|
42 |
+
],
|
43 |
+
|
44 |
+
[
|
45 |
+
"AreaChart",
|
46 |
+
"At which month the price of tea is the highest?",
|
47 |
+
"images/New_test/AreaChart.png"
|
48 |
+
],
|
49 |
+
|
50 |
+
[
|
51 |
+
"StackedArea",
|
52 |
+
"What was the ratio of boys named 'Justin' to boys named 'Kevin' in the 3rd month in the USA?",
|
53 |
+
"images/New_test/StackedArea.png"
|
54 |
+
],
|
55 |
+
|
56 |
+
[
|
57 |
+
"BubbleChart",
|
58 |
+
"What is the number of employees of the company that has lowest annual income?",
|
59 |
+
"images/New_test/BubbleChart.png"
|
60 |
+
],
|
61 |
+
|
62 |
+
[
|
63 |
+
"Choropleth",
|
64 |
+
"Is the GDP of CA higher than NV?",
|
65 |
+
"images/New_test/Choropleth.png"
|
66 |
+
],
|
67 |
+
|
68 |
+
[
|
69 |
+
"TreeMap",
|
70 |
+
"Which country has the largest population?",
|
71 |
+
"images/New_test/TreeMap.png"
|
72 |
+
]
|
73 |
+
]
|
saliency_map/__init__.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
|
2 |
-
import sys
|
3 |
-
|
4 |
-
if sys.version_info >= (3, 10):
|
5 |
-
print("Python version is above 3.10, patching the collections module.")
|
6 |
-
# Monkey patch collections
|
7 |
-
import collections
|
8 |
-
import collections.abc
|
9 |
-
|
10 |
-
for type_name in collections.abc.__all__:
|
11 |
-
setattr(collections, type_name, getattr(collections.abc, type_name))
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
saliency_map/cam.py
DELETED
@@ -1,75 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import cv2
|
3 |
-
from PIL import Image
|
4 |
-
import numpy as np
|
5 |
-
|
6 |
-
class MultimodalGradCAM:
|
7 |
-
def __init__(self, model, processor):
|
8 |
-
self.model = model
|
9 |
-
self.processor = processor
|
10 |
-
self.activations = {}
|
11 |
-
self.gradients = {}
|
12 |
-
|
13 |
-
# Register hooks
|
14 |
-
self._register_hooks()
|
15 |
-
|
16 |
-
def _register_hooks(self):
|
17 |
-
# Hook the last vision transformer layer
|
18 |
-
def forward_hook(module, input, output):
|
19 |
-
self.activations['vision'] = output.last_hidden_state
|
20 |
-
|
21 |
-
def backward_hook(module, grad_input, grad_output):
|
22 |
-
self.gradients['vision'] = grad_output[0]
|
23 |
-
|
24 |
-
vision_encoder = self.model.get_vision_encoder()
|
25 |
-
vision_encoder.layers[-1].register_forward_hook(forward_hook)
|
26 |
-
vision_encoder.layers[-1].register_backward_hook(backward_hook)
|
27 |
-
|
28 |
-
def generate_saliency(self, image, question):
|
29 |
-
# Preprocess inputs
|
30 |
-
inputs = self.processor(
|
31 |
-
text=question,
|
32 |
-
images=image,
|
33 |
-
return_tensors="pt",
|
34 |
-
padding=True
|
35 |
-
)
|
36 |
-
|
37 |
-
# Forward pass
|
38 |
-
outputs = self.model(**inputs)
|
39 |
-
answer_ids = outputs.logits.argmax(dim=-1)
|
40 |
-
|
41 |
-
# Get target token (use last token for answer)
|
42 |
-
target_token_id = answer_ids[0, -1].item()
|
43 |
-
target = outputs.logits[0, -1, target_token_id]
|
44 |
-
|
45 |
-
# Backward pass
|
46 |
-
self.model.zero_grad()
|
47 |
-
target.backward()
|
48 |
-
|
49 |
-
# Process activations and gradients
|
50 |
-
activations = self.activations['vision'].detach()
|
51 |
-
gradients = self.gradients['vision'].detach()
|
52 |
-
|
53 |
-
# Grad-CAM calculation
|
54 |
-
weights = gradients.mean(dim=[1, 2], keepdim=True) # Global average pooling
|
55 |
-
cam = (weights * activations).sum(dim=-1, keepdims=True)
|
56 |
-
cam = torch.relu(cam)
|
57 |
-
|
58 |
-
# Reshape and normalize
|
59 |
-
cam = cam.squeeze().cpu().numpy()
|
60 |
-
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
|
61 |
-
|
62 |
-
return cam
|
63 |
-
|
64 |
-
def visualize(self, image, cam):
|
65 |
-
# Resize CAM to original image size
|
66 |
-
img_size = image.size[::-1] # (width, height) -> (height, width)
|
67 |
-
cam = cv2.resize(cam, img_size)
|
68 |
-
|
69 |
-
# Convert to heatmap
|
70 |
-
heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
|
71 |
-
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
|
72 |
-
|
73 |
-
# Superimpose on original image
|
74 |
-
superimposed = np.array(image) * 0.4 + heatmap * 0.6
|
75 |
-
return Image.fromarray(np.uint8(superimposed))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|