KennyUTC's picture
Update app.py
5f0b955 verified
import gradio as gr
import json
import requests
data_url = "http://opencompass.openxlab.space/utils/RiseVis/data.json"
data = json.loads(requests.get(data_url).text)
# Get model names from the first entry
model_names = list(data[0]['results'].keys())
HTML_HEAD = '<table class="center">'
HTML_TAIL = '</table>'
N_COL = 5
WIDTH = 100 // N_COL
def get_image_gallery(idx, models):
assert isinstance(idx, str)
item = [x for x in data if x['index'] == idx]
assert len(item) == 1
item = item[0]
html = HTML_HEAD
models = list(models)
models.sort()
num_models = len(models)
for i in range((num_models - 1) // N_COL + 1):
sub_models = models[N_COL * i: N_COL * (i + 1)]
html += '<tr>'
for j in range(N_COL):
if j >= len(sub_models):
html += f'<td width={WIDTH}% style="text-align:center;"></td>'
else:
html += f'<td width={WIDTH}% style="text-align:center;"><h3>{sub_models[j]}</h3></td>'
html += '</tr><tr>'
for j in range(N_COL):
if j >= len(sub_models):
html += f'<td width={WIDTH}% style="text-align:center;"></td>'
else:
html += f'<td width={WIDTH}% style="text-align:center;"><img src="{URL_BASE + item["results"][sub_models[j]]}"></td>'
html += '</tr>'
html += HTML_TAIL
return html
URL_BASE = 'https://opencompass.openxlab.space/utils/RiseVis/'
def get_origin_image(idx, model='original'):
assert isinstance(idx, str)
item = [x for x in data if x['index'] == idx]
assert len(item) == 1
item = item[0]
file_name = item['image'] if model == 'original' else item['results']['model']
url = URL_BASE + file_name
return url
def read_instruction(idx):
assert isinstance(idx, str)
item = [x for x in data if x['index'] == idx]
assert len(item) == 1
return item[0]['instruction']
def on_prev(state):
for i, item in enumerate(data):
if item['index'] == state:
break
return data[i - 1]['index'], data[i - 1]['index']
def on_next(state):
for i, item in enumerate(data):
if item['index'] == state:
break
return data[i + 1]['index'], data[i + 1]['index']
with gr.Blocks() as demo:
gr.Markdown("# Gallery of Generation Results on RISEBench")
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
prev_button = gr.Button("PREV")
next_button = gr.Button("NEXT")
problem_index = gr.Textbox(value='causal_reasoning_1', label='Problem Index', interactive=True, visible=True)
state = gr.Markdown(value='causal_reasoning_1', label='Current Problem Index', visible=False)
def update_state(problem_index):
return problem_index
problem_index.submit(fn=update_state, inputs=[problem_index], outputs=[state])
prev_button.click(fn=on_prev, inputs=[state], outputs=[state, problem_index])
next_button.click(fn=on_next, inputs=[state], outputs=[state, problem_index])
model_checkboxes = gr.CheckboxGroup(label="Select Models", choices=model_names, value=model_names)
with gr.Column(scale=2):
instruction = gr.Textbox(label="Instruction", interactive=False, value=read_instruction, inputs=[state])
with gr.Column(scale=1):
image = gr.Image(label="Input Image", value=get_origin_image, inputs=[state])
gallery = gr.HTML(value=get_image_gallery, inputs=[state, model_checkboxes])
if __name__ == "__main__":
demo.launch(server_name='0.0.0.0')