cyber-tagger / app.py
CyberWaifu's picture
Update app.py
b403fe7 verified
raw
history blame
6.62 kB
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# Load model and metadata at startup (same as before)
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
metadata = json.load(open(meta_path, "r", encoding="utf-8"))
# Preprocessing function (same as before)
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
img = pil_image.convert("RGB").resize((512, 512))
arr = np.array(img).astype(np.float32) / 255.0
arr = np.transpose(arr, (2, 0, 1))
arr = np.expand_dims(arr, 0)
return arr
# Inference function with output format option
def tag_image(pil_image: Image.Image, output_format: str) -> str:
# Run model inference
input_tensor = preprocess_image(pil_image)
input_name = session.get_inputs()[0].name
initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
probs = 1 / (1 + np.exp(-refined_logits))
probs = probs[0]
idx_to_tag = metadata["idx_to_tag"]
tag_to_category = metadata.get("tag_to_category", {})
category_thresholds = metadata.get("category_thresholds", {})
default_threshold = 0.35
results_by_cat = {} # to store tags per category (for verbose output)
artist_tags_with_probs = []
character_tags_with_probs = []
general_tags_with_probs = []
all_artist_tags_probs = [] # Store all artist tags and their probabilities
# Collect tags above thresholds
for idx, prob in enumerate(probs):
tag = idx_to_tag[str(idx)]
cat = tag_to_category.get(tag, "unknown")
if cat == 'artist':
all_artist_tags_probs.append((tag, float(prob))) # Store all artist tags
thresh = category_thresholds.get(cat, default_threshold)
if float(prob) >= thresh:
# add to category dictionary
results_by_cat.setdefault(cat, []).append((tag, float(prob)))
if cat == 'artist':
artist_tags_with_probs.append((tag, float(prob)))
elif cat == 'character':
character_tags_with_probs.append((tag, float(prob)))
elif cat == 'general':
general_tags_with_probs.append((tag, float(prob)))
if output_format == "Prompt-style Tags":
artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_with_probs]
prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
# Ensure at least one artist tag if any artist tags were predicted at all, even below threshold
if not artist_prompt_tags and all_artist_tags_probs:
best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
prompt_tags = [best_artist_tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")] + prompt_tags
if not prompt_tags:
return "No tags predicted."
return ", ".join(prompt_tags)
else: # Detailed output
if not results_by_cat:
return "No tags predicted for this image."
# Ensure artist tag in detailed output even if below threshold
if 'artist' not in results_by_cat and all_artist_tags_probs:
best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1])
results_by_cat['artist'] = [(best_artist_tag, best_artist_prob)]
lines = []
lines.append("**Predicted Tags by Category:** \n") # (Markdown newline: two spaces + newline)
for cat, tag_list in results_by_cat.items():
# sort tags in this category by probability descending
tag_list.sort(key=lambda x: x[1], reverse=True)
lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
for tag, prob in tag_list:
tag_pretty = tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") # Escape parentheses here with raw string
lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
lines.append("") # blank line between categories
return "\n".join(lines)
# Build the Gradio Blocks UI
demo = gr.Blocks(theme="gradio/soft") # using a built-in theme for nicer styling
with demo:
# Header Section
gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
# Input/Output Section
with gr.Row():
# Left column: Image input and format selection
with gr.Column():
image_in = gr.Image(type="pil", label="Input Image")
format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
tag_button = gr.Button("🔍 Tag Image")
# Right column: Output display
with gr.Column():
output_box = gr.Markdown("") # will display the result in Markdown (supports bold, lists, etc.)
# Link the button click to the function
tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
# Footer/Info
gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6} • *Demo built with Gradio Blocks.*")
# Launch the app (automatically handled in Spaces)
demo.launch()