radames's picture
first
9a5a85e
raw
history blame
4.28 kB
from enum import Enum
import os
import re
from io import BytesIO
import uuid
import gradio as gr
from pathlib import Path
from huggingface_hub import Repository
import json
from db import Database
HF_TOKEN = os.environ.get("HF_TOKEN")
S3_DATA_FOLDER = Path("sd-multiplayer-data")
DB_FOLDER = Path("diffusers-gallery-data")
ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
repo = Repository(
local_dir=DB_FOLDER,
repo_type="dataset",
clone_from="huggingface-projects/diffusers-gallery-data",
use_auth_token=True,
)
repo.git_pull()
database = Database(DB_FOLDER)
blocks = gr.Blocks()
styles_cls = ["anime", "3D", "realistic", "other"]
nsfw_cls = ["safe", "suggestive", "explicit"]
js_get_url_params = """
function (current_model, styles, nsfw) {
const params = new URLSearchParams(window.location.search);
current_model.model_id = params.get("model_id") || "";
window.history.replaceState({}, document.title, "/");
return [current_model, styles, nsfw]
}
"""
def next_model(query_params, styles=None, nsfw=None):
model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None
print(model_id, styles, nsfw)
with database.get_db() as db:
if model_id:
cursor = db.execute(
"""SELECT *,
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
FROM models
WHERE id = ?""", (model_id,))
row = cursor.fetchone()
if row is None:
raise gr.Error("Cannot find model to annotate")
else:
cursor = db.execute(
"""SELECT *,
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
FROM models
WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL
ORDER BY RANDOM()
LIMIT 1""")
row = cursor.fetchone()
if row is None:
raise gr.Error("Cannot find any more models to annotate")
total_unflagged = row["total_unflagged"]
model_id = row["id"]
data = json.loads(row["data"])
images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")]
flags_data = json.loads(row["flags"] or "{}")
styles = flags_data.get("styles", [])
nsfw = flags_data.get("nsfw", None)
title = f'''#### [Model {model_id}](https://huggingface.co/{model_id})
**Unflaggedd** {total_unflagged}'''
return images, title, styles, nsfw, {"model_id": model_id}
def flag_model(current_model, styles=None, nsfw=None):
model_id = current_model["model_id"]
print("Flagging model", model_id, styles, nsfw)
with database.get_db() as db:
db.execute(
"""UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id))
return next_model({}, styles, nsfw)
with blocks:
gr.Markdown('''### Diffusers Gallery annotation tool
Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information.
''')
model_title = gr.Markdown()
gallery = gr.Gallery(
label="Images", show_label=False, elem_id="gallery"
).style(grid=[3])
styles = gr.CheckboxGroup(
styles_cls, info="Classify the image as one or more of the following classes")
nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?")
# invisible inputs to store the query params
query_params = gr.JSON(value={}, visible=False)
current_model = gr.State({})
next_btn = gr.Button("Next")
submit_btn = gr.Button("Submit")
next_btn.click(next_model, inputs=[query_params, styles, nsfw],
outputs=[gallery, model_title, styles, nsfw, current_model])
submit_btn.click(flag_model, inputs=[current_model, styles, nsfw], outputs=[
gallery, model_title, styles, nsfw, current_model])
blocks.load(next_model, inputs=[query_params, styles, nsfw],
outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params)
blocks.launch(enable_queue=False)