andrejrad commited on
Commit
30343a2
·
verified ·
1 Parent(s): 906a676

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -151
app.py CHANGED
@@ -3,23 +3,20 @@ from typing import Any, Dict, Tuple
3
  import gradio as gr
4
  from PIL import Image
5
  import torch
6
- import spaces
7
  from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
8
 
9
- # ------------------ ENV ------------------
 
 
10
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
11
- HF_TOKEN = os.environ.get("HF_TOKEN")
12
  TEMP = 0.1
13
  MAX_NEW_TOKENS = 2000
14
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
15
 
16
- # ------------------ GLOBALS (lazy) ------------------
17
- _processor: Any = None
18
- _tokenizer: Any = None
19
- _model: Any = None
20
- _last_load_error: str | None = None
21
-
22
- # ------------------ PROMPTS ------------------
23
  SYSTEM_PROMPT = (
24
  "You are an image annotation API trained to analyze YouTube video keyframes. "
25
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
@@ -56,8 +53,11 @@ Rules:
56
  - Output **only the JSON**, no extra text or explanation.
57
  """
58
 
59
- # ------------------ HELPERS ------------------
 
 
60
  def _json_extract(text: str):
 
61
  try:
62
  return json.loads(text)
63
  except Exception:
@@ -76,8 +76,8 @@ def _build_messages(image: Image.Image):
76
  {"type": "text", "text": USER_PROMPT}]}
77
  ]
78
 
79
- def _downscale_if_huge(pil: Image.Image, max_side: int = 1280) -> Image.Image:
80
- # keep aspect, cap longest side to max_side to avoid enormous tensors on ZeroGPU
81
  if pil is None:
82
  return pil
83
  w, h = pil.size
@@ -88,127 +88,79 @@ def _downscale_if_huge(pil: Image.Image, max_side: int = 1280) -> Image.Image:
88
  new_w, new_h = int(w * scale), int(h * scale)
89
  return pil.convert("RGB").resize((new_w, new_h), Image.BICUBIC)
90
 
91
- # ------------------ ZERO-GPU LAZY LOADER ------------------
92
- @spaces.GPU
93
- def _ensure_loaded() -> str:
94
- global _processor, _tokenizer, _model, _last_load_error
95
- if _model is not None and _processor is not None:
96
- return "already_loaded"
97
- try:
98
- cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
99
- if "clip" in cfg.__class__.__name__.lower():
100
- raise RuntimeError(
101
- f"MODEL_ID '{MODEL_ID}' is a CLIP/encoder config; need a causal VLM."
102
- )
103
 
104
- # Try quantized (as requested by your config)
105
- _processor = AutoProcessor.from_pretrained(
 
106
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
107
  )
108
- _model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
 
109
  MODEL_ID,
110
  token=HF_TOKEN,
111
  device_map="auto",
112
  torch_dtype=DTYPE,
113
  trust_remote_code=True,
114
  )
115
- _tokenizer = getattr(_processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
116
- MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
117
- )
118
- _last_load_error = None
119
- return "ok_quant"
120
  except Exception as e:
121
- # If the worker image doesn't have compressed-tensors, fall back dequantized
122
  if "compressed_tensors" in str(e):
123
- try:
124
- _processor = AutoProcessor.from_pretrained(
125
- MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
126
- )
127
- _model = AutoModelForCausalLM.from_pretrained(
128
- MODEL_ID,
129
- token=HF_TOKEN,
130
- device_map="auto",
131
- torch_dtype=DTYPE,
132
- trust_remote_code=True,
133
- quantization_config=None, # force dequantized load
134
- )
135
- _tokenizer = getattr(_processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
136
- MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
137
- )
138
- _last_load_error = None
139
- return "ok_dequant"
140
- except Exception as e2:
141
- _last_load_error = f"{e}\n\nFallback failed:\n{e2}\n{traceback.format_exc()}"
142
- _processor = _tokenizer = _model = None
143
- return "fail"
144
  else:
145
- _last_load_error = f"{e}\n{traceback.format_exc()}"
146
- _processor = _tokenizer = _model = None
147
- return "fail"
148
 
149
- def _safe_generate(inputs, try_json: bool = True) -> Tuple[str, bool, str]:
150
- """
151
- Multi-try generation to dodge ZeroGPU/transformers edge cases:
152
- 1) with response_format=json_object (if supported)
153
- 2) no response_format
154
- 3) shorter output + temp 0.0
155
- Returns: (text_or_error, ok, detail_tag)
156
- """
157
- gen_sets = []
158
 
159
- # (1) Preferred
160
- g1 = dict(temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
161
- eos = getattr(_model.config, "eos_token_id", None)
162
- if eos is not None:
163
- g1["eos_token_id"] = eos
164
- if try_json:
165
- g1["response_format"] = {"type": "json_object"}
166
- gen_sets.append(("json_object", g1))
167
-
168
- # (2) No response_format
169
- g2 = dict(temperature=TEMP, max_new_tokens=MAX_NEW_TOKENS)
170
- if eos is not None:
171
- g2["eos_token_id"] = eos
172
- gen_sets.append(("no_response_format", g2))
173
-
174
- # (3) Shorter, deterministic
175
- g3 = dict(temperature=0.0, max_new_tokens=min(512, MAX_NEW_TOKENS))
176
- if eos is not None:
177
- g3["eos_token_id"] = eos
178
- gen_sets.append(("short_deterministic", g3))
179
-
180
- last_err = None
181
- for tag, g in gen_sets:
182
- try:
183
- with torch.inference_mode():
184
- out = _model.generate(**inputs, **g)
185
- if hasattr(_processor, "decode"):
186
- text = _processor.decode(out[0], skip_special_tokens=True)
187
- else:
188
- text = _tokenizer.decode(out[0], skip_special_tokens=True)
189
- return text, True, tag
190
- except Exception as e:
191
- last_err = f"{tag}: {e}\n{traceback.format_exc()}"
192
- # continue to next strategy
193
- return f"Generation failed.\n{last_err or ''}", False, "all_failed"
194
-
195
- # ------------------ INFERENCE ------------------
196
- @spaces.GPU
197
- def annotate_image(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
198
- status = _ensure_loaded()
199
- if status == "fail":
200
- return f"❌ Load error:\n{_last_load_error}", None, False
201
 
 
 
 
 
202
  if image is None:
203
  return "Please upload an image.", None, False
 
 
 
 
 
 
 
204
 
205
- image = _downscale_if_huge(image, max_side=1280)
206
 
207
- # Build prompt
208
- if hasattr(_processor, "apply_chat_template"):
209
- prompt = _processor.apply_chat_template(_build_messages(image), add_generation_prompt=True, tokenize=False)
210
  else:
211
- # conservative fallback (rarely used on Gemma-3)
212
  msgs = _build_messages(image)
213
  prompt = ""
214
  for m in msgs:
@@ -219,38 +171,51 @@ def annotate_image(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool
219
  elif chunk["type"] == "image":
220
  prompt += f"{role}: [IMAGE]\n"
221
 
222
- try:
223
- inputs = _processor(text=prompt, images=image, return_tensors="pt").to(_model.device)
224
- except Exception as e:
225
- err = f"Preprocessing failed: {e}\n{traceback.format_exc()}"
226
- return err, None, False
227
-
228
- txt, ok, tag = _safe_generate(inputs, try_json=True)
229
- if not ok:
230
- return txt, None, False
231
 
232
- # Trim echoed prompt if present
233
- if USER_PROMPT in txt:
234
- txt = txt.split(USER_PROMPT)[-1].strip()
235
-
236
- parsed = _json_extract(txt)
237
- if isinstance(parsed, dict):
238
- return json.dumps(parsed, indent=2), parsed, True
239
-
240
- # Show raw + tag to help debug ValueError causes
241
- return f"(strategy={tag})\n" + txt, None, False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
- # Optional warmup to validate load on first worker
244
- @spaces.GPU(duration=60)
245
- def _warmup():
246
- try:
247
- return _ensure_loaded()
248
- except Exception as e:
249
- return f"warmup error: {e}"
250
 
251
- # ------------------ UI ------------------
252
- with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (ZeroGPU)") as demo:
253
- gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT · ZeroGPU)\nUpload an image to get **strict JSON** annotations.")
 
 
 
 
 
254
 
255
  with gr.Row():
256
  with gr.Column(scale=1):
@@ -259,14 +224,12 @@ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe
259
  with gr.Column(scale=1):
260
  out_text = gr.Code(label="Output (JSON or error)")
261
  out_json = gr.JSON(label="Parsed JSON")
262
- ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
263
 
264
- btn.click(annotate_image, inputs=[image], outputs=[out_text, out_json, ok_flag])
 
265
 
266
- # best-effort warmup
267
- try:
268
- _ = _warmup()
269
- except Exception:
270
- pass
271
 
272
- demo.queue(max_size=32).launch()
 
 
3
  import gradio as gr
4
  from PIL import Image
5
  import torch
 
6
  from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
7
 
8
+ # --------------------------
9
+ # Env / params
10
+ # --------------------------
11
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
12
+ HF_TOKEN = os.environ.get("HF_TOKEN") # set in Space → Settings → Variables & secrets
13
  TEMP = 0.1
14
  MAX_NEW_TOKENS = 2000
15
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
16
 
17
+ # --------------------------
18
+ # Prompts (yours)
19
+ # --------------------------
 
 
 
 
20
  SYSTEM_PROMPT = (
21
  "You are an image annotation API trained to analyze YouTube video keyframes. "
22
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
 
53
  - Output **only the JSON**, no extra text or explanation.
54
  """
