Spaces:
Running
Running
# app.py | |
import spaces | |
import ast | |
import torch | |
from PIL import Image, ImageDraw | |
import gradio as gr | |
import base64 | |
from io import BytesIO | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info # include this file in your repo if not pip-installable | |
_MODEL = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
"ByteDance-Seed/UI-TARS-1.5-7B", | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
_PROCESSOR = AutoProcessor.from_pretrained( | |
"ByteDance-Seed/UI-TARS-1.5-7B", | |
size={"shortest_edge": 100 * 28 * 28, "longest_edge": 16384 * 28 * 28}, # sane res | |
use_fast=True, | |
) | |
model = _MODEL | |
processor = _PROCESSOR | |
def draw_point(image: Image.Image, point=None, radius: int = 5): | |
"""Overlay a red dot on the screenshot where the model clicked.""" | |
img = image.copy() | |
if point: | |
x, y = point[0] * img.width, point[1] * img.height | |
ImageDraw.Draw(img).ellipse( | |
(x - radius, y - radius, x + radius, y + radius), fill="red" | |
) | |
return img | |
def navigate(screenshot, task: str): | |
"""Run one inference step on the GUIβreasoning model. | |
Args: | |
screenshot (PIL.Image): Latest UI screenshot. | |
task (str): Naturalβlanguage task description | |
history (list | str | None): Previous messages list. Accepts either an | |
actual Python list (via gr.JSON) or a JSON/Pythonβliteral string. | |
""" | |
# βββββββββββββββββββββ normalise history input ββββββββββββββββββββββββββ | |
messages=[] | |
prompt_header = ( | |
"You are a GUI agent. You are given a task and your action history, with screenshots." | |
"You need to perform the next action to complete the task. \n\n## Output Format\n```\nThought: ...\nAction: ...\n```\n\n## Action Space\n\nclick(start_box='<|box_start|>(x1, y1)<|box_end|>')\nleft_double(start_box='<|box_start|>(x1, y1)<|box_end|>')\nright_single(start_box='<|box_start|<(x1, y1)>|box_end|>')\ndrag(start_box='<|box_start|>(x1, y1)<|box_end|>', end_box='<|box_start|>(x3, y3)<|box_end|>')\n\n\n## Note\n- Use English in `Thought` part.\n- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.\nONLY OUTPUT THE CLICKS ACTIONS(CLICK, RIGHT SINGLE, LEFT DOUBLE)\n\n" | |
f"## User Instruction\n{task}" | |
) | |
current = {"role":"user","content":[{"type":"text","text":prompt_header},{"type": "image_url", "image_url":screenshot}]} | |
messages.append(current) | |
#New Comment 1 | |
# βββββββββββββββββββββββββββ model forward βββββββββββββββββββββββββββββ | |
images, videos = process_vision_info(messages) | |
text = processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
inputs = processor( | |
text=[text], | |
images=images, | |
videos=videos, | |
padding=True, | |
return_tensors="pt", | |
).to("cuda") | |
generated = model.generate(**inputs, max_new_tokens=128) | |
trimmed = [ | |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated) | |
] | |
raw_out = processor.batch_decode( | |
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
)[0] | |
# βββββββ draw predicted click for quick visual verification (optional) ββββββ | |
try: | |
actions = ast.literal_eval(raw_out) | |
for act in actions if isinstance(actions, list) else [actions]: | |
pos = act.get("position") | |
if pos and isinstance(pos, list) and len(pos) == 2: | |
screenshot = draw_point(screenshot, pos) | |
except Exception: | |
# decoding failed β just return original screenshot | |
pass | |
return screenshot, raw_out, messages | |
# ββββββββββββββββββββββββββ Gradio interface βββββββββββββββββββββββββββββββ | |
demo = gr.Interface( | |
fn=navigate, | |
inputs=[ | |
gr.Image(type="pil", label="Screenshot"), | |
gr.Textbox( | |
lines=1, | |
placeholder="e.g. Search the weather for New York", | |
label="Task", | |
) | |
], | |
outputs=[ | |
gr.Image(label="With Click Point"), | |
gr.Textbox(label="Raw Action JSON"), | |
gr.JSON(label="Updated Conversation History") | |
], | |
title="UI-Tars Navigation Demo", | |
) | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, # or True if you need a public link | |
ssr_mode=False, # turn off experimental SSR so the process blocks | |
) | |