File size: 9,753 Bytes
dcdd99b
0557d7f
ace6ed9
 
 
21b17c3
ace6ed9
b989be2
0557d7f
dcdd99b
b989be2
f617893
b989be2
f617893
 
9f13dde
0557d7f
f617893
ace6ed9
 
 
 
 
 
 
 
 
 
 
 
 
 
b989be2
 
 
 
 
 
 
 
 
ace6ed9
 
 
6a4178d
 
 
b989be2
6a4178d
b989be2
 
 
f617893
ace6ed9
 
b989be2
f617893
 
 
 
 
 
b989be2
dcdd99b
 
f617893
 
dcdd99b
 
 
 
 
 
 
f617893
dcdd99b
f617893
 
 
6a4178d
f617893
21b17c3
 
 
 
 
0557d7f
b989be2
8d3d460
9f13dde
 
b989be2
9f13dde
b989be2
dcdd99b
9f13dde
b989be2
30343a2
 
 
 
 
dcdd99b
21b17c3
f617893
8d3d460
f617893
 
dcdd99b
 
 
b989be2
dcdd99b
 
f617893
dcdd99b
f617893
30343a2
 
 
f617893
 
30343a2
 
6a4178d
b989be2
dcdd99b
6a4178d
 
30343a2
dcdd99b
ace6ed9
b989be2
9f13dde
f617893
30343a2
dcdd99b
6a4178d
dcdd99b
6a4178d
f617893
30343a2
b989be2
ace6ed9
f617893
 
 
 
 
 
 
 
 
 
 
 
30343a2
b989be2
 
dcdd99b
 
f617893
 
dcdd99b
 
f617893
 
b989be2
dcdd99b
b989be2
dcdd99b
 
 
b989be2
dcdd99b
 
f617893
 
dcdd99b
 
f617893
 
b989be2
dcdd99b
b989be2
dcdd99b
b989be2
dcdd99b
8d3d460
0557d7f
b989be2
8d3d460
 
30343a2
 
 
ace6ed9
 
 
0557d7f
ace6ed9
0557d7f
ace6ed9
dcdd99b
21b17c3
dcdd99b
1cffa06
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os, json, traceback
from typing import Any, Dict, Tuple
import gradio as gr
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig

# ===== Env / params =====
MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
HF_TOKEN = os.environ.get("HF_TOKEN")

# Latency/quality knobs (tuned for A100-80GB)
TEMP = 0.1                      # per model docs
MAX_NEW_TOKENS = 384            # fast + sufficient for schema (raise to 512/768 if needed)
VISION_LONG_SIDE = 896          # matches vision_config.image_size
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32

# ===== Prompts (schema-only; no example output) =====
SYSTEM_PROMPT = (
    "You are an image annotation API trained to analyze YouTube video keyframes. "
    "You will be given instructions on the output format, what to caption, and how to perform your job. "
    "Follow those instructions. For descriptions and summaries, provide them directly and do not lead them with "
    "'This image shows' or 'This keyframe displays...', just get right into the details."
)

USER_PROMPT = """You are an image annotation API trained to analyze YouTube video keyframes. You must respond with a valid JSON object matching the exact structure below.

Your job is to extract detailed **factual elements directly visible** in the image. Do not speculate or interpret artistic intent, camera focus, or composition. Do not include phrases like "this appears to be", "this looks like", or anything about the image itself. Describe what **is physically present in the frame**, and nothing more.

Return JSON in this structure:

{
    "description": "A detailed, factual account of what is visibly happening (4 sentences max). Only mention concrete elements or actions that are clearly shown. Do not include anything about how the image is styled, shot, or composed. Do not lead the description with something like 'This image shows' or 'this keyframe is...', just get right into the details.",
    "objects": ["object1 with relevant visual details", "object2 with relevant visual details", ...],
    "actions": ["action1 with participants and context", "action2 with participants and context", ...],
    "environment": "Detailed factual description of the setting and atmosphere based on visible cues (e.g., interior of a classroom with fluorescent lighting, or outdoor forest path with snow-covered trees).",
    "content_type": "The type of content it is, e.g., 'real-world footage', 'video game', 'animation', 'cartoon', 'CGI', 'VTuber', etc.",
    "specific_style": "Specific genre, aesthetic, or platform style (e.g., anime, 3D animation, mobile gameplay, vlog, tutorial, news broadcast, etc.)",
    "production_quality": "Visible production level: e.g., 'professional studio', 'amateur handheld', 'webcam recording', 'TV broadcast', etc.",
    "summary": "One clear, comprehensive sentence summarizing the visual content of the frame. Like the description, get right to the point.",
    "logos": ["logo1 with visual description", "logo2 with visual description", ...]
}

Rules:
- Be specific and literal. Focus on what is explicitly visible.
- Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
- No artistic or cinematic analysis.
- Always include the language of any text in the image if present as an object, e.g., "English text", "Japanese text", "Russian text", etc.
- Maximum 10 objects and 5 actions.
- Return an empty array for 'logos' if none are present.
- Always output strictly valid JSON with proper escaping.
- Output **only the JSON**, no extra text or explanation.
- Do **not** copy any example strings from the instructions or use ellipses ('...'). Produce concrete values drawn from the image only.
"""

