Commit
·
88fd33d
1
Parent(s):
e82dda8
up
Browse files- __pycache__/app.cpython-310.pyc +0 -0
- app.py +205 -15
- requirements.txt +1 -1
- verify.py +18 -0
__pycache__/app.cpython-310.pyc
ADDED
|
Binary file (6.4 kB). View file
|
|
|
app.py
CHANGED
|
@@ -1,39 +1,77 @@
|
|
| 1 |
from datasets import load_dataset
|
| 2 |
from collections import Counter
|
| 3 |
-
from random import
|
|
|
|
|
|
|
|
|
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
parti_prompt_results = []
|
| 7 |
ORG = "diffusers-parti-prompts"
|
| 8 |
SUBMISSIONS = {
|
| 9 |
-
"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
# "Kadinsky":
|
| 14 |
}
|
| 15 |
-
NUM_QUESTIONS =
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
submission_names = list(SUBMISSIONS.keys())
|
| 18 |
num_images = len(SUBMISSIONS[submission_names[0]])
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def start():
|
| 21 |
ids = {id: 0 for id in range(num_images)}
|
| 22 |
|
| 23 |
-
# submissions = load_dataset(os.path.join(ORG, "submissions"))
|
| 24 |
# submitted_ids = Counter(submissions["ids"])
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
| 27 |
ids = {**ids, **submitted_ids}
|
| 28 |
|
| 29 |
# sort by count
|
| 30 |
-
ids = sorted(ids)
|
|
|
|
| 31 |
|
| 32 |
# get lowest count ids
|
| 33 |
-
id_candidates = ids[:(
|
| 34 |
|
| 35 |
# get random `NUM_QUESTIONS` ids to check
|
| 36 |
-
image_ids =
|
| 37 |
images = {}
|
| 38 |
|
| 39 |
for i in range(NUM_QUESTIONS):
|
|
@@ -41,12 +79,164 @@ def start():
|
|
| 41 |
shuffle(order)
|
| 42 |
|
| 43 |
id = image_ids[i]
|
|
|
|
| 44 |
images[i] = {
|
| 45 |
-
"prompt":
|
|
|
|
| 46 |
"id": id,
|
| 47 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
-
return images
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from datasets import load_dataset
|
| 2 |
from collections import Counter
|
| 3 |
+
from random import sample, shuffle
|
| 4 |
+
import datasets
|
| 5 |
+
from pandas import DataFrame
|
| 6 |
+
from huggingface_hub import list_datasets
|
| 7 |
import os
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
import secrets
|
| 11 |
+
|
| 12 |
|
| 13 |
parti_prompt_results = []
|
| 14 |
ORG = "diffusers-parti-prompts"
|
| 15 |
SUBMISSIONS = {
|
| 16 |
+
"sd-v1-5": load_dataset(os.path.join(ORG, "sd-v1-5"))["train"],
|
| 17 |
+
"sd-v2-1": load_dataset(os.path.join(ORG, "sd-v2.1"))["train"],
|
| 18 |
+
"if-v1-0": load_dataset(os.path.join(ORG, "karlo-v1"))["train"],
|
| 19 |
+
"karlo": load_dataset(os.path.join(ORG, "if-v-1.0"))["train"],
|
| 20 |
# "Kadinsky":
|
| 21 |
}
|
| 22 |
+
NUM_QUESTIONS = 10
|
| 23 |
+
MODEL_KEYS = "-".join(SUBMISSIONS.keys())
|
| 24 |
+
SUBMISSION_ORG = f"results-{MODEL_KEYS}"
|
| 25 |
+
|
| 26 |
|
| 27 |
submission_names = list(SUBMISSIONS.keys())
|
| 28 |
num_images = len(SUBMISSIONS[submission_names[0]])
|
| 29 |
|
| 30 |
+
|
| 31 |
+
def generate_random_hash(length=8):
|
| 32 |
+
"""
|
| 33 |
+
Generates a random hash of specified length.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
length (int): The length of the hash to generate.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
str: A random hash of specified length.
|
| 40 |
+
"""
|
| 41 |
+
if length % 2 != 0:
|
| 42 |
+
raise ValueError("Length should be an even number.")
|
| 43 |
+
|
| 44 |
+
num_bytes = length // 2
|
| 45 |
+
random_bytes = secrets.token_bytes(num_bytes)
|
| 46 |
+
random_hash = secrets.token_hex(num_bytes)
|
| 47 |
+
|
| 48 |
+
return random_hash
|
| 49 |
+
|
| 50 |
+
|
| 51 |
def start():
|
| 52 |
ids = {id: 0 for id in range(num_images)}
|
| 53 |
|
|
|
|
| 54 |
# submitted_ids = Counter(submissions["ids"])
|
| 55 |
+
all_datasets = list_datasets(author=SUBMISSION_ORG)
|
| 56 |
+
relevant_ids = [d.id for d in all_datasets]
|
| 57 |
+
|
| 58 |
+
submitted_ids = []
|
| 59 |
+
for _id in relevant_ids:
|
| 60 |
+
ds = load_dataset(_id)["train"]
|
| 61 |
+
submitted_ids += ds["id"]
|
| 62 |
|
| 63 |
+
submitted_ids = Counter(submitted_ids)
|
| 64 |
ids = {**ids, **submitted_ids}
|
| 65 |
|
| 66 |
# sort by count
|
| 67 |
+
ids = sorted(ids.items(), key=lambda x: x[1])
|
| 68 |
+
ids = [i[0] for i in ids]
|
| 69 |
|
| 70 |
# get lowest count ids
|
| 71 |
+
id_candidates = ids[: (10 * NUM_QUESTIONS)]
|
| 72 |
|
| 73 |
# get random `NUM_QUESTIONS` ids to check
|
| 74 |
+
image_ids = sample(id_candidates, k=NUM_QUESTIONS)
|
| 75 |
images = {}
|
| 76 |
|
| 77 |
for i in range(NUM_QUESTIONS):
|
|
|
|
| 79 |
shuffle(order)
|
| 80 |
|
| 81 |
id = image_ids[i]
|
| 82 |
+
row = SUBMISSIONS[submission_names[0]][id]
|
| 83 |
images[i] = {
|
| 84 |
+
"prompt": row["Prompt"],
|
| 85 |
+
"result": "",
|
| 86 |
"id": id,
|
| 87 |
+
"Challenge": row["Challenge"],
|
| 88 |
+
"Category": row["Category"],
|
| 89 |
+
"Note": row["Note"],
|
| 90 |
+
}
|
| 91 |
+
for n, m in enumerate(order):
|
| 92 |
+
images[i][f"choice_{n}"] = m
|
| 93 |
+
|
| 94 |
+
images_frame = DataFrame.from_dict(images, orient="index")
|
| 95 |
+
return images_frame
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def process(dataframe, row_number=0):
|
| 99 |
+
if row_number == NUM_QUESTIONS:
|
| 100 |
+
return None, ""
|
| 101 |
+
|
| 102 |
+
image_id = dataframe.iloc[row_number]["id"]
|
| 103 |
+
choices = [
|
| 104 |
+
submission_names[dataframe.iloc[row_number][f"choice_{i}"]]
|
| 105 |
+
for i in range(len(SUBMISSIONS))
|
| 106 |
+
]
|
| 107 |
+
images = [SUBMISSIONS[c][int(image_id)]["images"] for c in choices]
|
| 108 |
+
|
| 109 |
+
prompt = SUBMISSIONS[choices[0]][int(image_id)]["Prompt"]
|
| 110 |
+
prompt = f"Prompt {row_number + 1}/{NUM_QUESTIONS}: '{prompt}'"
|
| 111 |
+
|
| 112 |
+
return images, prompt
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def write_result(user_choice, row_number, dataframe, prompt):
|
| 116 |
+
if row_number == NUM_QUESTIONS:
|
| 117 |
+
return row_number, dataframe
|
| 118 |
+
|
| 119 |
+
user_choice = int(user_choice)
|
| 120 |
+
chosen_model = submission_names[dataframe.iloc[row_number][f"choice_{user_choice}"]]
|
| 121 |
+
|
| 122 |
+
dataframe.loc[row_number, "result"] = chosen_model
|
| 123 |
+
return row_number + 1, dataframe
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_index(evt: gr.SelectData) -> int:
|
| 127 |
+
return evt.index
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def change_view(row_number, dataframe):
|
| 131 |
+
if row_number == NUM_QUESTIONS:
|
| 132 |
+
|
| 133 |
+
favorite_model = dataframe["result"].value_counts().idxmax()
|
| 134 |
+
dataset = datasets.Dataset.from_pandas(dataframe)
|
| 135 |
+
dataset = dataset.remove_columns(set(dataset.column_names) - set(["id", "result"]))
|
| 136 |
+
hash = generate_random_hash()
|
| 137 |
+
repo_id = os.path.join(SUBMISSION_ORG, hash)
|
| 138 |
+
|
| 139 |
+
dataset.push_to_hub(repo_id, token=os.getenv("HF_TOKEN"))
|
| 140 |
+
return {
|
| 141 |
+
result: f"You are of type: {favorite_model}!",
|
| 142 |
+
result_view: gr.update(visible=True),
|
| 143 |
+
gallery_view: gr.update(visible=False),
|
| 144 |
+
}
|
| 145 |
+
else:
|
| 146 |
+
return {
|
| 147 |
+
result: "",
|
| 148 |
+
result_view: gr.update(visible=False),
|
| 149 |
+
gallery_view: gr.update(visible=True),
|
| 150 |
}
|
| 151 |
|
|
|
|
| 152 |
|
| 153 |
+
if True:
|
| 154 |
+
TITLE = "Open-Source Parti Prompts"
|
| 155 |
+
DESCRIPTION = "An interactive 'Which Generative AI' game to evaluate open-source generative AI models"
|
| 156 |
+
GALLERY_COLUMN_NUM = len(SUBMISSIONS)
|
| 157 |
+
|
| 158 |
+
with gr.Blocks(css="style.css") as demo:
|
| 159 |
+
gr.Markdown(TITLE)
|
| 160 |
+
gr.Markdown(DESCRIPTION)
|
| 161 |
+
start_button = gr.Button("Start").style(full_width=False)
|
| 162 |
+
|
| 163 |
+
headers = ["prompt", "result", "id", "Challenge", "Category", "Note"] + [
|
| 164 |
+
f"choice_{i}" for i in range(len(SUBMISSIONS))
|
| 165 |
+
]
|
| 166 |
+
datatype = ["str", "str", "number", "str", "str", "str"] + len(SUBMISSIONS) * [
|
| 167 |
+
"number"
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
with gr.Column(visible=False):
|
| 171 |
+
row_number = gr.Number(
|
| 172 |
+
label="Current row selection index",
|
| 173 |
+
value=0,
|
| 174 |
+
precision=0,
|
| 175 |
+
interactive=False,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Create Data Frame
|
| 179 |
+
with gr.Column(visible=False) as result_view:
|
| 180 |
+
result = gr.Markdown("")
|
| 181 |
+
dataframe = gr.Dataframe(
|
| 182 |
+
headers=headers,
|
| 183 |
+
datatype=datatype,
|
| 184 |
+
row_count=NUM_QUESTIONS,
|
| 185 |
+
col_count=(6 + len(SUBMISSIONS), "fixed"),
|
| 186 |
+
interactive=False,
|
| 187 |
+
)
|
| 188 |
+
gr.Markdown("Click on start to play again!")
|
| 189 |
+
|
| 190 |
+
with gr.Column(visible=True) as gallery_view:
|
| 191 |
+
gr.Markdown("Pick your the photo that best corresponds to the prompt.")
|
| 192 |
+
prompt = gr.Markdown(f"Prompt 1/{NUM_QUESTIONS}: ")
|
| 193 |
+
gallery = gr.Gallery(
|
| 194 |
+
label="All images", show_label=False, elem_id="gallery"
|
| 195 |
+
).style(columns=GALLERY_COLUMN_NUM, object_fit="contain")
|
| 196 |
+
|
| 197 |
+
next_button = gr.Button("Select").style(full_width=False)
|
| 198 |
+
|
| 199 |
+
with gr.Column(visible=False):
|
| 200 |
+
selected_image = gr.Number(label="Selected index", value=-1, precision=0)
|
| 201 |
+
|
| 202 |
+
start_button.click(
|
| 203 |
+
fn=start,
|
| 204 |
+
inputs=[],
|
| 205 |
+
outputs=dataframe
|
| 206 |
+
).then(
|
| 207 |
+
fn=lambda x: 0 if x == NUM_QUESTIONS else x,
|
| 208 |
+
inputs=[row_number],
|
| 209 |
+
outputs=[row_number],
|
| 210 |
+
).then(
|
| 211 |
+
fn=change_view,
|
| 212 |
+
inputs=[row_number, dataframe],
|
| 213 |
+
outputs=[result_view, gallery_view, result]
|
| 214 |
+
).then(
|
| 215 |
+
fn=process, inputs=[dataframe], outputs=[gallery, prompt]
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
gallery.select(
|
| 219 |
+
fn=get_index,
|
| 220 |
+
outputs=selected_image,
|
| 221 |
+
queue=False,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
next_button.click(
|
| 225 |
+
fn=write_result,
|
| 226 |
+
inputs=[selected_image, row_number, dataframe, prompt],
|
| 227 |
+
outputs=[row_number, dataframe],
|
| 228 |
+
).then(
|
| 229 |
+
fn=process,
|
| 230 |
+
inputs=[dataframe, row_number],
|
| 231 |
+
outputs=[gallery, prompt]
|
| 232 |
+
).then(
|
| 233 |
+
fn=change_view,
|
| 234 |
+
inputs=[row_number, dataframe],
|
| 235 |
+
outputs=[result_view, gallery_view, result]
|
| 236 |
+
).then(
|
| 237 |
+
fn=lambda x: 0 if x == NUM_QUESTIONS else x,
|
| 238 |
+
inputs=[row_number],
|
| 239 |
+
outputs=[row_number],
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
demo.launch()
|
requirements.txt
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
-
requests
|
| 2 |
datasets
|
|
|
|
|
|
|
|
|
| 1 |
datasets
|
| 2 |
+
pandas
|
verify.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from collections import Counter
|
| 4 |
+
from random import choices, shuffle
|
| 5 |
+
from pandas import DataFrame
|
| 6 |
+
import os
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
parti_prompt_results = []
|
| 10 |
+
ORG = "diffusers-parti-prompts"
|
| 11 |
+
SUBMISSIONS = {
|
| 12 |
+
"sd_v1_5": load_dataset(os.path.join(ORG, "sd-v1-5"))["train"],
|
| 13 |
+
"sd_v2_1": load_dataset(os.path.join(ORG, "sd-v2.1"))["train"],
|
| 14 |
+
"if_v1_0": load_dataset(os.path.join(ORG, "karlo-v1"))["train"],
|
| 15 |
+
"karlo": load_dataset(os.path.join(ORG, "if-v-1.0"))["train"],
|
| 16 |
+
# "Kadinsky":
|
| 17 |
+
}
|
| 18 |
+
import ipdb; ipdb.set_trace()
|