johnbridges commited on
Commit
dc49ef7
·
verified ·
1 Parent(s): e15bd89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -29
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import gradio as gr
2
  import json
3
  import os
4
  from typing import Any, List, Dict
@@ -23,6 +23,7 @@ def locate_text_backbone(model):
23
  Tries common attribute names used by VLMs to find the LLM/text stack.
24
  Falls back to the whole model if unknown.
25
  """
 
26
  for name in [
27
  "language_model", # e.g., model.language_model
28
  "text_model", # e.g., model.text_model
@@ -33,20 +34,31 @@ def locate_text_backbone(model):
33
  m = getattr(model, name, None)
34
  if m is not None:
35
  return m, name
 
 
36
  for name, child in model.named_children():
37
  if hasattr(child, "lm_head") or hasattr(child, "get_input_embeddings"):
38
  return child, name
 
 
39
  return model, None
40
 
 
41
  def pick_device() -> str:
42
- return "cpu" # force CPU
 
43
 
44
  def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str:
 
 
 
 
45
  tok = getattr(processor, "tokenizer", None)
46
  if hasattr(processor, "apply_chat_template"):
47
  return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
48
  if tok is not None and hasattr(tok, "apply_chat_template"):
49
  return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
50
  texts = []
51
  for m in messages:
52
  for c in m.get("content", []):
@@ -63,6 +75,9 @@ def batch_decode_compat(processor, token_id_batches, **kw):
63
  raise AttributeError("No batch_decode available on processor or tokenizer.")
64
 
65
  def get_image_proc_params(processor) -> Dict[str, int]:
 
 
 
66
  ip = getattr(processor, "image_processor", None)
67
  return {
68
  "patch_size": getattr(ip, "patch_size", 14),
@@ -72,6 +87,9 @@ def get_image_proc_params(processor) -> Dict[str, int]:
72
  }
73
 
74
  def trim_generated(generated_ids, inputs):
 
 
 
75
  in_ids = getattr(inputs, "input_ids", None)
76
  if in_ids is None and isinstance(inputs, dict):
77
  in_ids = inputs.get("input_ids", None)
@@ -87,34 +105,13 @@ model_loaded = False
87
  load_error_message = ""
88
 
89
  try:
 
90
  model = AutoModelForImageTextToText.from_pretrained(
91
  MODEL_ID,
92
- torch_dtype=torch.bfloat16, # CPU-friendly
93
  trust_remote_code=True
94
  ).to(pick_device())
95
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
96
-
97
- # >>> INT8 QUANT START -----------------------------------------------------
98
- # Quantize only the text/LLM backbone (nn.Linear layers) to dynamic INT8.
99
- text_backbone, attr_name = locate_text_backbone(model)
100
- print("[INT8] Quantizing text backbone with dynamic INT8...")
101
- quantized_llm = quantize_dynamic(
102
- text_backbone,
103
- {torch.nn.Linear},
104
- dtype=torch.qint8
105
- )
106
- if attr_name is not None:
107
- setattr(model, attr_name, quantized_llm)
108
- else:
109
- for name, child in list(model.named_children()):
110
- if child is text_backbone:
111
- setattr(model, name, quantized_llm)
112
- break
113
- torch.set_num_threads(max(1, os.cpu_count() or 1))
114
- model.eval()
115
- print("[INT8] Done.")
116
- # <<< INT8 QUANT END -------------------------------------------------------
117
-
118
  model_loaded = True
119
  print("Model and processor loaded successfully.")
120
  except Exception as e:
@@ -148,11 +145,16 @@ def run_inference_localization(
148
  messages_for_template: List[dict[str, Any]],
149
  pil_image_for_processing: Image.Image
150
  ) -> str:
 
 
 
151
  try:
152
  model.to(pick_device())
153
 
 
154
  text_prompt = apply_chat_template_compat(processor, messages_for_template)
155
 
 
156
  inputs = processor(
157
  text=[text_prompt],
158
  images=[pil_image_for_processing],
@@ -160,19 +162,23 @@ def run_inference_localization(
160
  return_tensors="pt",
161
  )
162
 
 
163
  if isinstance(inputs, dict):
164
  for k, v in list(inputs.items()):
165
  if hasattr(v, "to"):
166
  inputs[k] = v.to(model.device)
167
 
 
168
  generated_ids = model.generate(
169
  **inputs,
170
  max_new_tokens=128,
171
  do_sample=False,
172
  )
173
 
 
174
  generated_ids_trimmed = trim_generated(generated_ids, inputs)
175
 
 
176
  decoded_output = batch_decode_compat(
177
  processor,
178
  generated_ids_trimmed,
@@ -195,6 +201,7 @@ def predict_click_location(input_pil_image: Image.Image, instruction: str):
195
  if not instruction or instruction.strip() == "":
196
  return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB")
197
 
 
198
  try:
199
  ip = get_image_proc_params(processor)
200
  resized_height, resized_width = smart_resize(
@@ -213,18 +220,22 @@ def predict_click_location(input_pil_image: Image.Image, instruction: str):
213
  traceback.print_exc()
214
  return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB")
215
 
 
216
  messages = get_localization_prompt(resized_image, instruction)
217
 
 
218
  try:
219
  coordinates_str = run_inference_localization(messages, resized_image)
220
  except Exception as e:
221
  return f"Error during model inference: {e}", resized_image.copy().convert("RGB")
222
 
 
223
  output_image_with_click = resized_image.copy().convert("RGB")
224
  match = re.search(r"Click\((\d+),\s*(\d+)\)", coordinates_str)
225
  if match:
226
  try:
227
- x = int(match.group(1)); y = int(match.group(2))
 
228
  draw = ImageDraw.Draw(output_image_with_click)
229
  radius = max(5, min(resized_width // 100, resized_height // 100, 15))
230
  bbox = (x - radius, y - radius, x + radius, y + radius)
@@ -255,7 +266,7 @@ except Exception as e:
255
  pass
256
 
257
  # --- Gradio UI ---
258
- title = "Holo1-3B: Action VLM Localization Demo (CPU + INT8 text)"
259
  article = f"""
