johnbridges commited on
Commit
7af0a49
·
verified ·
1 Parent(s): dc160b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -87
app.py CHANGED
@@ -1,8 +1,5 @@
1
- # requirements.txt stays fine, but for CUDA wheels you usually want:
2
- # pip install --index-url https://download.pytorch.org/whl/cu121 torch torchvision --upgrade
3
-
4
  import gradio as gr
5
- import json, os, re, traceback
6
  from typing import Any, List, Dict
7
 
8
  import spaces
@@ -18,6 +15,10 @@ MODEL_ID = "Hcompany/Holo1-3B"
18
  # ---------------- Device / DType helpers ----------------
19
 
20
  def pick_device() -> str:
 
 
 
 
21
  forced = os.getenv("FORCE_DEVICE", "").lower().strip()
22
  if forced in {"cpu", "cuda", "mps"}:
23
  return forced
@@ -29,11 +30,9 @@ def pick_device() -> str:
29
 
30
  def pick_dtype(device: str) -> torch.dtype:
31
  if device == "cuda":
32
- major, minor = torch.cuda.get_device_capability() # e.g. (8, 0) for A100
33
- # Prefer bfloat16 on Ampere+ (>= 8.x). Otherwise float16.
34
- return torch.bfloat16 if major >= 8 else torch.float16
35
  if device == "mps":
36
- # MPS autocast supports float16 well; bfloat16 is improving but use float16 for safety.
37
  return torch.float16
38
  return torch.float32 # CPU
39
 
@@ -44,7 +43,7 @@ def move_to_device(batch, device: str):
44
  return batch.to(device, non_blocking=True)
45
  return batch
46
 
47
- # --- Chat/template helpers (unchanged except minor tidy) ---
48
  def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str:
49
  tok = getattr(processor, "tokenizer", None)
50
  if hasattr(processor, "apply_chat_template"):
@@ -83,40 +82,27 @@ def trim_generated(generated_ids, inputs):
83
  return [out_ids for out_ids in generated_ids]
84
  return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
85
 
86
- # --- Load model/processor once with correct device/dtype ---
87
- active_device = pick_device()
88
- active_dtype = pick_dtype(active_device)
89
-
90
- # Optional perf knobs for CUDA
91
- if active_device == "cuda":
92
- torch.backends.cuda.matmul.allow_tf32 = True
93
- torch.set_float32_matmul_precision("high") # better perf on Ampere+
94
-
95
- print(f"Loading model and processor for {MODEL_ID} on device={active_device}, dtype={active_dtype}...")
96
  model = None
97
  processor = None
98
  model_loaded = False
99
  load_error_message = ""
100
 
101
  try:
102
- # Note: for single-GPU we explicitly set dtype then .to(device).
103
- # If you want HF Accelerate sharding: set device_map="auto" and drop explicit .to().
104
  model = AutoModelForImageTextToText.from_pretrained(
105
  MODEL_ID,
106
- torch_dtype=active_dtype if active_device != "cpu" else torch.float32,
107
  trust_remote_code=True,
108
  )
109
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
110
-
111
- # Move model to device and eval
112
- model.to(active_device)
113
  model.eval()
114
  model_loaded = True
115
- print("Model and processor loaded successfully.")
116
  except Exception as e:
117
  load_error_message = (
118
  f"Error loading model/processor: {e}\n"
119
- "This might be due to CUDA/MPS availability, model ID, or wheel incompatibility.\n"
120
  "Check the full traceback in the logs."
121
  )
122
  print(load_error_message)
@@ -139,68 +125,79 @@ def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[di
139
  }
140
  ]
141
 
142
- # --- Inference (device-agnostic; uses AMP on GPU) ---
143
  @torch.inference_mode()
144
  def run_inference_localization(
145
  messages_for_template: List[dict[str, Any]],
146
- pil_image_for_processing: Image.Image
 
 
147
  ) -> str:
148
- try:
149
- # 1) Build prompt text
150
- text_prompt = apply_chat_template_compat(processor, messages_for_template)
151
-
152
- # 2) Prepare inputs (text + image)
153
- inputs = processor(
154
- text=[text_prompt],
155
- images=[pil_image_for_processing],
156
- padding=True,
157
- return_tensors="pt",
158
- )
159
- inputs = move_to_device(inputs, active_device)
160
-
161
- # 3) Generate (deterministic). Use autocast on GPU/MPS.
162
- use_amp = active_device in {"cuda", "mps"}
163
- amp_dtype = active_dtype if active_device == "cuda" else torch.float16
164
-
165
- if use_amp:
166
- with torch.cuda.amp.autocast(enabled=(active_device == "cuda"), dtype=amp_dtype):
167
- generated_ids = model.generate(
168
- **inputs,
169
- max_new_tokens=128,
170
- do_sample=False,
171
- )
172
- else:
173
- generated_ids = model.generate(
174
- **inputs,
175
- max_new_tokens=128,
176
- do_sample=False,
177
- )
178
 
