File size: 7,920 Bytes
086766e
 
 
 
 
 
 
8da1280
086766e
7ec5b17
086766e
8da1280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ec5b17
 
 
086766e
 
8da1280
 
 
 
 
 
7ec5b17
 
 
8da1280
 
 
086766e
 
 
c24087d
8da1280
 
7ec5b17
 
1676c6e
8da1280
d13085d
8da1280
 
 
 
 
1676c6e
8da1280
 
 
1676c6e
8da1280
 
 
 
 
 
 
 
 
598cad3
c24087d
8da1280
 
 
c24087d
8da1280
 
 
 
eb1e40e
8da1280
 
 
eb1e40e
 
8da1280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb1e40e
8da1280
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download

# --- Constants ---
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
IMAGE_SIZE = (512, 512)
DEFAULT_THRESHOLD = 0.35


# --- Helper Functions ---
def download_model_and_metadata(repo_id: str, model_filename: str, meta_filename: str, cache_dir: str = "."):
    """Downloads the ONNX model and metadata from Hugging Face Hub."""
    model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir=cache_dir)
    meta_path = hf_hub_download(repo_id=repo_id, filename=meta_filename, cache_dir=cache_dir)
    return model_path, meta_path

def load_model_session(model_path: str) -> ort.InferenceSession:
    """Loads the ONNX model inference session."""
    return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])

def load_metadata(meta_path: str) -> dict:
    """Loads the metadata from the JSON file."""
    with open(meta_path, "r", encoding="utf-8") as f:
        return json.load(f)

def preprocess_image(pil_image: Image.Image, image_size: tuple = IMAGE_SIZE) -> np.ndarray:
    """Preprocesses the PIL image to numpy array for model input."""
    img = pil_image.convert("RGB").resize(image_size)
    arr = np.array(img).astype(np.float32) / 255.0
    arr = np.transpose(arr, (2, 0, 1))
    arr = np.expand_dims(arr, 0)
    return arr

def apply_sigmoid(logits: np.ndarray) -> np.ndarray:
    """Applies sigmoid function to logits to get probabilities."""
    return 1 / (1 + np.exp(-logits))

def extract_tags_from_probabilities(probs: np.ndarray, metadata: dict, threshold: float = DEFAULT_THRESHOLD) -> dict:
    """Extracts tags and probabilities from the model output probabilities."""
    idx_to_tag = metadata["idx_to_tag"]
    tag_to_category = metadata.get("tag_to_category", {})
    category_thresholds = metadata.get("category_thresholds", {})
    results_by_cat = {}
    all_artist_tags_probs = []

    for idx, prob in enumerate(probs):
        tag = idx_to_tag[str(idx)]
        cat = tag_to_category.get(tag, "unknown")
        if cat == 'artist':
            all_artist_tags_probs.append((tag, float(prob)))
        thresh = category_thresholds.get(cat, threshold)
        if float(prob) >= thresh:
            results_by_cat.setdefault(cat, []).append((tag, float(prob)))

    return results_by_cat, all_artist_tags_probs

def format_prompt_style_output(results_by_cat: dict, all_artist_tags_probs: list) -> str:
    """Formats the output as a comma-separated prompt-style string."""
    artist_tags_with_probs = results_by_cat.get('artist', [])
    character_tags_with_probs = results_by_cat.get('character', [])
    general_tags_with_probs = results_by_cat.get('general', [])

    artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
    character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
    general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)

    artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
    character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
    general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_with_probs]

    prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags

    if not artist_prompt_tags and all_artist_tags_probs:
        best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1]) if all_artist_tags_probs else (None, None)
        if best_artist_tag: # Check if best_artist_tag is not None
            prompt_tags = [best_artist_tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")] + prompt_tags

    if not prompt_tags:
        return "No tags predicted."
    return ", ".join(prompt_tags)

def format_detailed_output(results_by_cat: dict, all_artist_tags_probs: list) -> str:
    """Formats the output as a detailed markdown string with categories and probabilities."""
    if not results_by_cat:
        return "No tags predicted for this image."

    if 'artist' not in results_by_cat and all_artist_tags_probs:
        best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1]) if all_artist_tags_probs else (None, None)
        if best_artist_tag: # Check if best_artist_tag is not None
            results_by_cat['artist'] = [(best_artist_tag, best_artist_prob)]

    lines = []
    lines.append("**Predicted Tags by Category:**  \n")
    for cat, tag_list in results_by_cat.items():
        tag_list.sort(key=lambda x: x[1], reverse=True)
        lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
        for tag, prob in tag_list:
            tag_pretty = tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")
            lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
        lines.append("")
    return "\n".join(lines)

# --- Inference Function ---
def tag_image(pil_image: Image.Image, output_format: str, session: ort.InferenceSession, metadata: dict) -> str:
    """Tags the image and formats the output based on the selected format."""
    if pil_image is None:
        return "Please upload an image."

    input_tensor = preprocess_image(pil_image)
    input_name = session.get_inputs()[0].name
    initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
    probs = apply_sigmoid(refined_logits)[0]  # Apply sigmoid and get probabilities for the first (and only) image in batch

    results_by_cat, all_artist_tags_probs = extract_tags_from_probabilities(probs, metadata)

    if output_format == "Prompt-style Tags":
        return format_prompt_style_output(results_by_cat, all_artist_tags_probs)
    else:  # Detailed Output
        return format_detailed_output(results_by_cat, all_artist_tags_probs)


# --- Gradio UI ---
def create_gradio_interface(session: ort.InferenceSession, metadata: dict) -> gr.Blocks:
    """Creates the Gradio Blocks interface."""
    demo = gr.Blocks(theme="gradio/soft")

    with demo:
        gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
        gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")

        with gr.Row():
            with gr.Column():
                image_in = gr.Image(type="pil", label="Input Image")
                format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
                tag_button = gr.Button("🔍 Tag Image")
            with gr.Column():
                output_box = gr.Markdown("")

        tag_button.click(
            fn=tag_image,
            inputs=[image_in, format_choice],
            outputs=output_box,
            extra_args=[session, metadata] # Pass session and metadata as extra arguments
        )

        gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   •   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   •   **ONNX Runtime:** for efficient CPU inference   •   *Demo built with Gradio Blocks.*")
    return demo

# --- Main Script ---
if __name__ == "__main__":
    model_path, meta_path = download_model_and_metadata(MODEL_REPO, MODEL_FILE, META_FILE)
    session = load_model_session(model_path)
    metadata = load_metadata(meta_path)
    demo = create_gradio_interface(session, metadata)
    demo.launch()