Spaces:
Running
on
A100
Running
on
A100
import os, json, re, traceback | |
from typing import Any, Dict, Tuple | |
import gradio as gr | |
from PIL import Image | |
import torch | |
import spaces | |
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
# --------- ENV / PARAMS ---------- | |
MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b") | |
HF_TOKEN = os.environ.get("HF_TOKEN") # put this in Space -> Settings -> Variables & secrets | |
TEMP = 0.1 | |
MAX_NEW_TOKENS = 2000 | |
# Lazy globals (ZeroGPU-safe) | |
_processor: Any = None | |
_tokenizer: Any = None | |
_model: Any = None | |
_last_load_error: str | None = None | |
# --------- PROMPTS (yours) ---------- | |
SYSTEM_PROMPT = ( | |
"You are an image annotation API trained to analyze YouTube video keyframes. " | |
"You will be given instructions on the output format, what to caption, and how to perform your job. " | |
"Follow those instructions. For descriptions and summaries, provide them directly and do not lead them with " | |
"'This image shows' or 'This keyframe displays...', just get right into the details." | |
) | |
USER_PROMPT = """You are an image annotation API trained to analyze YouTube video keyframes. You must respond with a valid JSON object matching the exact structure below. | |
Your job is to extract detailed **factual elements directly visible** in the image. Do not speculate or interpret artistic intent, camera focus, or composition. Do not include phrases like "this appears to be", "this looks like", or anything about the image itself. Describe what **is physically present in the frame**, and nothing more. | |
Return JSON in this structure: | |
{ | |
"description": "A detailed, factual account of what is visibly happening (4 sentences max). Only mention concrete elements or actions that are clearly shown. Do not include anything about how the image is styled, shot, or composed. Do not lead the description with something like 'This image shows' or 'this keyframe is...', just get right into the details.", | |
"objects": ["object1 with relevant visual details", "object2 with relevant visual details", ...], | |
"actions": ["action1 with participants and context", "action2 with participants and context", ...], | |
"environment": "Detailed factual description of the setting and atmosphere based on visible cues (e.g., interior of a classroom with fluorescent lighting, or outdoor forest path with snow-covered trees).", | |
"content_type": "The type of content it is, e.g. 'real-world footage', 'video game', 'animation', 'cartoon', 'CGI', 'VTuber', etc.", | |
"specific_style": "Specific genre, aesthetic, or platform style (e.g., anime, 3D animation, mobile gameplay, vlog, tutorial, news broadcast, etc.)", | |
"production_quality": "Visible production level: e.g., 'professional studio', 'amateur handheld', 'webcam recording', 'TV broadcast', etc.", | |
"summary": "One clear, comprehensive sentence summarizing the visual content of the frame. Like the description, get right to the point.", | |
"logos": ["logo1 with visual description", "logo2 with visual description", ...] | |
} | |
Rules: | |
- Be specific and literal. Focus on what is explicitly visible. | |
- Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit. | |
- No artistic or cinematic analysis. | |
- Always include the language of any text in the image if present as an object, e.g. "English text", "Japanese text", "Russian text", etc. | |
- Maximum 10 objects and 5 actions. | |
- Return an empty array for 'logos' if none are present. | |
- Always output strictly valid JSON with proper escaping. | |
- Output **only the JSON**, no extra text or explanation. | |
""" | |
# --------- HELPERS ---------- | |
def _json_extract(text: str): | |
"""Strict parse -> top-level {...} fallback.""" | |
try: | |
return json.loads(text) | |
except Exception: | |
m = re.search(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL) | |
if m: | |
try: | |
return json.loads(m.group(0)) | |
except Exception: | |
pass | |
return None | |
def _build_messages(image: Image.Image): | |
return [ | |
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, | |
{"role": "user", "content": [{"type": "image", "image": image}, | |
{"type": "text", "text": USER_PROMPT}]} | |
] | |
# --------- ZERO-GPU LAZY LOADER ---------- | |
def _ensure_loaded() -> str: | |
""" | |
Load the model only when a ZeroGPU worker with a GPU is attached. | |
Tries quantized path first (compressed-tensors), then falls back to unquantized. | |
""" | |
global _processor, _tokenizer, _model, _last_load_error | |
if _model is not None and _processor is not None: | |
return "already_loaded" | |
try: | |
# Sanity: config should be gemma3 causal VLM (not CLIP) | |
cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True) | |
if "clip" in cfg.__class__.__name__.lower(): | |
raise RuntimeError( | |
f"MODEL_ID '{MODEL_ID}' resolves to CLIP/encoder config; need a causal VLM checkpoint." | |
) | |
# Try quantized (as per your config) | |
_processor = AutoProcessor.from_pretrained( | |
MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True | |
) | |
_model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
token=HF_TOKEN, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
trust_remote_code=True, | |
) | |
_tokenizer = getattr(_processor, "tokenizer", None) or AutoTokenizer.from_pretrained( | |
MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True | |
) | |
_last_load_error = None | |
return "ok_quant" | |
except Exception as e: | |
# Fallback: disable quantization (more VRAM) | |
if "compressed_tensors" in str(e): | |
try: | |
_processor = AutoProcessor.from_pretrained( | |
MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True | |
) | |
_model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
token=HF_TOKEN, | |
device_map="auto", | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
trust_remote_code=True, | |
quantization_config=None, # force dequantized load | |
) | |
_tokenizer = getattr(_processor, "tokenizer", None) or AutoTokenizer.from_pretrained( | |
MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True | |
) | |
_last_load_error = None | |
return "ok_dequant" | |
except Exception as e2: | |
_last_load_error = f"{e}\n\nFallback failed:\n{e2}\n{traceback.format_exc()}" | |
_processor = _tokenizer = _model = None | |
return "fail" | |
else: | |
_last_load_error = f"{e}\n{traceback.format_exc()}" | |
_processor = _tokenizer = _model = None | |
return "fail" | |
# --------- INFERENCE ---------- | |
def annotate_image(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]: | |
status = _ensure_loaded() | |
if status == "fail": | |
return f"❌ Load error:\n{_last_load_error}", None, False | |
if image is None: | |
return "Please upload an image.", None, False | |
# Prompt assembly | |
if hasattr(_processor, "apply_chat_template"): | |
prompt = _processor.apply_chat_template(_build_messages(image), add_generation_prompt=True, tokenize=False) | |
else: | |
msgs = _build_messages(image) | |
prompt = "" | |
for m in msgs: | |
role = m["role"].upper() | |
for chunk in m["content"]: | |
if chunk["type"] == "text": | |
prompt += f"{role}: {chunk['text']}\n" | |
elif chunk["type"] == "image": | |
prompt += f"{role}: [IMAGE]\n" | |
inputs = _processor(text=prompt, images=image, return_tensors="pt").to(_model.device) | |
gen_kwargs = dict( | |
temperature=TEMP, | |
max_new_tokens=MAX_NEW_TOKENS, | |
) | |
# respect multiple eos ids if present | |
eos = getattr(_model.config, "eos_token_id", None) | |
if eos is not None: | |
gen_kwargs["eos_token_id"] = eos | |
# Try JSON-only output (if supported) | |
try: | |
gen_kwargs["response_format"] = {"type": "json_object"} | |
except Exception: | |
pass | |
with torch.inference_mode(): | |
out = _model.generate(**inputs, **gen_kwargs) | |
text = (_processor.decode(out[0], skip_special_tokens=True) | |
if hasattr(_processor, "decode") | |
else _tokenizer.decode(out[0], skip_special_tokens=True)) | |
if USER_PROMPT in text: | |
text = text.split(USER_PROMPT)[-1].strip() | |
parsed = _json_extract(text) | |
if isinstance(parsed, dict): | |
return json.dumps(parsed, indent=2), parsed, True | |
return text, None, False | |
# Optional: quick warmup to validate loading on first worker | |
def _warmup(): | |
try: | |
return _ensure_loaded() | |
except Exception as e: | |
return f"warmup error: {e}" | |
# --------- UI ---------- | |
with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (ZeroGPU)") as demo: | |
gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT · ZeroGPU)\nUpload an image to get **strict JSON** annotations.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image = gr.Image(type="pil", label="Upload Image", image_mode="RGB") | |
btn = gr.Button("Annotate", variant="primary") | |
with gr.Column(scale=1): | |
out_text = gr.Code(label="Output (JSON or error)") | |
out_json = gr.JSON(label="Parsed JSON") | |
ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False) | |
btn.click(annotate_image, inputs=[image], outputs=[out_text, out_json, ok_flag]) | |
# fire a non-blocking warmup | |
try: | |
_ = _warmup() | |
except Exception: | |
pass | |
demo.queue(max_size=32).launch() | |