179
- # 4) Trim prompt tokens if possible
180
- generated_ids_trimmed = trim_generated(generated_ids, inputs)
 
 
 
 
 
181
 
182
- # 5) Decode
183
- decoded_output = batch_decode_compat(
184
- processor,
185
- generated_ids_trimmed,
186
- skip_special_tokens=True,
187
- clean_up_tokenization_spaces=False
188
  )
189
 
190
- return decoded_output[0] if decoded_output else ""
191
- except Exception as e:
192
- print(f"Error during model inference: {e}")
193
- traceback.print_exc()
194
- raise
 
 
 
195
 
196
- # --- Gradio processing function ---
 
 
197
  def predict_click_location(input_pil_image: Image.Image, instruction: str):
198
  if not model_loaded or not processor or not model:
199
- return f"Model not loaded. Error: {load_error_message}", None
200
  if not input_pil_image:
201
- return "No image provided. Please upload an image.", None
202
  if not instruction or instruction.strip() == "":
203
- return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  # 1) Resize according to image processor params (safe defaults if missing)
206
  try:
@@ -217,25 +214,26 @@ def predict_click_location(input_pil_image: Image.Image, instruction: str):
217
  resample=Image.Resampling.LANCZOS
218
  )
219
  except Exception as e:
220
- print(f"Error resizing image: {e}")
221
  traceback.print_exc()
222
- return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB")
223
 
224
  # 2) Build messages with image + instruction
225
  messages = get_localization_prompt(resized_image, instruction)
226
 
227
  # 3) Run inference
228
  try:
229
- coordinates_str = run_inference_localization(messages, resized_image)
230
  except Exception as e:
231
- return f"Error during model inference: {e}", resized_image.copy().convert("RGB")
 
232
 
233
  # 4) Parse coordinates and draw marker
234
  output_image_with_click = resized_image.copy().convert("RGB")
235
  match = re.search(r"Click\((\d+),\s*(\d+)\)", coordinates_str)
236
  if match:
237
  try:
238
- x = int(match.group(1)); y = int(match.group(2))
 
239
  draw = ImageDraw.Draw(output_image_with_click)
