Stremly commited on
Commit
e995be0
Β·
verified Β·
1 Parent(s): 4244b33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -30
app.py CHANGED
@@ -5,71 +5,130 @@ import torch
5
  from PIL import Image, ImageDraw
6
  import gradio as gr
7
 
8
- from transformers import Qwen2_5_VLForConditionalGeneration
9
- from transformers import AutoProcessor
10
  from qwen_vl_utils import process_vision_info # include this file in your repo if not pip-installable
11
 
12
  # ---- model & processor loaded on CPU ----
13
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
14
  "ByteDance-Seed/UI-TARS-1.5-7B",
15
  device_map="auto",
16
- torch_dtype=torch.float32, # CPU-friendly
17
  )
18
  processor = AutoProcessor.from_pretrained(
19
  "ByteDance-Seed/UI-TARS-1.5-7B",
20
  size={"shortest_edge": 256 * 28 * 28, "longest_edge": 1344 * 28 * 28},
21
  use_fast=True,
22
-
23
  )
24
 
25
- def draw_point(image: Image.Image, point=None, radius=5):
 
 
26
  img = image.copy()
27
  if point:
28
  x, y = point[0] * img.width, point[1] * img.height
29
  ImageDraw.Draw(img).ellipse(
30
- (x - radius, y - radius, x + radius, y + radius), fill='red'
31
  )
32
  return img
33
 
 
34
  @spaces.GPU
35
- def navigate(image, task, platform):
36
- messages = [
37
- {"role": "user", "content": [{"type": "text", "text": f"You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. \n\n## Output Format\n```\nThought: ...\nAction: ...\n```\n\n## Action Space\n\nclick(start_box='"},
38
- {"type": "image_url", "image_url": image}
39
- ]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  ]
41
- # prepare inputs
42
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
43
  images, videos = process_vision_info(messages)
44
- inputs = processor(text=[text], images=images, videos=videos, padding=True, return_tensors="pt")
45
- inputs = inputs.to("cuda")
 
 
 
 
 
46
 
47
- # generate
48
  generated = model.generate(**inputs, max_new_tokens=128)
49
- trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated)]
50
- out = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
 
 
 
51
 
52
- # optionally parse JSON and draw point
53
  try:
54
- actions = ast.literal_eval(out)
55
  for act in actions if isinstance(actions, list) else [actions]:
56
- pos = act.get('position')
57
- if pos and isinstance(pos, list) and len(pos)==2:
58
- image = draw_point(image, pos)
59
- return image, out
60
- except:
61
- return image, out
 
 
62
 
63
 
 
 
64
  demo = gr.Interface(
65
  fn=navigate,
66
  inputs=[
67
  gr.Image(type="pil", label="Screenshot"),
68
- gr.Textbox(lines=1, placeholder="e.g. Search the weather for New York", label="Task"),
 
 
 
 
69
  gr.Dropdown(choices=["web", "phone"], value="web", label="Platform"),
 
 
 
 
 
70
  ],
71
- outputs=[gr.Image(label="With Click Point"), gr.Textbox(label="Raw Action JSON")],
72
- title="ShowUI-2B Navigation Demo",
73
  )
74
 
75
  demo.launch(
@@ -77,4 +136,4 @@ demo.launch(
77
  server_port=7860,
78
  share=False, # or True if you need a public link
79
  ssr_mode=False, # turn off experimental SSR so the process blocks
80
- )
 
5
  from PIL import Image, ImageDraw
6
  import gradio as gr
7
 
8
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
 
9
  from qwen_vl_utils import process_vision_info # include this file in your repo if not pip-installable
10
 
11
  # ---- model & processor loaded on CPU ----
12
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
13
  "ByteDance-Seed/UI-TARS-1.5-7B",
14
  device_map="auto",
15
+ torch_dtype=torch.float32, # CPU‑friendly
16
  )
17
  processor = AutoProcessor.from_pretrained(
18
  "ByteDance-Seed/UI-TARS-1.5-7B",
19
  size={"shortest_edge": 256 * 28 * 28, "longest_edge": 1344 * 28 * 28},
20
  use_fast=True,
 
21
  )
22
 
23
+
24
+ def draw_point(image: Image.Image, point=None, radius: int = 5):
25
+ """Overlay a red dot on the screenshot where the model clicked."""
26
  img = image.copy()
