johnbridges commited on
Commit
06c29a2
·
1 Parent(s): edc4e1c

just copying app.py

Browse files
Files changed (1) hide show
  1. app-cpu-torch.py +301 -0
app-cpu-torch.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import os
4
+ from typing import Any, List, Dict
5
+ import spaces
6
+
7
+ from PIL import Image, ImageDraw
8
+ import requests
9
+ from transformers import AutoModelForImageTextToText, AutoProcessor
10
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
11
+ import torch
12
+ import re
13
+ import traceback
14
+
15
+ # --- Configuration ---
16
+ MODEL_ID = "Hcompany/Holo1-3B"
17
+
18
+ # --- Helpers (robust across different transformers versions) ---
19
+
20
+ def pick_device() -> str:
21
+ # Force CPU per request
22
+ return "cpu"
23
+
24
+ def apply_chat_template_compat(processor, messages: List[Dict[str, Any]]) -> str:
25
+ """
26
+ Works whether apply_chat_template lives on the processor or tokenizer,
27
+ or not at all (falls back to naive text join of 'text' contents).
28
+ """
29
+ tok = getattr(processor, "tokenizer", None)
30
+ if hasattr(processor, "apply_chat_template"):
31
+ return processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
32
+ if tok is not None and hasattr(tok, "apply_chat_template"):
33
+ return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
34
+ # Fallback: concatenate visible text segments
35
+ texts = []
36
+ for m in messages:
37
+ for c in m.get("content", []):
38
+ if isinstance(c, dict) and c.get("type") == "text":
39
+ texts.append(c.get("text", ""))
40
+ return "\n".join(texts)
41
+
42
+ def batch_decode_compat(processor, token_id_batches, **kw):
43
+ tok = getattr(processor, "tokenizer", None)
44
+ if tok is not None and hasattr(tok, "batch_decode"):
45
+ return tok.batch_decode(token_id_batches, **kw)
46
+ if hasattr(processor, "batch_decode"):
47
+ return processor.batch_decode(token_id_batches, **kw)
48
+ raise AttributeError("No batch_decode available on processor or tokenizer.")
49
+
50
+ def get_image_proc_params(processor) -> Dict[str, int]:
51
+ """
52
+ Safely access image processor params with defaults that work for Qwen2-VL family.
53
+ """
54
+ ip = getattr(processor, "image_processor", None)
55
+ return {
56
+ "patch_size": getattr(ip, "patch_size", 14),
57
+ "merge_size": getattr(ip, "merge_size", 1),
58
+ "min_pixels": getattr(ip, "min_pixels", 256 * 256),
59
+ "max_pixels": getattr(ip, "max_pixels", 1280 * 1280),
60
+ }
61
+
62
+ def trim_generated(generated_ids, inputs):
63
+ """
64
+ Trim prompt tokens from generated tokens when input_ids exist.
65
+ """
66
+ in_ids = getattr(inputs, "input_ids", None)
67
+ if in_ids is None and isinstance(inputs, dict):
68
+ in_ids = inputs.get("input_ids", None)
69
+ if in_ids is None:
70
+ return [out_ids for out_ids in generated_ids]
71
+ return [out_ids[len(in_seq):] for in_seq, out_ids in zip(in_ids, generated_ids)]
72
+
73
+ # --- Model and Processor Loading (Load once) ---
74
+ print(f"Loading model and processor for {MODEL_ID} (CPU only)...")
75
+ model = None
76
+ processor = None
77
+ model_loaded = False
78
+ load_error_message = ""
79
+
80
+ try:
81
+ # CPU-friendly dtype; bf16 on CPU is spotty, so prefer float32
82
+ model = AutoModelForImageTextToText.from_pretrained(
83
+ MODEL_ID,
84
+ torch_dtype=torch.float32,
85
+ trust_remote_code=True
86
+ ).to(pick_device())
87
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
88
+ model_loaded = True
89
+ print("Model and processor loaded successfully.")
90
+ except Exception as e:
91
+ load_error_message = (
92
+ f"Error loading model/processor: {e}\n"
93
+ "This might be due to network issues, an incorrect model ID, or incompatible library versions.\n"
94
+ "Check the full traceback in the Space logs."
95
+ )
96
+ print(load_error_message)
97
+ traceback.print_exc()
98
+
99
+ # --- Prompt builder ---
100
+ def get_localization_prompt(pil_image: Image.Image, instruction: str) -> List[dict]:
101
+ guidelines: str = (
102
+ "Localize an element on the GUI image according to my instructions and "
103
+ "output a click position as Click(x, y) with x num pixels from the left edge "
104
+ "and y num pixels from the top edge."
105
+ )
106
+ return [
107
+ {
108
+ "role": "user",
109
+ "content": [
110
+ {"type": "image", "image": pil_image},
111
+ {"type": "text", "text": f"{guidelines}\n{instruction}"}
112
+ ],
113
+ }
114
+ ]
115
+
116
+ # --- Inference (CPU) ---
117
+ def run_inference_localization(
118
+ messages_for_template: List[dict[str, Any]],
119
+ pil_image_for_processing: Image.Image
120
+ ) -> str:
121
+ """
122
+ CPU inference; robust to processor/tokenizer differences and logs full traceback on failure.
123
+ """
124
+ try:
125
+ model.to(pick_device())
126
+
127
+ # 1) Build prompt text via robust helper
128
+ text_prompt = apply_chat_template_compat(processor, messages_for_template)
129
+
130
+ # 2) Prepare inputs (text + image)
131
+ inputs = processor(
132
+ text=[text_prompt],
133
+ images=[pil_image_for_processing],
134
+ padding=True,
135
+ return_tensors="pt",
136
+ )
137
+
138
+ # Move tensor inputs to the same device as model (CPU)
139
+ if isinstance(inputs, dict):
140
+ for k, v in list(inputs.items()):
141
+ if hasattr(v, "to"):
142
+ inputs[k] = v.to(model.device)
143
+
144
+ # 3) Generate (deterministic)
145
+ generated_ids = model.generate(
146
+ **inputs,
147
+ max_new_tokens=128,
148
+ do_sample=False,
149
+ )
150
+
151
+ # 4) Trim prompt tokens if possible
152
+ generated_ids_trimmed = trim_generated(generated_ids, inputs)
153
+
154
+ # 5) Decode via robust helper
155
+ decoded_output = batch_decode_compat(
156
+ processor,
157
+ generated_ids_trimmed,
158
+ skip_special_tokens=True,
159
+ clean_up_tokenization_spaces=False
160
+ )
161
+
162
+ return decoded_output[0] if decoded_output else ""
163
+ except Exception as e:
164
+ print(f"Error during model inference: {e}")
165
+ traceback.print_exc()
166
+ raise
167
+
168
+ # --- Gradio processing function ---
169
+ def predict_click_location(input_pil_image: Image.Image, instruction: str):
170
+ if not model_loaded or not processor or not model:
171
+ return f"Model not loaded. Error: {load_error_message}", None
172
+ if not input_pil_image:
173
+ return "No image provided. Please upload an image.", None
174
+ if not instruction or instruction.strip() == "":
175
+ return "No instruction provided. Please type an instruction.", input_pil_image.copy().convert("RGB")
176
+
177
+ # 1) Resize according to image processor params (safe defaults if missing)
178
+ try:
179
+ ip = get_image_proc_params(processor)
180
+ resized_height, resized_width = smart_resize(
181
+ input_pil_image.height,
182
+ input_pil_image.width,
183
+ factor=ip["patch_size"] * ip["merge_size"],
184
+ min_pixels=ip["min_pixels"],
185
+ max_pixels=ip["max_pixels"],
186
+ )
187
+ resized_image = input_pil_image.resize(
188
+ size=(resized_width, resized_height),
189
+ resample=Image.Resampling.LANCZOS
190
+ )
191
+ except Exception as e:
192
+ print(f"Error resizing image: {e}")
193
+ traceback.print_exc()
194
+ return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB")
195
+
196
+ # 2) Build messages with image + instruction
197
+ messages = get_localization_prompt(resized_image, instruction)
198
+
199
+ # 3) Run inference
200
+ try:
201
+ coordinates_str = run_inference_localization(messages, resized_image)
202
+ except Exception as e:
203
+ return f"Error during model inference: {e}", resized_image.copy().convert("RGB")
204
+
205
+ # 4) Parse coordinates and draw marker
206
+ output_image_with_click = resized_image.copy().convert("RGB")
207
+ match = re.search(r"Click\((\d+),\s*(\d+)\)", coordinates_str)
208
+ if match:
209
+ try:
210
+ x = int(match.group(1))
211
+ y = int(match.group(2))
212
+ draw = ImageDraw.Draw(output_image_with_click)
213
+ radius = max(5, min(resized_width // 100, resized_height // 100, 15))
214
+ bbox = (x - radius, y - radius, x + radius, y + radius)
215
+ draw.ellipse(bbox, outline="red", width=max(2, radius // 4))
216
+ print(f"Predicted and drawn click at: ({x}, {y}) on resized image ({resized_width}x{resized_height})")
217
+ except Exception as e:
218
+ print(f"Error drawing on image: {e}")
219
+ traceback.print_exc()
220
+ else:
221
+ print(f"Could not parse 'Click(x, y)' from model output: {coordinates_str}")
222
+
223
+ return coordinates_str, output_image_with_click
224
+
225
+ # --- Load Example Data ---
226
+ example_image = None
227
+ example_instruction = "Select July 14th as the check-out date"
228
+ try:
229
+ example_image_url = "https://huggingface.co/Hcompany/Holo1-7B/resolve/main/calendar_example.jpg"
230
+ example_image = Image.open(requests.get(example_image_url, stream=True).raw)
231
+ except Exception as e:
232
+ print(f"Could not load example image from URL: {e}")
233
+ traceback.print_exc()
234
+ try:
235
+ example_image = Image.new("RGB", (200, 150), color="lightgray")
236
+ draw = ImageDraw.Draw(example_image)
237
+ draw.text((10, 10), "Example image\nfailed to load", fill="black")
238
+ except Exception:
239
+ pass
240
+
241
+ # --- Gradio UI ---
242
+ title = "Holo1-7B: Action VLM Localization Demo (CPU)"
243
+ article = f"""
244
+ <p style='text-align: center'>
245
+ Model: <a href='https://huggingface.co/{MODEL_ID}' target='_blank'>{MODEL_ID}</a> by HCompany |
246
+ Paper: <a href='https://cdn.prod.website-files.com/67e2dbd9acff0c50d4c8a80c/683ec8095b353e8b38317f80_h_tech_report_v1.pdf' target='_blank'>HCompany Tech Report</a> |
247
+ Blog: <a href='https://www.hcompany.ai/surfer-h' target='_blank'>Surfer-H Blog Post</a>
248
+ </p>
249
+ """
250
+
251
+ if not model_loaded:
252
+ with gr.Blocks() as demo:
253
+ gr.Markdown(f"# <center>⚠️ Error: Model Failed to Load ⚠️</center>")
254
+ gr.Markdown(f"<center>{load_error_message}</center>")
255
+ gr.Markdown("<center>See Space logs for the full traceback.</center>")
256
+ else:
257
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
258
+ gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
259
+
260
+ with gr.Row():
261
+ with gr.Column(scale=1):
262
+ input_image_component = gr.Image(type="pil", label="Input UI Image", height=400)
263
+ instruction_component = gr.Textbox(
264
+ label="Instruction",
265
+ placeholder="e.g., Click the 'Login' button",
266
+ info="Type the action you want the model to localize on the image."
267
+ )
268
+ submit_button = gr.Button("Localize Click", variant="primary")
269
+
270
+ with gr.Column(scale=1):
271
+ output_coords_component = gr.Textbox(
272
+ label="Predicted Coordinates (Format: Click(x, y))",
273
+ interactive=False
274
+ )
275
+ output_image_component = gr.Image(
276
+ type="pil",
277
+ label="Image with Predicted Click Point",
278
+ height=400,
279
+ interactive=False
280
+ )
281
+
282
+ if example_image:
283
+ gr.Examples(
284
+ examples=[[example_image, example_instruction]],
285
+ inputs=[input_image_component, instruction_component],
286
+ outputs=[output_coords_component, output_image_component],
287
+ fn=predict_click_location,
288
+ cache_examples="lazy",
289
+ )
290
+
291
+ gr.Markdown(article)
292
+
293
+ submit_button.click(
294
+ fn=predict_click_location,
295
+ inputs=[input_image_component, instruction_component],
296
+ outputs=[output_coords_component, output_image_component]
297
+ )
298
+
299
+ if __name__ == "__main__":
300
+ # CPU Spaces can be slow; keep debug True for logs
301
+ demo.launch(debug=True)