240
  radius = max(5, min(resized_width // 100, resized_height // 100, 15))
241
  bbox = (x - radius, y - radius, x + radius, y + radius)
@@ -247,7 +245,7 @@ def predict_click_location(input_pil_image: Image.Image, instruction: str):
247
  else:
248
  print(f"Could not parse 'Click(x, y)' from model output: {coordinates_str}")
249
 
250
- return coordinates_str, output_image_with_click
251
 
252
  # --- Load Example Data ---
253
  example_image = None
@@ -266,13 +264,13 @@ except Exception as e:
266
  pass
267
 
268
  # --- Gradio UI ---
269
- title = "Holo1-3B: Holo1 Localization Demo"
270
  article = f"""
271
  <p style='text-align: center'>
272
- Device: <b>{active_device}</b> &nbsp;|&nbsp; DType: <b>{str(active_dtype).replace('torch.', '')}</b> &nbsp;|&nbsp;
273
  Model: <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>{MODEL_ID}</a> by HCompany |
274
  Paper: <a href='https://cdn.prod.website-files.com/67e2dbd9acff0c50d4c8a80c/683ec8095b353e8b38317f80_h_tech_report_v1.pdf' target='_blank'>HCompany Tech Report</a> |
275
- Blog: <a href='https://www.hcompany.ai/surfer-h' target='_blank'>Surfer-H Blog Post</a>
 
276
  </p>
277
  """
278
 
@@ -307,12 +305,17 @@ else:
307
  height=400,
308
  interactive=False
309
  )
 
 
 
 
 
310
 
311
  if example_image:
312
  gr.Examples(
313
  examples=[[example_image, example_instruction]],
314
  inputs=[input_image_component, instruction_component],
315
- outputs=[output_coords_component, output_image_component],
316
  fn=predict_click_location,
317
  cache_examples="lazy",
318
  )
@@ -320,8 +323,9 @@ else:
320
  submit_button.click(
321
  fn=predict_click_location,
322
  inputs=[input_image_component, instruction_component],
323
- outputs=[output_coords_component, output_image_component]
324
  )
325
 
326
  if __name__ == "__main__":
 
327
  demo.launch(debug=True)
 
 
 
 
1
  import gradio as gr
2
+ import json, os, re, traceback, contextlib
3
  from typing import Any, List, Dict
4
 
5
  import spaces
 
15
  # ---------------- Device / DType helpers ----------------
16
 
17
  def pick_device() -> str:
18
+ """
19
+ On HF Spaces (ZeroGPU), CUDA is only available inside @spaces.GPU calls.
20
+ We still honor FORCE_DEVICE for local testing.
21
+ """
22
  forced = os.getenv("FORCE_DEVICE", "").lower().strip()
23
  if forced in {"cpu", "cuda", "mps"}:
24
  return forced
 
30
 
31
  def pick_dtype(device: str) -> torch.dtype:
32
  if device == "cuda":
33
+ major, _ = torch.cuda.get_device_capability()
34
+ return torch.bfloat16 if major >= 8 else torch.float16 # Ampere+ -> bf16
 
35
  if device == "mps":
 
36
  return torch.float16
37
  return torch.float32 # CPU
38
 
 
43
  return batch.to(device, non_blocking=True)
44
  return batch
45
 
46
+ # --- Chat/template helpers ---
47
  def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str:
48
  tok = getattr(processor, "tokenizer", None)
49
  if hasattr(processor, "apply_chat_template"):
 
82
  return [out_ids for out_ids in generated_ids]
83
  return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
84
 
85
+ # --- Load model/processor ON CPU at import time (required for ZeroGPU) ---
86
+ print(f"Loading model and processor for {MODEL_ID} on CPU startup (ZeroGPU safe)...")
 
 
 
 
 
 
 
 
87
  model = None
88
  processor = None
89
  model_loaded = False
90
  load_error_message = ""
91
 
92
  try:
 
 
93
  model = AutoModelForImageTextToText.from_pretrained(
94
  MODEL_ID,
95
+ torch_dtype=torch.float32, # CPU-safe dtype at import
96
  trust_remote_code=True,
97
  )
98
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
99
  model.eval()
100
  model_loaded = True
101
+ print("Model and processor loaded on CPU.")
102
  except Exception as e:
103
  load_error_message = (
104
  f"Error loading model/processor: {e}\n"
105
+ "This might be due to network/model ID/library versions.\n"
106
  "Check the full traceback in the logs."
107
  )
108
  print(load_error_message)
 
125
  }
126
  ]
127
 
128
+ # --- Inference core (device passed in; AMP used when suitable) ---
129
  @torch.inference_mode()
130
  def run_inference_localization(
131
  messages_for_template: List[dict[str, Any]],
132
+ pil_image_for_processing: Image.Image,
133
+ device: str,
134
+ dtype: torch.dtype,
135
  ) -> str:
136
+ text_prompt = apply_chat_template_compat(processor, messages_for_template)
137
+
138
+ inputs = processor(
139
+ text=[text_prompt],
140
+ images=[pil_image_for_processing],
141
+ padding=True,
142
+ return_tensors="pt",
143
+ )
144
+ inputs = move_to_device(inputs, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # AMP contexts
147
+ if device == "cuda":
148
+ amp_ctx = torch.autocast(device_type="cuda", dtype=dtype)
149
+ elif device == "mps":
150
+ amp_ctx = torch.autocast(device_type="mps", dtype=torch.float16)
151
+ else:
152
+ amp_ctx = contextlib.nullcontext()
153
 
154
+ with amp_ctx:
155
+ generated_ids = model.generate(
156
+ **inputs,
157
+ max_new_tokens=128,
158
+ do_sample=False,
 
159
  )
160
 
161
+ generated_ids_trimmed = trim_generated(generated_ids, inputs)
162
+ decoded_output = batch_decode_compat(
163
+ processor,
164
+ generated_ids_trimmed,
165
+ skip_special_tokens=True,
166
+ clean_up_tokenization_spaces=False
167
+ )
168
+ return decoded_output[0] if decoded_output else ""
169
 
170
+ # --- Gradio processing function (ZeroGPU-visible) ---
171
+ # Decorate the function Gradio calls so Spaces detects a GPU entry point.
172
+ @spaces.GPU(duration=120) # keep GPU attached briefly between calls (seconds)
173
  def predict_click_location(input_pil_image: Image.Image, instruction: str):
174
  if not model_loaded or not processor or not model:
175
+ return f"Model not loaded. Error: {load_error_message}", None, "device: n/a | dtype: n/a"
176
  if not input_pil_image:
177
+ return "No image provided. Please upload an image.", None, "device: n/a | dtype: n/a"
178
  if not instruction or instruction.strip() == "":
179
+ return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB"), "device: n/a | dtype: n/a"
180
+
181
+ # Decide device/dtype *inside* the GPU-decorated call
182
+ device = pick_device()
183
+ dtype = pick_dtype(device)
184
+
185
+ # Optional perf knobs for CUDA
186
+ if device == "cuda":
187
+ torch.backends.cuda.matmul.allow_tf32 = True
188
+ torch.set_float32_matmul_precision("high")
189
+
190
+ # If needed, move model now that GPU is available
191
+ try:
192
+ p = next(model.parameters())
193
+ cur_dev = p.device.type
194
+ cur_dtype = p.dtype
195
+ except StopIteration:
196
+ cur_dev, cur_dtype = "cpu", torch.float32
197
+
198
+ if cur_dev != device or cur_dtype != dtype:
199
+ model.to(device=device, dtype=dtype)
200
+ model.eval()
201
 
202
  # 1) Resize according to image processor params (safe defaults if missing)
203
  try:
 
214
  resample=Image.Resampling.LANCZOS
215
  )
