Commit
·
9f57ecf
1
Parent(s):
8ecfcea
- app.py +180 -0
- gui_actor/__init__.py +0 -0
- gui_actor/constants.py +40 -0
- gui_actor/dataset.py +533 -0
- gui_actor/inference.py +300 -0
- gui_actor/modeling.py +361 -0
- gui_actor/modeling_qwen25vl.py +337 -0
- gui_actor/trainer.py +313 -0
- gui_actor/utils.py +90 -0
- requirements.txt +16 -0
app.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64, os
|
2 |
+
# import spaces
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Optional
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from qwen_vl_utils import process_vision_info
|
11 |
+
from datasets import load_dataset
|
12 |
+
from transformers import AutoProcessor
|
13 |
+
from gui_actor.constants import chat_template
|
14 |
+
from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
|
15 |
+
from gui_actor.inference import inference
|
16 |
+
|
17 |
+
MAX_PIXELS = 3200 * 1800
|
18 |
+
|
19 |
+
def resize_image(image, resize_to_pixels=MAX_PIXELS):
|
20 |
+
image_width, image_height = image.size
|
21 |
+
if (resize_to_pixels is not None) and ((image_width * image_height) != resize_to_pixels):
|
22 |
+
resize_ratio = (resize_to_pixels / (image_width * image_height)) ** 0.5
|
23 |
+
image_width_resized, image_height_resized = int(image_width * resize_ratio), int(image_height * resize_ratio)
|
24 |
+
image = image.resize((image_width_resized, image_height_resized))
|
25 |
+
return image
|
26 |
+
|
27 |
+
# @spaces.GPU
|
28 |
+
@torch.inference_mode()
|
29 |
+
def draw_point(image: Image.Image, point: list, radius=8, color=(255, 0, 0, 128)):
|
30 |
+
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
|
31 |
+
overlay_draw = ImageDraw.Draw(overlay)
|
32 |
+
x, y = point
|
33 |
+
overlay_draw.ellipse(
|
34 |
+
[(x - radius, y - radius), (x + radius, y + radius)],
|
35 |
+
outline=color,
|
36 |
+
width=5 # Adjust thickness as needed
|
37 |
+
)
|
38 |
+
image = image.convert('RGBA')
|
39 |
+
combined = Image.alpha_composite(image, overlay)
|
40 |
+
combined = combined.convert('RGB')
|
41 |
+
return combined
|
42 |
+
|
43 |
+
# @spaces.GPU
|
44 |
+
@torch.inference_mode()
|
45 |
+
def get_attn_map(image, attn_scores, n_width, n_height):
|
46 |
+
w, h = image.size
|
47 |
+
scores = np.array(attn_scores[0]).reshape(n_height, n_width)
|
48 |
+
|
49 |
+
scores_norm = (scores - scores.min()) / (scores.max() - scores.min())
|
50 |
+
# Resize score map to match image size
|
51 |
+
score_map = Image.fromarray((scores_norm * 255).astype(np.uint8)).resize((w, h), resample=Image.NEAREST) # BILINEAR)
|
52 |
+
# Apply colormap
|
53 |
+
colormap = plt.get_cmap('jet')
|
54 |
+
colored_score_map = colormap(np.array(score_map) / 255.0) # returns RGBA
|
55 |
+
colored_score_map = (colored_score_map[:, :, :3] * 255).astype(np.uint8)
|
56 |
+
colored_overlay = Image.fromarray(colored_score_map)
|
57 |
+
|
58 |
+
# Blend with original image
|
59 |
+
blended = Image.blend(image, colored_overlay, alpha=0.3)
|
60 |
+
return blended
|
61 |
+
|
62 |
+
# load model
|
63 |
+
if torch.cuda.is_available():
|
64 |
+
# os.system('pip install flash-attn --no-build-isolation')
|
65 |
+
model_name_or_path = "microsoft/GUI-Actor-7B-Qwen2.5-VL"
|
66 |
+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
|
67 |
+
tokenizer = data_processor.tokenizer
|
68 |
+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
|
69 |
+
model_name_or_path,
|
70 |
+
torch_dtype=torch.bfloat16,
|
71 |
+
device_map="cuda:0",
|
72 |
+
attn_implementation="flash_attention_2"
|
73 |
+
).eval()
|
74 |
+
else:
|
75 |
+
model_name_or_path = "microsoft/GUI-Actor-3B-Qwen2.5-VL"
|
76 |
+
data_processor = AutoProcessor.from_pretrained(model_name_or_path)
|
77 |
+
tokenizer = data_processor.tokenizer
|
78 |
+
model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
|
79 |
+
model_name_or_path,
|
80 |
+
torch_dtype=torch.bfloat16,
|
81 |
+
device_map="cpu"
|
82 |
+
).eval()
|
83 |
+
|
84 |
+
title = "GUI-Actor"
|
85 |
+
header = """
|
86 |
+
<div align="center">
|
87 |
+
<h1 style="padding-bottom: 10px; padding-top: 10px;">🎯 <strong>GUI-Actor</strong>: Coordinate-Free Visual Grounding for GUI Agents</h1>
|
88 |
+
<div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
|
89 |
+
Qianhui Wu*, Kanzhi Cheng*, Rui Yang*, Chaoyun Zhang, Jianwei Yang, Huiqiang Jiang, Jian Mu, Baolin Peng, Bo Qiao, Reuben Tan, Si Qin, Lars Liden<br>
|
90 |
+
Qingwei Lin, Huan Zhang, Tong Zhang, Jianbing Zhang, Dongmei Zhang, Jianfeng Gao<br/>
|
91 |
+
</div>
|
92 |
+
<div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
|
93 |
+
<a href="https://microsoft.github.io/GUI-Actor/">🌐 Project Page</a> | <a href="https://arxiv.org/abs/2403.12968">📄 arXiv Paper</a> | <a href="https://github.com/microsoft/GUI-Actor">💻 Github Repo</a><br/>
|
94 |
+
</div>
|
95 |
+
</div>
|
96 |
+
"""
|
97 |
+
|
98 |
+
theme = "soft"
|
99 |
+
css = """#anno-img .mask {opacity: 0.5; transition: all 0.2s ease-in-out;}
|
100 |
+
#anno-img .mask.active {opacity: 0.7}"""
|
101 |
+
|
102 |
+
# @spaces.GPU
|
103 |
+
@torch.inference_mode()
|
104 |
+
def process(image, instruction):
|
105 |
+
# resize image
|
106 |
+
w, h = image.size
|
107 |
+
if w * h > MAX_PIXELS:
|
108 |
+
image = resize_image(image)
|
109 |
+
|
110 |
+
conversation = [
|
111 |
+
{
|
112 |
+
"role": "system",
|
113 |
+
"content": [
|
114 |
+
{
|
115 |
+
"type": "text",
|
116 |
+
"text": "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>).",
|
117 |
+
}
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"role": "user",
|
122 |
+
"content": [
|
123 |
+
{
|
124 |
+
"type": "image",
|
125 |
+
"image": image, # PIL.Image.Image or str to path
|
126 |
+
# "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64,"
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"type": "text",
|
130 |
+
"text": instruction,
|
131 |
+
},
|
132 |
+
],
|
133 |
+
},
|
134 |
+
]
|
135 |
+
|
136 |
+
try:
|
137 |
+
pred = inference(conversation, model, tokenizer, data_processor, use_placeholder=True, topk=3)
|
138 |
+
except Exception as e:
|
139 |
+
print(e)
|
140 |
+
return image, f"Error: {e}", None
|
141 |
+
|
142 |
+
px, py = pred["topk_points"][0]
|
143 |
+
output_coord = f"({px:.4f}, {py:.4f})"
|
144 |
+
img_with_point = draw_point(image, (px * w, py * h))
|
145 |
+
|
146 |
+
n_width, n_height = pred["n_width"], pred["n_height"]
|
147 |
+
attn_scores = pred["attn_scores"]
|
148 |
+
att_map = get_attn_map(image, attn_scores, n_width, n_height)
|
149 |
+
|
150 |
+
return img_with_point, output_coord, att_map
|
151 |
+
|
152 |
+
|
153 |
+
with gr.Blocks(title=title, css=css) as demo:
|
154 |
+
gr.Markdown(header)
|
155 |
+
with gr.Row():
|
156 |
+
with gr.Column():
|
157 |
+
input_image = gr.Image(
|
158 |
+
type='pil', label='Upload image')
|
159 |
+
# text box
|
160 |
+
input_instruction = gr.Textbox(label='Instruction', placeholder='Text your (low-level) instruction here')
|
161 |
+
submit_button = gr.Button(
|
162 |
+
value='Submit', variant='primary')
|
163 |
+
with gr.Column():
|
164 |
+
image_with_point = gr.Image(type='pil', label='Image with Point (red circle)')
|
165 |
+
with gr.Accordion('Detailed prediction'):
|
166 |
+
pred_xy = gr.Textbox(label='Predicted Coordinates', placeholder='(x, y)')
|
167 |
+
att_map = gr.Image(type='pil', label='Attention Map')
|
168 |
+
|
169 |
+
submit_button.click(
|
170 |
+
fn=process,
|
171 |
+
inputs=[
|
172 |
+
input_image,
|
173 |
+
input_instruction
|
174 |
+
],
|
175 |
+
outputs=[image_with_point, pred_xy, att_map]
|
176 |
+
)
|
177 |
+
|
178 |
+
# demo.launch(debug=False, show_error=True, share=True)
|
179 |
+
# demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
|
180 |
+
demo.queue().launch(share=False)
|
gui_actor/__init__.py
ADDED
File without changes
|
gui_actor/constants.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
4 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
5 |
+
|
6 |
+
LOGDIR = "."
|
7 |
+
|
8 |
+
# Model Constants
|
9 |
+
IGNORE_INDEX = -100
|
10 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
11 |
+
DEFAULT_POINTER_START_TOKEN = "<|pointer_start|>"
|
12 |
+
DEFAULT_POINTER_END_TOKEN = "<|pointer_end|>"
|
13 |
+
DEFAULT_POINTER_PAD_TOKEN = "<|pointer_pad|>"
|
14 |
+
|
15 |
+
# UNMASK_TOKEN_IDS = [198, 151644, 151645]
|
16 |
+
|
17 |
+
# System Message
|
18 |
+
grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>)."
|
19 |
+
|
20 |
+
# Chat Template
|
21 |
+
chat_template = "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
22 |
+
|
23 |
+
assistant_template = "{% for message in messages %}{{'<|im_start|>' + message['role']}}{% if 'recipient' in message %}<|recipient|>{{ message['recipient'] }}{% endif %}{{'\n' + message['content'][0]['text']}}{% if 'end_turn' in message and message['end_turn'] %}{{'<|diff_marker|>\n'}}{% else %}{{'<|im_end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|recipient|>' }}{% endif %}"
|
24 |
+
|
25 |
+
# Special Tokens
|
26 |
+
ADDITIONAL_SPECIAL_TOKENS = [
|
27 |
+
"<|recipient|>",
|
28 |
+
"<|diff_marker|>",
|
29 |
+
DEFAULT_POINTER_START_TOKEN,
|
30 |
+
DEFAULT_POINTER_END_TOKEN,
|
31 |
+
DEFAULT_POINTER_PAD_TOKEN,
|
32 |
+
]
|
33 |
+
|
34 |
+
# Action Patterns to be replaced with special tokens
|
35 |
+
ACTION_PATTENS_XY = [
|
36 |
+
r"x=([0-9.]+), y=([0-9.]+)",
|
37 |
+
r"from_coord=\[([0-9.]+), ([0-9.]+)\], to_coord=\[([0-9.]+), ([0-9.]+)\]",
|
38 |
+
]
|
39 |
+
|
40 |
+
until = ["<|diff_marker|>"]
|
gui_actor/dataset.py
ADDED
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import re
|
7 |
+
import ast
|
8 |
+
from typing import Dict
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import transformers
|
12 |
+
import yaml
|
13 |
+
from qwen_vl_utils import smart_resize, process_vision_info
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
|
16 |
+
from gui_actor.constants import (
|
17 |
+
IGNORE_INDEX,
|
18 |
+
DEFAULT_IMAGE_TOKEN,
|
19 |
+
DEFAULT_POINTER_START_TOKEN,
|
20 |
+
DEFAULT_POINTER_PAD_TOKEN,
|
21 |
+
DEFAULT_POINTER_END_TOKEN,
|
22 |
+
ACTION_PATTENS_XY,
|
23 |
+
ADDITIONAL_SPECIAL_TOKENS,
|
24 |
+
assistant_template,
|
25 |
+
chat_template,
|
26 |
+
grounding_system_message,
|
27 |
+
)
|
28 |
+
from gui_actor.trainer import rank0_print
|
29 |
+
|
30 |
+
|
31 |
+
def reformat_coordinates(text):
|
32 |
+
"""
|
33 |
+
(1) Find all the coordinates in the text.
|
34 |
+
(2) Replace the coordinates with the special tokens.
|
35 |
+
(3) Return the new text and the coordinates as a list of (x, y), where x in [0, 1] and y in [0, 1].
|
36 |
+
"""
|
37 |
+
epsilon = 0.001
|
38 |
+
def adjust_coord(c):
|
39 |
+
"""
|
40 |
+
Adjust coordinate if it is too close to 0 or 1.
|
41 |
+
"""
|
42 |
+
if abs(c) < epsilon:
|
43 |
+
return epsilon
|
44 |
+
elif abs(c - 1) < epsilon:
|
45 |
+
return 1 - epsilon
|
46 |
+
return c
|
47 |
+
|
48 |
+
all_matches = []
|
49 |
+
for pattern in ACTION_PATTENS_XY:
|
50 |
+
matches = list(re.finditer(pattern, text))
|
51 |
+
for match in matches:
|
52 |
+
all_matches.append((match.start(), match.groups()))
|
53 |
+
if pattern == ACTION_PATTENS_XY[0]:
|
54 |
+
target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
|
55 |
+
else:
|
56 |
+
target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}, {DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
|
57 |
+
text = re.sub(
|
58 |
+
pattern,
|
59 |
+
target_text,
|
60 |
+
text
|
61 |
+
)
|
62 |
+
|
63 |
+
coordinates = []
|
64 |
+
all_matches.sort(key=lambda x: x[0])
|
65 |
+
# Extract coordinates in order
|
66 |
+
for _, groups in all_matches:
|
67 |
+
# When two coordinate values are found, parse them as one (x, y) pair.
|
68 |
+
if len(groups) == 2:
|
69 |
+
x_str, y_str = groups
|
70 |
+
x = adjust_coord(ast.literal_eval(x_str))
|
71 |
+
y = adjust_coord(ast.literal_eval(y_str))
|
72 |
+
coordinates.append((x, y))
|
73 |
+
# When four coordinate values are found, parse them as two pairs.
|
74 |
+
elif len(groups) == 4:
|
75 |
+
x1_str, y1_str, x2_str, y2_str = groups
|
76 |
+
x1 = adjust_coord(ast.literal_eval(x1_str))
|
77 |
+
y1 = adjust_coord(ast.literal_eval(y1_str))
|
78 |
+
x2 = adjust_coord(ast.literal_eval(x2_str))
|
79 |
+
y2 = adjust_coord(ast.literal_eval(y2_str))
|
80 |
+
coordinates.append((x1, y1))
|
81 |
+
coordinates.append((x2, y2))
|
82 |
+
|
83 |
+
return text, coordinates
|
84 |
+
|
85 |
+
def get_token_index(image_processor, image, point_x, point_y):
|
86 |
+
"""
|
87 |
+
Get the index of the visual token that contains the point (x, y).
|
88 |
+
Args:
|
89 |
+
image_processor: the image processor
|
90 |
+
image: the image in PIL format
|
91 |
+
point_x: the x coordinate of the point, in [0, 1].
|
92 |
+
point_y: the y coordinate of the point, in [0, 1].
|
93 |
+
"""
|
94 |
+
if len(image) != 1:
|
95 |
+
raise ValueError(f"Expected 1 image, got {len(image)}")
|
96 |
+
|
97 |
+
# get the original image size and the resized image size
|
98 |
+
image = image[0]
|
99 |
+
w, h = image.size
|
100 |
+
px, py = w * point_x, h * point_y
|
101 |
+
# rank0_print(f"px: {px}, py: {py}")
|
102 |
+
# get the token index
|
103 |
+
merge_patch_size = image_processor.patch_size * image_processor.merge_size
|
104 |
+
x_index = math.floor(px / merge_patch_size)
|
105 |
+
y_index = math.floor(py / merge_patch_size)
|
106 |
+
|
107 |
+
visual_token_index = y_index * (w // merge_patch_size) + x_index
|
108 |
+
|
109 |
+
# merge all above print into one line
|
110 |
+
return visual_token_index
|
111 |
+
|
112 |
+
def get_multi_patch_labels(image_processor, image, bbox_gt):
|
113 |
+
"""
|
114 |
+
Get the multi-patch labels for the bounding box.
|
115 |
+
Args:
|
116 |
+
image_processor: the image processor
|
117 |
+
image: the image in PIL format
|
118 |
+
bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1]
|
119 |
+
"""
|
120 |
+
if len(image) != 1:
|
121 |
+
raise ValueError(f"Expected 1 image, got {len(image)}")
|
122 |
+
|
123 |
+
# Get the original image size and the resized image size
|
124 |
+
image = image[0]
|
125 |
+
w, h = image.size
|
126 |
+
|
127 |
+
bbox_gt = [bbox_gt[0]*w, bbox_gt[1]*h, bbox_gt[2]*w, bbox_gt[3]*h]
|
128 |
+
# Extract bounding box coordinates
|
129 |
+
x_min, y_min, x_max, y_max = bbox_gt
|
130 |
+
x_min = max(0, x_min)
|
131 |
+
y_min = max(0, y_min)
|
132 |
+
x_max = min(w, x_max)
|
133 |
+
y_max = min(h, y_max)
|
134 |
+
|
135 |
+
merge_patch_size = image_processor.patch_size * image_processor.merge_size
|
136 |
+
assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}"
|
137 |
+
grid_h, grid_w = h // merge_patch_size, w // merge_patch_size
|
138 |
+
|
139 |
+
binary_mask = torch.zeros(grid_h * grid_w)
|
140 |
+
# Iterate through all patches, check if they overlap with the bounding box
|
141 |
+
for y_idx in range(grid_h):
|
142 |
+
for x_idx in range(grid_w):
|
143 |
+
# Calculate patch boundaries
|
144 |
+
patch_x_min = x_idx * merge_patch_size
|
145 |
+
patch_y_min = y_idx * merge_patch_size
|
146 |
+
patch_x_max = patch_x_min + merge_patch_size
|
147 |
+
patch_y_max = patch_y_min + merge_patch_size
|
148 |
+
|
149 |
+
# Check if patch overlaps with the bounding box
|
150 |
+
if not (patch_x_max <= x_min or patch_x_min >= x_max or
|
151 |
+
patch_y_max <= y_min or patch_y_min >= y_max):
|
152 |
+
# Calculate patch index in the flattened grid
|
153 |
+
patch_idx = y_idx * grid_w + x_idx
|
154 |
+
binary_mask[patch_idx] = 1
|
155 |
+
|
156 |
+
return binary_mask
|
157 |
+
|
158 |
+
def token_index_to_coordinates(image_processor, visual_token_index, image_width, image_height):
|
159 |
+
merge_patch_size = image_processor.patch_size * image_processor.merge_size
|
160 |
+
x_index = visual_token_index % (image_width // merge_patch_size)
|
161 |
+
y_index = visual_token_index // (image_width // merge_patch_size)
|
162 |
+
px = x_index * merge_patch_size + merge_patch_size / 2
|
163 |
+
py = y_index * merge_patch_size + merge_patch_size / 2
|
164 |
+
return px, py
|
165 |
+
|
166 |
+
class LazySupervisedDataset(Dataset):
|
167 |
+
def __init__(
|
168 |
+
self,
|
169 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
170 |
+
processor: transformers.ProcessorMixin,
|
171 |
+
data_path: str,
|
172 |
+
data_args,
|
173 |
+
):
|
174 |
+
super().__init__()
|
175 |
+
self.tokenizer = tokenizer
|
176 |
+
self.processor = processor
|
177 |
+
self.list_data_dict = []
|
178 |
+
self.list_image_path = []
|
179 |
+
self.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0]
|
180 |
+
self.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0]
|
181 |
+
self.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
|
182 |
+
|
183 |
+
# Handle multiple JSON files specified in the data_path
|
184 |
+
if "{" in data_path and "}" in data_path:
|
185 |
+
base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups()
|
186 |
+
file_names = file_pattern.split(",")
|
187 |
+
rank0_print(f"Loading {file_names} from {base_path}")
|
188 |
+
data_args.dataset_paths = []
|
189 |
+
for file_name in file_names:
|
190 |
+
data_args.dataset_paths.append(f"{base_path}{file_name}.json")
|
191 |
+
full_path = f"{base_path}{file_name}.json"
|
192 |
+
rank0_print(f"Loading {full_path}")
|
193 |
+
with open(full_path) as file:
|
194 |
+
cur_data_dict = json.load(file)
|
195 |
+
rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}")
|
196 |
+
self.list_data_dict.extend(cur_data_dict)
|
197 |
+
elif data_path.endswith(".yaml"):
|
198 |
+
with open(data_path) as file:
|
199 |
+
yaml_data = yaml.safe_load(file)
|
200 |
+
datasets = yaml_data.get("datasets")
|
201 |
+
# file should be in the format of:
|
202 |
+
# datasets:
|
203 |
+
# - json_path: xxxx1.json
|
204 |
+
# sampling_strategy: first:1000
|
205 |
+
# - json_path: xxxx2.json
|
206 |
+
# sampling_strategy: end:3000
|
207 |
+
# - json_path: xxxx3.json
|
208 |
+
# sampling_strategy: random:999
|
209 |
+
data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
|
210 |
+
for dataset in datasets:
|
211 |
+
json_path = dataset.get("json_path")
|
212 |
+
sampling_strategy = dataset.get("sampling_strategy", "all")
|
213 |
+
images_folder = dataset.get("images_folder")
|
214 |
+
sampling_number = None
|
215 |
+
|
216 |
+
rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy")
|
217 |
+
|
218 |
+
if json_path.endswith(".jsonl"):
|
219 |
+
cur_data_dict = []
|
220 |
+
with open(json_path) as json_file:
|
221 |
+
for line in json_file:
|
222 |
+
cur_data_dict.append(json.loads(line.strip()))
|
223 |
+
elif json_path.endswith(".json"):
|
224 |
+
# NOTE: we only use json_path with .json now
|
225 |
+
# Handle the images_folder in yaml
|
226 |
+
with open(json_path) as json_file:
|
227 |
+
cur_data_dict = json.load(json_file)
|
228 |
+
else:
|
229 |
+
raise ValueError(f"Unsupported file type: {json_path}")
|
230 |
+
|
231 |
+
if ":" in sampling_strategy:
|
232 |
+
sampling_strategy, sampling_number = sampling_strategy.split(":")
|
233 |
+
if "%" in sampling_number:
|
234 |
+
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
|
235 |
+
else:
|
236 |
+
sampling_number = int(sampling_number)
|
237 |
+
|
238 |
+
# Apply the sampling strategy
|
239 |
+
if sampling_strategy == "first" and sampling_number is not None:
|
240 |
+
cur_data_dict = cur_data_dict[:sampling_number]
|
241 |
+
elif sampling_strategy == "end" and sampling_number is not None:
|
242 |
+
cur_data_dict = cur_data_dict[-sampling_number:]
|
243 |
+
elif sampling_strategy == "random" and sampling_number is not None:
|
244 |
+
random.shuffle(cur_data_dict)
|
245 |
+
cur_data_dict = cur_data_dict[:sampling_number]
|
246 |
+
|
247 |
+
rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
|
248 |
+
self.list_data_dict.extend(cur_data_dict)
|
249 |
+
self.list_image_path.extend([images_folder] * len(cur_data_dict))
|
250 |
+
else:
|
251 |
+
data_args.dataset_paths = [data_path]
|
252 |
+
rank0_print(f"Loading {data_path}")
|
253 |
+
with open(data_path) as file:
|
254 |
+
cur_data_dict = json.load(file)
|
255 |
+
rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}")
|
256 |
+
self.list_data_dict.extend(cur_data_dict)
|
257 |
+
self.list_image_path.extend([""] * len(cur_data_dict)) # NOTE: the image subfolder is empty...
|
258 |
+
|
259 |
+
rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
|
260 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
261 |
+
self.tokenizer = tokenizer
|
262 |
+
self.data_args = data_args
|
263 |
+
|
264 |
+
def __len__(self):
|
265 |
+
return len(self.list_data_dict)
|
266 |
+
|
267 |
+
@property
|
268 |
+
def lengths(self):
|
269 |
+
length_list = []
|
270 |
+
for sample in self.list_data_dict:
|
271 |
+
img_tokens = (
|
272 |
+
1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
|
273 |
+
)
|
274 |
+
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
|
275 |
+
return length_list
|
276 |
+
|
277 |
+
@property
|
278 |
+
def modality_lengths(self):
|
279 |
+
length_list = []
|
280 |
+
for sample in self.list_data_dict:
|
281 |
+
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
|
282 |
+
assert cur_len > 0, f"Conversation length is 0 for {sample}"
|
283 |
+
|
284 |
+
img_tokens = (
|
285 |
+
1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
|
286 |
+
)
|
287 |
+
|
288 |
+
if "image" in sample or "video" in sample or self.data_args.early_mix_text:
|
289 |
+
length_list.append(cur_len + img_tokens)
|
290 |
+
else:
|
291 |
+
length_list.append(-cur_len)
|
292 |
+
return length_list
|
293 |
+
|
294 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
295 |
+
sample = self._get_item(i)
|
296 |
+
if sample is None:
|
297 |
+
new_index = random.randint(0, len(self.list_data_dict) - 1)
|
298 |
+
return self.__getitem__(new_index)
|
299 |
+
else:
|
300 |
+
return sample
|
301 |
+
try:
|
302 |
+
sample = self._get_item(i)
|
303 |
+
if sample is None:
|
304 |
+
new_index = random.randint(0, len(self.list_data_dict) - 1)
|
305 |
+
return self.__getitem__(new_index)
|
306 |
+
except Exception as e:
|
307 |
+
print(f"Failed to fetch sample {i}. Exception:", e)
|
308 |
+
new_index = random.randint(0, len(self.list_data_dict) - 1)
|
309 |
+
return self.__getitem__(new_index)
|
310 |
+
return sample
|
311 |
+
|
312 |
+
def _get_item(self, i) -> Dict[str, torch.Tensor]:
|
313 |
+
sources = self.list_data_dict[i]
|
314 |
+
image_path = os.path.join(self.data_args.image_folder, self.list_image_path[i])
|
315 |
+
|
316 |
+
if "image" in sources:
|
317 |
+
image_file = self.list_data_dict[i]["image"]
|
318 |
+
if type(image_file) is list:
|
319 |
+
image_list = [os.path.join(image_path, image_file) for image_file in image_file]
|
320 |
+
else:
|
321 |
+
image_list = [os.path.join(image_path, image_file)]
|
322 |
+
|
323 |
+
sources = copy.deepcopy(sources["conversations"])
|
324 |
+
elif "video" in sources:
|
325 |
+
raise NotImplementedError("Video is not supported for Qwen2VL")
|
326 |
+
else:
|
327 |
+
sources = copy.deepcopy(sources["conversations"])
|
328 |
+
|
329 |
+
item_id = self.list_data_dict[i].get("id", i)
|
330 |
+
|
331 |
+
data_dict = self.preprocess_qwen2vl(sources, self.tokenizer, self.processor, image_list, id=item_id)
|
332 |
+
if isinstance(i, int):
|
333 |
+
data_dict = {
|
334 |
+
"input_ids": data_dict["input_ids"][0],
|
335 |
+
"labels": data_dict["labels"][0],
|
336 |
+
"coordinates": data_dict["coordinates"][0],
|
337 |
+
"visual_token_indices_of_coordinates": data_dict["visual_token_indices_of_coordinates"][0],
|
338 |
+
"pixel_values": data_dict["pixel_values"],
|
339 |
+
"image_grid_thw": data_dict["image_grid_thw"],
|
340 |
+
"multi_patch_labels": data_dict["multi_patch_labels"][0], # add multi_patch_labels
|
341 |
+
}
|
342 |
+
|
343 |
+
data_dict["id"] = item_id
|
344 |
+
|
345 |
+
# return None if the input_ids is longer than the model_max_length
|
346 |
+
n_image_tokens = (
|
347 |
+
data_dict["image_grid_thw"][0][0] *
|
348 |
+
data_dict["image_grid_thw"][0][1] *
|
349 |
+
data_dict["image_grid_thw"][0][2] /
|
350 |
+
self.processor.image_processor.merge_size /
|
351 |
+
self.processor.image_processor.merge_size
|
352 |
+
)
|
353 |
+
if (len(data_dict["input_ids"]) + n_image_tokens) > self.tokenizer.model_max_length:
|
354 |
+
rank0_print(f"=== Removed data_dict {i} because it is longer than the model_max_length: {len(data_dict['input_ids'])} + {n_image_tokens} > {self.tokenizer.model_max_length}")
|
355 |
+
return None
|
356 |
+
|
357 |
+
return data_dict
|
358 |
+
|
359 |
+
def preprocess_qwen2vl(
|
360 |
+
self,
|
361 |
+
source, # conversations
|
362 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
363 |
+
processor: transformers.ProcessorMixin,
|
364 |
+
image: list,
|
365 |
+
system_message: str = grounding_system_message,
|
366 |
+
agent_mode: bool = True,
|
367 |
+
chat_template: str = chat_template,
|
368 |
+
assistant_template: str = assistant_template,
|
369 |
+
id: int = None,
|
370 |
+
) -> Dict:
|
371 |
+
roles = {"human": "user", "gpt": "assistant", "system": "system"}
|
372 |
+
assistant_template = assistant_template if agent_mode else chat_template
|
373 |
+
processor.tokenizer = tokenizer
|
374 |
+
assert tokenizer.additional_special_tokens == ADDITIONAL_SPECIAL_TOKENS
|
375 |
+
|
376 |
+
# Apply prompt templates
|
377 |
+
pixel_values, image_grid_thw = None, None
|
378 |
+
|
379 |
+
input_id, target = [], []
|
380 |
+
coordinates = []
|
381 |
+
visual_token_indices_of_coordinates = []
|
382 |
+
multi_patch_labels = []
|
383 |
+
|
384 |
+
image_list = []
|
385 |
+
image_index = 0
|
386 |
+
|
387 |
+
## prepare the system message
|
388 |
+
if roles[source[0]["from"]] == "system":
|
389 |
+
system_message = source[0]["value"]
|
390 |
+
source = source[1:self.data_args.max_conv_turns]
|
391 |
+
# else: use the constant system message
|
392 |
+
system_input_id = tokenizer.apply_chat_template(
|
393 |
+
conversation=[{"role": "system", "content": [{"type": "text", "text": system_message}]}],
|
394 |
+
chat_template=chat_template,
|
395 |
+
)
|
396 |
+
input_id += system_input_id
|
397 |
+
target += [IGNORE_INDEX] * len(system_input_id)
|
398 |
+
|
399 |
+
## prepare user-assistant conversation
|
400 |
+
for conv in source:
|
401 |
+
# regularize the conversation format
|
402 |
+
try:
|
403 |
+
role = conv["role"]
|
404 |
+
content = conv["content"]
|
405 |
+
except Exception:
|
406 |
+
role = conv["from"]
|
407 |
+
content = conv["value"]
|
408 |
+
role = roles.get(role, role)
|
409 |
+
|
410 |
+
# Count the number of <image> tokens in the content
|
411 |
+
image_count = content.count(DEFAULT_IMAGE_TOKEN)
|
412 |
+
if image_count > 0:
|
413 |
+
assert role == "user", "Images are only supported for user messages"
|
414 |
+
# include image information regarding to current conversation turn
|
415 |
+
image_placeholders = []
|
416 |
+
for _ in range(image_count):
|
417 |
+
image_placeholders.append({
|
418 |
+
"type": "image",
|
419 |
+
"image": image[image_index],
|
420 |
+
"min_pixels": self.processor.image_processor.min_pixels,
|
421 |
+
"max_pixels": self.processor.image_processor.max_pixels,
|
422 |
+
})
|
423 |
+
image_index += 1
|
424 |
+
|
425 |
+
content = content.replace(DEFAULT_IMAGE_TOKEN, "")
|
426 |
+
conv = {"role": role, "content": image_placeholders + [{"type": "text", "text": content}]}
|
427 |
+
|
428 |
+
image_inputs, _ = process_vision_info([conv]) # list of PIL.Image.Image
|
429 |
+
image_list.extend(image_inputs)
|
430 |
+
|
431 |
+
templated_conv = tokenizer.apply_chat_template(
|
432 |
+
conversation=[conv], chat_template=chat_template, tokenize=False
|
433 |
+
)
|
434 |
+
inputs = processor(text=[templated_conv], images=image_inputs, return_tensors="pt")
|
435 |
+
|
436 |
+
if pixel_values is None and image_grid_thw is None:
|
437 |
+
pixel_values = inputs["pixel_values"]
|
438 |
+
image_grid_thw = inputs["image_grid_thw"]
|
439 |
+
else:
|
440 |
+
pixel_values = torch.concat([pixel_values, inputs["pixel_values"]], dim=0)
|
441 |
+
image_grid_thw = torch.concat([image_grid_thw, inputs["image_grid_thw"]], dim=0)
|
442 |
+
else:
|
443 |
+
if role in ["user", "system"]:
|
444 |
+
conv = {"role": role, "content": [{"type": "text", "text": content}]}
|
445 |
+
else: # assistant
|
446 |
+
conv = {
|
447 |
+
"role": role,
|
448 |
+
"content": [{"type": "text", "text": content}],
|
449 |
+
"recipient": conv.get("recipient", "os"),
|
450 |
+
"end_turn": conv.get("end_turn", True),
|
451 |
+
"bbox_gt": conv.get("bbox_gt", None),
|
452 |
+
}
|
453 |
+
if conv["recipient"] == "os":
|
454 |
+
if len(image_inputs) == 0:
|
455 |
+
raise ValueError("No image found for visual grounding")
|
456 |
+
# replace the coordinates with the special tokens
|
457 |
+
text, coord = reformat_coordinates(conv["content"][0]["text"])
|
458 |
+
conv["content"][0]["text"] = text
|
459 |
+
# rank0_print(f"coord: {coord}")
|
460 |
+
|
461 |
+
# get the visual token indices of the coordinates
|
462 |
+
coordinates.extend(coord)
|
463 |
+
for (point_x, point_y) in coord:
|
464 |
+
visual_token_index = get_token_index(
|
465 |
+
processor.image_processor,
|
466 |
+
image_list,
|
467 |
+
point_x,
|
468 |
+
point_y
|
469 |
+
)
|
470 |
+
# px, py = token_index_to_coordinates(
|
471 |
+
# processor.image_processor,
|
472 |
+
# visual_token_index,
|
473 |
+
# image_list[0].size[0], # make sure the size here is after qwen2vl processing
|
474 |
+
# image_list[0].size[1]
|
475 |
+
# )
|
476 |
+
# rank0_print(f"estimated px: {px}, py: {py}")
|
477 |
+
visual_token_indices_of_coordinates.append(visual_token_index)
|
478 |
+
|
479 |
+
if conv["bbox_gt"] is not None:
|
480 |
+
patch_mask = get_multi_patch_labels(
|
481 |
+
processor.image_processor,
|
482 |
+
image_list,
|
483 |
+
conv["bbox_gt"]
|
484 |
+
)
|
485 |
+
multi_patch_labels.append(patch_mask)
|
486 |
+
|
487 |
+
templated_conv = tokenizer.apply_chat_template(
|
488 |
+
conversation=[conv],
|
489 |
+
chat_template=assistant_template,
|
490 |
+
tokenize=False,
|
491 |
+
)
|
492 |
+
inputs = processor(text=[templated_conv], return_tensors="pt")
|
493 |
+
|
494 |
+
encode_id = inputs.input_ids[0].tolist()
|
495 |
+
|
496 |
+
input_id += encode_id
|
497 |
+
if role in ["user", "system"]:
|
498 |
+
target += [IGNORE_INDEX] * len(encode_id)
|
499 |
+
else:
|
500 |
+
target += encode_id
|
501 |
+
|
502 |
+
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
|
503 |
+
|
504 |
+
# make the labels of all pointer_end_token_id to be IGNORE_INDEX
|
505 |
+
target = [IGNORE_INDEX if token == self.pointer_end_token_id else token for token in target]
|
506 |
+
|
507 |
+
input_ids = torch.tensor([input_id], dtype=torch.long)
|
508 |
+
targets = torch.tensor([target], dtype=torch.long)
|
509 |
+
visual_token_indices_of_coordinates = torch.tensor([visual_token_indices_of_coordinates], dtype=torch.long) if len(visual_token_indices_of_coordinates) > 0 else [None]
|
510 |
+
coordinates = [coordinates] if len(coordinates) > 0 else [None]
|
511 |
+
|
512 |
+
# process multi_patch_labels
|
513 |
+
if len(multi_patch_labels) > 0:
|
514 |
+
multi_patch_labels = [torch.stack(multi_patch_labels)]
|
515 |
+
else:
|
516 |
+
multi_patch_labels = [None]
|
517 |
+
|
518 |
+
data_dict = {
|
519 |
+
"input_ids": input_ids, # tensor(bs x seq_len)
|
520 |
+
"labels": targets, # tensor(bs x seq_len)
|
521 |
+
}
|
522 |
+
|
523 |
+
if pixel_values is not None:
|
524 |
+
data_dict["pixel_values"] = pixel_values
|
525 |
+
data_dict["image_grid_thw"] = image_grid_thw
|
526 |
+
|
527 |
+
# if len(coordinates[0]) != len(visual_token_indices_of_coordinates[0]):
|
528 |
+
# raise ValueError(f"The number of coordinates ({len(coordinates[0])}) does not match the number of image token indices ({len(visual_token_indices_of_coordinates[0])})")
|
529 |
+
data_dict["coordinates"] = coordinates
|
530 |
+
data_dict["visual_token_indices_of_coordinates"] = visual_token_indices_of_coordinates
|
531 |
+
data_dict["multi_patch_labels"] = multi_patch_labels
|
532 |
+
|
533 |
+
return data_dict
|
gui_actor/inference.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import os
|
5 |
+
from qwen_vl_utils import process_vision_info
|
6 |
+
from transformers import (
|
7 |
+
Qwen2VLForConditionalGeneration,
|
8 |
+
LogitsProcessor,
|
9 |
+
LogitsProcessorList,
|
10 |
+
AutoModelForCausalLM,
|
11 |
+
AutoTokenizer
|
12 |
+
)
|
13 |
+
from gui_actor.constants import (
|
14 |
+
DEFAULT_POINTER_END_TOKEN,
|
15 |
+
DEFAULT_POINTER_PAD_TOKEN,
|
16 |
+
chat_template
|
17 |
+
)
|
18 |
+
|
19 |
+
class ForceFollowTokensLogitsProcessor(LogitsProcessor):
|
20 |
+
"""
|
21 |
+
Forces tokens B (pointer_pad_token) and C (pointer_end_token) to follow token A (pointer_start_token).
|
22 |
+
Whenever token_a_id is generated, enqueue the forced_sequence (e.g. [B, C]).
|
23 |
+
As long as forced tokens remain in the queue, force them in the output.
|
24 |
+
"""
|
25 |
+
def __init__(self, token_a_id, forced_sequence=[DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN]):
|
26 |
+
super().__init__()
|
27 |
+
self.token_a_id = token_a_id
|
28 |
+
self.forced_sequence = forced_sequence # list of token IDs, e.g. [B_id, C_id]
|
29 |
+
self.force_queue = [] # holds the tokens we still need to force
|
30 |
+
|
31 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
32 |
+
"""
|
33 |
+
Called at each decoding step to modify `scores`.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
input_ids: shape (batch_size, seq_len). The already-decoded tokens.
|
37 |
+
scores: shape (batch_size, vocab_size). Model logits for the next token.
|
38 |
+
"""
|
39 |
+
batch_size = input_ids.shape[0]
|
40 |
+
if batch_size > 1:
|
41 |
+
raise NotImplementedError("Batch size must be 1 for this logits processor.")
|
42 |
+
|
43 |
+
# We assume batch_size=1 for simplicity; if you have multiple sequences,
|
44 |
+
# you'll need to adapt the logic to handle each item in the batch.
|
45 |
+
last_token_id = input_ids[0, -1].item()
|
46 |
+
|
47 |
+
# If the last token was A, enqueue B and C
|
48 |
+
if last_token_id == self.token_a_id:
|
49 |
+
self.force_queue.extend(self.forced_sequence)
|
50 |
+
|
51 |
+
# If we have forced tokens waiting in the queue, override the distribution
|
52 |
+
if len(self.force_queue) > 0:
|
53 |
+
forced_token = self.force_queue.pop(0) # next token to force
|
54 |
+
# Create a mask of -inf for all tokens except the forced one
|
55 |
+
new_scores = torch.full_like(scores, float('-inf'))
|
56 |
+
new_scores[0, forced_token] = 0.0 # log prob = 0 => prob = 1
|
57 |
+
return new_scores
|
58 |
+
|
59 |
+
# Otherwise, return scores unmodified
|
60 |
+
return scores
|
61 |
+
|
62 |
+
|
63 |
+
def get_prediction_region_point(attn_scores, n_width, n_height, top_n=30, activation_threshold=0.3, return_all_regions=True, rect_center=False):
|
64 |
+
"""
|
65 |
+
1. Select activated patches
|
66 |
+
2. Divide connected patches into different regions
|
67 |
+
3. Calculate the average activation value for each region
|
68 |
+
4. Select the region with the highest average activation value
|
69 |
+
5. Return the center point of that region as the final prediction point
|
70 |
+
"""
|
71 |
+
|
72 |
+
# Get patches with activation values greater than a certain proportion of the maximum activation value as activated patches
|
73 |
+
# Get the highest activation value and threshold
|
74 |
+
max_score = attn_scores[0].max().item()
|
75 |
+
threshold = max_score * activation_threshold
|
76 |
+
# Select all patches above the threshold
|
77 |
+
mask = attn_scores[0] > threshold
|
78 |
+
valid_indices = torch.nonzero(mask).squeeze(-1)
|
79 |
+
topk_values = attn_scores[0][valid_indices]
|
80 |
+
topk_indices = valid_indices
|
81 |
+
|
82 |
+
# Convert indices to 2D coordinates
|
83 |
+
topk_coords = []
|
84 |
+
for idx in topk_indices.tolist():
|
85 |
+
y = idx // n_width
|
86 |
+
x = idx % n_width
|
87 |
+
topk_coords.append((y, x, idx))
|
88 |
+
|
89 |
+
# Divide into connected regions
|
90 |
+
regions = []
|
91 |
+
visited = set()
|
92 |
+
for i, (y, x, idx) in enumerate(topk_coords):
|
93 |
+
if idx in visited:
|
94 |
+
continue
|
95 |
+
|
96 |
+
# Start a new region
|
97 |
+
region = [(y, x, idx, topk_values[i].item())]
|
98 |
+
visited.add(idx)
|
99 |
+
queue = [(y, x, idx, topk_values[i].item())]
|
100 |
+
|
101 |
+
# BFS to find connected points
|
102 |
+
while queue:
|
103 |
+
cy, cx, c_idx, c_val = queue.pop(0)
|
104 |
+
|
105 |
+
# Check 4 adjacent directions
|
106 |
+
for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
|
107 |
+
ny, nx = cy + dy, cx + dx
|
108 |
+
n_idx = ny * n_width + nx
|
109 |
+
|
110 |
+
# Check if this adjacent point is in the topk list
|
111 |
+
for j, (ty, tx, t_idx) in enumerate(topk_coords):
|
112 |
+
if ty == ny and tx == nx and t_idx not in visited:
|
113 |
+
visited.add(t_idx)
|
114 |
+
region.append((ny, nx, t_idx, topk_values[j].item()))
|
115 |
+
queue.append((ny, nx, t_idx, topk_values[j].item()))
|
116 |
+
|
117 |
+
regions.append(region)
|
118 |
+
|
119 |
+
# Calculate the average activation value for each region
|
120 |
+
region_scores = []
|
121 |
+
region_centers = []
|
122 |
+
region_points = []
|
123 |
+
|
124 |
+
for region in regions:
|
125 |
+
# Calculate average score for the region
|
126 |
+
avg_score = sum(item[3] for item in region) / len(region)
|
127 |
+
region_scores.append(avg_score)
|
128 |
+
|
129 |
+
# Calculate normalized center coordinates for each patch, then take the average
|
130 |
+
normalized_centers = []
|
131 |
+
weights = []
|
132 |
+
y_coords = set()
|
133 |
+
x_coords = set()
|
134 |
+
|
135 |
+
for y, x, _, score in region:
|
136 |
+
# Normalized coordinates of the center point for each patch
|
137 |
+
center_y = (y + 0.5) / n_height
|
138 |
+
center_x = (x + 0.5) / n_width
|
139 |
+
normalized_centers.append((center_x, center_y))
|
140 |
+
weights.append(score)
|
141 |
+
|
142 |
+
y_coords.add(center_y)
|
143 |
+
x_coords.add(center_x)
|
144 |
+
|
145 |
+
region_points.append(normalized_centers)
|
146 |
+
|
147 |
+
# Calculate the average of normalized coordinates as the region center
|
148 |
+
if not rect_center:
|
149 |
+
# Weighted average
|
150 |
+
total_weight = sum(weights)
|
151 |
+
weighted_x = sum(nc[0] * w for nc, w in zip(normalized_centers, weights)) / total_weight
|
152 |
+
weighted_y = sum(nc[1] * w for nc, w in zip(normalized_centers, weights)) / total_weight
|
153 |
+
avg_center_x, avg_center_y = weighted_x, weighted_y
|
154 |
+
# # Simple average
|
155 |
+
# avg_center_x = sum(nc[0] for nc in normalized_centers) / len(normalized_centers)
|
156 |
+
# avg_center_y = sum(nc[1] for nc in normalized_centers) / len(normalized_centers)
|
157 |
+
else:
|
158 |
+
avg_center_x = sum(x_coords) / len(x_coords)
|
159 |
+
avg_center_y = sum(y_coords) / len(y_coords)
|
160 |
+
region_centers.append((avg_center_x, avg_center_y))
|
161 |
+
|
162 |
+
# Select the region with the highest average activation value
|
163 |
+
sorted_indices = sorted(range(len(region_scores)), key=lambda i: region_scores[i], reverse=True)
|
164 |
+
sorted_scores = [region_scores[i] for i in sorted_indices]
|
165 |
+
sorted_centers = [region_centers[i] for i in sorted_indices]
|
166 |
+
sorted_points = [region_points[i] for i in sorted_indices]
|
167 |
+
best_point = sorted_centers[0]
|
168 |
+
|
169 |
+
if return_all_regions:
|
170 |
+
# Outputs:
|
171 |
+
# 1. best_point: the center point of the region with the highest average activation value
|
172 |
+
# 2. sorted_centers: the center points of all regions, sorted by the average activation value in descending order
|
173 |
+
# 3. sorted_scores: the average activation values of all regions, sorted in descending order
|
174 |
+
# 4. sorted_points: the normalized center coordinates of all patches, sorted by the average activation value in descending order
|
175 |
+
return best_point, sorted_centers, sorted_scores, sorted_points
|
176 |
+
else:
|
177 |
+
return best_point
|
178 |
+
|
179 |
+
|
180 |
+
def inference(conversation, model, tokenizer, data_processor, logits_processor=None, use_placeholder=False, topk=5):
|
181 |
+
"""
|
182 |
+
conversation = [
|
183 |
+
{
|
184 |
+
"role": "system",
|
185 |
+
"content": [
|
186 |
+
{
|
187 |
+
"type": "text",
|
188 |
+
"text": grounding_system_message,
|
189 |
+
}
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"role": "user",
|
194 |
+
"content": [
|
195 |
+
{
|
196 |
+
"type": "image",
|
197 |
+
"image": example["image"], # PIL.Image.Image or str to path
|
198 |
+
# "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64,"
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"type": "text",
|
202 |
+
"text": example["instruction"]
|
203 |
+
},
|
204 |
+
],
|
205 |
+
},
|
206 |
+
]
|
207 |
+
"""
|
208 |
+
if logits_processor is None:
|
209 |
+
logits_processor = ForceFollowTokensLogitsProcessor(
|
210 |
+
token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0],
|
211 |
+
forced_sequence=[
|
212 |
+
tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
|
213 |
+
]
|
214 |
+
)
|
215 |
+
|
216 |
+
assiatant_starter = "" if not use_placeholder else "<|im_start|>assistant<|recipient|>os\npyautogui.click(<|pointer_start|><|pointer_pad|><|pointer_end|>)"
|
217 |
+
|
218 |
+
pred = {
|
219 |
+
"output_text": None, # generated text
|
220 |
+
"n_width": None, # number of patch_tokens in width dimension
|
221 |
+
"n_height": None, # number of patch_tokens in height dimension
|
222 |
+
"attn_scores": None, # attention scores over the image patches
|
223 |
+
"topk_points": None, # topk points
|
224 |
+
"topk_values": None, # topk values
|
225 |
+
"topk_points_all": None, # all points
|
226 |
+
}
|
227 |
+
|
228 |
+
# prepare text
|
229 |
+
text = data_processor.apply_chat_template(conversation,
|
230 |
+
tokenize=False,
|
231 |
+
add_generation_prompt=False,
|
232 |
+
chat_template=chat_template
|
233 |
+
)
|
234 |
+
text += assiatant_starter
|
235 |
+
|
236 |
+
# prepare inputs
|
237 |
+
image_inputs, video_inputs = process_vision_info(conversation)
|
238 |
+
inputs = data_processor(text=[text],
|
239 |
+
images=image_inputs,
|
240 |
+
videos=video_inputs,
|
241 |
+
padding=True,
|
242 |
+
return_tensors="pt"
|
243 |
+
)
|
244 |
+
inputs = inputs.to(model.device)
|
245 |
+
|
246 |
+
# generate
|
247 |
+
results = model.generate(**inputs,
|
248 |
+
max_new_tokens=2048 if not use_placeholder else 1,
|
249 |
+
logits_processor=LogitsProcessorList([logits_processor]),
|
250 |
+
return_dict_in_generate=True,
|
251 |
+
output_hidden_states=True
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
# decode the generated ids
|
256 |
+
input_ids = inputs["input_ids"][0]
|
257 |
+
generated_ids = results.sequences[0][len(input_ids):]
|
258 |
+
output_text = tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
|
259 |
+
pred["output_text"] = output_text
|
260 |
+
|
261 |
+
# check if there are <POINTER_TOKEN> is inside the input_ids or generated_ids
|
262 |
+
if use_placeholder:
|
263 |
+
pointer_pad_mask = (inputs["input_ids"][0] == model.config.pointer_pad_token_id) # n_all_input_tokens
|
264 |
+
else:
|
265 |
+
pointer_pad_mask = (generated_ids[:-1] == model.config.pointer_pad_token_id) # seq_len_generated_ids-1
|
266 |
+
|
267 |
+
# if there are no <POINTER_TOKEN> in the input_ids or generated_ids, return the pred
|
268 |
+
if len(pointer_pad_mask) == 0:
|
269 |
+
return pred
|
270 |
+
|
271 |
+
# otherwise, get the coordinate from the action head
|
272 |
+
if use_placeholder:
|
273 |
+
decoder_hidden_states = results.hidden_states[0][-1][0] # n_all_input_tokens, hidden_size
|
274 |
+
else:
|
275 |
+
decoder_hidden_states = [step_hidden_states[-1][0] for step_hidden_states in results.hidden_states[1:]]
|
276 |
+
decoder_hidden_states = torch.cat(decoder_hidden_states, dim=0) # seq_len_generated_ids-1, hidden_size
|
277 |
+
decoder_hidden_states = decoder_hidden_states[pointer_pad_mask] # n_pointer_pad_tokens, hidden_size
|
278 |
+
|
279 |
+
# get the image embeddings as encoder vectors
|
280 |
+
# image_embeds = model.visual(inputs["pixel_values"], grid_thw=inputs["image_grid_thw"]) # n_image_tokens, hidden_size
|
281 |
+
image_mask = (inputs["input_ids"][0] == tokenizer.encode("<|image_pad|>")[0])
|
282 |
+
image_embeds = results.hidden_states[0][0][0][image_mask] # n_image_tokens, hidden_size
|
283 |
+
|
284 |
+
attn_scores, _ = model.multi_patch_pointer_head(image_embeds, decoder_hidden_states)
|
285 |
+
pred["attn_scores"] = attn_scores.tolist()
|
286 |
+
|
287 |
+
_, n_height, n_width = (inputs["image_grid_thw"][0] // model.visual.spatial_merge_size).tolist()
|
288 |
+
pred["n_width"] = n_width
|
289 |
+
pred["n_height"] = n_height
|
290 |
+
|
291 |
+
# get the topk points according to the attention scores
|
292 |
+
best_point, region_points, region_scores, region_points_all = get_prediction_region_point(attn_scores, n_width, n_height, return_all_regions=True, rect_center=False)
|
293 |
+
topk_points = region_points[:topk] if len(region_points) > topk else region_points
|
294 |
+
topk_values = region_scores[:topk] if len(region_scores) > topk else region_scores
|
295 |
+
topk_points_all = region_points_all[:topk] if len(region_points_all) > topk else region_points_all
|
296 |
+
pred["topk_points"] = topk_points
|
297 |
+
pred["topk_values"] = topk_values
|
298 |
+
pred["topk_points_all"] = topk_points_all
|
299 |
+
|
300 |
+
return pred
|
gui_actor/modeling.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration
|
6 |
+
from gui_actor.constants import IGNORE_INDEX
|
7 |
+
from typing import List, Tuple, Union, Optional
|
8 |
+
from gui_actor.trainer import rank0_print
|
9 |
+
|
10 |
+
class QwenVLwithVisionHeadOutputWithPast(Qwen2VLCausalLMOutputWithPast):
|
11 |
+
"""
|
12 |
+
Output class for Qwen2VL with pointer head, extending the base output class.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
16 |
+
Language modeling loss.
|
17 |
+
pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
18 |
+
Vision pointer network loss.
|
19 |
+
pointer_scores (`List[torch.FloatTensor]`, *optional*):
|
20 |
+
Attention scores from the pointer network, one tensor per batch item.
|
21 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
22 |
+
Combined loss (weighted sum of lm_loss and pointer_loss).
|
23 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
24 |
+
Prediction scores from the language modeling head.
|
25 |
+
past_key_values, hidden_states, attentions, rope_deltas:
|
26 |
+
Same as parent class.
|
27 |
+
"""
|
28 |
+
def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs):
|
29 |
+
super().__init__(*args, **kwargs)
|
30 |
+
self.lm_loss = lm_loss
|
31 |
+
self.pointer_loss = pointer_loss
|
32 |
+
self.pointer_scores = pointer_scores
|
33 |
+
|
34 |
+
|
35 |
+
class VisionHead_MultiPatch(nn.Module):
|
36 |
+
def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1):
|
37 |
+
super().__init__()
|
38 |
+
self.d_model = d_model
|
39 |
+
|
40 |
+
# Note: We omit additional normalization here because Qwen2VL
|
41 |
+
# already normalizes hidden states using RMSNorm.
|
42 |
+
self.projection_enc = nn.Sequential(
|
43 |
+
nn.Linear(d_model, projection_dim),
|
44 |
+
nn.GELU(),
|
45 |
+
nn.Linear(projection_dim, d_model)
|
46 |
+
)
|
47 |
+
self.projection_dec = nn.Sequential(
|
48 |
+
nn.Linear(d_model, projection_dim),
|
49 |
+
nn.GELU(),
|
50 |
+
nn.Linear(projection_dim, d_model)
|
51 |
+
)
|
52 |
+
|
53 |
+
# Add self-attention layer for visual features
|
54 |
+
self.self_attention = nn.MultiheadAttention(
|
55 |
+
embed_dim=d_model,
|
56 |
+
num_heads=num_attention_heads,
|
57 |
+
dropout=dropout_rate,
|
58 |
+
batch_first=True
|
59 |
+
)
|
60 |
+
|
61 |
+
# Layer normalization and residual connection
|
62 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
63 |
+
self.dropout = nn.Dropout(dropout_rate)
|
64 |
+
|
65 |
+
def forward(self,
|
66 |
+
hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size
|
67 |
+
hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample
|
68 |
+
labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox
|
69 |
+
do_single_patch: bool = False,
|
70 |
+
):
|
71 |
+
|
72 |
+
enc_input = hidden_state_enc.unsqueeze(0)
|
73 |
+
attn_output, _ = self.self_attention(
|
74 |
+
query=enc_input,
|
75 |
+
key=enc_input,
|
76 |
+
value=enc_input,
|
77 |
+
# attn_mask=attention_mask,
|
78 |
+
need_weights=False
|
79 |
+
)
|
80 |
+
# Residual connection and layer normalization
|
81 |
+
hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output))
|
82 |
+
# Remove batch dimension
|
83 |
+
hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model]
|
84 |
+
|
85 |
+
# Apply the projection networks.
|
86 |
+
proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model]
|
87 |
+
proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model]
|
88 |
+
|
89 |
+
# Compute scaled dot-product attention scores.
|
90 |
+
# Scaling by sqrt(d_model) is critical regardless of variable n_enc.
|
91 |
+
scaling = self.d_model ** 0.5
|
92 |
+
patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc]
|
93 |
+
|
94 |
+
# Softmax normalization is applied along the encoder dimension.
|
95 |
+
attn_weights = F.softmax(patch_logits, dim=-1)
|
96 |
+
|
97 |
+
loss = None
|
98 |
+
if (labels is not None) and (not do_single_patch):
|
99 |
+
epsilon = 1e-8
|
100 |
+
labels_float = labels.float()
|
101 |
+
# Normalize each row to get target probability distribution
|
102 |
+
target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon)
|
103 |
+
|
104 |
+
# Apply log_softmax to logits
|
105 |
+
pred_log_probs = F.log_softmax(patch_logits, dim=-1)
|
106 |
+
# Use KL divergence as loss
|
107 |
+
loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean')
|
108 |
+
|
109 |
+
if do_single_patch and (labels is not None):
|
110 |
+
loss = F.cross_entropy(attn_scores, labels)
|
111 |
+
|
112 |
+
return attn_weights, loss
|
113 |
+
|
114 |
+
|
115 |
+
class Qwen2VLForConditionalGenerationWithPointer(Qwen2VLForConditionalGeneration):
|
116 |
+
def __init__(self, *args, **kwargs):
|
117 |
+
super().__init__(*args, **kwargs)
|
118 |
+
self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size)
|
119 |
+
self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0)
|
120 |
+
self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0)
|
121 |
+
self.post_init()
|
122 |
+
|
123 |
+
def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight):
|
124 |
+
self.pointer_loss_weight = pointer_loss_weight
|
125 |
+
self.lm_loss_weight = lm_loss_weight
|
126 |
+
|
127 |
+
def forward(self,
|
128 |
+
input_ids: torch.LongTensor = None, # (batch_size, seq_len)
|
129 |
+
attention_mask: Optional[torch.Tensor] = None,
|
130 |
+
position_ids: Optional[torch.LongTensor] = None,
|
131 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
132 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
133 |
+
labels: Optional[torch.LongTensor] = None,
|
134 |
+
use_cache: Optional[bool] = None,
|
135 |
+
output_attentions: Optional[bool] = None,
|
136 |
+
output_hidden_states: Optional[bool] = None,
|
137 |
+
return_dict: Optional[bool] = None,
|
138 |
+
pixel_values: Optional[torch.Tensor] = None,
|
139 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
140 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
141 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
142 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
143 |
+
cache_position: Optional[torch.LongTensor] = None,
|
144 |
+
# Grounding
|
145 |
+
visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token
|
146 |
+
multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox
|
147 |
+
if_multi_patch: bool = True,
|
148 |
+
coordinates: Optional[List[Tuple[float, float]]] = None,
|
149 |
+
verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]:
|
150 |
+
|
151 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
152 |
+
output_hidden_states = (
|
153 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
154 |
+
)
|
155 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
156 |
+
|
157 |
+
if verbose:
|
158 |
+
rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...")
|
159 |
+
rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...")
|
160 |
+
rank0_print(f"pixel_values: {pixel_values.shape}")
|
161 |
+
rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}")
|
162 |
+
rank0_print(f"coordinates: {coordinates}")
|
163 |
+
rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}")
|
164 |
+
rank0_print(f"return_dict: {return_dict}")
|
165 |
+
|
166 |
+
if inputs_embeds is None:
|
167 |
+
inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model)
|
168 |
+
if pixel_values is not None:
|
169 |
+
pixel_values = pixel_values.type(self.visual.dtype)
|
170 |
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
171 |
+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
172 |
+
n_image_features = image_embeds.shape[0]
|
173 |
+
if n_image_tokens != n_image_features:
|
174 |
+
raise ValueError(
|
175 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
176 |
+
)
|
177 |
+
image_mask = (
|
178 |
+
(input_ids == self.config.image_token_id)
|
179 |
+
.unsqueeze(-1)
|
180 |
+
.expand_as(inputs_embeds)
|
181 |
+
.to(inputs_embeds.device)
|
182 |
+
)
|
183 |
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
184 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
185 |
+
|
186 |
+
if pixel_values_videos is not None:
|
187 |
+
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
188 |
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
189 |
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
190 |
+
n_video_features = video_embeds.shape[0]
|
191 |
+
if n_video_tokens != n_video_features:
|
192 |
+
raise ValueError(
|
193 |
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
194 |
+
)
|
195 |
+
video_mask = (
|
196 |
+
(input_ids == self.config.video_token_id)
|
197 |
+
.unsqueeze(-1)
|
198 |
+
.expand_as(inputs_embeds)
|
199 |
+
.to(inputs_embeds.device)
|
200 |
+
)
|
201 |
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
202 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
203 |
+
|
204 |
+
if attention_mask is not None:
|
205 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
206 |
+
|
207 |
+
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
208 |
+
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
209 |
+
# calculate RoPE index once per generation in the pre-fill stage only
|
210 |
+
if (
|
211 |
+
(cache_position is not None and cache_position[0] == 0)
|
212 |
+
or self.rope_deltas is None
|
213 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
214 |
+
):
|
215 |
+
position_ids, rope_deltas = self.get_rope_index(
|
216 |
+
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
217 |
+
)
|
218 |
+
self.rope_deltas = rope_deltas
|
219 |
+
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
220 |
+
else:
|
221 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
222 |
+
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
223 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
224 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
225 |
+
if cache_position is not None: # otherwise `deltas` is an int `0`
|
226 |
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
|
227 |
+
delta = delta.to(position_ids.device)
|
228 |
+
position_ids = position_ids.add(delta)
|
229 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
230 |
+
|
231 |
+
outputs = self.model(
|
232 |
+
input_ids=None,
|
233 |
+
position_ids=position_ids,
|
234 |
+
attention_mask=attention_mask,
|
235 |
+
past_key_values=past_key_values,
|
236 |
+
inputs_embeds=inputs_embeds,
|
237 |
+
use_cache=use_cache,
|
238 |
+
output_attentions=output_attentions,
|
239 |
+
output_hidden_states=output_hidden_states,
|
240 |
+
return_dict=return_dict,
|
241 |
+
cache_position=cache_position,
|
242 |
+
)
|
243 |
+
|
244 |
+
hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model)
|
245 |
+
logits = self.lm_head(hidden_states)
|
246 |
+
|
247 |
+
lm_loss = None
|
248 |
+
if labels is not None and self.lm_loss_weight > 0:
|
249 |
+
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
250 |
+
logits = logits.float()
|
251 |
+
# Shift so that tokens < n predict n
|
252 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
253 |
+
shift_labels = labels[..., 1:].contiguous()
|
254 |
+
# Flatten the tokens
|
255 |
+
loss_fct = nn.CrossEntropyLoss()
|
256 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
257 |
+
shift_labels = shift_labels.view(-1)
|
258 |
+
# Enable model parallelism
|
259 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
260 |
+
lm_loss = loss_fct(shift_logits, shift_labels)
|
261 |
+
|
262 |
+
|
263 |
+
# If vision supervision is requested, process the action head.
|
264 |
+
pointer_loss = None
|
265 |
+
pointer_scores = []
|
266 |
+
if visual_token_indices_of_coordinates is not None:
|
267 |
+
batch_size = input_ids.shape[0]
|
268 |
+
pointer_losses = []
|
269 |
+
|
270 |
+
# Process each sample individually because the number of visual and target tokens may vary.
|
271 |
+
for i in range(batch_size):
|
272 |
+
dummy_target = False
|
273 |
+
|
274 |
+
# Get the token ids and corresponding hidden states for sample i.
|
275 |
+
token_ids = input_ids[i] # shape: (seq_length,)
|
276 |
+
hs = hidden_states[i] # shape: (seq_length, d_model)
|
277 |
+
|
278 |
+
# Identify visual tokens indices.
|
279 |
+
visual_mask = (token_ids == self.config.image_token_id)
|
280 |
+
visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,)
|
281 |
+
|
282 |
+
# Identify target tokens (the ones that should attend to visual features).
|
283 |
+
target_mask = (token_ids == self.config.pointer_pad_token_id)
|
284 |
+
target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
|
285 |
+
|
286 |
+
# If either visual or target tokens are missing, skip this sample.
|
287 |
+
if visual_indices.numel() == 0:
|
288 |
+
raise ValueError(f"No visual or target tokens found for sample {i}.")
|
289 |
+
if target_indices.numel() == 0:
|
290 |
+
target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token
|
291 |
+
gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth
|
292 |
+
if if_multi_patch: # task the first 4 visual tokens as the ground truth
|
293 |
+
sample_labels = torch.zeros_like(visual_indices).unsqueeze(0)
|
294 |
+
sample_labels[0][:4] = 1
|
295 |
+
dummy_target = True
|
296 |
+
else:
|
297 |
+
# For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,)
|
298 |
+
# where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token.
|
299 |
+
gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,)
|
300 |
+
if if_multi_patch:
|
301 |
+
sample_labels = multi_patch_labels[i]
|
302 |
+
|
303 |
+
# Gather the corresponding hidden state representations.
|
304 |
+
# visual_hidden = hs[visual_indices] # shape: (n_visual, d_model)
|
305 |
+
visual_embeds = inputs_embeds[i][visual_indices]
|
306 |
+
target_hidden = hs[target_indices] # shape: (n_target, d_model)
|
307 |
+
|
308 |
+
# Calculate loss for multi-patch mode
|
309 |
+
if if_multi_patch:
|
310 |
+
# Ensure the number of targets matches between sample and labels
|
311 |
+
if sample_labels.shape[0] != target_indices.shape[0]:
|
312 |
+
raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens")
|
313 |
+
|
314 |
+
# Process using VisionHead_MultiPatch
|
315 |
+
attn_scores, loss_v = self.multi_patch_pointer_head(
|
316 |
+
visual_embeds,
|
317 |
+
target_hidden,
|
318 |
+
labels=sample_labels
|
319 |
+
)
|
320 |
+
|
321 |
+
else:
|
322 |
+
# Deprecated branch - single patch mode is no longer used
|
323 |
+
# Run the action head to compute the attention (from target tokens to visual tokens) and its loss.
|
324 |
+
attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt)
|
325 |
+
|
326 |
+
pointer_scores.append(attn_scores.detach().cpu())
|
327 |
+
|
328 |
+
pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v)
|
329 |
+
|
330 |
+
pointer_loss = torch.stack(pointer_losses).mean()
|
331 |
+
|
332 |
+
# Combine the LM loss and vision loss using the provided loss weights.
|
333 |
+
|
334 |
+
if lm_loss is None:
|
335 |
+
total_loss = pointer_loss
|
336 |
+
elif pointer_loss is None:
|
337 |
+
total_loss = lm_loss
|
338 |
+
else:
|
339 |
+
total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss
|
340 |
+
|
341 |
+
if return_dict:
|
342 |
+
return QwenVLwithVisionHeadOutputWithPast(
|
343 |
+
lm_loss=lm_loss,
|
344 |
+
pointer_loss=pointer_loss,
|
345 |
+
pointer_scores=pointer_scores,
|
346 |
+
loss=total_loss,
|
347 |
+
logits=logits,
|
348 |
+
past_key_values=outputs.past_key_values,
|
349 |
+
hidden_states=outputs.hidden_states,
|
350 |
+
attentions=outputs.attentions,
|
351 |
+
rope_deltas=self.rope_deltas,
|
352 |
+
)
|
353 |
+
else:
|
354 |
+
# When labels are provided, parent's forward returns a tuple with loss as the first element.
|
355 |
+
if labels is not None:
|
356 |
+
# Replace the LM loss with the combined loss.
|
357 |
+
output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:]
|
358 |
+
print(f"returning: total_loss, logits, pointer_scores, ...")
|
359 |
+
return (total_loss,) + output if total_loss is not None else output
|
360 |
+
else:
|
361 |
+
return outputs
|
gui_actor/modeling_qwen25vl.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from typing import List, Tuple, Union, Optional
|
6 |
+
|
7 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
8 |
+
Qwen2_5_VLCausalLMOutputWithPast,
|
9 |
+
Qwen2_5_VLForConditionalGeneration,
|
10 |
+
)
|
11 |
+
from gui_actor.constants import IGNORE_INDEX
|
12 |
+
from gui_actor.trainer import rank0_print
|
13 |
+
|
14 |
+
|
15 |
+
def _get_token_embedding_layer(hf_model: nn.Module) -> nn.Module:
|
16 |
+
"""
|
17 |
+
Robustly locate the token embedding layer across HF versions.
|
18 |
+
"""
|
19 |
+
if hasattr(hf_model, "get_input_embeddings") and callable(hf_model.get_input_embeddings):
|
20 |
+
return hf_model.get_input_embeddings()
|
21 |
+
# Fallbacks (shouldn't be needed on recent transformers, but safe to keep)
|
22 |
+
lm = getattr(hf_model, "language_model", None)
|
23 |
+
if lm is not None and hasattr(lm, "embed_tokens"):
|
24 |
+
return lm.embed_tokens
|
25 |
+
raise AttributeError("Could not locate token embedding layer on model (no get_input_embeddings/embed_tokens).")
|
26 |
+
|
27 |
+
|
28 |
+
class QwenVLwithVisionHeadOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
|
29 |
+
"""
|
30 |
+
Output class for Qwen2_5_VL with pointer head, extending the base output class.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
34 |
+
Language modeling loss.
|
35 |
+
pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
36 |
+
Vision pointer network loss.
|
37 |
+
pointer_scores (`List[torch.FloatTensor]`, *optional*):
|
38 |
+
Attention scores from the pointer network, one tensor per batch item.
|
39 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
|
40 |
+
Combined loss (weighted sum of lm_loss and pointer_loss).
|
41 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
42 |
+
Prediction scores from the language modeling head.
|
43 |
+
past_key_values, hidden_states, attentions, rope_deltas:
|
44 |
+
Same as parent class.
|
45 |
+
"""
|
46 |
+
def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs):
|
47 |
+
super().__init__(*args, **kwargs)
|
48 |
+
self.lm_loss = lm_loss
|
49 |
+
self.pointer_loss = pointer_loss
|
50 |
+
self.pointer_scores = pointer_scores
|
51 |
+
|
52 |
+
|
53 |
+
class VisionHead_MultiPatch(nn.Module):
|
54 |
+
def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1):
|
55 |
+
super().__init__()
|
56 |
+
self.d_model = d_model
|
57 |
+
|
58 |
+
self.projection_enc = nn.Sequential(
|
59 |
+
nn.Linear(d_model, projection_dim),
|
60 |
+
nn.GELU(),
|
61 |
+
nn.Linear(projection_dim, d_model),
|
62 |
+
)
|
63 |
+
self.projection_dec = nn.Sequential(
|
64 |
+
nn.Linear(d_model, projection_dim),
|
65 |
+
nn.GELU(),
|
66 |
+
nn.Linear(projection_dim, d_model),
|
67 |
+
)
|
68 |
+
|
69 |
+
self.self_attention = nn.MultiheadAttention(
|
70 |
+
embed_dim=d_model, num_heads=num_attention_heads, dropout=dropout_rate, batch_first=True
|
71 |
+
)
|
72 |
+
|
73 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
74 |
+
self.dropout = nn.Dropout(dropout_rate)
|
75 |
+
|
76 |
+
def forward(
|
77 |
+
self,
|
78 |
+
hidden_state_enc, # [n_enc, d_model]
|
79 |
+
hidden_state_dec, # [n_dec, d_model]
|
80 |
+
labels: Optional[torch.Tensor] = None, # [n_dec, n_enc] binary mask of patches in bbox
|
81 |
+
do_single_patch: bool = False,
|
82 |
+
):
|
83 |
+
enc_input = hidden_state_enc.unsqueeze(0)
|
84 |
+
attn_output, _ = self.self_attention(query=enc_input, key=enc_input, value=enc_input, need_weights=False)
|
85 |
+
hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output)).squeeze(0) # [n_enc, d_model]
|
86 |
+
|
87 |
+
proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model]
|
88 |
+
proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model]
|
89 |
+
|
90 |
+
scaling = self.d_model ** 0.5
|
91 |
+
patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc]
|
92 |
+
|
93 |
+
attn_weights = F.softmax(patch_logits, dim=-1)
|
94 |
+
|
95 |
+
loss = None
|
96 |
+
if (labels is not None) and (not do_single_patch):
|
97 |
+
epsilon = 1e-8
|
98 |
+
labels_float = labels.float()
|
99 |
+
target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon)
|
100 |
+
pred_log_probs = F.log_softmax(patch_logits, dim=-1)
|
101 |
+
loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean')
|
102 |
+
|
103 |
+
if do_single_patch and (labels is not None):
|
104 |
+
# NOTE: if you ever enable this branch, use patch_logits for CE
|
105 |
+
loss = F.cross_entropy(patch_logits, labels)
|
106 |
+
|
107 |
+
return attn_weights, loss
|
108 |
+
|
109 |
+
|
110 |
+
class Qwen2_5_VLForConditionalGenerationWithPointer(Qwen2_5_VLForConditionalGeneration):
|
111 |
+
def __init__(self, *args, **kwargs):
|
112 |
+
super().__init__(*args, **kwargs)
|
113 |
+
self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size)
|
114 |
+
self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0)
|
115 |
+
self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0)
|
116 |
+
self.post_init()
|
117 |
+
|
118 |
+
# init rope cache slot (used in return_dict path)
|
119 |
+
self.rope_deltas = None
|
120 |
+
|
121 |
+
def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight):
|
122 |
+
self.pointer_loss_weight = pointer_loss_weight
|
123 |
+
self.lm_loss_weight = lm_loss_weight
|
124 |
+
|
125 |
+
def forward(
|
126 |
+
self,
|
127 |
+
input_ids: torch.LongTensor = None, # (batch_size, seq_len)
|
128 |
+
attention_mask: Optional[torch.Tensor] = None,
|
129 |
+
position_ids: Optional[torch.LongTensor] = None,
|
130 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
131 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
132 |
+
labels: Optional[torch.LongTensor] = None,
|
133 |
+
use_cache: Optional[bool] = None,
|
134 |
+
output_attentions: Optional[bool] = None,
|
135 |
+
output_hidden_states: Optional[bool] = None,
|
136 |
+
return_dict: Optional[bool] = None,
|
137 |
+
pixel_values: Optional[torch.Tensor] = None,
|
138 |
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
139 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
140 |
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
141 |
+
rope_deltas: Optional[torch.LongTensor] = None,
|
142 |
+
cache_position: Optional[torch.LongTensor] = None,
|
143 |
+
second_per_grid_ts: Optional[torch.Tensor] = None,
|
144 |
+
# Grounding
|
145 |
+
visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # (batch_size, n_target)
|
146 |
+
multi_patch_labels: Optional[torch.Tensor] = None, # list/packed: [(n_target, n_visual), ...]
|
147 |
+
if_multi_patch: bool = True,
|
148 |
+
coordinates: Optional[List[Tuple[float, float]]] = None,
|
149 |
+
verbose: bool = False,
|
150 |
+
) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]:
|
151 |
+
|
152 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
153 |
+
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
154 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
155 |
+
|
156 |
+
if verbose:
|
157 |
+
rank0_print(f"input_ids: {None if input_ids is None else (input_ids.shape, input_ids[0][:5])}")
|
158 |
+
rank0_print(f"labels: {None if labels is None else (labels.shape, labels[0][:5])}")
|
159 |
+
rank0_print(f"pixel_values: {None if pixel_values is None else pixel_values.shape}")
|
160 |
+
rank0_print(f"image_grid_thw: {None if image_grid_thw is None else image_grid_thw.shape}")
|
161 |
+
rank0_print(f"coordinates: {coordinates}")
|
162 |
+
rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}")
|
163 |
+
rank0_print(f"return_dict: {return_dict}")
|
164 |
+
|
165 |
+
if inputs_embeds is None:
|
166 |
+
if input_ids is None:
|
167 |
+
raise ValueError("Either inputs_embeds or input_ids must be provided.")
|
168 |
+
|
169 |
+
# FIX: use embedding accessor instead of .embed_tokens
|
170 |
+
token_embedding = _get_token_embedding_layer(self.model)
|
171 |
+
inputs_embeds = token_embedding(input_ids) # (batch, seq_len, d_model)
|
172 |
+
|
173 |
+
if pixel_values is not None:
|
174 |
+
pixel_values = pixel_values.type(self.visual.dtype)
|
175 |
+
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
176 |
+
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
177 |
+
n_image_features = image_embeds.shape[0]
|
178 |
+
if n_image_tokens != n_image_features:
|
179 |
+
raise ValueError(
|
180 |
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features: {n_image_features}"
|
181 |
+
)
|
182 |
+
image_mask = (
|
183 |
+
(input_ids == self.config.image_token_id)
|
184 |
+
.unsqueeze(-1)
|
185 |
+
.expand_as(inputs_embeds)
|
186 |
+
.to(inputs_embeds.device)
|
187 |
+
)
|
188 |
+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
189 |
+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
|
190 |
+
|
191 |
+
if pixel_values_videos is not None:
|
192 |
+
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
193 |
+
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
194 |
+
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
195 |
+
n_video_features = video_embeds.shape[0]
|
196 |
+
if n_video_tokens != n_video_features:
|
197 |
+
raise ValueError(
|
198 |
+
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features: {n_video_features}"
|
199 |
+
)
|
200 |
+
video_mask = (
|
201 |
+
(input_ids == self.config.video_token_id)
|
202 |
+
.unsqueeze(-1)
|
203 |
+
.expand_as(inputs_embeds)
|
204 |
+
.to(inputs_embeds.device)
|
205 |
+
)
|
206 |
+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
207 |
+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
|
208 |
+
|
209 |
+
if attention_mask is not None:
|
210 |
+
attention_mask = attention_mask.to(inputs_embeds.device)
|
211 |
+
|
212 |
+
# RoPE positions / deltas
|
213 |
+
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
214 |
+
if (
|
215 |
+
(cache_position is not None and cache_position[0] == 0)
|
216 |
+
or self.rope_deltas is None
|
217 |
+
or (past_key_values is None or past_key_values.get_seq_length() == 0)
|
218 |
+
):
|
219 |
+
position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
|
220 |
+
self.rope_deltas = rope_deltas
|
221 |
+
else:
|
222 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
223 |
+
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
|
224 |
+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
|
225 |
+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
|
226 |
+
if cache_position is not None:
|
227 |
+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0).to(position_ids.device)
|
228 |
+
position_ids = position_ids.add(delta)
|
229 |
+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
230 |
+
|
231 |
+
outputs = self.model(
|
232 |
+
input_ids=None,
|
233 |
+
position_ids=position_ids,
|
234 |
+
attention_mask=attention_mask,
|
235 |
+
past_key_values=past_key_values,
|
236 |
+
inputs_embeds=inputs_embeds,
|
237 |
+
use_cache=use_cache,
|
238 |
+
output_attentions=output_attentions,
|
239 |
+
output_hidden_states=output_hidden_states,
|
240 |
+
return_dict=return_dict,
|
241 |
+
cache_position=cache_position,
|
242 |
+
)
|
243 |
+
|
244 |
+
hidden_states = outputs[0] # (batch, seq_len, d_model)
|
245 |
+
logits = self.lm_head(hidden_states)
|
246 |
+
|
247 |
+
lm_loss = None
|
248 |
+
if labels is not None and self.lm_loss_weight > 0:
|
249 |
+
logits = logits.float()
|
250 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
251 |
+
shift_labels = labels[..., 1:].contiguous()
|
252 |
+
loss_fct = nn.CrossEntropyLoss()
|
253 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
254 |
+
shift_labels = shift_labels.view(-1).to(shift_logits.device)
|
255 |
+
lm_loss = loss_fct(shift_logits, shift_labels)
|
256 |
+
|
257 |
+
pointer_loss = None
|
258 |
+
pointer_scores = []
|
259 |
+
if visual_token_indices_of_coordinates is not None:
|
260 |
+
batch_size = input_ids.shape[0]
|
261 |
+
pointer_losses = []
|
262 |
+
|
263 |
+
for i in range(batch_size):
|
264 |
+
dummy_target = False
|
265 |
+
|
266 |
+
token_ids = input_ids[i] # (seq_len,)
|
267 |
+
hs = hidden_states[i] # (seq_len, d_model)
|
268 |
+
|
269 |
+
visual_mask = (token_ids == self.config.image_token_id)
|
270 |
+
visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # (n_visual,)
|
271 |
+
|
272 |
+
target_mask = (token_ids == self.config.pointer_pad_token_id)
|
273 |
+
target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
|
274 |
+
|
275 |
+
if visual_indices.numel() == 0:
|
276 |
+
raise ValueError(f"No visual tokens found for sample {i}.")
|
277 |
+
|
278 |
+
if target_indices.numel() == 0:
|
279 |
+
target_indices = torch.tensor([hs.shape[0] - 1], device=hs.device)
|
280 |
+
gt = torch.tensor([0], device=hs.device) # not used in multi-patch
|
281 |
+
if if_multi_patch:
|
282 |
+
sample_labels = torch.zeros_like(visual_indices).unsqueeze(0)
|
283 |
+
sample_labels[0][:4] = 1
|
284 |
+
dummy_target = True
|
285 |
+
else:
|
286 |
+
gt = visual_token_indices_of_coordinates[i].to(hs.device) # (n_target,)
|
287 |
+
if if_multi_patch:
|
288 |
+
sample_labels = multi_patch_labels[i]
|
289 |
+
|
290 |
+
# Use input embeddings for visual tokens (image tokens got replaced earlier)
|
291 |
+
visual_embeds = inputs_embeds[i][visual_indices] # (n_visual, d_model)
|
292 |
+
target_hidden = hs[target_indices] # (n_target, d_model)
|
293 |
+
|
294 |
+
if if_multi_patch:
|
295 |
+
if sample_labels.shape[0] != target_indices.shape[0]:
|
296 |
+
raise ValueError(
|
297 |
+
f"Sample {i} mismatched targets: {sample_labels.shape[0]} labels vs {target_indices.shape[0]} targets"
|
298 |
+
)
|
299 |
+
attn_scores, loss_v = self.multi_patch_pointer_head(
|
300 |
+
visual_embeds,
|
301 |
+
target_hidden,
|
302 |
+
labels=sample_labels,
|
303 |
+
)
|
304 |
+
else:
|
305 |
+
# Deprecated: single-patch branch
|
306 |
+
attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt)
|
307 |
+
|
308 |
+
pointer_scores.append(attn_scores.detach().cpu())
|
309 |
+
pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v)
|
310 |
+
|
311 |
+
pointer_loss = torch.stack(pointer_losses).mean()
|
312 |
+
|
313 |
+
if lm_loss is None:
|
314 |
+
total_loss = pointer_loss
|
315 |
+
elif pointer_loss is None:
|
316 |
+
total_loss = lm_loss
|
317 |
+
else:
|
318 |
+
total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss
|
319 |
+
|
320 |
+
if return_dict:
|
321 |
+
return QwenVLwithVisionHeadOutputWithPast(
|
322 |
+
lm_loss=lm_loss,
|
323 |
+
pointer_loss=pointer_loss,
|
324 |
+
pointer_scores=pointer_scores,
|
325 |
+
loss=total_loss,
|
326 |
+
logits=logits,
|
327 |
+
past_key_values=outputs.past_key_values,
|
328 |
+
hidden_states=outputs.hidden_states,
|
329 |
+
attentions=outputs.attentions,
|
330 |
+
rope_deltas=self.rope_deltas,
|
331 |
+
)
|
332 |
+
else:
|
333 |
+
if labels is not None:
|
334 |
+
output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:]
|
335 |
+
return (total_loss,) + output if total_loss is not None else output
|
336 |
+
else:
|
337 |
+
return outputs
|
gui_actor/trainer.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import timedelta
|
2 |
+
from functools import wraps
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
import transformers
|
8 |
+
from accelerate import Accelerator, DataLoaderConfiguration
|
9 |
+
from accelerate.utils import GradientAccumulationPlugin, InitProcessGroupKwargs
|
10 |
+
from torch.utils.data import DataLoader, RandomSampler
|
11 |
+
from transformers import Trainer
|
12 |
+
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
13 |
+
from transformers.trainer_pt_utils import get_parameter_names
|
14 |
+
from transformers.trainer_utils import has_length
|
15 |
+
from transformers.utils import (
|
16 |
+
is_accelerate_available,
|
17 |
+
is_datasets_available,
|
18 |
+
is_sagemaker_mp_enabled,
|
19 |
+
)
|
20 |
+
from transformers.trainer_pt_utils import LengthGroupedSampler as HFLengthGroupedSampler
|
21 |
+
from transformers.trainer_utils import seed_worker
|
22 |
+
from transformers.utils import logging
|
23 |
+
|
24 |
+
if is_datasets_available():
|
25 |
+
import datasets
|
26 |
+
|
27 |
+
|
28 |
+
def rank0_print(*args):
|
29 |
+
if dist.is_initialized():
|
30 |
+
if dist.get_rank() == 0:
|
31 |
+
print(f"Rank {dist.get_rank()}: ", *args)
|
32 |
+
else:
|
33 |
+
print(*args)
|
34 |
+
|
35 |
+
|
36 |
+
def maybe_zero_3(param, ignore_status=False, name=None):
|
37 |
+
from deepspeed import zero
|
38 |
+
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
|
39 |
+
|
40 |
+
if hasattr(param, "ds_id"):
|
41 |
+
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE and not ignore_status:
|
42 |
+
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
|
43 |
+
with zero.GatheredParameters([param]):
|
44 |
+
param = param.data.detach().cpu().clone()
|
45 |
+
else:
|
46 |
+
param = param.detach().cpu().clone()
|
47 |
+
return param
|
48 |
+
|
49 |
+
|
50 |
+
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
|
51 |
+
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
|
52 |
+
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
|
53 |
+
return to_return
|
54 |
+
|
55 |
+
|
56 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
57 |
+
"""Collects the state dict and dump to disk."""
|
58 |
+
trainer.accelerator.wait_for_everyone()
|
59 |
+
torch.cuda.synchronize()
|
60 |
+
|
61 |
+
if trainer.deepspeed:
|
62 |
+
trainer.save_model(output_dir)
|
63 |
+
return
|
64 |
+
|
65 |
+
state_dict = trainer.model.state_dict()
|
66 |
+
if trainer.args.should_save:
|
67 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
68 |
+
del state_dict
|
69 |
+
trainer._save(output_dir, state_dict=cpu_state_dict)
|
70 |
+
|
71 |
+
|
72 |
+
class AGUVISTrainer(Trainer):
|
73 |
+
|
74 |
+
def __init__(self, *args, **kwargs):
|
75 |
+
super().__init__(*args, **kwargs)
|
76 |
+
|
77 |
+
original_save = self._save
|
78 |
+
original_save_model = self.save_model
|
79 |
+
|
80 |
+
def modify_eos_token(func):
|
81 |
+
@wraps(func)
|
82 |
+
def wrapper(*args, **kwargs):
|
83 |
+
tokenizer = self.processing_class.tokenizer
|
84 |
+
old_config_id = self.model.config.eos_token_id
|
85 |
+
old_eos_token = tokenizer.eos_token
|
86 |
+
old_generation_config_eos_token_id = (
|
87 |
+
self.model.generation_config.eos_token_id if hasattr(self.model, "generation_config") else None
|
88 |
+
)
|
89 |
+
|
90 |
+
try:
|
91 |
+
new_eos_token_id = tokenizer.convert_tokens_to_ids("<|diff_marker|>")
|
92 |
+
self.model.config.eos_token_id = [new_eos_token_id]
|
93 |
+
tokenizer.eos_token = "<|diff_marker|>"
|
94 |
+
if hasattr(self.model, "generation_config"):
|
95 |
+
self.model.generation_config.eos_token_id = [new_eos_token_id]
|
96 |
+
|
97 |
+
print("Set eos token id to", new_eos_token_id)
|
98 |
+
print("Set eos token to", "<|diff_marker|>")
|
99 |
+
print("Set generation config eos token id to", [new_eos_token_id])
|
100 |
+
|
101 |
+
result = func(*args, **kwargs)
|
102 |
+
return result
|
103 |
+
finally:
|
104 |
+
self.model.config.eos_token_id = old_config_id
|
105 |
+
tokenizer.eos_token = old_eos_token
|
106 |
+
if hasattr(self.model, "generation_config") and old_generation_config_eos_token_id is not None:
|
107 |
+
self.model.generation_config.eos_token_id = old_generation_config_eos_token_id
|
108 |
+
|
109 |
+
print("Set eos token id back to", old_config_id)
|
110 |
+
print("Set eos token back to", old_eos_token)
|
111 |
+
if old_generation_config_eos_token_id is not None:
|
112 |
+
print("Set generation config eos token id back to", old_generation_config_eos_token_id)
|
113 |
+
|
114 |
+
return wrapper
|
115 |
+
|
116 |
+
self._save = modify_eos_token(original_save)
|
117 |
+
self.save_model = modify_eos_token(original_save_model)
|
118 |
+
|
119 |
+
def create_accelerator_and_postprocess(self):
|
120 |
+
grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
|
121 |
+
grad_acc_kwargs["sync_with_dataloader"] = False
|
122 |
+
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
|
123 |
+
|
124 |
+
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
|
125 |
+
|
126 |
+
# create accelerator object
|
127 |
+
dispatch_batches = getattr(self.args, "dispatch_batches", None)
|
128 |
+
split_batches = getattr(self.args, "split_batches", None)
|
129 |
+
self.dataloader_config = DataLoaderConfiguration(
|
130 |
+
dispatch_batches=dispatch_batches,
|
131 |
+
split_batches=split_batches,
|
132 |
+
)
|
133 |
+
self.accelerator = Accelerator(
|
134 |
+
dataloader_config=self.dataloader_config,
|
135 |
+
deepspeed_plugin=self.args.deepspeed_plugin,
|
136 |
+
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
137 |
+
kwargs_handlers=[accelerator_kwargs],
|
138 |
+
)
|
139 |
+
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
|
140 |
+
self.gather_function = self.accelerator.gather_for_metrics
|
141 |
+
|
142 |
+
# deepspeed and accelerate flags covering both trainer args and accelerate launcher
|
143 |
+
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
144 |
+
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
145 |
+
|
146 |
+
# post accelerator creation setup
|
147 |
+
if self.is_fsdp_enabled:
|
148 |
+
fsdp_plugin = self.accelerator.state.fsdp_plugin
|
149 |
+
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
|
150 |
+
"limit_all_gathers", fsdp_plugin.limit_all_gathers
|
151 |
+
)
|
152 |
+
if is_accelerate_available("0.23.0"):
|
153 |
+
fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
|
154 |
+
"activation_checkpointing", fsdp_plugin.activation_checkpointing
|
155 |
+
)
|
156 |
+
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
|
157 |
+
raise ValueError(
|
158 |
+
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
|
159 |
+
"can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
|
160 |
+
"when using FSDP."
|
161 |
+
)
|
162 |
+
|
163 |
+
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
|
164 |
+
self.propagate_args_to_deepspeed()
|
165 |
+
|
166 |
+
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
167 |
+
if self.train_dataset is None or not has_length(self.train_dataset):
|
168 |
+
return None
|
169 |
+
|
170 |
+
if self.args.group_by_length:
|
171 |
+
lengths = self.train_dataset.lengths
|
172 |
+
return HFLengthGroupedSampler(
|
173 |
+
self.args.train_batch_size * self.args.gradient_accumulation_steps,
|
174 |
+
dataset=self.train_dataset,
|
175 |
+
lengths=lengths,
|
176 |
+
)
|
177 |
+
elif self.args.group_by_modality_length:
|
178 |
+
lengths = self.train_dataset.modality_lengths
|
179 |
+
return HFLengthGroupedSampler(
|
180 |
+
self.args.train_batch_size * self.args.gradient_accumulation_steps,
|
181 |
+
dataset=self.train_dataset,
|
182 |
+
lengths=lengths,
|
183 |
+
)
|
184 |
+
else:
|
185 |
+
return RandomSampler(self.train_dataset)
|
186 |
+
|
187 |
+
def get_train_dataloader(self) -> DataLoader:
|
188 |
+
"""
|
189 |
+
Returns the training [`~torch.utils.data.DataLoader`].
|
190 |
+
|
191 |
+
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
|
192 |
+
training if necessary) otherwise.
|
193 |
+
|
194 |
+
Subclass and override this method if you want to inject some custom behavior.
|
195 |
+
"""
|
196 |
+
if self.train_dataset is None:
|
197 |
+
raise ValueError("Trainer: training requires a train_dataset.")
|
198 |
+
|
199 |
+
train_dataset = self.train_dataset
|
200 |
+
data_collator = self.data_collator
|
201 |
+
if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
|
202 |
+
train_dataset = self._remove_unused_columns(train_dataset, description="training")
|
203 |
+
else:
|
204 |
+
data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
|
205 |
+
|
206 |
+
dataloader_params = {
|
207 |
+
"batch_size": self._train_batch_size,
|
208 |
+
"collate_fn": data_collator,
|
209 |
+
"num_workers": self.args.dataloader_num_workers,
|
210 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
211 |
+
"persistent_workers": self.args.dataloader_persistent_workers,
|
212 |
+
}
|
213 |
+
|
214 |
+
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
215 |
+
dataloader_params["sampler"] = self._get_train_sampler()
|
216 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
217 |
+
dataloader_params["worker_init_fn"] = seed_worker
|
218 |
+
dataloader_params["prefetch_factor"] = (
|
219 |
+
self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None
|
220 |
+
)
|
221 |
+
|
222 |
+
dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
223 |
+
|
224 |
+
return dataloader
|
225 |
+
|
226 |
+
def create_optimizer(self):
|
227 |
+
"""
|
228 |
+
Setup the optimizer.
|
229 |
+
|
230 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
231 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
232 |
+
"""
|
233 |
+
if is_sagemaker_mp_enabled():
|
234 |
+
return super().create_optimizer()
|
235 |
+
|
236 |
+
opt_model = self.model
|
237 |
+
|
238 |
+
if self.optimizer is None:
|
239 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
240 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
241 |
+
optimizer_grouped_parameters = [
|
242 |
+
{
|
243 |
+
"params": [
|
244 |
+
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
|
245 |
+
],
|
246 |
+
"weight_decay": self.args.weight_decay,
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"params": [
|
250 |
+
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
|
251 |
+
],
|
252 |
+
"weight_decay": 0.0,
|
253 |
+
},
|
254 |
+
]
|
255 |
+
|
256 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
|
257 |
+
|
258 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
259 |
+
|
260 |
+
return self.optimizer
|
261 |
+
|
262 |
+
def create_optimizer_with_different_learning_rates(self):
|
263 |
+
"""
|
264 |
+
Setup the optimizer.
|
265 |
+
|
266 |
+
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
|
267 |
+
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
|
268 |
+
"""
|
269 |
+
if is_sagemaker_mp_enabled():
|
270 |
+
raise NotImplementedError("Sagemaker MP is not supported for separate learning rate yet")
|
271 |
+
return super().create_optimizer()
|
272 |
+
|
273 |
+
opt_model = self.model
|
274 |
+
|
275 |
+
if self.optimizer is None:
|
276 |
+
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
277 |
+
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
278 |
+
|
279 |
+
new_parameters = []
|
280 |
+
for name, param in opt_model.named_parameters():
|
281 |
+
if ("pointer_head" in name) or ("embed_tokens" in name):
|
282 |
+
new_parameters.append(name)
|
283 |
+
rank0_print(f"new_parameters: {len(new_parameters)}")
|
284 |
+
|
285 |
+
optimizer_grouped_parameters = [
|
286 |
+
{
|
287 |
+
"params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n not in new_parameters) and p.requires_grad)],
|
288 |
+
"weight_decay": self.args.weight_decay,
|
289 |
+
"lr": self.args.learning_rate,
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n not in new_parameters) and p.requires_grad)],
|
293 |
+
"weight_decay": 0.0,
|
294 |
+
"lr": self.args.learning_rate,
|
295 |
+
},
|
296 |
+
{
|
297 |
+
"params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n in new_parameters) and p.requires_grad)],
|
298 |
+
"weight_decay": self.args.weight_decay,
|
299 |
+
"lr": self.args.learning_rate_new_params,
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n in new_parameters) and p.requires_grad)],
|
303 |
+
"weight_decay": 0.0,
|
304 |
+
"lr": self.args.learning_rate_new_params,
|
305 |
+
},
|
306 |
+
]
|
307 |
+
|
308 |
+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) # {'lr': 0.0001, 'betas': (0.9, 0.999), 'eps': 1e-08}
|
309 |
+
optimizer_kwargs.pop("lr")
|
310 |
+
|
311 |
+
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
312 |
+
|
313 |
+
return self.optimizer
|
gui_actor/utils.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw, ImageColor
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
def dump_args_to_json(model_config, data_processor, model_args, data_args, training_args, output_dir):
|
6 |
+
def is_json_serializable(v):
|
7 |
+
try:
|
8 |
+
json.dumps(v)
|
9 |
+
return True
|
10 |
+
except:
|
11 |
+
return False
|
12 |
+
|
13 |
+
save_path = f"{output_dir}/args.json"
|
14 |
+
if not os.path.exists(save_path):
|
15 |
+
with open(save_path, "w") as f:
|
16 |
+
json.dump({
|
17 |
+
"model_config": {k: v for k, v in model_config.__dict__.items() if is_json_serializable(v)},
|
18 |
+
"data_processor_config": {k: v for k, v in data_processor.__dict__.items() if is_json_serializable(v)},
|
19 |
+
"image_processor_config": {k: v for k, v in data_processor.image_processor.__dict__.items() if is_json_serializable(v)},
|
20 |
+
"model_args": {k: v for k, v in model_args.__dict__.items() if is_json_serializable(v)},
|
21 |
+
"data_args": {k: v for k, v in data_args.__dict__.items() if is_json_serializable(v)},
|
22 |
+
"training_args": {k: v for k, v in training_args.__dict__.items() if is_json_serializable(v)},
|
23 |
+
}, f, indent=4)
|
24 |
+
|
25 |
+
def draw_point(image: Image.Image, point: list, color=None):
|
26 |
+
if isinstance(color, str):
|
27 |
+
try:
|
28 |
+
color = ImageColor.getrgb(color)
|
29 |
+
color = color + (128,)
|
30 |
+
except ValueError:
|
31 |
+
color = (255, 0, 0, 128)
|
32 |
+
else:
|
33 |
+
color = (255, 0, 0, 128)
|
34 |
+
|
35 |
+
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
|
36 |
+
overlay_draw = ImageDraw.Draw(overlay)
|
37 |
+
radius = 14
|
38 |
+
x, y = point
|
39 |
+
|
40 |
+
overlay_draw.rectangle(
|
41 |
+
[x - radius, y - radius, x + radius, y + radius],
|
42 |
+
fill=color
|
43 |
+
)
|
44 |
+
|
45 |
+
center_radius = radius * 0.1
|
46 |
+
overlay_draw.ellipse(
|
47 |
+
[(x - center_radius, y - center_radius),
|
48 |
+
(x + center_radius, y + center_radius)],
|
49 |
+
fill=(0, 255, 0, 255)
|
50 |
+
)
|
51 |
+
|
52 |
+
image = image.convert('RGBA')
|
53 |
+
combined = Image.alpha_composite(image, overlay)
|
54 |
+
|
55 |
+
return combined.convert('RGB')
|
56 |
+
|
57 |
+
def draw_bbox(image: Image.Image, bbox: list, color=None):
|
58 |
+
"""bbox is in the format of [x1, y1, x2, y2]"""
|
59 |
+
if isinstance(color, str):
|
60 |
+
try:
|
61 |
+
color = ImageColor.getrgb(color)
|
62 |
+
color = color + (128,)
|
63 |
+
except ValueError:
|
64 |
+
color = (255, 0, 0, 128)
|
65 |
+
else:
|
66 |
+
color = (255, 0, 0, 128)
|
67 |
+
|
68 |
+
overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
|
69 |
+
overlay_draw = ImageDraw.Draw(overlay)
|
70 |
+
overlay_draw.rectangle(bbox, fill=color)
|
71 |
+
return Image.alpha_composite(image, overlay).convert('RGB')
|
72 |
+
|
73 |
+
def do_boxes_overlap(box1, box2):
|
74 |
+
"""
|
75 |
+
Check if two boxes overlap.
|
76 |
+
|
77 |
+
Each box is represented as a tuple: (x1, y1, x2, y2)
|
78 |
+
Where (x1, y1) is the top-left and (x2, y2) is the bottom-right corner.
|
79 |
+
"""
|
80 |
+
# Unpack the coordinates
|
81 |
+
x1_min, y1_min, x1_max, y1_max = box1
|
82 |
+
x2_min, y2_min, x2_max, y2_max = box2
|
83 |
+
|
84 |
+
# Check for no overlap
|
85 |
+
if x1_max < x2_min or x2_max < x1_min:
|
86 |
+
return False
|
87 |
+
if y1_max < y2_min or y2_max < y1_min:
|
88 |
+
return False
|
89 |
+
|
90 |
+
return True
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
accelerate
|
3 |
+
torch
|
4 |
+
Pillow
|
5 |
+
requests
|
6 |
+
torchvision
|
7 |
+
torchaudio
|
8 |
+
gradio
|
9 |
+
gradio_client
|
10 |
+
spaces
|
11 |
+
opencv-python-headless
|
12 |
+
datasets
|
13 |
+
qwen-vl-utils
|
14 |
+
pre-commit
|
15 |
+
matplotlib
|
16 |
+
#flash-attn
|