Spaces:
Running
Running
File size: 7,920 Bytes
086766e 8da1280 086766e 7ec5b17 086766e 8da1280 7ec5b17 086766e 8da1280 7ec5b17 8da1280 086766e c24087d 8da1280 7ec5b17 1676c6e 8da1280 d13085d 8da1280 1676c6e 8da1280 1676c6e 8da1280 598cad3 c24087d 8da1280 c24087d 8da1280 eb1e40e 8da1280 eb1e40e 8da1280 eb1e40e 8da1280 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download
# --- Constants ---
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
IMAGE_SIZE = (512, 512)
DEFAULT_THRESHOLD = 0.35
# --- Helper Functions ---
def download_model_and_metadata(repo_id: str, model_filename: str, meta_filename: str, cache_dir: str = "."):
"""Downloads the ONNX model and metadata from Hugging Face Hub."""
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir=cache_dir)
meta_path = hf_hub_download(repo_id=repo_id, filename=meta_filename, cache_dir=cache_dir)
return model_path, meta_path
def load_model_session(model_path: str) -> ort.InferenceSession:
"""Loads the ONNX model inference session."""
return ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
def load_metadata(meta_path: str) -> dict:
"""Loads the metadata from the JSON file."""
with open(meta_path, "r", encoding="utf-8") as f:
return json.load(f)
def preprocess_image(pil_image: Image.Image, image_size: tuple = IMAGE_SIZE) -> np.ndarray:
"""Preprocesses the PIL image to numpy array for model input."""
img = pil_image.convert("RGB").resize(image_size)
arr = np.array(img).astype(np.float32) / 255.0
arr = np.transpose(arr, (2, 0, 1))
arr = np.expand_dims(arr, 0)
return arr
def apply_sigmoid(logits: np.ndarray) -> np.ndarray:
"""Applies sigmoid function to logits to get probabilities."""
return 1 / (1 + np.exp(-logits))
def extract_tags_from_probabilities(probs: np.ndarray, metadata: dict, threshold: float = DEFAULT_THRESHOLD) -> dict:
"""Extracts tags and probabilities from the model output probabilities."""
idx_to_tag = metadata["idx_to_tag"]
tag_to_category = metadata.get("tag_to_category", {})
category_thresholds = metadata.get("category_thresholds", {})
results_by_cat = {}
all_artist_tags_probs = []
for idx, prob in enumerate(probs):
tag = idx_to_tag[str(idx)]
cat = tag_to_category.get(tag, "unknown")
if cat == 'artist':
all_artist_tags_probs.append((tag, float(prob)))
thresh = category_thresholds.get(cat, threshold)
if float(prob) >= thresh:
results_by_cat.setdefault(cat, []).append((tag, float(prob)))
return results_by_cat, all_artist_tags_probs
def format_prompt_style_output(results_by_cat: dict, all_artist_tags_probs: list) -> str:
"""Formats the output as a comma-separated prompt-style string."""
artist_tags_with_probs = results_by_cat.get('artist', [])
character_tags_with_probs = results_by_cat.get('character', [])
general_tags_with_probs = results_by_cat.get('general', [])
artist_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
character_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
general_tags_with_probs.sort(key=lambda x: x[1], reverse=True)
artist_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in artist_tags_with_probs]
character_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in character_tags_with_probs]
general_prompt_tags = [tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)") for tag, prob in general_tags_with_probs]
prompt_tags = artist_prompt_tags + character_prompt_tags + general_prompt_tags
if not artist_prompt_tags and all_artist_tags_probs:
best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1]) if all_artist_tags_probs else (None, None)
if best_artist_tag: # Check if best_artist_tag is not None
prompt_tags = [best_artist_tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")] + prompt_tags
if not prompt_tags:
return "No tags predicted."
return ", ".join(prompt_tags)
def format_detailed_output(results_by_cat: dict, all_artist_tags_probs: list) -> str:
"""Formats the output as a detailed markdown string with categories and probabilities."""
if not results_by_cat:
return "No tags predicted for this image."
if 'artist' not in results_by_cat and all_artist_tags_probs:
best_artist_tag, best_artist_prob = max(all_artist_tags_probs, key=lambda item: item[1]) if all_artist_tags_probs else (None, None)
if best_artist_tag: # Check if best_artist_tag is not None
results_by_cat['artist'] = [(best_artist_tag, best_artist_prob)]
lines = []
lines.append("**Predicted Tags by Category:** \n")
for cat, tag_list in results_by_cat.items():
tag_list.sort(key=lambda x: x[1], reverse=True)
lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
for tag, prob in tag_list:
tag_pretty = tag.replace("_", " ").replace("(", r"\\(").replace(")", r"\\)")
lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
lines.append("")
return "\n".join(lines)
# --- Inference Function ---
def tag_image(pil_image: Image.Image, output_format: str, session: ort.InferenceSession, metadata: dict) -> str:
"""Tags the image and formats the output based on the selected format."""
if pil_image is None:
return "Please upload an image."
input_tensor = preprocess_image(pil_image)
input_name = session.get_inputs()[0].name
initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
probs = apply_sigmoid(refined_logits)[0] # Apply sigmoid and get probabilities for the first (and only) image in batch
results_by_cat, all_artist_tags_probs = extract_tags_from_probabilities(probs, metadata)
if output_format == "Prompt-style Tags":
return format_prompt_style_output(results_by_cat, all_artist_tags_probs)
else: # Detailed Output
return format_detailed_output(results_by_cat, all_artist_tags_probs)
# --- Gradio UI ---
def create_gradio_interface(session: ort.InferenceSession, metadata: dict) -> gr.Blocks:
"""Creates the Gradio Blocks interface."""
demo = gr.Blocks(theme="gradio/soft")
with demo:
gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
with gr.Row():
with gr.Column():
image_in = gr.Image(type="pil", label="Input Image")
format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
tag_button = gr.Button("🔍 Tag Image")
with gr.Column():
output_box = gr.Markdown("")
tag_button.click(
fn=tag_image,
inputs=[image_in, format_choice],
outputs=output_box,
extra_args=[session, metadata] # Pass session and metadata as extra arguments
)
gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime) • **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags) • **ONNX Runtime:** for efficient CPU inference • *Demo built with Gradio Blocks.*")
return demo
# --- Main Script ---
if __name__ == "__main__":
model_path, meta_path = download_model_and_metadata(MODEL_REPO, MODEL_FILE, META_FILE)
session = load_model_session(model_path)
metadata = load_metadata(meta_path)
demo = create_gradio_interface(session, metadata)
demo.launch() |