File size: 4,276 Bytes
9a5a85e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from enum import Enum
import os
import re
from io import BytesIO
import uuid
import gradio as gr
from pathlib import Path
from huggingface_hub import Repository
import json
from db import Database

HF_TOKEN = os.environ.get("HF_TOKEN")
S3_DATA_FOLDER = Path("sd-multiplayer-data")
DB_FOLDER = Path("diffusers-gallery-data")
ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"

repo = Repository(
    local_dir=DB_FOLDER,
    repo_type="dataset",
    clone_from="huggingface-projects/diffusers-gallery-data",
    use_auth_token=True,
)
repo.git_pull()

database = Database(DB_FOLDER)


blocks = gr.Blocks()

styles_cls = ["anime", "3D", "realistic", "other"]
nsfw_cls = ["safe", "suggestive", "explicit"]


js_get_url_params = """
    function (current_model, styles, nsfw) {
        const params = new URLSearchParams(window.location.search);
        current_model.model_id = params.get("model_id") || "";
        window.history.replaceState({}, document.title, "/");
        return [current_model, styles, nsfw]
        }
"""


def next_model(query_params, styles=None, nsfw=None):
    model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None
    print(model_id, styles, nsfw)

    with database.get_db() as db:
        if model_id:
            cursor = db.execute(
                """SELECT *,
                    SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
                FROM models
                WHERE id = ?""", (model_id,))
            row = cursor.fetchone()
            if row is None:
                raise gr.Error("Cannot find model to annotate")
        else:
            cursor = db.execute(
                """SELECT *,
                    SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
                FROM models
                WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL
                ORDER BY RANDOM()
                LIMIT 1""")
            row = cursor.fetchone()
            if row is None:
                raise gr.Error("Cannot find any more models to annotate")

        total_unflagged = row["total_unflagged"]
        model_id = row["id"]
        data = json.loads(row["data"])
        images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")]
        flags_data = json.loads(row["flags"] or "{}")
        styles = flags_data.get("styles", [])
        nsfw = flags_data.get("nsfw", None)

        title = f'''#### [Model {model_id}](https://huggingface.co/{model_id})
        **Unflaggedd** {total_unflagged}'''

        return images, title, styles, nsfw, {"model_id": model_id}


def flag_model(current_model, styles=None, nsfw=None):
    model_id = current_model["model_id"]
    print("Flagging model", model_id, styles, nsfw)
    with database.get_db() as db:
        db.execute(
            """UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id))
    return next_model({}, styles, nsfw)


with blocks:
    gr.Markdown('''### Diffusers Gallery annotation tool
    Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information.
    ''')
    model_title = gr.Markdown()
    gallery = gr.Gallery(
        label="Images", show_label=False, elem_id="gallery"
    ).style(grid=[3])
    styles = gr.CheckboxGroup(
        styles_cls, info="Classify the image as one or more of the following classes")
    nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?")
    # invisible inputs to store the query params
    query_params = gr.JSON(value={}, visible=False)
    current_model = gr.State({})
    next_btn = gr.Button("Next")
    submit_btn = gr.Button("Submit")
    next_btn.click(next_model, inputs=[query_params, styles, nsfw],
                   outputs=[gallery, model_title,  styles, nsfw, current_model])

    submit_btn.click(flag_model, inputs=[current_model, styles, nsfw], outputs=[
                     gallery, model_title, styles, nsfw, current_model])

    blocks.load(next_model, inputs=[query_params, styles, nsfw],
                outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params)

blocks.launch(enable_queue=False)