andrejrad commited on
Commit
9f13dde
verified
1 Parent(s): 62cada7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -48
app.py CHANGED
@@ -6,19 +6,20 @@ import torch
6
  import spaces
7
  from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
8
 
9
- # --------- ENV / PARAMS ----------
10
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
11
- HF_TOKEN = os.environ.get("HF_TOKEN") # put this in Space -> Settings -> Variables & secrets
12
  TEMP = 0.1
13
  MAX_NEW_TOKENS = 2000
 
14
 
15
- # Lazy globals (ZeroGPU-safe)
16
  _processor: Any = None
17
  _tokenizer: Any = None
18
  _model: Any = None
19
  _last_load_error: str | None = None
20
 
21
- # --------- PROMPTS (yours) ----------
22
  SYSTEM_PROMPT = (
23
  "You are an image annotation API trained to analyze YouTube video keyframes. "
24
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
@@ -55,9 +56,8 @@ Rules:
55
  - Output **only the JSON**, no extra text or explanation.
56
  """
57
 
58
- # --------- HELPERS ----------
59
  def _json_extract(text: str):
60
- """Strict parse -> top-level {...} fallback."""
61
  try:
62
  return json.loads(text)
63
  except Exception:
@@ -76,26 +76,32 @@ def _build_messages(image: Image.Image):
76
  {"type": "text", "text": USER_PROMPT}]}
77
  ]
78
 
79
- # --------- ZERO-GPU LAZY LOADER ----------
 
 
 
 
 
 
 
 
 
 
 
 
80
  @spaces.GPU
81
  def _ensure_loaded() -> str:
82
- """
83
- Load the model only when a ZeroGPU worker with a GPU is attached.
84
- Tries quantized path first (compressed-tensors), then falls back to unquantized.
85
- """
86
  global _processor, _tokenizer, _model, _last_load_error
87
  if _model is not None and _processor is not None:
88
  return "already_loaded"
89
-
90
  try:
91
- # Sanity: config should be gemma3 causal VLM (not CLIP)
92
  cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
93
  if "clip" in cfg.__class__.__name__.lower():
94
  raise RuntimeError(
95
- f"MODEL_ID '{MODEL_ID}' resolves to CLIP/encoder config; need a causal VLM checkpoint."
96
  )
97
 
98
- # Try quantized (as per your config)
99
  _processor = AutoProcessor.from_pretrained(
100
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
101
  )
@@ -103,7 +109,7 @@ def _ensure_loaded() -> str:
103
  MODEL_ID,
104
  token=HF_TOKEN,
105
  device_map="auto",
106
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
107
  trust_remote_code=True,
108
  )
109
  _tokenizer = getattr(_processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
@@ -112,7 +118,7 @@ def _ensure_loaded() -> str:
112
  _last_load_error = None
113
  return "ok_quant"
114
  except Exception as e:
115
- # Fallback: disable quantization (more VRAM)
116
  if "compressed_tensors" in str(e):
117
  try:
118
  _processor = AutoProcessor.from_pretrained(
@@ -122,7 +128,7 @@ def _ensure_loaded() -> str:
122
  MODEL_ID,
123
  token=HF_TOKEN,
124
  device_map="auto",
125
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
126
  trust_remote_code=True,
127
  quantization_config=None, # force dequantized load
128
  )
@@ -140,7 +146,53 @@ def _ensure_loaded() -> str:
140
  _processor = _tokenizer = _model = None
141
  return "fail"
142
 
143
- # --------- INFERENCE ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  @spaces.GPU
145
  def annotate_image(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
146
  status = _ensure_loaded()
@@ -150,10 +202,13 @@ def annotate_image(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool
150
  if image is None:
151
  return "Please upload an image.", None, False
152
 
153
- # Prompt assembly
 
 
154
  if hasattr(_processor, "apply_chat_template"):
155
  prompt = _processor.apply_chat_template(_build_messages(image), add_generation_prompt=True, tokenize=False)
156
  else:
 
157
  msgs = _build_messages(image)
158
  prompt = ""
159
  for m in msgs:
@@ -164,39 +219,28 @@ def annotate_image(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool
164
  elif chunk["type"] == "image":
165
  prompt += f"{role}: [IMAGE]\n"
166
 
167
- inputs = _processor(text=prompt, images=image, return_tensors="pt").to(_model.device)
168
-
169
- gen_kwargs = dict(
170
- temperature=TEMP,
171
- max_new_tokens=MAX_NEW_TOKENS,
172
- )
173
- # respect multiple eos ids if present
174
- eos = getattr(_model.config, "eos_token_id", None)
175
- if eos is not None:
176
- gen_kwargs["eos_token_id"] = eos
177
-
178
- # Try JSON-only output (if supported)
179
  try:
180
- gen_kwargs["response_format"] = {"type": "json_object"}
181
- except Exception:
182
- pass
183
-
184
- with torch.inference_mode():
185
- out = _model.generate(**inputs, **gen_kwargs)
186
 
187
- text = (_processor.decode(out[0], skip_special_tokens=True)
188
- if hasattr(_processor, "decode")
189
- else _tokenizer.decode(out[0], skip_special_tokens=True))
190
 
191
- if USER_PROMPT in text:
192
- text = text.split(USER_PROMPT)[-1].strip()
 
193
 
194
- parsed = _json_extract(text)
195
  if isinstance(parsed, dict):
196
  return json.dumps(parsed, indent=2), parsed, True
197
- return text, None, False
198
 
199
- # Optional: quick warmup to validate loading on first worker
 
 
 
200
  @spaces.GPU(duration=60)
201
  def _warmup():
202
  try:
@@ -204,7 +248,7 @@ def _warmup():
204
  except Exception as e:
205
  return f"warmup error: {e}"
206
 
207
- # --------- UI ----------
208
  with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (ZeroGPU)") as demo:
209
  gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT 路 ZeroGPU)\nUpload an image to get **strict JSON** annotations.")
210
 
@@ -219,7 +263,7 @@ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe
219
 
220
  btn.click(annotate_image, inputs=[image], outputs=[out_text, out_json, ok_flag])
221
 
222
- # fire a non-blocking warmup
223
  try:
224
  _ = _warmup()
225
  except Exception:
 
6
  import spaces
7
  from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
8
 
9
+ # ------------------ ENV ------------------
10
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
11
+ HF_TOKEN = os.environ.get("HF_TOKEN")
12
  TEMP = 0.1
13
  MAX_NEW_TOKENS = 2000
14
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
15
 
16
+ # ------------------ GLOBALS (lazy) ------------------
17
  _processor: Any = None
18
  _tokenizer: Any = None
19
  _model: Any = None
20
  _last_load_error: str | None = None
21
 
22
+ # ------------------ PROMPTS ------------------
23
  SYSTEM_PROMPT = (
24
  "You are an image annotation API trained to analyze YouTube video keyframes. "
25
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
 
56
  - Output **only the JSON**, no extra text or explanation.
57
  """
58
 
59
+ # ------------------ HELPERS ------------------
60
  def _json_extract(text: str):
 
61
  try:
62
  return json.loads(text)
63
  except Exception:
 
76
  {"type": "text", "text": USER_PROMPT}]}
