uitars / app.py
Stremly's picture
Update app.py
dc6b09c verified
# app.py
import spaces
import ast
import torch
from PIL import Image, ImageDraw
import gradio as gr
import base64
from io import BytesIO
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info # include this file in your repo if not pip-installable
_MODEL = Qwen2_5_VLForConditionalGeneration.from_pretrained(
"ByteDance-Seed/UI-TARS-1.5-7B",
device_map="auto",
torch_dtype=torch.float16
)
_PROCESSOR = AutoProcessor.from_pretrained(
"ByteDance-Seed/UI-TARS-1.5-7B",
size={"shortest_edge": 100 * 28 * 28, "longest_edge": 16384 * 28 * 28}, # sane res
use_fast=True,
)
model = _MODEL
processor = _PROCESSOR
def draw_point(image: Image.Image, point=None, radius: int = 5):
"""Overlay a red dot on the screenshot where the model clicked."""
img = image.copy()
if point:
x, y = point[0] * img.width, point[1] * img.height
ImageDraw.Draw(img).ellipse(
(x - radius, y - radius, x + radius, y + radius), fill="red"
)
return img
@spaces.GPU
def navigate(screenshot, task: str):
"""Run one inference step on the GUI‑reasoning model.
Args:
screenshot (PIL.Image): Latest UI screenshot.
task (str): Natural‑language task description
history (list | str | None): Previous messages list. Accepts either an
actual Python list (via gr.JSON) or a JSON/Python‑literal string.
"""
# ───────────────────── normalise history input ──────────────────────────
messages=[]
prompt_header = (
"You are a GUI agent. You are given a task and your action history, with screenshots."
"You need to perform the next action to complete the task. \n\n## Output Format\n```\nThought: ...\nAction: ...\n```\n\n## Action Space\n\nclick(start_box='<|box_start|>(x1, y1)<|box_end|>')\nleft_double(start_box='<|box_start|>(x1, y1)<|box_end|>')\nright_single(start_box='<|box_start|<(x1, y1)>|box_end|>')\ndrag(start_box='<|box_start|>(x1, y1)<|box_end|>', end_box='<|box_start|>(x3, y3)<|box_end|>')\n\n\n## Note\n- Use English in `Thought` part.\n- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.\nONLY OUTPUT THE CLICKS ACTIONS(CLICK, RIGHT SINGLE, LEFT DOUBLE)\n\n"
f"## User Instruction\n{task}"
)
current = {"role":"user","content":[{"type":"text","text":prompt_header},{"type": "image_url", "image_url":screenshot}]}
messages.append(current)
#New Comment 1
# ─────────────────────────── model forward ─────────────────────────────
images, videos = process_vision_info(messages)
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[text],
images=images,
videos=videos,
padding=True,
return_tensors="pt",
).to("cuda")
generated = model.generate(**inputs, max_new_tokens=128)
trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated)
]
raw_out = processor.batch_decode(
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
# ─────── draw predicted click for quick visual verification (optional) ──────
try:
actions = ast.literal_eval(raw_out)
for act in actions if isinstance(actions, list) else [actions]:
pos = act.get("position")
if pos and isinstance(pos, list) and len(pos) == 2:
screenshot = draw_point(screenshot, pos)
except Exception:
# decoding failed β†’ just return original screenshot
pass
return screenshot, raw_out, messages
# ────────────────────────── Gradio interface ───────────────────────────────
demo = gr.Interface(
fn=navigate,
inputs=[
gr.Image(type="pil", label="Screenshot"),
gr.Textbox(
lines=1,
placeholder="e.g. Search the weather for New York",
label="Task",
)
],
outputs=[
gr.Image(label="With Click Point"),
gr.Textbox(label="Raw Action JSON"),
gr.JSON(label="Updated Conversation History")
],
title="UI-Tars Navigation Demo",
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False, # or True if you need a public link
ssr_mode=False, # turn off experimental SSR so the process blocks
)