uitars / app.py
Stremly's picture
Update app.py
7de7341 verified
raw
history blame
5.7 kB
# app.py
import spaces
import ast
import torch
from PIL import Image, ImageDraw
import gradio as gr
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info # include this file in your repo if not pip-installable
# ---- model & processor loaded on CPU ----
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"ByteDance-Seed/UI-TARS-1.5-7B",
device_map="auto",
torch_dtype=torch.float16, # CPU‑friendly
)
processor = AutoProcessor.from_pretrained(
"ByteDance-Seed/UI-TARS-1.5-7B",
size={"shortest_edge": 100 * 28 * 28, "longest_edge": 16384 * 28 * 28},
use_fast=True,
)
def draw_point(image: Image.Image, point=None, radius: int = 5):
"""Overlay a red dot on the screenshot where the model clicked."""
img = image.copy()
if point:
x, y = point[0] * img.width, point[1] * img.height
ImageDraw.Draw(img).ellipse(
(x - radius, y - radius, x + radius, y + radius), fill="red"
)
return img
@spaces.GPU
def navigate(screenshot, task: str, platform: str, history):
"""Run one inference step on the GUI‑reasoning model.
Args:
screenshot (PIL.Image): Latest UI screenshot.
task (str): Natural‑language task description.
platform (str): Either "web" or "phone" for prompt conditioning.
history (list | str | None): Previous messages list. Accepts either an
actual Python list (via gr.JSON) or a JSON/Python‑literal string.
"""
# ───────────────────── normalise history input ──────────────────────────
messages=[]
if isinstance(history, str):
try:
messages= ast.literal_eval(history)
except Exception as exc:
raise ValueError("`history` must be a JSON/Python list: " + str(exc))
else:
messages = history
prompt_header = (
"You are a GUI agent. You are given a task and your action history, with screenshots."
"You need to perform the next action to complete the task. \n\n## Output Format\n```\nThought: ...\nAction: ...\n```\n\n## Action Space\n\nclick(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e')\nleft_double(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e')\nright_single(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e')\ndrag(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e', end_box='\u003c|box_start|\u003e(x3, y3)\u003c|box_end|\u003e')\nhotkey(key='')\ntype(content='') #If you want to submit your input, use \"\\n\" at the end of `content`.\nscroll(start_box='\u003c|box_start|\u003e(x1, y1)\u003c|box_end|\u003e', direction='down or up or right or left')\nwait() #Sleep for 5s and take a screenshot to check for any changes.\nfinished(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format.\n\n\n## Note\n- Use English in `Thought` part.\n- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.\n\n"
f"## User Instruction\n{task}"
)
current = {"role":"user","content":[{"type":"text","text":prompt_header},{"type": "image_url", "image_url":screenshot}]}
messages.append(current)
# ─────────────────────────── model forward ─────────────────────────────
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
images, videos = process_vision_info(messages)
inputs = processor(
text=[text],
images=images,
videos=videos,
padding=True,
return_tensors="pt",
).to("cuda")
generated = model.generate(**inputs, max_new_tokens=128)
trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated)
]
raw_out = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# ─────── draw predicted click for quick visual verification (optional) ──────
try:
actions = ast.literal_eval(raw_out)
for act in actions if isinstance(actions, list) else [actions]:
pos = act.get("position")
if pos and isinstance(pos, list) and len(pos) == 2:
screenshot = draw_point(screenshot, pos)
except Exception:
# decoding failed β†’ just return original screenshot
pass
return screenshot, raw_out
# ────────────────────────── Gradio interface ───────────────────────────────
demo = gr.Interface(
fn=navigate,
inputs=[
gr.Image(type="pil", label="Screenshot"),
gr.Textbox(
lines=1,
placeholder="e.g. Search the weather for New York",
label="Task",
),
gr.Dropdown(choices=["web", "phone"], value="web", label="Platform"),
gr.JSON(label="Conversation History (list)", value=[]),
],
outputs=[
gr.Image(label="With Click Point"),
gr.Textbox(label="Raw Action JSON"),
],
title="UI-Tars Navigation Demo",
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False, # or True if you need a public link
ssr_mode=False, # turn off experimental SSR so the process blocks
)