# ===== Utils =====
def extract_last_json(s: str):
    """
    Return the last balanced {...} JSON object found in the string.
    This avoids grabbing the schema block from the prompt if it echoes.
    """
    last = None
    start, depth = None, 0
    for i, ch in enumerate(s):
        if ch == '{':
            if depth == 0:
                start = i
            depth += 1
        elif ch == '}':
            if depth > 0:
                depth -= 1
                if depth == 0 and start is not None:
                    chunk = s[start:i+1]
                    try:
                        last = json.loads(chunk)
                    except Exception:
                        pass
                    start = None
    return last

def build_messages(image: Image.Image):
    return [
        {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
        {"role": "user",   "content": [{"type": "image", "image": image},
                                       {"type": "text",  "text": USER_PROMPT}]}
    ]

def resize_to_vision(pil: Image.Image, long_side: int = VISION_LONG_SIDE) -> Image.Image:
    if pil is None: return pil
    w, h = pil.size
    m = max(w, h)
    if m <= long_side:
        return pil.convert("RGB")
    s = long_side / m
    return pil.convert("RGB").resize((int(w*s), int(h*s)), Image.BICUBIC)

# ===== Load model (A100) =====
processor = tokenizer = model = None
LOAD_ERROR = None
try:
    cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
    if "clip" in cfg.__class__.__name__.lower():
        raise RuntimeError(f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder repo; need a causal VLM.")

    print("[boot] loading processor…", flush=True)
    processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)

    print("[boot] loading model…", flush=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        token=HF_TOKEN,
        device_map="cuda",     # keep on A100
        torch_dtype=DTYPE,
        trust_remote_code=True,
        # quantization_config=None,  # uncomment to force full precision if you removed quant in repo
    )

    tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
        MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
    )
    print("[boot] ready.", flush=True)

except Exception as e:
    LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"

# ===== Inference =====
def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
    if image is None:
        return "Please upload an image.", None, False
    if model is None or processor is None:
        return f"❌ Load error:\n{LOAD_ERROR}", None, False

    image = resize_to_vision(image, VISION_LONG_SIDE)

    # Build chat prompt
    if hasattr(processor, "apply_chat_template"):
        prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
    else:
        prompt = USER_PROMPT

    # Tokenize with vision
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
    eos = getattr(model.config, "eos_token_id", None)

    def _decode_only_new(out_ids):
        """
        Decode only the newly generated tokens (exclude prompt tokens),
        so we don't accidentally parse the schema block from the prompt.
        """
        input_len = inputs["input_ids"].shape[1]
        gen_ids = out_ids[0][input_len:]
        # Prefer processor.decode if available (some VLMs customize decoding)
        if hasattr(processor, "decode"):
            return processor.decode(gen_ids, skip_special_tokens=True)
        return tokenizer.decode(gen_ids, skip_special_tokens=True)

    tried = []

    # (1) Greedy (fast, stable)
    try:
        g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
        if eos is not None:
            g["eos_token_id"] = eos
        with torch.inference_mode():
            out = model.generate(**inputs, **g)
        text = _decode_only_new(out)
        parsed = extract_last_json(text)
        if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
            return json.dumps(parsed, indent=2), parsed, True
        tried.append(("greedy", "parse-failed-or-ellipses"))
    except Exception as e:
        tried.append(("greedy", f"err={e}"))

    # (2) Short sampled retry
    try:
        g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
        if eos is not None:
            g["eos_token_id"] = eos
        with torch.inference_mode():
            out = model.generate(**inputs, **g)
        text = _decode_only_new(out)
        parsed = extract_last_json(text)
        if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
            return json.dumps(parsed, indent=2), parsed, True
        tried.append(("sample_t0.1", "parse-failed-or-ellipses"))
    except Exception as e:
        tried.append(("sample_t0.1", f"err={e}"))

    return "Generation failed.\n" + str(tried), None, False

# ===== UI =====
with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="ClipTagger (VLM)") as demo:
    gr.Markdown("# ClipTagger\nUpload an image to get **strict JSON** annotations.")
    if LOAD_ERROR:
        with gr.Accordion("Startup Error Details", open=False):
            gr.Markdown(f"```\n{LOAD_ERROR}\n```")
    with gr.Row():
        with gr.Column(scale=1):
            image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
            btn = gr.Button("Annotate", variant="primary")
        with gr.Column(scale=1):
            out_text = gr.Code(label="Output (JSON or error)")
            out_json = gr.JSON(label="Parsed JSON")
            ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)

    btn.click(generate, inputs=[image], outputs=[out_text, out_json, ok_flag])

demo.queue(max_size=32).launch()