Spaces:
Runtime error
Runtime error
File size: 4,276 Bytes
9a5a85e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from enum import Enum
import os
import re
from io import BytesIO
import uuid
import gradio as gr
from pathlib import Path
from huggingface_hub import Repository
import json
from db import Database
HF_TOKEN = os.environ.get("HF_TOKEN")
S3_DATA_FOLDER = Path("sd-multiplayer-data")
DB_FOLDER = Path("diffusers-gallery-data")
ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
repo = Repository(
local_dir=DB_FOLDER,
repo_type="dataset",
clone_from="huggingface-projects/diffusers-gallery-data",
use_auth_token=True,
)
repo.git_pull()
database = Database(DB_FOLDER)
blocks = gr.Blocks()
styles_cls = ["anime", "3D", "realistic", "other"]
nsfw_cls = ["safe", "suggestive", "explicit"]
js_get_url_params = """
function (current_model, styles, nsfw) {
const params = new URLSearchParams(window.location.search);
current_model.model_id = params.get("model_id") || "";
window.history.replaceState({}, document.title, "/");
return [current_model, styles, nsfw]
}
"""
def next_model(query_params, styles=None, nsfw=None):
model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None
print(model_id, styles, nsfw)
with database.get_db() as db:
if model_id:
cursor = db.execute(
"""SELECT *,
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
FROM models
WHERE id = ?""", (model_id,))
row = cursor.fetchone()
if row is None:
raise gr.Error("Cannot find model to annotate")
else:
cursor = db.execute(
"""SELECT *,
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
FROM models
WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL
ORDER BY RANDOM()
LIMIT 1""")
row = cursor.fetchone()
if row is None:
raise gr.Error("Cannot find any more models to annotate")
total_unflagged = row["total_unflagged"]
model_id = row["id"]
data = json.loads(row["data"])
images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")]
flags_data = json.loads(row["flags"] or "{}")
styles = flags_data.get("styles", [])
nsfw = flags_data.get("nsfw", None)
title = f'''#### [Model {model_id}](https://huggingface.co/{model_id})
**Unflaggedd** {total_unflagged}'''
return images, title, styles, nsfw, {"model_id": model_id}
def flag_model(current_model, styles=None, nsfw=None):
model_id = current_model["model_id"]
print("Flagging model", model_id, styles, nsfw)
with database.get_db() as db:
db.execute(
"""UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id))
return next_model({}, styles, nsfw)
with blocks:
gr.Markdown('''### Diffusers Gallery annotation tool
Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information.
''')
model_title = gr.Markdown()
gallery = gr.Gallery(
label="Images", show_label=False, elem_id="gallery"
).style(grid=[3])
styles = gr.CheckboxGroup(
styles_cls, info="Classify the image as one or more of the following classes")
nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?")
# invisible inputs to store the query params
query_params = gr.JSON(value={}, visible=False)
current_model = gr.State({})
next_btn = gr.Button("Next")
submit_btn = gr.Button("Submit")
next_btn.click(next_model, inputs=[query_params, styles, nsfw],
outputs=[gallery, model_title, styles, nsfw, current_model])
submit_btn.click(flag_model, inputs=[current_model, styles, nsfw], outputs=[
gallery, model_title, styles, nsfw, current_model])
blocks.load(next_model, inputs=[query_params, styles, nsfw],
outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params)
blocks.launch(enable_queue=False)
|