lisabdunlap commited on
Commit
7631655
·
verified ·
1 Parent(s): adadbed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py CHANGED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ import argparse
4
+ import gradio as gr
5
+
6
+ # Load the JSONL file
7
+ def load_jsonl(file_path):
8
+ data = []
9
+ with open(file_path, 'r') as f:
10
+ for line in f:
11
+ data.append(json.loads(line))
12
+ return data
13
+
14
+ def display_pairwise_answer(data):
15
+ chat_mds = pairwise_to_gradio_chat_mds(data)
16
+
17
+ return chat_mds
18
+
19
+
20
+ newline_pattern1 = re.compile("\n\n(\d+\. )")
21
+ newline_pattern2 = re.compile("\n\n(- )")
22
+
23
+
24
+ def post_process_answer(x):
25
+ # """Fix Markdown rendering problems."""
26
+ # x = x.replace("\u2022", "- ")
27
+ # x = re.sub(newline_pattern1, "\n\g<1>", x)
28
+ # x = re.sub(newline_pattern2, "\n\g<1>", x)
29
+ return x
30
+
31
+
32
+ def pairwise_to_gradio_chat_mds(data):
33
+ end = data["turn"] * 3
34
+ ans_a = data["conversation_a"]
35
+ ans_b = data["conversation_b"]
36
+
37
+ mds = [""] * end
38
+ base = 0
39
+ for i in range(0, end, 3):
40
+ mds[i] = "## User Prompt\n" + data["conversation_a"][base]["content"].strip()
41
+ mds[i + 1] = f"## {data['model_a']}\n" + post_process_answer(
42
+ ans_a[base + 1]["content"].strip()
43
+ )
44
+ mds[i + 2] = f"## {data['model_b']}\n" + post_process_answer(
45
+ ans_b[base + 1]["content"].strip()
46
+ )
47
+ base += 2
48
+
49
+ winner = data["winner"] if "tie" in data["winner"] else data[data["winner"]]
50
+ mds += [f"## Winner: {winner}"]
51
+
52
+ mds += [""] * (16 - len(mds))
53
+
54
+ return mds
55
+
56
+ # Filtering functions
57
+ def filter_by_language(language):
58
+ return [item for item in data if item['language'] == language]
59
+
60
+ def filter_by_outcome(outcome, filtered_data):
61
+ return [item for item in filtered_data if item['outcome'] == outcome]
62
+
63
+ def filter_by_model(model, filtered_data):
64
+ if model == "anyone":
65
+ return [item for item in filtered_data]
66
+ return [item for item in filtered_data if item['opponent'] == model]
67
+
68
+ def filter_by_conversation_a_prefix(prefix, filtered_data):
69
+ return [item for item in filtered_data if item['conversation_a'][0]["content"][:128] == prefix]
70
+
71
+ # Create Gradio interface
72
+ def update_outcome_options(language):
73
+ filtered_data = filter_by_language(language)
74
+ outcomes = [item['outcome'] for item in filtered_data]
75
+ outcomes = list(dict.fromkeys(["gemini-1.5-pro-api-0514 Won"] + outcomes)) if "gemini-1.5-pro-api-0514 Won" in outcomes else list(set(outcomes))
76
+ filtered_data = filter_by_outcome(outcomes[0], filtered_data)
77
+ models = ["anyone"] + list(sorted(set(item['opponent'] for item in filtered_data)))
78
+ filtered_data = filter_by_model(models[0], filtered_data)
79
+ prefixes = [item['conversation_a'][0]["content"][:128] for item in filtered_data]
80
+ return gr.update(choices=outcomes, value=outcomes[0]), gr.update(choices=models, value=models[0]), gr.update(choices=prefixes, value=prefixes[0])
81
+
82
+
83
+ def update_model_opponent(language, outcome):
84
+ filtered_data = filter_by_language(language)
85
+ filtered_data = filter_by_outcome(outcome, filtered_data)
86
+ models = ["anyone"] + sorted(set(item['opponent'] for item in filtered_data))
87
+ filtered_data = filter_by_model(models[0], filtered_data)
88
+ prefixes = [item['conversation_a'][0]["content"][:128] for item in filtered_data]
89
+ return gr.update(choices=models, value=models[0]), gr.update(choices=prefixes, value=prefixes[0])
90
+
91
+
92
+ def update_question_options(language, outcome, model):
93
+ filtered_data = filter_by_language(language)
94
+ filtered_data = filter_by_outcome(outcome, filtered_data)
95
+ filtered_data = filter_by_model(model, filtered_data)
96
+ prefixes = [item['conversation_a'][0]["content"][:128] for item in filtered_data]
97
+ return gr.update(choices=prefixes, value=prefixes[0])
98
+
99
+
100
+ def display_filtered_data(language, outcome, model, prefix):
101
+ filtered_data = filter_by_language(language)
102
+ filtered_data = filter_by_outcome(outcome, filtered_data)
103
+ filtered_data = filter_by_model(model, filtered_data)
104
+ filtered_data = filter_by_conversation_a_prefix(prefix, filtered_data)
105
+ if len(filtered_data) == 0:
106
+ return [""] * 16
107
+ return pairwise_to_gradio_chat_mds(filtered_data[0])
108
+
109
+
110
+ def next_question(language, outcome, model, prefix):
111
+ filtered_data = filter_by_language(language)
112
+ filtered_data = filter_by_outcome(outcome, filtered_data)
113
+ filtered_data = filter_by_model(model, filtered_data)
114
+
115
+ all_items = [item['conversation_a'][0]["content"][:128] for item in filtered_data]
116
+ if prefix:
117
+ i = all_items.index(prefix) + 1
118
+ else:
119
+ i = 0
120
+
121
+ if i >= len(all_items):
122
+ return gr.update(choices=all_items, value=all_items[-1])
123
+
124
+ return gr.update(choices=all_items, value=all_items[i])
125
+
126
+
127
+ if __name__ == "__main__":
128
+ parser = argparse.ArgumentParser()
129
+ parser.add_argument("--host", type=str, default="0.0.0.0")
130
+ parser.add_argument("--port", type=int)
131
+ parser.add_argument("--share", action="store_true")
132
+ args = parser.parse_args()
133
+ print(args)
134
+
135
+ data = load_jsonl('data/gemini_battles.jsonl')
136
+
137
+ default_lang = "English"
138
+ default_opponent = "claude-3-5-sonnet-20240620"
139
+ default_outcome = "gemini-1.5-pro-api-0514 Won"
140
+ filter_data = filter_by_language(language=default_lang)
141
+ filter_data = filter_by_model(model=default_opponent, filtered_data=filter_data)
142
+ filter_data = filter_by_outcome(outcome=default_outcome, filtered_data=filter_data)
143
+ question_prefixes = [item['conversation_a'][0]["content"][:128] for item in filter_data]
144
+ #print(question_prefixes)
145
+ default_question = question_prefixes[2]
146
+
147
+ # Extract unique values for dropdowns
148
+ with gr.Blocks() as demo:
149
+ gr.Markdown(value="# Welcome to gemini-1.5-pro-api-0514 battles")
150
+ with gr.Row():
151
+ with gr.Column():
152
+ filter_data = filter_by_language(language=default_lang)
153
+ languages = ["English"] + list(sorted(set([item['language'] for item in data if item['language'] != "English"])))
154
+ language_dropdown = gr.Dropdown(label="Select Language", choices=languages, value=default_lang)
155
+ with gr.Column():
156
+ filter_data = filter_by_language(language=default_lang)
157
+ models = ["anyone"] + sorted(set(item['opponent'] for item in filter_data))
158
+ model_dropdown = gr.Dropdown(label="Opponent", choices=models, value=default_opponent)
159
+ with gr.Column():
160
+ filter_data = filter_by_language(language=default_lang)
161
+ filter_data = filter_by_model(model=default_opponent, filtered_data=filter_data)
162
+ outcomes = sorted(set(item['outcome'] for item in filter_data))
163
+ outcome_dropdown = gr.Dropdown(label="Outcome", choices=outcomes, value=default_outcome)
164
+
165
+ with gr.Row():
166
+ with gr.Column(scale=5):
167
+ filter_data = filter_by_outcome(outcome=default_outcome, filtered_data=filter_data)
168
+ question_prefixes = [item['conversation_a'][0]["content"][:128] for item in filter_data]
169
+ question_dropdown = gr.Dropdown(label="Select Question", choices=question_prefixes, value=default_question)
170
+ with gr.Column():
171
+ next_button = gr.Button("Next Question")
172
+
173
+ default_chat_mds = display_filtered_data(default_lang, default_outcome, default_opponent, default_question)
174
+ # Conversation
175
+ chat_mds = []
176
+ for i in range(5):
177
+ chat_mds.append(gr.Markdown(elem_id=f"user_question_{i+1}", value=default_chat_mds[len(chat_mds)]))
178
+ with gr.Row():
179
+ for j in range(2):
180
+ with gr.Column(scale=100):
181
+ chat_mds.append(gr.Markdown(value=default_chat_mds[len(chat_mds)]))
182
+
183
+ if j == 0:
184
+ with gr.Column(scale=1, min_width=8):
185
+ gr.Markdown()
186
+ chat_mds.append(gr.Markdown())
187
+
188
+ language_dropdown.change(fn=update_outcome_options, inputs=language_dropdown, outputs=[outcome_dropdown, model_dropdown, question_dropdown])
189
+ outcome_dropdown.change(fn=update_model_opponent, inputs=[language_dropdown, outcome_dropdown], outputs=[model_dropdown, question_dropdown])
190
+ model_dropdown.change(fn=update_question_options, inputs=[language_dropdown, outcome_dropdown, model_dropdown], outputs=question_dropdown)
191
+ next_button.click(fn=next_question, inputs=[language_dropdown, outcome_dropdown, model_dropdown, question_dropdown], outputs=question_dropdown)
192
+ question_dropdown.change(fn=display_filtered_data, inputs=[language_dropdown, outcome_dropdown, model_dropdown, question_dropdown], outputs=chat_mds)
193
+
194
+ question_dropdown = next_question(default_lang, default_outcome, default_opponent, default_question)
195
+ chat_mds = display_filtered_data(default_lang, default_outcome, default_opponent, default_question)
196
+
197
+ demo.launch(share=args.share)