cliptagger-12b / app.py
andrejrad's picture
Create app.py
ace6ed9 verified
raw
history blame
4.09 kB
import os, json, re
import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = os.environ.get("MODEL_ID", "GrassData/cliptagger-12b")
HF_TOKEN = os.environ.get("HF_TOKEN")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# Load processor & model
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
torch_dtype=DTYPE,
device_map="auto",
trust_remote_code=True
)
# Prompts (system + user, as given)
SYSTEM_PROMPT = (
"You are an image annotation API trained to analyze YouTube video keyframes. "
"You will be given instructions on the output format, what to caption, and how to perform your job. "
"Follow those instructions. For descriptions and summaries, provide them directly and do not lead them with "
"'This image shows' or 'This keyframe displays...', just get right into the details."
)
USER_PROMPT = """You are an image annotation API trained to analyze YouTube video keyframes. You must respond with a valid JSON object matching the exact structure below.
Your job is to extract detailed **factual elements directly visible** in the image. Do not speculate or interpret artistic intent, camera focus, or composition. Do not include phrases like "this appears to be", "this looks like", or anything about the image itself. Describe what **is physically present in the frame**, and nothing more.
Return JSON in this structure:
{
"description": "...",
"objects": ["..."],
"actions": ["..."],
"environment": "...",
"content_type": "...",
"specific_style": "...",
"production_quality": "...",
"summary": "...",
"logos": ["..."]
}
Rules:
- Be specific and literal.
- No mood/emotion/narrative unless explicit.
- No artistic/cinematic analysis.
- Include the language of any visible text (e.g., "English text").
- ≤10 objects, ≤5 actions.
- 'logos' must be [] if none are present.
- Strictly valid JSON, properly escaped.
- Output only JSON, no extra text.
"""
def run_inference(image: Image.Image):
# Messages
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": USER_PROMPT}]}
]
prompt_inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
inputs = processor(text=prompt_inputs, images=image, return_tensors="pt").to(model.device)
with torch.inference_mode():
out = model.generate(
**inputs,
do_sample=False, # deterministic since temp=0.1
temperature=0.1,
max_new_tokens=2000,
eos_token_id=processor.tokenizer.eos_token_id,
response_format={"type": "json_object"} # force JSON mode
)
text = processor.decode(out[0], skip_special_tokens=True)
# Clean parse
try:
parsed = json.loads(text)
pretty = json.dumps(parsed, indent=2)
return pretty, parsed
except Exception:
return text, {"error": "Invalid JSON"}
def ui_submit(img):
if img is None:
return "Please upload an image.", None
return run_inference(img)
# ---- UI ----
with gr.Blocks(title="ClipTagger-12B Keyframe Annotator") as demo:
gr.Markdown("# ClipTagger-12B Keyframe Annotator\nUpload a photo to get structured JSON annotations.")
with gr.Row():
with gr.Column(scale=1):
image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
btn = gr.Button("Annotate", variant="primary")
with gr.Column(scale=1):
out_text = gr.Code(label="Model Output (JSON)")
out_json = gr.JSON(label="Parsed JSON")
btn.click(ui_submit, inputs=[image], outputs=[out_text, out_json])
demo.queue(max_size=32, concurrency_count=1).launch()