77
  ]
78
 
79
+ def _downscale_if_huge(pil: Image.Image, max_side: int = 1280) -> Image.Image:
80
+ # keep aspect, cap longest side to max_side to avoid enormous tensors on ZeroGPU
81
+ if pil is None:
82
+ return pil
83
+ w, h = pil.size
84
+ m = max(w, h)
85
+ if m <= max_side:
86
+ return pil.convert("RGB")
87
+ scale = max_side / m
88
+ new_w, new_h = int(w * scale), int(h * scale)
89
+ return pil.convert("RGB").resize((new_w, new_h), Image.BICUBIC)
90
+
91
+ # ------------------ ZERO-GPU LAZY LOADER ------------------
92
  @spaces.GPU
93
  def _ensure_loaded() -> str:
 
 
 
 
94
  global _processor, _tokenizer, _model, _last_load_error
95
  if _model is not None and _processor is not None:
96
  return "already_loaded"
 
97
  try:
 
98
  cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
99
  if "clip" in cfg.__class__.__name__.lower():
100
  raise RuntimeError(
101
+ f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder config; need a causal VLM."
102
  )
103
 
104
+ # Try quantized (as requested by your config)
105
  _processor = AutoProcessor.from_pretrained(
106
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
107
  )
 
109
  MODEL_ID,
110
  token=HF_TOKEN,
111
  device_map="auto",
112
+ torch_dtype=DTYPE,
113
  trust_remote_code=True,
114
  )
