Spaces:
Runtime error
Runtime error
from enum import Enum | |
import os | |
import re | |
from io import BytesIO | |
import uuid | |
import gradio as gr | |
from pathlib import Path | |
from huggingface_hub import Repository | |
import json | |
from db import Database | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
S3_DATA_FOLDER = Path("sd-multiplayer-data") | |
DB_FOLDER = Path("diffusers-gallery-data") | |
ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/" | |
repo = Repository( | |
local_dir=DB_FOLDER, | |
repo_type="dataset", | |
clone_from="huggingface-projects/diffusers-gallery-data", | |
use_auth_token=True, | |
) | |
repo.git_pull() | |
database = Database(DB_FOLDER) | |
blocks = gr.Blocks() | |
styles_cls = ["anime", "3D", "realistic", "other"] | |
nsfw_cls = ["safe", "suggestive", "explicit"] | |
js_get_url_params = """ | |
function (current_model, styles, nsfw) { | |
const params = new URLSearchParams(window.location.search); | |
current_model.model_id = params.get("model_id") || ""; | |
window.history.replaceState({}, document.title, "/"); | |
return [current_model, styles, nsfw] | |
} | |
""" | |
def next_model(query_params, styles=None, nsfw=None): | |
model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None | |
print(model_id, styles, nsfw) | |
with database.get_db() as db: | |
if model_id: | |
cursor = db.execute( | |
"""SELECT *, | |
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged | |
FROM models | |
WHERE id = ?""", (model_id,)) | |
row = cursor.fetchone() | |
if row is None: | |
raise gr.Error("Cannot find model to annotate") | |
else: | |
cursor = db.execute( | |
"""SELECT *, | |
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged | |
FROM models | |
WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL | |
ORDER BY RANDOM() | |
LIMIT 1""") | |
row = cursor.fetchone() | |
if row is None: | |
raise gr.Error("Cannot find any more models to annotate") | |
total_unflagged = row["total_unflagged"] | |
model_id = row["id"] | |
data = json.loads(row["data"]) | |
images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")] | |
flags_data = json.loads(row["flags"] or "{}") | |
styles = flags_data.get("styles", []) | |
nsfw = flags_data.get("nsfw", None) | |
title = f'''#### [Model {model_id}](https://huggingface.co/{model_id}) | |
**Unflaggedd** {total_unflagged}''' | |
return images, title, styles, nsfw, {"model_id": model_id} | |
def flag_model(current_model, styles=None, nsfw=None): | |
model_id = current_model["model_id"] | |
print("Flagging model", model_id, styles, nsfw) | |
with database.get_db() as db: | |
db.execute( | |
"""UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id)) | |
return next_model({}, styles, nsfw) | |
with blocks: | |
gr.Markdown('''### Diffusers Gallery annotation tool | |
Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information. | |
''') | |
model_title = gr.Markdown() | |
gallery = gr.Gallery( | |
label="Images", show_label=False, elem_id="gallery" | |
).style(grid=[3]) | |
styles = gr.CheckboxGroup( | |
styles_cls, info="Classify the image as one or more of the following classes") | |
nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?") | |
# invisible inputs to store the query params | |
query_params = gr.JSON(value={}, visible=False) | |
current_model = gr.State({}) | |
next_btn = gr.Button("Next") | |
submit_btn = gr.Button("Submit") | |
next_btn.click(next_model, inputs=[query_params, styles, nsfw], | |
outputs=[gallery, model_title, styles, nsfw, current_model]) | |
submit_btn.click(flag_model, inputs=[current_model, styles, nsfw], outputs=[ | |
gallery, model_title, styles, nsfw, current_model]) | |
blocks.load(next_model, inputs=[query_params, styles, nsfw], | |
outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params) | |
blocks.launch(enable_queue=False) | |