55
 
56
+ # --------------------------
57
+ # Utilities
58
+ # --------------------------
59
  def _json_extract(text: str):
60
+ """Strict JSON parse with top-level {...} fallback."""
61
  try:
62
  return json.loads(text)
63
  except Exception:
 
76
  {"type": "text", "text": USER_PROMPT}]}
77
  ]
78
 
79
+ def _downscale_if_huge(pil: Image.Image, max_side: int = 1792) -> Image.Image:
80
+ """Cap longest side to keep memory predictable; A100 is roomy but this avoids extreme uploads."""
81
  if pil is None:
82
  return pil
83
  w, h = pil.size
 
88
  new_w, new_h = int(w * scale), int(h * scale)
89
  return pil.convert("RGB").resize((new_w, new_h), Image.BICUBIC)
90
 
91
+ # --------------------------
92
+ # Load model (dedicated GPU)
93
+ # --------------------------
94
+ processor = tokenizer = model = None
95
+ LOAD_ERROR = None
96
+
97
+ try:
98
+ cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
99
+ if "clip" in cfg.__class__.__name__.lower():
100
+ raise RuntimeError(
101
+ f"MODEL_ID '{MODEL_ID}' resolves to a CLIP/encoder config; need a causal VLM checkpoint."
102
+ )
103
 
