andrejrad commited on
Commit
0557d7f
·
verified ·
1 Parent(s): 501fb5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -86
app.py CHANGED
@@ -1,24 +1,24 @@
1
  import os, json, re, traceback
 
2
  import gradio as gr
3
  from PIL import Image
4
  import torch
5
  import spaces
6
 
7
  # --------------------------
8
- # Config (via Space secrets)
9
  # --------------------------
10
- # ADAPTER_ID: your fine-tune adapter repo (PEFT). Example: GrassData/cliptagger-12b
11
- # BASE_ID: the Gemma-3 VLM base you fine-tuned from. Example: google/gemma-3-12b-it (gated)
12
- # HF_TOKEN: user access token that has access to BASE_ID (if gated)
13
- ADAPTER_ID = os.environ.get("MODEL_ID", os.environ.get("ADAPTER_ID", "inference-net/ClipTagger-12b"))
14
- BASE_ID = os.environ.get("BASE_ID", "google/gemma-3-12b-it")
15
- HF_TOKEN = os.environ.get("HF_TOKEN")
16
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
19
 
 
 
 
20
  # --------------------------
21
- # Prompts (your spec)
22
  # --------------------------
23
  SYSTEM_PROMPT = (
24
  "You are an image annotation API trained to analyze YouTube video keyframes. "
@@ -57,101 +57,77 @@ Rules:
57
  """
58
 
59
  # --------------------------
60
- # Load base + adapter (PEFT)
61
  # --------------------------
62
- def load_model_stack():
63
- from transformers import AutoProcessor, AutoTokenizer, AutoConfig, AutoModelForCausalLM
64
- from peft import PeftModel
65
 
66
- # Prefer loading processor from BASE_ID (has preproc files). If you've vendored
67
- # processor files into the adapter repo, you can switch to ADAPTER_ID here.
 
 
 
 
 
 
 
 
 
 
68
  try:
69
  processor = AutoProcessor.from_pretrained(
70
- BASE_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
71
  )
72
  except TypeError:
73
- # Some processor classes don't accept use_fast
74
  processor = AutoProcessor.from_pretrained(
75
- BASE_ID, token=HF_TOKEN, trust_remote_code=True
76
  )
77
 
78
- # Sanity check: ADAPTER should not be CLIP-only
79
- cfg = AutoConfig.from_pretrained(ADAPTER_ID, token=HF_TOKEN, trust_remote_code=True)
80
- if cfg.__class__.__name__.lower().startswith("clip"):
81
- raise RuntimeError(
82
- f"MODEL_ID/ADAPTER_ID ({ADAPTER_ID}) resolves to a CLIP/encoder config "
83
- "and cannot be used with AutoModelForCausalLM. Point to your PEFT adapter "
84
- "repo (Gemma-3 VLM adapters) or a full causal VLM checkpoint."
85
- )
86
-
87
- base = AutoModelForCausalLM.from_pretrained(
88
- BASE_ID,
89
  token=HF_TOKEN,
90
  device_map="auto",
91
  torch_dtype=DTYPE,
92
  trust_remote_code=True,
93
  )
94
 
95
- model = PeftModel.from_pretrained(
96
- base,
97
- ADAPTER_ID,
98
- token=HF_TOKEN,
99
  )
100
 
101
- # Merge adapters for faster inference (optional)
102
- try:
103
- model = model.merge_and_unload()
104
- except Exception:
105
- # If merge isn’t supported, we keep PEFT wrapper
106
- pass
107
-
108
- tokenizer = getattr(processor, "tokenizer", None)
109
- if tokenizer is None:
110
- tokenizer = AutoTokenizer.from_pretrained(
111
- BASE_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
112
- )
113
-
114
- return processor, tokenizer, model
115
-
116
- LOAD_ERROR = None
117
- processor = tokenizer = model = None
118
- try:
119
- processor, tokenizer, model = load_model_stack()
120
  except Exception as e:
121
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
122
 
123
  # --------------------------
124
  # Inference
125
  # --------------------------
126
- def build_messages(image: Image.Image):
127
  return [
128
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
129
- {"role": "user", "content": [{"type": "image", "image": image},
130
- {"type": "text", "text": USER_PROMPT}]}
131
  ]
132
 
133
- def generate_json(image: Image.Image):
134
  if image is None:
135
  return "Please upload an image.", None, False
136
-
137
  if model is None or processor is None:
138
  msg = (
139
  "❌ Model failed to load.\n\n"
140
- f"{LOAD_ERROR or 'Unknown error. Check BASE_ID/ADAPTER_ID/HF_TOKEN.'}\n"
141
- " Ensure HF_TOKEN belongs to an account with access to the BASE_ID (if gated).\n"
142
- "• Ensure MODEL_ID/ADAPTER_ID points to a Gemma-3 VLM PEFT adapter (not CLIP).\n"
143
- "• Optionally vendor processor files into your adapter repo."
144
  )
145
  return msg, None, False
146
 
147
- # Prepare chat prompt
148
  if hasattr(processor, "apply_chat_template"):
149
  prompt = processor.apply_chat_template(
150
- build_messages(image), add_generation_prompt=True, tokenize=False
151
  )
152
  else:
153
- # Fallback join (rare for Gemma-3)
154
- msgs = build_messages(image)
155
  prompt = ""
156
  for m in msgs:
157
  role = m["role"].upper()
@@ -164,39 +140,49 @@ def generate_json(image: Image.Image):
164
  # Tokenize with vision
165
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
166
 
167
- # Generate with fixed params
168
  gen_kwargs = dict(
169
- max_new_tokens=2000,
170
- temperature=0.1,
171
- eos_token_id=getattr(tokenizer, "eos_token_id", None),
172
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- # Ask for JSON-only if supported by the model head
175
- # (Some trust_remote_code models accept response_format)
176
  try:
177
  gen_kwargs["response_format"] = {"type": "json_object"}
178
  except Exception:
179
  pass
180
 
181
  with torch.inference_mode():
182
- out = model.generate(**inputs, **gen_kwargs)
183
 
184
- # Decode
185
  if hasattr(processor, "decode"):
186
- text = processor.decode(out[0], skip_special_tokens=True)
187
  else:
188
- text = tokenizer.decode(out[0], skip_special_tokens=True)
189
 
190
- # Best-effort: trim any preamble
191
  if USER_PROMPT in text:
192
  text = text.split(USER_PROMPT)[-1].strip()
193
 
194
- # Parse JSON
195
  try:
196
  parsed = json.loads(text)
197
  return json.dumps(parsed, indent=2), parsed, True
198
  except Exception:
199
- # Try to recover a top-level {...}
200
  m = re.search(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
201
  if m:
202
  try:
@@ -204,14 +190,37 @@ def generate_json(image: Image.Image):
204
  return json.dumps(parsed, indent=2), parsed, True
205
  except Exception:
206
  pass
 
207
  return text, None, False
208
 
209
  # --------------------------
210
- # UI
211
  # --------------------------
212
- with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (ClipTagger + Adapter)") as demo:
213
- gr.Markdown("# Keyframe Annotator (ClipTagger)\nUpload an image to get **strict JSON** annotations.")
 
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  if LOAD_ERROR:
216
  with gr.Accordion("Startup Error Details", open=False):
217
  gr.Markdown(f"```\n{LOAD_ERROR}\n```")
@@ -219,17 +228,16 @@ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe
219
  with gr.Row():
220
  with gr.Column(scale=1):
221
  image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
222
- annotate_btn = gr.Button("Annotate", variant="primary")
223
  with gr.Column(scale=1):
224
- out_code = gr.Code(label="Model Output (JSON or error text)")
225
  out_json = gr.JSON(label="Parsed JSON")
226
  ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
227
 
228
- @spaces.GPU # ensures a GPU task is registered
229
- def on_submit(img):
230
- text, js, ok = generate_json(img)
231
  return text, js, ok
232
 
233
- annotate_btn.click(on_submit, inputs=[image], outputs=[out_code, out_json, ok_flag])
234
 
235
  demo.queue(max_size=32).launch()
 
1
  import os, json, re, traceback
2
+ from typing import Any, Dict, Tuple
3
  import gradio as gr
4
  from PIL import Image
5
  import torch
6
  import spaces
7
 
8
  # --------------------------
9
+ # Environment
10
  # --------------------------
11
+ MODEL_ID = os.environ.get("MODEL_ID", "inference-net/ClipTagger-12b")
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
13
 
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
16
 
17
+ TEMP = 0.1
18
+ MAX_NEW_TOKENS = 2000
19
+
20
  # --------------------------
21
+ # Prompts (yours)
22
  # --------------------------
23
  SYSTEM_PROMPT = (
24
  "You are an image annotation API trained to analyze YouTube video keyframes. "
 
57
  """
58
 
59
  # --------------------------
60
+ # Load full VLM (Gemma-3)
61
  # --------------------------
62
+ from transformers import AutoConfig, AutoProcessor, AutoTokenizer, AutoModelForCausalLM
 
 
63
 
64
+ processor = tokenizer = model = None
65
+ LOAD_ERROR = None
66
+
67
+ try:
68
+ cfg = AutoConfig.from_pretrained(MODEL_ID, token=HF_TOKEN, trust_remote_code=True)
69
+ if "clip" in cfg.__class__.__name__.lower():
70
+ raise RuntimeError(
71
+ f"MODEL_ID '{MODEL_ID}' resolves to a CLIP/encoder config. "
72
+ "Point MODEL_ID to your full VLM checkpoint (this repo's config shows gemma3)."
73
+ )
74
+
75
+ # Processor (has vision + tokenizer routing)
76
  try:
77
  processor = AutoProcessor.from_pretrained(
78
+ MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
79
  )
80
  except TypeError:
 
81
  processor = AutoProcessor.from_pretrained(
82
+ MODEL_ID, token=HF_TOKEN, trust_remote_code=True
83
  )
84
 
85
+ # Model
86
+ model = AutoModelForCausalLM.from_pretrained(
87
+ MODEL_ID,
 
 
 
 
 
 
 
 
88
  token=HF_TOKEN,
89
  device_map="auto",
90
  torch_dtype=DTYPE,
91
  trust_remote_code=True,
92
  )
93
 
94
+ # Tokenizer (fall back in case processor doesn't expose it)
95
+ tokenizer = getattr(processor, "tokenizer", None) or AutoTokenizer.from_pretrained(
96
+ MODEL_ID, token=HF_TOKEN, trust_remote_code=True, use_fast=True
 
97
  )
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  except Exception as e:
100
  LOAD_ERROR = f"{e}\n\n{traceback.format_exc()}"
101
 
102
  # --------------------------
103
  # Inference
104
  # --------------------------
105
+ def _build_messages(image: Image.Image):
106
  return [
107
  {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
108
+ {"role": "user", "content": [{"type": "image", "image": image},
109
+ {"type": "text", "text": USER_PROMPT}]}
110
  ]
111
 
112
+ def _run(image: Image.Image) -> Tuple[str, Dict[str, Any], bool]:
113
  if image is None:
114
  return "Please upload an image.", None, False
 
115
  if model is None or processor is None:
116
  msg = (
117
  "❌ Model failed to load.\n\n"
118
+ f"{LOAD_ERROR or 'Unknown error.'}\n"
119
+ "Check: MODEL_ID, HF_TOKEN, and that the repo includes processor + model shards."
 
 
120
  )
121
  return msg, None, False
122
 
123
+ # Build chat input
124
  if hasattr(processor, "apply_chat_template"):
125
  prompt = processor.apply_chat_template(
126
+ _build_messages(image), add_generation_prompt=True, tokenize=False
127
  )
128
  else:
129
+ # Conservative fallback
130
+ msgs = _build_messages(image)
131
  prompt = ""
132
  for m in msgs:
133
  role = m["role"].upper()
 
140
  # Tokenize with vision
141
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
142
 
143
+ # Generation args
144
  gen_kwargs = dict(
145
+ temperature=TEMP,
146
+ max_new_tokens=MAX_NEW_TOKENS,
 
147
  )
148
+ # If your config has multiple eos ids (yours does: [1, 106]), pass them
149
+ eos_id = getattr(tokenizer, "eos_token_id", None)
150
+ try:
151
+ # prefer config’s eos_token_id if list-like
152
+ from transformers.utils import is_torch_available
153
+ cfg_eos = getattr(model.config, "eos_token_id", None)
154
+ if isinstance(cfg_eos, (list, tuple)):
155
+ gen_kwargs["eos_token_id"] = list(cfg_eos)
156
+ elif eos_id is not None:
157
+ gen_kwargs["eos_token_id"] = eos_id
158
+ except Exception:
159
+ if eos_id is not None:
160
+ gen_kwargs["eos_token_id"] = eos_id
161
 
162
+ # Ask model to emit strict JSON (supported in newer transformers for some models)
 
163
  try:
164
  gen_kwargs["response_format"] = {"type": "json_object"}
165
  except Exception:
166
  pass
167
 
168
  with torch.inference_mode():
169
+ out_ids = model.generate(**inputs, **gen_kwargs)
170
 
171
+ # Decode via processor if available (some VLMs override decode)
172
  if hasattr(processor, "decode"):
173
+ text = processor.decode(out_ids[0], skip_special_tokens=True)
174
  else:
175
+ text = tokenizer.decode(out_ids[0], skip_special_tokens=True)
176
 
177
+ # Trim any echoed prompt
178
  if USER_PROMPT in text:
179
  text = text.split(USER_PROMPT)[-1].strip()
180
 
181
+ # Strict parse, with fallback to top-level {...}
182
  try:
183
  parsed = json.loads(text)
184
  return json.dumps(parsed, indent=2), parsed, True
185
  except Exception:
 
186
  m = re.search(r"\{(?:[^{}]|(?R))*\}", text, flags=re.DOTALL)
187
  if m:
188
  try:
 
190
  return json.dumps(parsed, indent=2), parsed, True
191
  except Exception:
192
  pass
193
+ # Return raw text to help debug prompt adherence if needed
194
  return text, None, False
195
 
196
  # --------------------------
197
+ # Spaces GPU entry + warmup
198
  # --------------------------
199
+ @spaces.GPU
200
+ def annotate_image(pil: Image.Image):
201
+ return _run(pil)
202
 
203
+ @spaces.GPU(duration=60)
204
+ def _warmup():
205
+ if model is None or processor is None:
206
+ return "skip"
207
+ try:
208
+ dummy = Image.new("RGB", (64, 64), (127, 127, 127))
209
+ _ = _run(dummy)
210
+ return "ok"
211
+ except Exception as e:
212
+ return f"warmup error: {e}"
213
+
214
+ try:
215
+ _ = _warmup()
216
+ except Exception:
217
+ pass
218
+
219
+ # --------------------------
220
+ # UI
221
+ # --------------------------
222
+ with gr.Blocks(theme=gr.themes.Soft(), analytics_enabled=False, title="Keyframe Annotator (Gemma-3 VLM)") as demo:
223
+ gr.Markdown("# Keyframe Annotator (Gemma-3-12B FT)\nUpload an image to get **strict JSON** annotations.")
224
  if LOAD_ERROR:
225
  with gr.Accordion("Startup Error Details", open=False):
226
  gr.Markdown(f"```\n{LOAD_ERROR}\n```")
 
228
  with gr.Row():
229
  with gr.Column(scale=1):
230
  image = gr.Image(type="pil", label="Upload Image", image_mode="RGB")
231
+ btn = gr.Button("Annotate", variant="primary")
232
  with gr.Column(scale=1):
233
+ out_text = gr.Code(label="Output (JSON or error)")
234
  out_json = gr.JSON(label="Parsed JSON")
235
  ok_flag = gr.Checkbox(label="Valid JSON", value=False, interactive=False)
236
 
237
+ def on_click(img):
238
+ text, js, ok = _run(img)
 
239
  return text, js, ok
240
 
241
+ btn.click(annotate_image, inputs=[image], outputs=[out_text, out_json, ok_flag])
242
 
243
  demo.queue(max_size=32).launch()