Spaces:
Running
on
A100
Running
on
A100
Update app.py
Browse files
app.py
CHANGED
@@ -5,14 +5,17 @@ from PIL import Image
|
|
5 |
import torch
|
6 |
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
7 |
|
8 |
-
#
|
9 |
MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
13 |
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
14 |
|
15 |
-
#
|
16 |
SYSTEM_PROMPT = (
|
17 |
"You are an image annotation API trained to analyze YouTube video keyframes. "
|
18 |
"You will be given instructions on the output format, what to caption, and how to perform your job. "
|
@@ -27,39 +30,40 @@ Your job is to extract detailed **factual elements directly visible** in the ima
|
|
27 |
Return JSON in this structure:
|
28 |
|
29 |
{
|
30 |
-
"description": "...",
|
31 |
-
"objects": ["...
|
32 |
-
"actions": ["...
|
33 |
-
"environment": "
|
34 |
-
"content_type": "
|
35 |
-
"specific_style": "
|
36 |
-
"production_quality": "
|
37 |
-
"summary": "
|
38 |
-
"logos": ["...
|
39 |
}
|
40 |
|
41 |
Rules:
|
42 |
- Be specific and literal. Focus on what is explicitly visible.
|
43 |
- Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
|
44 |
- No artistic or cinematic analysis.
|
45 |
-
- Always include the language of any text in the image if present as an object.
|
46 |
- Maximum 10 objects and 5 actions.
|
47 |
-
- Return
|
48 |
-
-
|
|
|
|
|
49 |
"""
|
50 |
|
51 |
-
#
|
52 |
def extract_top_level_json(s: str):
|
|
|
53 |
try:
|
54 |
return json.loads(s)
|
55 |
except Exception:
|
56 |
pass
|
57 |
-
start = None
|
58 |
-
depth = 0
|
59 |
for i, ch in enumerate(s):
|
60 |
if ch == '{':
|
61 |
-
if depth == 0:
|
62 |
-
start = i
|
63 |
depth += 1
|
64 |
elif ch == '}':
|
65 |
if depth > 0:
|
@@ -79,16 +83,16 @@ def build_messages(image):
|
|
79 |
{"type": "text", "text": USER_PROMPT}]}
|
80 |
]
|
81 |
|
82 |
-
def
|
83 |
if pil is None: return pil
|
84 |
w, h = pil.size
|
85 |
m = max(w, h)
|
86 |
-
if m <=
|
87 |
return pil.convert("RGB")
|
88 |
-
s =
|
89 |
return pil.convert("RGB").resize((int(w*s), int(h*s)), Image.BICUBIC)
|
90 |
|
91 |
-
#
|
92 |
processor = tokenizer = model = None
|
93 |
LOAD_ERROR = None
|
94 |
try:
|
@@ -100,10 +104,10 @@ try:
|
|
100 |
model = AutoModelForCausalLM.from_pretrained(
|
101 |
MODEL_ID,
|
102 |
token=HF_TOKEN,
|
103 |
-
device_map="
|
104 |
torch_dtype=DTYPE,
|
105 |
trust_remote_code=True,
|
106 |
-
# quantization_config=None, # uncomment
|
107 |
)
|
108 |
tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
|
109 |
MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
|
@@ -111,58 +115,57 @@ try:
|
|
111 |
except Exception as e:
|
112 |
LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
|
113 |
|
114 |
-
#
|
115 |
def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
|
116 |
if image is None:
|
117 |
return "Please upload an image.", None, False
|
118 |
if model is None or processor is None:
|
119 |
return f"❌ Load error:\n{LOAD_ERROR}", None, False
|
120 |
|
121 |
-
image =
|
122 |
|
|
|
123 |
if hasattr(processor, "apply_chat_template"):
|
124 |
prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
|
125 |
else:
|
126 |
prompt = USER_PROMPT
|
127 |
|
128 |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
|
|
129 |
|
130 |
tried = []
|
131 |
-
|
|
|
132 |
try:
|
133 |
g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
|
134 |
-
eos
|
135 |
-
if eos is not None:
|
136 |
-
g["eos_token_id"] = eos
|
137 |
with torch.inference_mode():
|
138 |
out = model.generate(**inputs, **g)
|
139 |
text = processor.decode(out[0], skip_special_tokens=True)
|
140 |
parsed = extract_top_level_json(text)
|
141 |
-
if isinstance(parsed, dict):
|
142 |
return json.dumps(parsed, indent=2), parsed, True
|
143 |
-
tried.append(("greedy", "parse-failed"))
|
144 |
except Exception as e:
|
145 |
tried.append(("greedy", f"err={e}"))
|
146 |
|
147 |
-
# (2)
|
148 |
try:
|
149 |
g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
|
150 |
-
eos
|
151 |
-
if eos is not None:
|
152 |
-
g["eos_token_id"] = eos
|
153 |
with torch.inference_mode():
|
154 |
out = model.generate(**inputs, **g)
|
155 |
text = processor.decode(out[0], skip_special_tokens=True)
|
156 |
parsed = extract_top_level_json(text)
|
157 |
-
if isinstance(parsed, dict):
|
158 |
return json.dumps(parsed, indent=2), parsed, True
|
159 |
-
tried.append(("
|
160 |
except Exception as e:
|
161 |
-
tried.append(("
|
162 |
|
163 |
return "Generation failed.\n" + str(tried), None, False
|
164 |
|
165 |
-
#
|
166 |
with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="ClipTagger (VLM)") as demo:
|
167 |
gr.Markdown("# ClipTagger\nUpload an image to get **strict JSON** annotations.")
|
168 |
if LOAD_ERROR:
|
|
|
5 |
import torch
|
6 |
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
7 |
|
8 |
+
# ===== Env / params =====
|
9 |
MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
|
10 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
11 |
+
|
12 |
+
# Latency/quality knobs (tuned for A100)
|
13 |
+
TEMP = 0.1 # per model docs
|
14 |
+
MAX_NEW_TOKENS = 384 # fast + sufficient for schema (raise to 512/768 later if needed)
|
15 |
+
VISION_LONG_SIDE = 896 # matches your vision_config.image_size
|
16 |
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
17 |
|
18 |
+
# ===== Prompts (exact, no example output) =====
|
19 |
SYSTEM_PROMPT = (
|
20 |
"You are an image annotation API trained to analyze YouTube video keyframes. "
|
21 |
"You will be given instructions on the output format, what to caption, and how to perform your job. "
|
|
|
30 |
Return JSON in this structure:
|
31 |
|
32 |
{
|
33 |
+
"description": "A detailed, factual account of what is visibly happening (4 sentences max). Only mention concrete elements or actions that are clearly shown. Do not include anything about how the image is styled, shot, or composed. Do not lead the description with something like 'This image shows' or 'this keyframe is...', just get right into the details.",
|
34 |
+
"objects": ["object1 with relevant visual details", "object2 with relevant visual details", ...],
|
35 |
+
"actions": ["action1 with participants and context", "action2 with participants and context", ...],
|
36 |
+
"environment": "Detailed factual description of the setting and atmosphere based on visible cues (e.g., interior of a classroom with fluorescent lighting, or outdoor forest path with snow-covered trees).",
|
37 |
+
"content_type": "The type of content it is, e.g., 'real-world footage', 'video game', 'animation', 'cartoon', 'CGI', 'VTuber', etc.",
|
38 |
+
"specific_style": "Specific genre, aesthetic, or platform style (e.g., anime, 3D animation, mobile gameplay, vlog, tutorial, news broadcast, etc.)",
|
39 |
+
"production_quality": "Visible production level: e.g., 'professional studio', 'amateur handheld', 'webcam recording', 'TV broadcast', etc.",
|
40 |
+
"summary": "One clear, comprehensive sentence summarizing the visual content of the frame. Like the description, get right to the point.",
|
41 |
+
"logos": ["logo1 with visual description", "logo2 with visual description", ...]
|
42 |
}
|
43 |
|
44 |
Rules:
|
45 |
- Be specific and literal. Focus on what is explicitly visible.
|
46 |
- Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
|
47 |
- No artistic or cinematic analysis.
|
48 |
+
- Always include the language of any text in the image if present as an object, e.g., "English text", "Japanese text", "Russian text", etc.
|
49 |
- Maximum 10 objects and 5 actions.
|
50 |
+
- Return an empty array for 'logos' if none are present.
|
51 |
+
- Always output strictly valid JSON with proper escaping.
|
52 |
+
- Output **only the JSON**, no extra text or explanation.
|
53 |
+
- Do not use placeholder strings or ellipses ('...'). Replace with concrete values directly observed in the image only.
|
54 |
"""
|
55 |
|
56 |
+
# ===== Utils =====
|
57 |
def extract_top_level_json(s: str):
|
58 |
+
"""Parse JSON; if there’s surrounding text, extract the first balanced {...} block."""
|
59 |
try:
|
60 |
return json.loads(s)
|
61 |
except Exception:
|
62 |
pass
|
63 |
+
start, depth = None, 0
|
|
|
64 |
for i, ch in enumerate(s):
|
65 |
if ch == '{':
|
66 |
+
if depth == 0: start = i
|
|
|
67 |
depth += 1
|
68 |
elif ch == '}':
|
69 |
if depth > 0:
|
|
|
83 |
{"type": "text", "text": USER_PROMPT}]}
|
84 |
]
|
85 |
|
86 |
+
def resize_to_vision(pil: Image.Image, long_side: int = VISION_LONG_SIDE) -> Image.Image:
|
87 |
if pil is None: return pil
|
88 |
w, h = pil.size
|
89 |
m = max(w, h)
|
90 |
+
if m <= long_side:
|
91 |
return pil.convert("RGB")
|
92 |
+
s = long_side / m
|
93 |
return pil.convert("RGB").resize((int(w*s), int(h*s)), Image.BICUBIC)
|
94 |
|
95 |
+
# ===== Load model (A100) =====
|
96 |
processor = tokenizer = model = None
|
97 |
LOAD_ERROR = None
|
98 |
try:
|
|
|
104 |
model = AutoModelForCausalLM.from_pretrained(
|
105 |
MODEL_ID,
|
106 |
token=HF_TOKEN,
|
107 |
+
device_map="cuda", # keep on A100
|
108 |
torch_dtype=DTYPE,
|
109 |
trust_remote_code=True,
|
110 |
+
# quantization_config=None, # uncomment to force full precision if you removed quant
|
111 |
)
|
112 |
tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
|
113 |
MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
|
|
|
115 |
except Exception as e:
|
116 |
LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
|
117 |
|
118 |
+
# ===== Inference =====
|
119 |
def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
|
120 |
if image is None:
|
121 |
return "Please upload an image.", None, False
|
122 |
if model is None or processor is None:
|
123 |
return f"❌ Load error:\n{LOAD_ERROR}", None, False
|
124 |
|
125 |
+
image = resize_to_vision(image, VISION_LONG_SIDE)
|
126 |
|
127 |
+
# Chat prompt
|
128 |
if hasattr(processor, "apply_chat_template"):
|
129 |
prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
|
130 |
else:
|
131 |
prompt = USER_PROMPT
|
132 |
|
133 |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
|
134 |
+
eos = getattr(model.config, "eos_token_id", None)
|
135 |
|
136 |
tried = []
|
137 |
+
|
138 |
+
# (1) Greedy (fast, stable)
|
139 |
try:
|
140 |
g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
|
141 |
+
if eos is not None: g["eos_token_id"] = eos
|
|
|
|
|
142 |
with torch.inference_mode():
|
143 |
out = model.generate(**inputs, **g)
|
144 |
text = processor.decode(out[0], skip_special_tokens=True)
|
145 |
parsed = extract_top_level_json(text)
|
146 |
+
if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
|
147 |
return json.dumps(parsed, indent=2), parsed, True
|
148 |
+
tried.append(("greedy", "parse-failed-or-ellipses"))
|
149 |
except Exception as e:
|
150 |
tried.append(("greedy", f"err={e}"))
|
151 |
|
152 |
+
# (2) Short sampled retry
|
153 |
try:
|
154 |
g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
|
155 |
+
if eos is not None: g["eos_token_id"] = eos
|
|
|
|
|
156 |
with torch.inference_mode():
|
157 |
out = model.generate(**inputs, **g)
|
158 |
text = processor.decode(out[0], skip_special_tokens=True)
|
159 |
parsed = extract_top_level_json(text)
|
160 |
+
if isinstance(parsed, dict) and "..." not in json.dumps(parsed):
|
161 |
return json.dumps(parsed, indent=2), parsed, True
|
162 |
+
tried.append(("sample_t0.1", "parse-failed-or-ellipses"))
|
163 |
except Exception as e:
|
164 |
+
tried.append(("sample_t0.1", f"err={e}"))
|
165 |
|
166 |
return "Generation failed.\n" + str(tried), None, False
|
167 |
|
168 |
+
# ===== UI =====
|
169 |
with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="ClipTagger (VLM)") as demo:
|
170 |
gr.Markdown("# ClipTagger\nUpload an image to get **strict JSON** annotations.")
|
171 |
if LOAD_ERROR:
|