andrejrad commited on
Commit
8d3d460
·
verified ·
1 Parent(s): dcdd99b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -80
app.py CHANGED
@@ -9,10 +9,10 @@ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, Aut
9
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
  TEMP = 0.1
12
- MAX_NEW_TOKENS = 768 # faster demo; raise later if needed
13
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
14
 
15
- # -------- Prompts (yours) --------
16
  SYSTEM_PROMPT = (
17
  "You are an image annotation API trained to analyze YouTube video keyframes. "
18
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
@@ -27,37 +27,33 @@ Your job is to extract detailed **factual elements directly visible** in the ima
27
  Return JSON in this structure:
28
 
29
  {
30
- "description": "A detailed, factual account of what is visibly happening (4 sentences max). Only mention concrete elements or actions that are clearly shown. Do not include anything about how the image is styled, shot, or composed. Do not lead the description with something like 'This image shows' or 'this keyframe is...', just get right into the details.",
31
- "objects": ["object1 with relevant visual details", "object2 with relevant visual details", ...],
32
- "actions": ["action1 with participants and context", "action2 with participants and context", ...],
33
- "environment": "Detailed factual description of the setting and atmosphere based on visible cues (e.g., interior of a classroom with fluorescent lighting, or outdoor forest path with snow-covered trees).",
34
- "content_type": "The type of content it is, e.g. 'real-world footage', 'video game', 'animation', 'cartoon', 'CGI', 'VTuber', etc.",
35
- "specific_style": "Specific genre, aesthetic, or platform style (e.g., anime, 3D animation, mobile gameplay, vlog, tutorial, news broadcast, etc.)",
36
- "production_quality": "Visible production level: e.g., 'professional studio', 'amateur handheld', 'webcam recording', 'TV broadcast', etc.",
37
- "summary": "One clear, comprehensive sentence summarizing the visual content of the frame. Like the description, get right to the point.",
38
- "logos": ["logo1 with visual description", "logo2 with visual description", ...]
39
  }
40
 
41
  Rules:
42
  - Be specific and literal. Focus on what is explicitly visible.
43
  - Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
44
  - No artistic or cinematic analysis.
45
- - Always include the language of any text in the image if present as an object, e.g. "English text", "Japanese text", "Russian text", etc.
46
  - Maximum 10 objects and 5 actions.
47
- - Return an empty array for 'logos' if none are present.
48
- - Always output strictly valid JSON with proper escaping.
49
- - Output **only the JSON**, no extra text or explanation.
50
  """
51
 
52
  # -------- Utils --------
53
  def extract_top_level_json(s: str):
54
- """Parse JSON; if extra text around it, extract the first balanced {...} block."""
55
- # Fast path
56
  try:
57
  return json.loads(s)
58
  except Exception:
59
  pass
60
- # Brace-stack extraction
61
  start = None
62
  depth = 0
63
  for i, ch in enumerate(s):
@@ -73,11 +69,10 @@ def extract_top_level_json(s: str):
73
  try:
74
  return json.loads(chunk)
75
  except Exception:
76
- # continue scanning for the next candidate
77
  start = None
78
  return None
79
 
80
- def build_messages(image: Image.Image):
81
  return [
82
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
83
  {"role": "user", "content": [{"type": "image", "image": image},
@@ -85,8 +80,7 @@ def build_messages(image: Image.Image):
85
  ]
86
 
87
  def downscale_if_huge(pil: Image.Image, max_side: int = 1792) -> Image.Image:
88
- if pil is None:
89
- return pil
90
  w, h = pil.size
91
  m = max(w, h)
92
  if m <= max_side:
@@ -94,41 +88,26 @@ def downscale_if_huge(pil: Image.Image, max_side: int = 1792) -> Image.Image:
94
  s = max_side / m
95
  return pil.convert("RGB").resize((int(w*s), int(h*s)), Image.BICUBIC)
96
 
97
- # -------- Load model (A100) --------
98
  processor = tokenizer = model = None
99
  LOAD_ERROR = None
100
-
101
  try:
102
  cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
103
  if "clip" in cfg.__class__.__name__.lower():
104
  raise RuntimeError(f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder repo; need a causal VLM.")
105
 
106
- print("[boot] loading processor…", flush=True)
107
- try:
108
- processor = AutoProcessor.from_pretrained(
109
- MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
110
- )
111
- except TypeError:
112
- processor = AutoProcessor.from_pretrained(
113
- MODEL_ID, token=HF_TOKEN, trust_remote_code=True
114
- )
115
-
116
- print("[boot] loading model…", flush=True)
117
- # Force full-precision path on A100 first; add quantized path later if desired
118
  model = AutoModelForCausalLM.from_pretrained(
119
  MODEL_ID,
120
  token=HF_TOKEN,
121
  device_map="auto",
122
  torch_dtype=DTYPE,
123
  trust_remote_code=True,
124
- # quantization_config=None, # keep commented if you want to honor repo quant; uncomment to force dequant
125
  )
126
-
127
  tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
128
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
129
  )
130
- print("[boot] ready.", flush=True)
131
-
132
  except Exception as e:
133
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
134
 
@@ -141,82 +120,54 @@ def generate(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
141
 
142
  image = downscale_if_huge(image)
143
 
144
- # Build prompt
145
  if hasattr(processor, "apply_chat_template"):
146
  prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
147
  else:
148
- # fallback join (rare)
149
  prompt = USER_PROMPT
150
 
151
- # Tokenize with vision
152
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
153
 
154
- # Common gen kwargs
155
- eos = getattr(model.config, "eos_token_id", None)
156
-
157
  tried = []
158
-
159
- # (1) Greedy, no sampling (most stable, no temperature arg)
160
  try:
161
  g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
 
162
  if eos is not None:
163
  g["eos_token_id"] = eos
164
  with torch.inference_mode():
165
  out = model.generate(**inputs, **g)
166
- text = (processor.decode(out[0], skip_special_tokens=True)
167
- if hasattr(processor, "decode")
168
- else tokenizer.decode(out[0], skip_special_tokens=True))
169
  parsed = extract_top_level_json(text)
170
  if isinstance(parsed, dict):
171
  return json.dumps(parsed, indent=2), parsed, True
172
- tried.append(("greedy", "parsed-failed"))
173
  except Exception as e:
174
  tried.append(("greedy", f"err={e}"))
175
 
176
- # (2) Sampling with temperature=0.1
177
  try:
178
  g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
 
179
  if eos is not None:
180
  g["eos_token_id"] = eos
181
  with torch.inference_mode():
182
  out = model.generate(**inputs, **g)
183
- text = (processor.decode(out[0], skip_special_tokens=True)
184
- if hasattr(processor, "decode")
185
- else tokenizer.decode(out[0], skip_special_tokens=True))
186
  parsed = extract_top_level_json(text)
187
  if isinstance(parsed, dict):
188
  return json.dumps(parsed, indent=2), parsed, True
189
- tried.append(("sample_t0.1", "parsed-failed"))
190
  except Exception as e:
191
- tried.append(("sample_t0.1", f"err={e}"))
192
 
193
- # (3) Shorter greedy
194
- try:
195
- g = dict(do_sample=False, max_new_tokens=min(512, MAX_NEW_TOKENS))
196
- if eos is not None:
197
- g["eos_token_id"] = eos
198
- with torch.inference_mode():
199
- out = model.generate(**inputs, **g)
200
- text = (processor.decode(out[0], skip_special_tokens=True)
201
- if hasattr(processor, "decode")
202
- else tokenizer.decode(out[0], skip_special_tokens=True))
203
- parsed = extract_top_level_json(text)
204
- if isinstance(parsed, dict):
205
- return json.dumps(parsed, indent=2), parsed, True
206
- tried.append(("greedy_short", "parsed-failed"))
207
- except Exception as e:
208
- tried.append(("greedy_short", f"err={e}"))
209
-
210
- # Debug info if all fail
211
- return "Generation failed.\nTried: " + "\n".join([f"{t[0]} -> {t[1]}" for t in tried]), None, False
212
 
213
  # -------- UI --------
214
- with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (Gemma-3 VLM · A100)") as demo:
215
- gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT · A100)\nUpload an image to get **strict JSON** annotations.")
216
  if LOAD_ERROR:
217
  with gr.Accordion("Startup Error Details", open=False):
218
  gr.Markdown(f"```\n{LOAD_ERROR}\n```")
219
-
220
  with gr.Row():
221
  with gr.Column(scale=1):
222
  image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
 
9
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
  TEMP = 0.1
12
+ MAX_NEW_TOKENS = 768 # safe for demo; raise if needed
13
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
14
 
15
+ # -------- Prompts --------
16
  SYSTEM_PROMPT = (
17
  "You are an image annotation API trained to analyze YouTube video keyframes. "
18
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
 
27
  Return JSON in this structure:
28
 
29
  {
30
+ "description": "...",
31
+ "objects": ["..."],
32
+ "actions": ["..."],
33
+ "environment": "...",
34
+ "content_type": "...",
35
+ "specific_style": "...",
36
+ "production_quality": "...",
37
+ "summary": "...",
38
+ "logos": ["..."]
39
  }
40
 
41
  Rules:
42
  - Be specific and literal. Focus on what is explicitly visible.
43
  - Do NOT include interpretations of emotion, mood, or narrative unless it's visually explicit.
44
  - No artistic or cinematic analysis.
45
+ - Always include the language of any text in the image if present as an object.
46
  - Maximum 10 objects and 5 actions.
47
+ - Return [] for 'logos' if none are present.
48
+ - Strictly valid JSON only.
 
49
  """
50
 
51
  # -------- Utils --------
52
  def extract_top_level_json(s: str):
 
 
53
  try:
54
  return json.loads(s)
55
  except Exception:
56
  pass
 
57
  start = None
58
  depth = 0
59
  for i, ch in enumerate(s):
 
69
  try:
70
  return json.loads(chunk)
71
  except Exception:
 
72
  start = None
73
  return None
74
 
75
+ def build_messages(image):
76
  return [
77
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
78
  {"role": "user", "content": [{"type": "image", "image": image},
 
80
  ]
81
 
82
  def downscale_if_huge(pil: Image.Image, max_side: int = 1792) -> Image.Image:
83
+ if pil is None: return pil
 
84
  w, h = pil.size
85
  m = max(w, h)
86
  if m <= max_side:
 
88
  s = max_side / m
89
  return pil.convert("RGB").resize((int(w*s), int(h*s)), Image.BICUBIC)
90
 
91
+ # -------- Load model --------
92
  processor = tokenizer = model = None
93
  LOAD_ERROR = None
 
94
  try:
95
  cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
96
  if "clip" in cfg.__class__.__name__.lower():
97
  raise RuntimeError(f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder repo; need a causal VLM.")
98
 
99
+ processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
100
  model = AutoModelForCausalLM.from_pretrained(
101
  MODEL_ID,
102
  token=HF_TOKEN,
103
  device_map="auto",
104
  torch_dtype=DTYPE,
105
  trust_remote_code=True,
106
+ # quantization_config=None, # uncomment if you want to force full precision
107
  )
 
108
  tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
109
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
110
  )
 
 
111
  except Exception as e:
112
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
113
 
 
120
 
121
  image = downscale_if_huge(image)
122
 
 
123
  if hasattr(processor, "apply_chat_template"):
124
  prompt = processor.apply_chat_template(build_messages(image), add_generation_prompt=True, tokenize=False)
125
  else:
 
126
  prompt = USER_PROMPT
127
 
 
128
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
129
 
 
 
 
130
  tried = []
131
+ # (1) Greedy
 
132
  try:
133
  g = dict(do_sample=False, max_new_tokens=MAX_NEW_TOKENS)
134
+ eos = getattr(model.config, "eos_token_id", None)
135
  if eos is not None:
136
  g["eos_token_id"] = eos
137
  with torch.inference_mode():
138
  out = model.generate(**inputs, **g)
139
+ text = processor.decode(out[0], skip_special_tokens=True)
 
 
140
  parsed = extract_top_level_json(text)
141
  if isinstance(parsed, dict):
142
  return json.dumps(parsed, indent=2), parsed, True
143
+ tried.append(("greedy", "parse-failed"))
144
  except Exception as e:
145
  tried.append(("greedy", f"err={e}"))
146
 
147
+ # (2) Sampling
148
  try:
149
  g = dict(do_sample=True, temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
150
+ eos = getattr(model.config, "eos_token_id", None)
151
  if eos is not None:
152
  g["eos_token_id"] = eos
153
  with torch.inference_mode():
154
  out = model.generate(**inputs, **g)
155
+ text = processor.decode(out[0], skip_special_tokens=True)
 
 
156
  parsed = extract_top_level_json(text)
157
  if isinstance(parsed, dict):
158
  return json.dumps(parsed, indent=2), parsed, True
159
+ tried.append(("sample", "parse-failed"))
160
  except Exception as e:
161
+ tried.append(("sample", f"err={e}"))
162
 
163
+ return "Generation failed.\n" + str(tried), None, False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  # -------- UI --------
166
+ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="ClipTagger (VLM)") as demo:
167
+ gr.Markdown("# ClipTagger\nUpload an image to get **strict JSON** annotations.")
168
  if LOAD_ERROR:
169
  with gr.Accordion("Startup Error Details", open=False):
170
  gr.Markdown(f"```\n{LOAD_ERROR}\n```")
 
171
  with gr.Row():
172
  with gr.Column(scale=1):
173
  image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")