Stremly commited on
Commit
65f9291
Β·
verified Β·
1 Parent(s): bb332d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -72
app.py CHANGED
@@ -11,16 +11,10 @@ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
11
  from qwen_vl_utils import process_vision_info # include this file in your repo if not pip-installable
12
 
13
  # ---- model & processor loaded on CPU ----
14
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
15
- "ByteDance-Seed/UI-TARS-1.5-7B",
16
- device_map="auto",
17
- torch_dtype=torch.float16, # CPU‑friendly
18
- )
19
- processor = AutoProcessor.from_pretrained(
20
- "ByteDance-Seed/UI-TARS-1.5-7B",
21
- size={"shortest_edge": 100 * 28 * 28, "longest_edge": 16384 * 28 * 28},
22
- use_fast=True,
23
- )
24
 
25
 
26
  def draw_point(image: Image.Image, point=None, radius: int = 5):
@@ -46,72 +40,109 @@ def navigate(screenshot, task: str, platform: str, history):
46
  actual Python list (via gr.JSON) or a JSON/Python‑literal string.
47
  """
48
 
49
- # ───────────────────── normalise history input ──────────────────────────
50
- messages=[]
51
-
52
- if isinstance(history, str):
53
- try:
54
- messages= ast.literal_eval(history)
55
- except Exception as exc:
56
- raise ValueError("`history` must be a JSON/Python list: " + str(exc))
57
- else:
58
- messages = history
59
-
60
- prompt_header = (
61
- "You are a GUI agent. You are given a task and your action history, with screenshots."
62
- "You need to perform the next action to complete the task. \n\n## Output Format\n```\nThought: ...\nAction: ...\n```\n\n## Action Space\n\nclick(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e')\nleft_double(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e')\nright_single(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e')\ndrag(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e', end_box='\u003c|box_start|\u003e(x3, y3)\u003c|box_end|\u003e')\nhotkey(key='')\ntype(content='') #If you want to submit your input, use \"\\n\" at the end of `content`.\nscroll(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e', direction='down or up or right or left')\nwait() #Sleep for 5s and take a screenshot to check for any changes.\nfinished(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format.\n\n\n## Note\n- Use English in `Thought` part.\n- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.\n\n"
63
- f"## User Instruction\n{task}"
64
  )
65
- current = {"role":"user","content":[{"type":"text","text":prompt_header},{"type": "image_url", "image_url":screenshot}]}
66
 
67
- messages.append(current)
68
-
 
 
 
 
 
69
 
70
- # ─────────────────────────── model forward ─────────────────────────────
71
-
72
- images, videos = process_vision_info(messages)
73
- i=0
74
- for message in messages:
75
- if message['role'] == 'user' and isinstance(message.get('content'), list):
76
- for item in message['content']:
77
- if item.get('type') == 'image_url' and isinstance(item.get('image_url'), str):
78
- item['image_url'] = images[i]
79
- i+=1
80
-
81
- text = processor.apply_chat_template(
82
- messages, tokenize=False, add_generation_prompt=True
83
- )
84
- print("\nimages\n:",images)
85
- print("\ntext\n",text)
86
- print("\nmessages\n",messages)
87
- inputs = processor(
88
- text=[text],
89
- images=images,
90
- videos=videos,
91
- padding=True,
92
- return_tensors="pt",
93
- ).to("cuda")
94
-
95
- generated = model.generate(**inputs, max_new_tokens=128)
96
- trimmed = [
97
- out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated)
98
- ]
99
- raw_out = processor.batch_decode(
100
- trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
101
- )[0]
102
-
103
- # ─────── draw predicted click for quick visual verification (optional) ──────
104
  try:
105
- actions = ast.literal_eval(raw_out)
106
- for act in actions if isinstance(actions, list) else [actions]:
107
- pos = act.get("position")
108
- if pos and isinstance(pos, list) and len(pos) == 2:
109
- screenshot = draw_point(screenshot, pos)
110
- except Exception:
111
- # decoding failed β†’ just return original screenshot
112
- pass
113
-
114
- return screenshot, raw_out, messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
  # ────────────────────────── Gradio interface ───────────────────────────────
 
11
  from qwen_vl_utils import process_vision_info # include this file in your repo if not pip-installable
12
 
13
  # ---- model & processor loaded on CPU ----
14
+
15
+ # ─── lazy-load cache ──────────────────────────────────────────
16
+ _MODEL = None # will hold the quantised weights
17
+ _PROCESSOR = None # will hold the resized processor
 
 
 
 
 
 
18
 
19
 
20
  def draw_point(image: Image.Image, point=None, radius: int = 5):
 
40
  actual Python list (via gr.JSON) or a JSON/Python‑literal string.
41
  """
