andrejrad commited on
Commit
21b17c3
·
verified ·
1 Parent(s): 8ed3bc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -136
app.py CHANGED
@@ -4,22 +4,21 @@ import gradio as gr
4
  from PIL import Image
5
  import torch
6
  import spaces
 
7
 
8
- # --------------------------
9
- # Environment
10
- # --------------------------
11
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
12
- HF_TOKEN = os.environ.get("HF_TOKEN")
13
-
14
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
- DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
16
-
17
  TEMP = 0.1
18
  MAX_NEW_TOKENS = 2000
19
 
20
- # --------------------------
21
- # Prompts (yours)
22
- # --------------------------
 
 
 
 
23
  SYSTEM_PROMPT = (
24
  "You are an image annotation API trained to analyze YouTube video keyframes. "
25
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
@@ -56,77 +55,105 @@ Rules:
56
  - Output **only the JSON**, no extra text or explanation.
57
  """
58
 
59
- # --------------------------
60
- # Load full VLM (Gemma-3)
61
- # --------------------------
62
- from transformers import AutoConfig, AutoProcessor, AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
 
 
63
 
64
- processor = tokenizer = model = None
65
- LOAD_ERROR = None
 
 
 
 
66
 
67
- try:
68
- cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
69
- if "clip" in cfg.__class__.__name__.lower():
70
- raise RuntimeError(
71
- f"MODEL_ID '{MODEL_ID}' resolves to a CLIP/encoder config. "
72
- "Point MODEL_ID to your full VLM checkpoint (this repo's config shows gemma3)."
73
- )
 
 
 
74
 
75
- # Processor (has vision + tokenizer routing)
76
  try:
77
- processor = AutoProcessor.from_pretrained(
 
 
 
 
 
 
 
 
78
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
79
  )
80
- except TypeError:
81
- processor = AutoProcessor.from_pretrained(
82
- MODEL_ID, token=HF_TOKEN, trust_remote_code=True
 
 
 
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Model
86
- model = AutoModelForCausalLM.from_pretrained(
87
- MODEL_ID,
88
- token=HF_TOKEN,
89
- device_map="auto",
90
- torch_dtype=DTYPE,
91
- trust_remote_code=True,
92
- )
93
-
94
- # Tokenizer (fall back in case processor doesn't expose it)
95
- tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
96
- MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
97
- )
98
-
99
- except Exception as e:
100
- LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
101
-
102
- # --------------------------
103
- # Inference
104
- # --------------------------
105
- def _build_messages(image: Image.Image):
106
- return [
107
- {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
108
- {"role": "user", "content": [{"type": "image", "image": image},
109
- {"type": "text", "text": USER_PROMPT}]}
110
- ]
111
-
112
- def _run(image: Image.Image) -> Tuple[str, Dict[str, Any], bool]:
113
  if image is None:
114
  return "Please upload an image.", None, False
115
- if model is None or processor is None:
116
- msg = (
117
- "❌ Model failed to load.\n\n"
118
- f"{LOAD_ERROR or 'Unknown error.'}\n"
119
- "Check: MODEL_ID, HF_TOKEN, and that the repo includes processor + model shards."
120
- )
121
- return msg, None, False
122
 
123
- # Build chat input
124
- if hasattr(processor, "apply_chat_template"):
125
- prompt = processor.apply_chat_template(
126
- _build_messages(image), add_generation_prompt=True, tokenize=False
127
- )
128
  else:
129
- # Conservative fallback
130
  msgs = _build_messages(image)
131
  prompt = ""
132
  for m in msgs:
@@ -137,93 +164,49 @@ def _run(image: Image.Image) -> Tuple[str, Dict[str, Any], bool]:
137
  elif chunk["type"] == "image":
138
  prompt += f"{role}: [IMAGE]\n"
139
 
140
- # Tokenize with vision
141
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
142
 
143
- # Generation args
144
  gen_kwargs = dict(
145
  temperature=TEMP,
146
  max_new_tokens=MAX_NEW_TOKENS,
147
  )
148
- # If your config has multiple eos ids (yours does: [1, 106]), pass them
149
- eos_id = getattr(tokenizer, "eos_token_id", None)
150
- try:
151
- # prefer config’s eos_token_id if list-like
152
- from transformers.utils import is_torch_available
153
- cfg_eos = getattr(model.config, "eos_token_id", None)
154
- if isinstance(cfg_eos, (list, tuple)):
155
- gen_kwargs["eos_token_id"] = list(cfg_eos)
156
- elif eos_id is not None:
157
- gen_kwargs["eos_token_id"] = eos_id
158
- except Exception:
159
- if eos_id is not None:
160
- gen_kwargs["eos_token_id"] = eos_id
161
 
162
- # Ask model to emit strict JSON (supported in newer transformers for some models)
163
  try:
164
  gen_kwargs["response_format"] = {"type": "json_object"}
165
  except Exception:
166
  pass
167
 
168
  with torch.inference_mode():
169
- out_ids = model.generate(**inputs, **gen_kwargs)
170
 
171
- # Decode via processor if available (some VLMs override decode)
172
- if hasattr(processor, "decode"):
173
- text = processor.decode(out_ids[0], skip_special_tokens=True)
174
- else:
175
- text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
176
 
177
- # Trim any echoed prompt
178
  if USER_PROMPT in text:
179
  text = text.split(USER_PROMPT)[-1].strip()
180
 
181
- # Strict parse, with fallback to top-level {...}
182
- try:
183
- parsed = json.loads(text)
184
  return json.dumps(parsed, indent=2), parsed, True
185
- except Exception:
186
- m = re.search(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
187
- if m:
188
- try:
189
- parsed = json.loads(m.group(0))
190
- return json.dumps(parsed, indent=2), parsed, True
191
- except Exception:
192
- pass
193
- # Return raw text to help debug prompt adherence if needed
194
- return text, None, False
195
-
196
- # --------------------------
197
- # Spaces GPU entry + warmup
198
- # --------------------------
199
- @spaces.GPU
200
- def annotate_image(pil: Image.Image):
201
- return _run(pil)
202
 
 
203
  @spaces.GPU(duration=60)
204
  def _warmup():
205
- if model is None or processor is None:
206
- return "skip"
207
  try:
208
- dummy = Image.new("RGB", (64, 64), (127, 127, 127))
209
- _ = _run(dummy)
210
- return "ok"
211
  except Exception as e:
212
  return f"warmup error: {e}"
213
 
214
- try:
215
- _ = _warmup()
216
- except Exception:
217
- pass
218
-
219
- # --------------------------
220
- # UI
221
- # --------------------------
222
- with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (Gemma-3 VLM)") as demo:
223
- gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT)\nUpload an image to get **strict JSON** annotations.")
224
- if LOAD_ERROR:
225
- with gr.Accordion("Startup Error Details", open=False):
226
- gr.Markdown(f"```\n{LOAD_ERROR}\n```")
227
 
228
  with gr.Row():
229
  with gr.Column(scale=1):
@@ -234,10 +217,12 @@ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe
234
  out_json = gr.JSON(label="Parsed JSON")
235
  ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
236
 
237
- def on_click(img):
238
- text, js, ok = _run(img)
239
- return text, js, ok
240
-
241
  btn.click(annotate_image, inputs=[image], outputs=[out_text, out_json, ok_flag])
242
 
 
 
 
 
 
 
243
  demo.queue(max_size=32).launch()
 
4
  from PIL import Image
5
  import torch
6
  import spaces
7
+ from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoConfig
8
 
9
+ # --------- ENV / PARAMS ----------
 
 
10
  MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
11
+ HF_TOKEN = os.environ.get("HF_TOKEN") # put this in Space -> Settings -> Variables & secrets
 
 
 
 
12
  TEMP = 0.1
13
  MAX_NEW_TOKENS = 2000
14
 
15
+ # Lazy globals (ZeroGPU-safe)
16
+ _processor: Any = None
17
+ _tokenizer: Any = None
18
+ _model: Any = None
19
+ _last_load_error: str | None = None
20
+
21
+ # --------- PROMPTS (yours) ----------
22
  SYSTEM_PROMPT = (
23
  "You are an image annotation API trained to analyze YouTube video keyframes. "
24
  "You will be given instructions on the output format, what to caption, and how to perform your job. "
 
55
  - Output **only the JSON**, no extra text or explanation.
56
  """
57
 
58
+ # --------- HELPERS ----------
59
+ def _json_extract(text: str):
60
+ """Strict parse -> top-level {...} fallback."""
61
+ try:
62
+ return json.loads(text)
63
+ except Exception:
64
+ m = re.search(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
65
+ if m:
66
+ try:
67
+ return json.loads(m.group(0))
68
+ except Exception:
69
+ pass
70
+ return None
71
 
72
+ def _build_messages(image: Image.Image):
73
+ return [
74
+ {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
75
+ {"role": "user", "content": [{"type": "image", "image": image},
76
+ {"type": "text", "text": USER_PROMPT}]}
77
+ ]
78
 
79
+ # --------- ZERO-GPU LAZY LOADER ----------
80
+ @spaces.GPU
81
+ def _ensure_loaded() -> str:
82
+ """
83
+ Load the model only when a ZeroGPU worker with a GPU is attached.
84
+ Tries quantized path first (compressed-tensors), then falls back to unquantized.
85
+ """
86
+ global _processor, _tokenizer, _model, _last_load_error
87
+ if _model is not None and _processor is not None:
88
+ return "already_loaded"
89
 
 
90
  try:
91
+ # Sanity: config should be gemma3 causal VLM (not CLIP)
92
+ cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
93
+ if "clip" in cfg.__class__.__name__.lower():
94
+ raise RuntimeError(
95
+ f"MODEL_ID '{MODEL_ID}' resolves to CLIP/encoder config; need a causal VLM checkpoint."
96
+ )
97
+
98
+ # Try quantized (as per your config)
99
+ _processor = AutoProcessor.from_pretrained(
100
  MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
101
  )
102
+ _model = AutoModelForCausalLM.from_pretrained(
103
+ MODEL_ID,
104
+ token=HF_TOKEN,
105
+ device_map="auto",
106
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
107
+ trust_remote_code=True,
108
  )
109
+ _tokenizer = getattr(_processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
110
+ MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
111
+ )
112
+ _last_load_error = None
113
+ return "ok_quant"
114
+ except Exception as e:
115
+ # Fallback: disable quantization (more VRAM)
116
+ if "compressed_tensors" in str(e):
117
+ try:
118
+ _processor = AutoProcessor.from_pretrained(
119
+ MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
120
+ )
121
+ _model = AutoModelForCausalLM.from_pretrained(
122
+ MODEL_ID,
123
+ token=HF_TOKEN,
124
+ device_map="auto",
125
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
126
+ trust_remote_code=True,
127
+ quantization_config=None, # force dequantized load
128
+ )
129
+ _tokenizer = getattr(_processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
130
+ MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
131
+ )
132
+ _last_load_error = None
133
+ return "ok_dequant"
134
+ except Exception as e2:
135
+ _last_load_error = f"{e}\n\nFallback failed:\n{e2}\n{traceback.format_exc()}"
136
+ _processor = _tokenizer = _model = None
137
+ return "fail"
138
+ else:
139
+ _last_load_error = f"{e}\n{traceback.format_exc()}"
140
+ _processor = _tokenizer = _model = None
141
+ return "fail"
142
+
143
+ # --------- INFERENCE ----------
144
+ @spaces.GPU
145
+ def annotate_image(image: Image.Image) -> Tuple[str, Dict[str, Any] | None, bool]:
146
+ status = _ensure_loaded()
147
+ if status == "fail":
148
+ return f"❌ Load error:\n{_last_load_error}", None, False
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  if image is None:
151
  return "Please upload an image.", None, False
 
 
 
 
 
 
 
152
 
153
+ # Prompt assembly
154
+ if hasattr(_processor, "apply_chat_template"):
155
+ prompt = _processor.apply_chat_template(_build_messages(image), add_generation_prompt=True, tokenize=False)
 
 
156
  else:
 
157
  msgs = _build_messages(image)
158
  prompt = ""
159
  for m in msgs:
 
164
  elif chunk["type"] == "image":
165
  prompt += f"{role}: [IMAGE]\n"
166
 
167
+ inputs = _processor(text=prompt, images=image, return_tensors="pt").to(_model.device)
 
168
 
 
169
  gen_kwargs = dict(
170
  temperature=TEMP,
171
  max_new_tokens=MAX_NEW_TOKENS,
172
  )
173
+ # respect multiple eos ids if present
174
+ eos = getattr(_model.config, "eos_token_id", None)
175
+ if eos is not None:
176
+ gen_kwargs["eos_token_id"] = eos
 
 
 
 
 
 
 
 
 
177
 
178
+ # Try JSON-only output (if supported)
179
  try:
180
  gen_kwargs["response_format"] = {"type": "json_object"}
181
  except Exception:
182
  pass
183
 
184
  with torch.inference_mode():
185
+ out = _model.generate(**inputs, **gen_kwargs)
186
 
187
+ text = (_processor.decode(out[0], skip_special_tokens=True)
188
+ if hasattr(_processor, "decode")
189
+ else _tokenizer.decode(out[0], skip_special_tokens=True))
 
 
190
 
 
191
  if USER_PROMPT in text:
192
  text = text.split(USER_PROMPT)[-1].strip()
193
 
194
+ parsed = _json_extract(text)
195
+ if isinstance(parsed, dict):
 
196
  return json.dumps(parsed, indent=2), parsed, True
197
+ return text, None, False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ # Optional: quick warmup to validate loading on first worker
200
  @spaces.GPU(duration=60)
201
  def _warmup():
 
 
202
  try:
203
+ return _ensure_loaded()
 
 
204
  except Exception as e:
205
  return f"warmup error: {e}"
206
 
207
+ # --------- UI ----------
208
+ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (ZeroGPU)") as demo:
209
+ gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT · ZeroGPU)\nUpload an image to get **strict JSON** annotations.")
 
 
 
 
 
 
 
 
 
 
210
 
211
  with gr.Row():
212
  with gr.Column(scale=1):
 
217
  out_json = gr.JSON(label="Parsed JSON")
218
  ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
219
 
 
 
 
 
220
  btn.click(annotate_image, inputs=[image], outputs=[out_text, out_json, ok_flag])
221
 
222
+ # fire a non-blocking warmup
223
+ try:
224
+ _ = _warmup()
225
+ except Exception:
226
+ pass
227
+
228
  demo.queue(max_size=32).launch()