KennyUTC commited on
Commit
eef0d11
·
verified ·
1 Parent(s): bd1d69b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import requests
4
+
5
+ data_url = "http://opencompass.openxlab.space/utils/RiseVis/data.json"
6
+ data = json.loads(requests.get(data_url).text)
7
+ # Get model names from the first entry
8
+ model_names = list(data[0]['results'].keys())
9
+
10
+ HTML_HEAD = '<table class="center">'
11
+ HTML_TAIL = '</table>'
12
+ N_COL = 4
13
+ WIDTH = 100 // N_COL
14
+
15
+ def get_image_gallery(idx, models):
16
+ assert isinstance(idx, str)
17
+ item = [x for x in data if x['index'] == idx]
18
+ assert len(item) == 1
19
+ item = item[0]
20
+ html = HTML_HEAD
21
+ models = list(models)
22
+ models.sort()
23
+ num_models = len(models)
24
+ for i in range((num_models - 1) // N_COL + 1):
25
+ sub_models = models[N_COL * i: N_COL * (i + 1)]
26
+ html += '<tr>'
27
+ for j in range(N_COL):
28
+ if j >= len(sub_models):
29
+ html += f'<td width={WIDTH}% style="text-align:center;"></td>'
30
+ else:
31
+ html += f'<td width={WIDTH}% style="text-align:center;">{sub_models[j]}</td>'
32
+ html += '</tr><tr>'
33
+ for j in range(N_COL):
34
+ if j >= len(sub_models):
35
+ html += f'<td width={WIDTH}% style="text-align:center;"></td>'
36
+ else:
37
+ html += f'<td width={WIDTH}% style="text-align:center;"><img src="{URL_BASE + item["results"][sub_models[j]]}"></td>'
38
+ html += '</tr>'
39
+ html += HTML_TAIL
40
+ return html
41
+
42
+ URL_BASE = 'https://opencompass.openxlab.space/utils/RiseVis/'
43
+
44
+ def get_origin_image(idx, model='original'):
45
+ assert isinstance(idx, str)
46
+ item = [x for x in data if x['index'] == idx]
47
+ assert len(item) == 1
48
+ item = item[0]
49
+ file_name = item['image'] if model == 'original' else item['results']['model']
50
+ url = URL_BASE + file_name
51
+ return url
52
+
53
+ def read_instruction(idx):
54
+ assert isinstance(idx, str)
55
+ item = [x for x in data if x['index'] == idx]
56
+ assert len(item) == 1
57
+ return item[0]['instruction']
58
+
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("# Gallery of Generation Results on RISEBench")
61
+
62
+ with gr.Row():
63
+ with gr.Column(scale=2):
64
+ with gr.Row():
65
+ prev_button = gr.Button("PREV")
66
+ next_button = gr.Button("NEXT")
67
+ problem_index = gr.Textbox(value='causal_reasoning_1', label='Problem Index', interactive=True, visible=True)
68
+ state = gr.Markdown(value='causal_reasoning_1', label='Current Problem Index', visible=False)
69
+ def update_state(problem_index):
70
+ return problem_index
71
+ problem_index.submit(fn=update_state, inputs=[problem_index], outputs=[state])
72
+
73
+ model_checkboxes = gr.CheckboxGroup(label="Select Models", choices=model_names, value=model_names)
74
+
75
+ with gr.Column(scale=2):
76
+ instruction = gr.Textbox(label="Instruction", interactive=False, value=read_instruction, inputs=[state])
77
+
78
+ with gr.Column(scale=1):
79
+ image = gr.Image(label="Input Image", value=get_origin_image, inputs=[state])
80
+
81
+ gallery = gr.HTML(value=get_image_gallery, inputs=[state, model_checkboxes])
82
+
83
+ if __name__ == "__main__":
84
+ demo.launch(server_name='0.0.0.0')