260
  <p style='text-align: center'>
261
  Model: <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>{MODEL_ID}</a> by HCompany |
@@ -313,5 +324,5 @@ else:
313
  )
314
 
315
  if __name__ == "__main__":
316
- demo.launch(debug=True)
317
-
 
1
+ this is what I have so far : import gradio as gr
2
  import json
3
  import os
4
  from typing import Any, List, Dict
 
23
  Tries common attribute names used by VLMs to find the LLM/text stack.
24
  Falls back to the whole model if unknown.
25
  """
26
+ # common in Qwen-like / custom repos
27
  for name in [
28
  "language_model", # e.g., model.language_model
29
  "text_model", # e.g., model.text_model
 
34
  m = getattr(model, name, None)
35
  if m is not None:
36
  return m, name
37
+
38
+ # last resort: look for a child that has an lm_head or tied weights
39
  for name, child in model.named_children():
40
  if hasattr(child, "lm_head") or hasattr(child, "get_input_embeddings"):
41
  return child, name
42
+
43
+ # if still not found, return the model itself
44
  return model, None
45
 
46
+
47
  def pick_device() -> str:
48
+ # Force CPU per request
49
+ return "cpu"
50
 
51
  def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str:
52
+ """
53
+ Works whether apply_chat_template lives on the processor or tokenizer,
54
+ or not at all (falls back to naive text join of 'text' contents).
55
+ """
56
  tok = getattr(processor, "tokenizer", None)
57
  if hasattr(processor, "apply_chat_template"):
58
  return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
  if tok is not None and hasattr(tok, "apply_chat_template"):
60
  return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
61
+ # Fallback: concatenate visible text segments
62
  texts = []
63
  for m in messages:
64
  for c in m.get("content", []):
 
75
  raise AttributeError("No batch_decode available on processor or tokenizer.")
76
 
77
  def get_image_proc_params(processor) -> Dict[str, int]:
78
+ """
79
+ Safely access image processor params with defaults that work for Qwen2-VL family.
80
+ """
81
  ip = getattr(processor, "image_processor", None)
82
  return {
83
  "patch_size": getattr(ip, "patch_size", 14),
 
87
  }
88
 
89
  def trim_generated(generated_ids, inputs):
90
+ """
91
+ Trim prompt tokens from generated tokens when input_ids exist.
92
+ """
93
  in_ids = getattr(inputs, "input_ids", None)
94
  if in_ids is None and isinstance(inputs, dict):
95
  in_ids = inputs.get("input_ids", None)
 
105
  load_error_message = ""
106
 
107
  try:
108
+ # CPU-friendly dtype; bf16 on CPU is spotty, so prefer bfloat16
109
  model = AutoModelForImageTextToText.from_pretrained(
110
  MODEL_ID,
111
+ torch_dtype=torch.bfloat16,
112
  trust_remote_code=True
113
  ).to(pick_device())
114
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  model_loaded = True
116
  print("Model and processor loaded successfully.")
117
  except Exception as e:
 
145
  messages_for_template: List[dict[str, Any]],
146
  pil_image_for_processing: Image.Image
147
  ) -> str:
148
+ """
149
+ CPU inference; robust to processor/tokenizer differences and logs full traceback on failure.
150
+ """
151
  try:
152
  model.to(pick_device())
153
 
154
+ # 1) Build prompt text via robust helper
155
  text_prompt = apply_chat_template_compat(processor, messages_for_template)
156
 
157
+ # 2) Prepare inputs (text + image)
158
  inputs = processor(
159
  text=[text_prompt],
160
  images=[pil_image_for_processing],
 
162
  return_tensors="pt",
163
  )
164
 
165
+ # Move tensor inputs to the same device as model (CPU)
166
  if isinstance(inputs, dict):
167
  for k, v in list(inputs.items()):
168
  if hasattr(v, "to"):
169
  inputs[k] = v.to(model.device)
170
 
171
+ # 3) Generate (deterministic)
172
  generated_ids = model.generate(
173
  **inputs,
174
  max_new_tokens=128,
175
  do_sample=False,
176
  )
177
 
178
+ # 4) Trim prompt tokens if possible
179
  generated_ids_trimmed = trim_generated(generated_ids, inputs)
180
 
181
+ # 5) Decode via robust helper
182
  decoded_output = batch_decode_compat(
183
  processor,
184
  generated_ids_trimmed,
 
201
  if not instruction or instruction.strip() == "":
202
  return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB")
203
 
204
+ # 1) Resize according to image processor params (safe defaults if missing)
205
  try:
206
  ip = get_image_proc_params(processor)
207
  resized_height, resized_width = smart_resize(
 
220
  traceback.print_exc()
221
  return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB")
222
 
223
+ # 2) Build messages with image + instruction
224
  messages = get_localization_prompt(resized_image, instruction)
225
 
226
+ # 3) Run inference
227
  try:
228
  coordinates_str = run_inference_localization(messages, resized_image)
229
  except Exception as e:
230
  return f"Error during model inference: {e}", resized_image.copy().convert("RGB")
231
 
232
+ # 4) Parse coordinates and draw marker
233
  output_image_with_click = resized_image.copy().convert("RGB")
234
  match = re.search(r"Click\((\d+),\s*(\d+)\)", coordinates_str)
235
  if match:
236
  try:
237
+ x = int(match.group(1))
238
+ y = int(match.group(2))
239
  draw = ImageDraw.Draw(output_image_with_click)
240
  radius = max(5, min(resized_width // 100, resized_height // 100, 15))
241
  bbox = (x - radius, y - radius, x + radius, y + radius)
 
266
  pass
267
 
268
  # --- Gradio UI ---
269
+ title = "Holo1-3B: Action VLM Localization Demo (CPU)"
270
  article = f"""
271
  <p style='text-align: center'>
272
  Model: <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>{MODEL_ID}</a> by HCompany |
 
324
  )
325
 
326
  if __name__ == "__main__":
327
+ # CPU Spaces can be slow; keep debug True for logs
328
+ demo.launch(debug=True) .... I cant see where to put it all