104
+ # Try quantized path (compressed-tensors) per your config
105
+ try:
106
+ processor = AutoProcessor.from_pretrained(
107
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
108
  )
109
+ except TypeError:
110
+ processor = AutoProcessor.from_pretrained(
111
+ MODEL_ID, token=HF_TOKEN, trust_remote_code=True
112
+ )
113
+
114
+ try:
115
+ model = AutoModelForCausalLM.from_pretrained(
116
  MODEL_ID,
117
  token=HF_TOKEN,
118
  device_map="auto",
119
  torch_dtype=DTYPE,
120
  trust_remote_code=True,
121
  )
 
 
 
 
 
122
  except Exception as e:
123
+ # Fallback: disable quantization if the backend isn't available
124
  if "compressed_tensors" in str(e):
125
+ model = AutoModelForCausalLM.from_pretrained(
126
+ MODEL_ID,
127
+ token=HF_TOKEN,
128
+ device_map="auto",
129
+ torch_dtype=DTYPE,
130
+ trust_remote_code=True,
131
+ quantization_config=None,
132
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  else:
134
+ raise
 
 
135
 
136
+ tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
137
+ MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
138
+ )
 
 
 
 
 
 
139
 
140
+ except Exception as e:
141
+ LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ # --------------------------
144
+ # Inference
145
+ # --------------------------
146
+ def run(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
147
  if image is None:
148
  return "Please upload an image.", None, False
149
+ if model is None or processor is None:
150
+ msg = (
151
+ "❌ Model failed to load.\n\n"
152
+ f"{LOAD_ERROR or 'Unknown error.'}\n"
153
+ "Check MODEL_ID/HF_TOKEN and that the repo includes model + processor files."
154
+ )
155
+ return msg, None, False
156
 
157
+ image = _downscale_if_huge(image)
158
 
159
+ # Build chat prompt
160
+ if hasattr(processor, "apply_chat_template"):
161
+ prompt = processor.apply_chat_template(_build_messages(image), add_generation_prompt=True, tokenize=False)
162
  else:
163
+ # Very rare fallback path
164
  msgs = _build_messages(image)
165
  prompt = ""
166
  for m in msgs:
 
171
  elif chunk["type"] == "image":
172
  prompt += f"{role}: [IMAGE]\n"
173
 
174
+ # Tokenize with vision
175
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
176
 
177
+ # Gen args
178
+ gen_kwargs = dict(
179
+ temperature=TEMP,
180
+ max_new_tokens=MAX_NEW_TOKENS,
181
+ )
182
+ eos = getattr(model.config, "eos_token_id", None)
183
+ if eos is not None:
184
+ gen_kwargs["eos_token_id"] = eos
185
+
186
+ # Try to enforce JSON; if unsupported, we'll retry without
187
+ tried = []
188
+ for tag, extra in [
189
+ ("json_object", {"response_format": {"type": "json_object"}}),
190
+ ("no_response_format", {}),
191
+ ("short_deterministic", {"temperature": 0.0, "max_new_tokens": min(512, MAX_NEW_TOKENS)}),
192
+ ]:
193
+ try:
194
+ with torch.inference_mode():
195
+ out = model.generate(**inputs, **{**gen_kwargs, **extra})
196
+ text = (processor.decode(out[0], skip_special_tokens=True)
197
+ if hasattr(processor, "decode")
198
+ else AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN, use_fast=True).decode(out[0], skip_special_tokens=True))
199
+ if USER_PROMPT in text:
200
+ text = text.split(USER_PROMPT)[-1].strip()
201
+ parsed = _json_extract(text)
202
+ if isinstance(parsed, dict):
203
+ return json.dumps(parsed, indent=2), parsed, True
204
+ tried.append((tag, "parsed-failed"))
205
+ except Exception as e:
206
+ tried.append((tag, f"err={e}"))
207
 
208
+ # If all strategies failed, return debug info
209
+ return "Generation failed.\nTried: " + "\n".join([f"{t[0]} -> {t[1]}" for t in tried]), None, False
 
 
 
 
 
210
 
211
+ # --------------------------
212
+ # UI
213
+ # --------------------------
214
+ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (Gemma-3 VLM)") as demo:
215
+ gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT · A100)\nUpload an image to get **strict JSON** annotations.")
216
+ if LOAD_ERROR:
217
+ with gr.Accordion("Startup Error Details", open=False):
218
+ gr.Markdown(f"```\n{LOAD_ERROR}\n```")
219
 
220
  with gr.Row():
221
  with gr.Column(scale=1):
 
224
  with gr.Column(scale=1):
225
  out_text = gr.Code(label="Output (JSON or error)")
226
  out_json = gr.JSON(label="Parsed JSON")
227
+ ok = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
228
 
229
+ def on_click(img):
230
+ return run(img)
231
 
232
+ btn.click(on_click, inputs=[image], outputs=[out_text, out_json, ok])
 
 
 
 
233
 
234
+ # Conservative concurrency to avoid OOM spikes; A100-80GB can increase this.
235
+ demo.queue(max_size=32, max_concurrency=1).launch()