27
  if point:
28
  x, y = point[0] * img.width, point[1] * img.height
29
  ImageDraw.Draw(img).ellipse(
30
+ (x - radius, y - radius, x + radius, y + radius), fill="red"
31
  )
32
  return img
33
 
34
+
35
  @spaces.GPU
36
+ def navigate(screenshot, task: str, platform: str, history):
37
+ """Run one inference step on the GUI‑reasoning model.
38
+
39
+ Args:
40
+ screenshot (PIL.Image): Latest UI screenshot.
41
+ task (str): Natural‑language task description.
42
+ platform (str): Either "web" or "phone" for prompt conditioning.
43
+ history (list | str | None): Previous messages list. Accepts either an
44
+ actual Python list (via gr.JSON) or a JSON/Python‑literal string.
45
+ """
46
+
47
+ # ───────────────────── normalise history input ──────────────────────────
48
+ if history in (None, ""):
49
+ history_list = []
50
+ else:
51
+ if isinstance(history, str):
52
+ try:
53
+ history_list = ast.literal_eval(history)
54
+ except Exception as exc:
55
+ raise ValueError("`history` must be a JSON/Python list: " + str(exc))
56
+ else:
57
+ history_list = history
58
+
59
+ if not isinstance(history_list, list):
60
+ raise ValueError("`history` must be a list of messages.")
61
+
62
+ # ─────────────────── construct current user message ─────────────────────
63
+ prompt_header = (
64
+ "You are a GUI agent. You are given a task and your action history, "
65
+ "with screenshots. You need to perform the next action to complete "
66
+ "the task.\n\n## Output Format\n```\nThought: ...\nAction: ...\n```\n\n"
67
+ "## Action Space\nclick(start_box='...') / type(...)\n\n"
68
+ f"### Task\n{task}"
69
+ )
70
+
71
+ current_content = [
72
+ {"type": "text", "text": prompt_header},
73
+ {"type": "image_url", "image_url": screenshot},
74
  ]
75
+
76
+ messages = history_list + [{"role": "user", "content": current_content}]
77
+
78
+ # ─────────────────────────── model forward ─────────────────────────────
79
+ text = processor.apply_chat_template(
80
+ messages, tokenize=False, add_generation_prompt=True
81
+ )
82
  images, videos = process_vision_info(messages)
83
+ inputs = processor(
84
+ text=[text],
85
+ images=images,
86
+ videos=videos,
87
+ padding=True,
88
+ return_tensors="pt",
89
+ ).to("cuda")
90
 
 
91
  generated = model.generate(**inputs, max_new_tokens=128)
92
+ trimmed = [
93
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated)
94
+ ]
95
+ raw_out = processor.batch_decode(
96
+ trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
97
+ )[0]
98
 
99
+ # ─────── draw predicted click for quick visual verification (optional) ──────
100
  try:
101
+ actions = ast.literal_eval(raw_out)
102
  for act in actions if isinstance(actions, list) else [actions]:
103
+ pos = act.get("position")
104
+ if pos and isinstance(pos, list) and len(pos) == 2:
105
+ screenshot = draw_point(screenshot, pos)
106
+ except Exception:
107
+ # decoding failed β†’ just return original screenshot
108
+ pass
109
+
110
+ return screenshot, raw_out
111
 
112
 
113
+ # ────────────────────────── Gradio interface ───────────────────────────────
114
+
115
  demo = gr.Interface(
116
  fn=navigate,
117
  inputs=[
118
  gr.Image(type="pil", label="Screenshot"),
119
+ gr.Textbox(
120
+ lines=1,
121
+ placeholder="e.g. Search the weather for New York",
122
+ label="Task",
123
+ ),
124
  gr.Dropdown(choices=["web", "phone"], value="web", label="Platform"),
125
+ gr.JSON(label="Conversation History (list)", value=[]),
126
+ ],
127
+ outputs=[
128
+ gr.Image(label="With Click Point"),
129
+ gr.Textbox(label="Raw Action JSON"),
130
  ],
131
+ title="ShowUI‑2B Navigation Demo",
 
132
  )
133
 
134
  demo.launch(
 
136
  server_port=7860,
137
  share=False, # or True if you need a public link
138
  ssr_mode=False, # turn off experimental SSR so the process blocks
139
+ )