42
 
43
+ # ------- on-demand model / processor load -------------------------
44
+ if _MODEL is None:
45
+ from transformers import BitsAndBytesConfig
46
+
47
+ # 4-bit quantisation (~6 GB on H200)
48
+ bnb_cfg = BitsAndBytesConfig(
49
+ load_in_4bit=True,
50
+ bnb_4bit_compute_dtype=torch.float16,
51
+ bnb_4bit_use_double_quant=True,
 
 
 
 
 
 
52
  )
 
53
 
54
+ _MODEL = Qwen2_5_VLForConditionalGeneration.from_pretrained(
55
+ "ByteDance-Seed/UI-TARS-1.5-7B",
56
+ quantization_config=bnb_cfg,
57
+ device_map="auto",
58
+ torch_dtype=torch.float16,
59
+ low_cpu_mem_usage=True,
60
+ )
61
 
62
+ _PROCESSOR = AutoProcessor.from_pretrained(
63
+ "ByteDance-Seed/UI-TARS-1.5-7B",
64
+ size={"shortest_edge": 512, "longest_edge": 1344}, # sane res
65
+ use_fast=True,
66
+ )
67
+
68
+ # use mem-efficient attention kernels
69
+ torch.backends.cuda.enable_flash_sdp(False)
70
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
71
+
72
+ model = _MODEL
73
+ processor = _PROCESSOR
74
+
75
+ # ───────────────────── normalise history input ──────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  try:
77
+ messages=[]
78
+
79
+ if isinstance(history, str):
80
+ try:
81
+ messages= ast.literal_eval(history)
82
+ except Exception as exc:
83
+ raise ValueError("`history` must be a JSON/Python list: " + str(exc))
84
+ else:
85
+ messages = history
86
+
87
+ prompt_header = (
88
+ "You are a GUI agent. You are given a task and your action history, with screenshots."
89
+ "You need to perform the next action to complete the task. \n\n## Output Format\n```\nThought: ...\nAction: ...\n```\n\n## Action Space\n\nclick(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e')\nleft_double(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e')\nright_single(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e')\ndrag(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e', end_box='\u003c|box_start|\u003e(x3, y3)\u003c|box_end|\u003e')\nhotkey(key='')\ntype(content='') #If you want to submit your input, use \"\\n\" at the end of `content`.\nscroll(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e', direction='down or up or right or left')\nwait() #Sleep for 5s and take a screenshot to check for any changes.\nfinished(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format.\n\n\n## Note\n- Use English in `Thought` part.\n- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.\n\n"
90
+ f"## User Instruction\n{task}"
91
+ )
92
+ current = {"role":"user","content":[{"type":"text","text":prompt_header},{"type": "image_url", "image_url":screenshot}]}
93
+
94
+ messages.append(current)
95
+
96
+
97
+ # ─────────────────────────── model forward ─────────────────────────────
98
+
99
+ images, videos = process_vision_info(messages)
100
+ i=0
101
+ for message in messages:
102
+ if message['role'] == 'user' and isinstance(message.get('content'), list):
103
+ for item in message['content']:
104
+ if item.get('type') == 'image_url' and isinstance(item.get('image_url'), str):
105
+ item['image_url'] = images[i]
106
+ i+=1
107
+
108
+ text = processor.apply_chat_template(
109
+ messages, tokenize=False, add_generation_prompt=True
110
+ )
111
+ print("\nimages\n:",images)
112
+ print("\ntext\n",text)
113
+ print("\nmessages\n",messages)
114
+ inputs = processor(
115
+ text=[text],
116
+ images=images,
117
+ videos=videos,
118
+ padding=True,
119
+ return_tensors="pt",
120
+ ).to("cuda")
121
+
122
+ generated = model.generate(**inputs, max_new_tokens=128)
123
+ trimmed = [
124
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated)
125
+ ]
126
+ raw_out = processor.batch_decode(
127
+ trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
128
+ )[0]
129
+
130
+ # ─────── draw predicted click for quick visual verification (optional) ──────
131
+ try:
132
+ actions = ast.literal_eval(raw_out)
133
+ for act in actions if isinstance(actions, list) else [actions]:
134
+ pos = act.get("position")
135
+ if pos and isinstance(pos, list) and len(pos) == 2:
136
+ screenshot = draw_point(screenshot, pos)
137
+ except Exception:
138
+ # decoding failed β†’ just return original screenshot
139
+ pass
140
+
141
+ return screenshot, raw_out, messages
142
+
143
+ finally: # ← always executed
144
+ torch.cuda.empty_cache() # free unused blocks
145
+ torch.cuda.ipc_collect() # defrag for next call
146
 
147
 
148
  # ────────────────────────── Gradio interface ───────────────────────────────