216
  except Exception as e:
 
217
  traceback.print_exc()
218
+ return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}"
219
 
220
  # 2) Build messages with image + instruction
221
  messages = get_localization_prompt(resized_image, instruction)
222
 
223
  # 3) Run inference
224
  try:
225
+ coordinates_str = run_inference_localization(messages, resized_image, device, dtype)
226
  except Exception as e:
227
+ traceback.print_exc()
228
+ return f"Error during model inference: {e}", resized_image.copy().convert("RGB"), f"device: {device} | dtype: {dtype}"
229
 
230
  # 4) Parse coordinates and draw marker
231
  output_image_with_click = resized_image.copy().convert("RGB")
232
  match = re.search(r"Click\((\d+),\s*(\d+)\)", coordinates_str)
233
  if match:
234
  try:
235
+ x = int(match.group(1))
236
+ y = int(match.group(2))
237
  draw = ImageDraw.Draw(output_image_with_click)
238
  radius = max(5, min(resized_width // 100, resized_height // 100, 15))
239
  bbox = (x - radius, y - radius, x + radius, y + radius)
 
245
  else:
246
  print(f"Could not parse 'Click(x, y)' from model output: {coordinates_str}")
247
 
248
+ return coordinates_str, output_image_with_click, f"device: {device} | dtype: {str(dtype).replace('torch.', '')}"
249
 
250
  # --- Load Example Data ---
251
  example_image = None
 
264
  pass
265
 
266
  # --- Gradio UI ---
267
+ title = "Holo1-3B: Holo1 Localization Demo (ZeroGPU-ready)"
268
  article = f"""
269
  <p style='text-align: center'>
 
270
  Model: <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>{MODEL_ID}</a> by HCompany |
271
  Paper: <a href='https://cdn.prod.website-files.com/67e2dbd9acff0c50d4c8a80c/683ec8095b353e8b38317f80_h_tech_report_v1.pdf' target='_blank'>HCompany Tech Report</a> |
272
+ Blog: <a href='https://www.hcompany.ai/surfer-h' target='_blank'>Surfer-H Blog Post</a><br/>
273
+ <small>GPU (if available) is requested only during inference via @spaces.GPU.</small>
274
  </p>
275
  """
276
 
 
305
  height=400,
306
  interactive=False
307
  )
308
+ runtime_info = gr.Textbox(
309
+ label="Runtime Info",
310
+ value="device: n/a | dtype: n/a",
311
+ interactive=False
312
+ )
313
 
314
  if example_image:
315
  gr.Examples(
316
  examples=[[example_image, example_instruction]],
317
  inputs=[input_image_component, instruction_component],
318
+ outputs=[output_coords_component, output_image_component, runtime_info],
319
  fn=predict_click_location,
320
  cache_examples="lazy",
321
  )
 
323
  submit_button.click(
324
  fn=predict_click_location,
325
  inputs=[input_image_component, instruction_component],
326
+ outputs=[output_coords_component, output_image_component, runtime_info]
327
  )
328
 
329
  if __name__ == "__main__":
330
+ # Do NOT pass 'concurrency_count' or ZeroGPU-specific launch args.
331
  demo.launch(debug=True)