115
  _tokenizer = getattr(_processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
 
118
  _last_load_error = None
119
  return "ok_quant"
120
  except Exception as e:
121
+ # If the worker image doesn't have compressed-tensors, fall back dequantized
122
  if "compressed_tensors" in str(e):
123
  try:
124
  _processor = AutoProcessor.from_pretrained(
 
128
  MODEL_ID,
129
  token=HF_TOKEN,
130
  device_map="auto",
131
+ torch_dtype=DTYPE,
132
  trust_remote_code=True,
133
  quantization_config=None, # force dequantized load
134
  )
 
146
  _processor = _tokenizer = _model = None
147
  return "fail"
148
 
149
+ def _safe_generate(inputs, try_json: bool = True) -> Tuple[str, bool, str]:
150
+ """
151
+ Multi-try generation to dodge ZeroGPU/transformers edge cases:
152
+ 1) with response_format=json_object (if supported)
153
+ 2) no response_format
154
+ 3) shorter output + temp 0.0
155
+ Returns: (text_or_error, ok, detail_tag)
156
+ """
157
+ gen_sets = []
158
+
159
+ # (1) Preferred
160
+ g1 = dict(temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
161
+ eos = getattr(_model.config, "eos_token_id", None)
162
+ if eos is not None:
163
+ g1["eos_token_id"] = eos
164
+ if try_json:
165
+ g1["response_format"] = {"type": "json_object"}
166
+ gen_sets.append(("json_object", g1))
167
+
168
+ # (2) No response_format
169
+ g2 = dict(temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
170
+ if eos is not None:
171
+ g2["eos_token_id"] = eos
172
+ gen_sets.append(("no_response_format", g2))
173
+
174
+ # (3) Shorter, deterministic
175
+ g3 = dict(temperature=0.0, max_new_tokens=min(512, MAX_NEW_TOKENS))
176
+ if eos is not None:
177
+ g3["eos_token_id"] = eos
178
+ gen_sets.append(("short_deterministic", g3))
179
+
180
+ last_err = None
181
+ for tag, g in gen_sets:
182
+ try:
183
+ with torch.inference_mode():
184
+ out = _model.generate(**inputs, **g)
185
+ if hasattr(_processor, "decode"):
186
+ text = _processor.decode(out[0], skip_special_tokens=True)
187
+ else:
188
+ text = _tokenizer.decode(out[0], skip_special_tokens=True)
189
+ return text, True, tag
190
+ except Exception as e:
191
+ last_err = f"{tag}: {e}\n{traceback.format_exc()}"
192
+ # continue to next strategy
193
+ return f"Generation failed.\n{last_err or ''}", False, "all_failed"
194
+
195
+ # ------------------ INFERENCE ------------------
196
  @spaces.GPU
197
  def annotate_image(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
198
  status = _ensure_loaded()
 
202
  if image is None:
203
  return "Please upload an image.", None, False
204
 
205
+ image = _downscale_if_huge(image, max_side=1280)
206
+
207
+ # Build prompt
208
  if hasattr(_processor, "apply_chat_template"):
209
  prompt = _processor.apply_chat_template(_build_messages(image), add_generation_prompt=True, tokenize=False)
210
  else:
211
+ # conservative fallback (rarely used on Gemma-3)
212
  msgs = _build_messages(image)
213
  prompt = ""
214
  for m in msgs:
 
219
  elif chunk["type"] == "image":
220
  prompt += f"{role}: [IMAGE]\n"
221
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  try:
223
+ inputs = _processor(text=prompt, images=image, return_tensors="pt").to(_model.device)
224
+ except Exception as e:
225
+ err = f"Preprocessing failed: {e}\n{traceback.format_exc()}"
226
+ return err, None, False
 
 
227
 
228
+ txt, ok, tag = _safe_generate(inputs, try_json=True)
229
+ if not ok:
230
+ return txt, None, False
231
 
232
+ # Trim echoed prompt if present
233
+ if USER_PROMPT in txt:
234
+ txt = txt.split(USER_PROMPT)[-1].strip()
235
 
236
+ parsed = _json_extract(txt)
237
  if isinstance(parsed, dict):
238
  return json.dumps(parsed, indent=2), parsed, True
 
239
 
240
+ # Show raw + tag to help debug ValueError causes
241
+ return f"(strategy={tag})\n" + txt, None, False
242
+
243
+ # Optional warmup to validate load on first worker
244
  @spaces.GPU(duration=60)
245
  def _warmup():
246
  try:
 
248
  except Exception as e:
249
  return f"warmup error: {e}"
250
 
251
+ # ------------------ UI ------------------
252
  with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (ZeroGPU)") as demo:
253
  gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT 路 ZeroGPU)\nUpload an image to get **strict JSON** annotations.")
254
 
 
263
 
264
  btn.click(annotate_image, inputs=[image], outputs=[out_text, out_json, ok_flag])
265
 
266
+ # best-effort warmup
267
  try:
268
  _ = _warmup()
269
  except Exception: