Spaces:
Running
on
A100
Running
on
A100
File size: 9,753 Bytes
dcdd99b 0557d7f ace6ed9 21b17c3 ace6ed9 b989be2 0557d7f dcdd99b b989be2 f617893 b989be2 f617893 9f13dde 0557d7f f617893 ace6ed9 b989be2 ace6ed9 6a4178d b989be2 6a4178d b989be2 f617893 ace6ed9 b989be2 f617893 b989be2 dcdd99b f617893 dcdd99b f617893 dcdd99b f617893 6a4178d f617893 21b17c3 0557d7f b989be2 8d3d460 9f13dde b989be2 9f13dde b989be2 dcdd99b 9f13dde b989be2 30343a2 dcdd99b 21b17c3 f617893 8d3d460 f617893 dcdd99b b989be2 dcdd99b f617893 dcdd99b f617893 30343a2 f617893 30343a2 6a4178d b989be2 dcdd99b 6a4178d 30343a2 dcdd99b ace6ed9 b989be2 9f13dde f617893 30343a2 dcdd99b 6a4178d dcdd99b 6a4178d f617893 30343a2 b989be2 ace6ed9 f617893 30343a2 b989be2 dcdd99b f617893 dcdd99b f617893 b989be2 dcdd99b b989be2 dcdd99b b989be2 dcdd99b f617893 dcdd99b f617893 b989be2 dcdd99b b989be2 dcdd99b b989be2 dcdd99b 8d3d460 0557d7f b989be2 8d3d460 30343a2 ace6ed9 0557d7f ace6ed9 0557d7f ace6ed9 dcdd99b 21b17c3 dcdd99b 1cffa06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import os, json, traceback
from typing import Any, Dict, Tuple
import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
# ===== Env / params =====
MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
HF_TOKEN = os.environ.get("HF_TOKEN")
# Latency/quality knobs (tuned for A100-80GB)
TEMP = 0.1 # per model docs
MAX_NEW_TOKENS = 384 # fast + sufficient for schema (raise to 512/768 if needed)
VISION_LONG_SIDE = 896 # matches vision_config.image_size
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# ===== Prompts (schema-only; no example output) =====
SYSTEM_PROMPT = (
"You are an image annotation API trained to analyze YouTube video keyframes. "
"You will be given instructions on the output format, what to caption, and how to perform your job. "
"Follow those instructions. For descriptions and summaries, provide them directly and do not lead them with "
"'This image shows' or 'This keyframe displays...', just get right into the details."
)
USER_PROMPT = """You are an image annotation API trained to analyze YouTube video keyframes. You must respond with a valid JSON object matching the exact structure below.
Your job is to extract detailed **factual elements directly visible** in the image. Do not speculate or interpret artistic intent, camera focus, or composition. Do not include phrases like "this appears to be", "this looks like", or anything about the image itself. Describe what **is physically present in the frame**, and nothing more.
Return JSON in this structure:
{
"description": "A detailed, factual account of what is visibly happening (4 sentences max). Only mention concrete elements or actions that are clearly shown. Do not include anything about how the image is styled, shot, or composed. Do not lead the description with something like 'This image shows' or 'this keyframe is...', just get right into the details.",
"objects": ["object1 with relevant visual details", "object2 with relevant visual details", ...],
"actions": ["action1 with participants and context", "action2 with participants and context", ...],
"environment": "Detailed factual description of the setting and atmosphere based on visible cues (e.g., interior of a classroom with fluorescent lighting, or outdoor forest path with snow-covered trees).",
"content_type": "The type of content it is, e.g., 'real-world footage', 'video game', 'animation', 'cartoon', 'CGI', 'VTuber', etc.",
"specific_style": "Specific genre, aesthetic, or platform style (e.g., anime, 3D animation, mobile gameplay, vlog, tutorial, news broadcast, etc.)",
"production_quality": "Visible production level: e.g., 'professional studio', 'amateur handheld', 'webcam recording', 'TV broadcast', etc.",
"summary": "One clear, comprehensive sentence summarizing the visual content of the frame. Like the description, get right to the point.",
"logos": ["logo1 with visual description", "logo2 with visual description", ...]
}
Rules:
- Be specific and literal. Focus on what is explicitly visible.
- Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
- No artistic or cinematic analysis.
- Always include the language of any text in the image if present as an object, e.g., "English text", "Japanese text", "Russian text", etc.
- Maximum 10 objects and 5 actions.
- Return an empty array for 'logos' if none are present.
- Always output strictly valid JSON with proper escaping.
- Output **only the JSON**, no extra text or explanation.
- Do **not** copy any example strings from the instructions or use ellipses ('...'). Produce concrete values drawn from the image only.
"""
# ===== Utils =====
def extract_last_json(s: str):
"""
Return the last balanced {...} JSON object found in the string.
This avoids grabbing the schema block from the prompt if it echoes.
"""
last = None
start, depth = None, 0
for i, ch in enumerate(s):
if ch == '{':
if depth == 0:
start = i
depth += 1
elif ch == '}':
if depth > 0:
depth -= 1
if depth == 0 and start is not None:
chunk = s[start:i+1]
try:
last = json.loads(chunk)
except Exception:
pass
start = None
return last
def build_messages(image: Image.Image):
return [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "image", "image": image},
{"type": "text", "text": USER_PROMPT}]}
]
def resize_to_vision(pil: Image.Image, long_side: int = VISION_LONG_SIDE) -> Image.Image:
if pil is None: return pil
w, h = pil.size
m = max(w, h)
if m <= long_side:
return pil.convert("RGB")
s = long_side / m
return pil.convert("RGB").resize((int(w*s), int(h*s)), Image.BICUBIC)
# ===== Load model (A100) =====
processor = tokenizer = model = None
LOAD_ERROR = None
try:
cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
if "clip" in cfg.__class__.__name__.lower():
raise RuntimeError(f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder repo; need a causal VLM.")
print("[boot] loading processor…", flush=True)
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
print("[boot] loading model…", flush=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
device_map="cuda", # keep on A100
torch_dtype=DTYPE,
trust_remote_code=True,
# quantization_config=None, # uncomment to force full precision if you removed quant in repo
)
tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
)
print("[boot] ready.", flush=True)
except Exception as e:
LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
# ===== Inference =====
def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
if image is None:
return "Please upload an image.", None, False
if model is None or processor is None:
return f"❌ Load error:\n{LOAD_ERROR}", None, False
image = resize_to_vision(image, VISION_LONG_SIDE)
# Build chat prompt
if hasattr(processor, "apply_chat_template"):
prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
else:
prompt = USER_PROMPT
# Tokenize with vision
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
eos = getattr(model.config, "eos_token_id", None)
def _decode_only_new(out_ids):
"""
Decode only the newly generated tokens (exclude prompt tokens),
so we don't accidentally parse the schema block from the prompt.
"""
input_len = inputs["input_ids"].shape[1]
gen_ids = out_ids[0][input_len:]
# Prefer processor.decode if available (some VLMs customize decoding)
if hasattr(processor, "decode"):
return processor.decode(gen_ids, skip_special_tokens=True)
return tokenizer.decode(gen_ids, skip_special_tokens=True)
tried = []
# (1) Greedy (fast, stable)
try:
g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
if eos is not None:
g["eos_token_id"] = eos
with torch.inference_mode():
out = model.generate(**inputs, **g)
text = _decode_only_new(out)
parsed = extract_last_json(text)
if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
return json.dumps(parsed, indent=2), parsed, True
tried.append(("greedy", "parse-failed-or-ellipses"))
except Exception as e:
tried.append(("greedy", f"err={e}"))
# (2) Short sampled retry
try:
g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
if eos is not None:
g["eos_token_id"] = eos
with torch.inference_mode():
out = model.generate(**inputs, **g)
text = _decode_only_new(out)
parsed = extract_last_json(text)
if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
return json.dumps(parsed, indent=2), parsed, True
tried.append(("sample_t0.1", "parse-failed-or-ellipses"))
except Exception as e:
tried.append(("sample_t0.1", f"err={e}"))
return "Generation failed.\n" + str(tried), None, False
# ===== UI =====
with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="ClipTagger (VLM)") as demo:
gr.Markdown("# ClipTagger\nUpload an image to get **strict JSON** annotations.")
if LOAD_ERROR:
with gr.Accordion("Startup Error Details", open=False):
gr.Markdown(f"```\n{LOAD_ERROR}\n```")
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
btn = gr.Button("Annotate", variant="primary")
with gr.Column(scale=1):
out_text = gr.Code(label="Output (JSON or error)")
out_json = gr.JSON(label="Parsed JSON")
ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
btn.click(generate, inputs=[image], outputs=[out_text, out_json, ok_flag])
demo.queue(max_size=32).launch()
|