johnbridges commited on
Commit
9f57ecf
·
1 Parent(s): 8ecfcea
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64, os
2
+ # import spaces
3
+ import json
4
+ import torch
5
+ import gradio as gr
6
+ from typing import Optional
7
+ from PIL import Image, ImageDraw
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from qwen_vl_utils import process_vision_info
11
+ from datasets import load_dataset
12
+ from transformers import AutoProcessor
13
+ from gui_actor.constants import chat_template
14
+ from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
15
+ from gui_actor.inference import inference
16
+
17
+ MAX_PIXELS = 3200 * 1800
18
+
19
+ def resize_image(image, resize_to_pixels=MAX_PIXELS):
20
+ image_width, image_height = image.size
21
+ if (resize_to_pixels is not None) and ((image_width * image_height) != resize_to_pixels):
22
+ resize_ratio = (resize_to_pixels / (image_width * image_height)) ** 0.5
23
+ image_width_resized, image_height_resized = int(image_width * resize_ratio), int(image_height * resize_ratio)
24
+ image = image.resize((image_width_resized, image_height_resized))
25
+ return image
26
+
27
+ # @spaces.GPU
28
+ @torch.inference_mode()
29
+ def draw_point(image: Image.Image, point: list, radius=8, color=(255, 0, 0, 128)):
30
+ overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
31
+ overlay_draw = ImageDraw.Draw(overlay)
32
+ x, y = point
33
+ overlay_draw.ellipse(
34
+ [(x - radius, y - radius), (x + radius, y + radius)],
35
+ outline=color,
36
+ width=5 # Adjust thickness as needed
37
+ )
38
+ image = image.convert('RGBA')
39
+ combined = Image.alpha_composite(image, overlay)
40
+ combined = combined.convert('RGB')
41
+ return combined
42
+
43
+ # @spaces.GPU
44
+ @torch.inference_mode()
45
+ def get_attn_map(image, attn_scores, n_width, n_height):
46
+ w, h = image.size
47
+ scores = np.array(attn_scores[0]).reshape(n_height, n_width)
48
+
49
+ scores_norm = (scores - scores.min()) / (scores.max() - scores.min())
50
+ # Resize score map to match image size
51
+ score_map = Image.fromarray((scores_norm * 255).astype(np.uint8)).resize((w, h), resample=Image.NEAREST) # BILINEAR)
52
+ # Apply colormap
53
+ colormap = plt.get_cmap('jet')
54
+ colored_score_map = colormap(np.array(score_map) / 255.0) # returns RGBA
55
+ colored_score_map = (colored_score_map[:, :, :3] * 255).astype(np.uint8)
56
+ colored_overlay = Image.fromarray(colored_score_map)
57
+
58
+ # Blend with original image
59
+ blended = Image.blend(image, colored_overlay, alpha=0.3)
60
+ return blended
61
+
62
+ # load model
63
+ if torch.cuda.is_available():
64
+ # os.system('pip install flash-attn --no-build-isolation')
65
+ model_name_or_path = "microsoft/GUI-Actor-7B-Qwen2.5-VL"
66
+ data_processor = AutoProcessor.from_pretrained(model_name_or_path)
67
+ tokenizer = data_processor.tokenizer
68
+ model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
69
+ model_name_or_path,
70
+ torch_dtype=torch.bfloat16,
71
+ device_map="cuda:0",
72
+ attn_implementation="flash_attention_2"
73
+ ).eval()
74
+ else:
75
+ model_name_or_path = "microsoft/GUI-Actor-3B-Qwen2.5-VL"
76
+ data_processor = AutoProcessor.from_pretrained(model_name_or_path)
77
+ tokenizer = data_processor.tokenizer
78
+ model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
79
+ model_name_or_path,
80
+ torch_dtype=torch.bfloat16,
81
+ device_map="cpu"
82
+ ).eval()
83
+
84
+ title = "GUI-Actor"
85
+ header = """
86
+ <div align="center">
87
+ <h1 style="padding-bottom: 10px; padding-top: 10px;">🎯 <strong>GUI-Actor</strong>: Coordinate-Free Visual Grounding for GUI Agents</h1>
88
+ <div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
89
+ Qianhui Wu*, Kanzhi Cheng*, Rui Yang*, Chaoyun Zhang, Jianwei Yang, Huiqiang Jiang, Jian Mu, Baolin Peng, Bo Qiao, Reuben Tan, Si Qin, Lars Liden<br>
90
+ Qingwei Lin, Huan Zhang, Tong Zhang, Jianbing Zhang, Dongmei Zhang, Jianfeng Gao<br/>
91
+ </div>
92
+ <div style="padding-bottom: 10px; padding-top: 10px; font-size: 16px;">
93
+ <a href="https://microsoft.github.io/GUI-Actor/">🌐 Project Page</a> | <a href="https://arxiv.org/abs/2403.12968">📄 arXiv Paper</a> | <a href="https://github.com/microsoft/GUI-Actor">💻 Github Repo</a><br/>
94
+ </div>
95
+ </div>
96
+ """
97
+
98
+ theme = "soft"
99
+ css = """#anno-img .mask {opacity: 0.5; transition: all 0.2s ease-in-out;}
100
+ #anno-img .mask.active {opacity: 0.7}"""
101
+
102
+ # @spaces.GPU
103
+ @torch.inference_mode()
104
+ def process(image, instruction):
105
+ # resize image
106
+ w, h = image.size
107
+ if w * h > MAX_PIXELS:
108
+ image = resize_image(image)
109
+
110
+ conversation = [
111
+ {
112
+ "role": "system",
113
+ "content": [
114
+ {
115
+ "type": "text",
116
+ "text": "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>).",
117
+ }
118
+ ]
119
+ },
120
+ {
121
+ "role": "user",
122
+ "content": [
123
+ {
124
+ "type": "image",
125
+ "image": image, # PIL.Image.Image or str to path
126
+ # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "", will be split by "base64,"
127
+ },
128
+ {
129
+ "type": "text",
130
+ "text": instruction,
131
+ },
132
+ ],
133
+ },
134
+ ]
135
+
136
+ try:
137
+ pred = inference(conversation, model, tokenizer, data_processor, use_placeholder=True, topk=3)
138
+ except Exception as e:
139
+ print(e)
140
+ return image, f"Error: {e}", None
141
+
142
+ px, py = pred["topk_points"][0]
143
+ output_coord = f"({px:.4f}, {py:.4f})"
144
+ img_with_point = draw_point(image, (px * w, py * h))
145
+
146
+ n_width, n_height = pred["n_width"], pred["n_height"]
147
+ attn_scores = pred["attn_scores"]
148
+ att_map = get_attn_map(image, attn_scores, n_width, n_height)
149
+
150
+ return img_with_point, output_coord, att_map
151
+
152
+
153
+ with gr.Blocks(title=title, css=css) as demo:
154
+ gr.Markdown(header)
155
+ with gr.Row():
156
+ with gr.Column():
157
+ input_image = gr.Image(
158
+ type='pil', label='Upload image')
159
+ # text box
160
+ input_instruction = gr.Textbox(label='Instruction', placeholder='Text your (low-level) instruction here')
161
+ submit_button = gr.Button(
162
+ value='Submit', variant='primary')
163
+ with gr.Column():
164
+ image_with_point = gr.Image(type='pil', label='Image with Point (red circle)')
165
+ with gr.Accordion('Detailed prediction'):
166
+ pred_xy = gr.Textbox(label='Predicted Coordinates', placeholder='(x, y)')
167
+ att_map = gr.Image(type='pil', label='Attention Map')
168
+
169
+ submit_button.click(
170
+ fn=process,
171
+ inputs=[
172
+ input_image,
173
+ input_instruction
174
+ ],
175
+ outputs=[image_with_point, pred_xy, att_map]
176
+ )
177
+
178
+ # demo.launch(debug=False, show_error=True, share=True)
179
+ # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
180
+ demo.queue().launch(share=False)
gui_actor/__init__.py ADDED
File without changes
gui_actor/constants.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
4
+ WORKER_HEART_BEAT_INTERVAL = 15
5
+
6
+ LOGDIR = "."
7
+
8
+ # Model Constants
9
+ IGNORE_INDEX = -100
10
+ DEFAULT_IMAGE_TOKEN = "<image>"
11
+ DEFAULT_POINTER_START_TOKEN = "<|pointer_start|>"
12
+ DEFAULT_POINTER_END_TOKEN = "<|pointer_end|>"
13
+ DEFAULT_POINTER_PAD_TOKEN = "<|pointer_pad|>"
14
+
15
+ # UNMASK_TOKEN_IDS = [198, 151644, 151645]
16
+
17
+ # System Message
18
+ grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click(<your_special_token_here>)."
19
+
20
+ # Chat Template
21
+ chat_template = "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
22
+
23
+ assistant_template = "{% for message in messages %}{{'<|im_start|>' + message['role']}}{% if 'recipient' in message %}<|recipient|>{{ message['recipient'] }}{% endif %}{{'\n' + message['content'][0]['text']}}{% if 'end_turn' in message and message['end_turn'] %}{{'<|diff_marker|>\n'}}{% else %}{{'<|im_end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|recipient|>' }}{% endif %}"
24
+
25
+ # Special Tokens
26
+ ADDITIONAL_SPECIAL_TOKENS = [
27
+ "<|recipient|>",
28
+ "<|diff_marker|>",
29
+ DEFAULT_POINTER_START_TOKEN,
30
+ DEFAULT_POINTER_END_TOKEN,
31
+ DEFAULT_POINTER_PAD_TOKEN,
32
+ ]
33
+
34
+ # Action Patterns to be replaced with special tokens
35
+ ACTION_PATTENS_XY = [
36
+ r"x=([0-9.]+), y=([0-9.]+)",
37
+ r"from_coord=\[([0-9.]+), ([0-9.]+)\], to_coord=\[([0-9.]+), ([0-9.]+)\]",
38
+ ]
39
+
40
+ until = ["<|diff_marker|>"]
gui_actor/dataset.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import math
4
+ import os
5
+ import random
6
+ import re
7
+ import ast
8
+ from typing import Dict
9
+
10
+ import torch
11
+ import transformers
12
+ import yaml
13
+ from qwen_vl_utils import smart_resize, process_vision_info
14
+ from torch.utils.data import Dataset
15
+
16
+ from gui_actor.constants import (
17
+ IGNORE_INDEX,
18
+ DEFAULT_IMAGE_TOKEN,
19
+ DEFAULT_POINTER_START_TOKEN,
20
+ DEFAULT_POINTER_PAD_TOKEN,
21
+ DEFAULT_POINTER_END_TOKEN,
22
+ ACTION_PATTENS_XY,
23
+ ADDITIONAL_SPECIAL_TOKENS,
24
+ assistant_template,
25
+ chat_template,
26
+ grounding_system_message,
27
+ )
28
+ from gui_actor.trainer import rank0_print
29
+
30
+
31
+ def reformat_coordinates(text):
32
+ """
33
+ (1) Find all the coordinates in the text.
34
+ (2) Replace the coordinates with the special tokens.
35
+ (3) Return the new text and the coordinates as a list of (x, y), where x in [0, 1] and y in [0, 1].
36
+ """
37
+ epsilon = 0.001
38
+ def adjust_coord(c):
39
+ """
40
+ Adjust coordinate if it is too close to 0 or 1.
41
+ """
42
+ if abs(c) < epsilon:
43
+ return epsilon
44
+ elif abs(c - 1) < epsilon:
45
+ return 1 - epsilon
46
+ return c
47
+
48
+ all_matches = []
49
+ for pattern in ACTION_PATTENS_XY:
50
+ matches = list(re.finditer(pattern, text))
51
+ for match in matches:
52
+ all_matches.append((match.start(), match.groups()))
53
+ if pattern == ACTION_PATTENS_XY[0]:
54
+ target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
55
+ else:
56
+ target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}, {DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
57
+ text = re.sub(
58
+ pattern,
59
+ target_text,
60
+ text
61
+ )
62
+
63
+ coordinates = []
64
+ all_matches.sort(key=lambda x: x[0])
65
+ # Extract coordinates in order
66
+ for _, groups in all_matches:
67
+ # When two coordinate values are found, parse them as one (x, y) pair.
68
+ if len(groups) == 2:
69
+ x_str, y_str = groups
70
+ x = adjust_coord(ast.literal_eval(x_str))
71
+ y = adjust_coord(ast.literal_eval(y_str))
72
+ coordinates.append((x, y))
73
+ # When four coordinate values are found, parse them as two pairs.
74
+ elif len(groups) == 4:
75
+ x1_str, y1_str, x2_str, y2_str = groups
76
+ x1 = adjust_coord(ast.literal_eval(x1_str))
77
+ y1 = adjust_coord(ast.literal_eval(y1_str))
78
+ x2 = adjust_coord(ast.literal_eval(x2_str))
79
+ y2 = adjust_coord(ast.literal_eval(y2_str))
80
+ coordinates.append((x1, y1))
81
+ coordinates.append((x2, y2))
82
+
83
+ return text, coordinates
84
+
85
+ def get_token_index(image_processor, image, point_x, point_y):
86
+ """
87
+ Get the index of the visual token that contains the point (x, y).
88
+ Args:
89
+ image_processor: the image processor
90
+ image: the image in PIL format
91
+ point_x: the x coordinate of the point, in [0, 1].
92
+ point_y: the y coordinate of the point, in [0, 1].
93
+ """
94
+ if len(image) != 1:
95
+ raise ValueError(f"Expected 1 image, got {len(image)}")
96
+
97
+ # get the original image size and the resized image size
98
+ image = image[0]
99
+ w, h = image.size
100
+ px, py = w * point_x, h * point_y
101
+ # rank0_print(f"px: {px}, py: {py}")
102
+ # get the token index
103
+ merge_patch_size = image_processor.patch_size * image_processor.merge_size
104
+ x_index = math.floor(px / merge_patch_size)
105
+ y_index = math.floor(py / merge_patch_size)
106
+
107
+ visual_token_index = y_index * (w // merge_patch_size) + x_index
108
+
109
+ # merge all above print into one line
110
+ return visual_token_index
111
+
112
+ def get_multi_patch_labels(image_processor, image, bbox_gt):
113
+ """
114
+ Get the multi-patch labels for the bounding box.
115
+ Args:
116
+ image_processor: the image processor
117
+ image: the image in PIL format
118
+ bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1]
119
+ """
120
+ if len(image) != 1:
121
+ raise ValueError(f"Expected 1 image, got {len(image)}")
122
+
123
+ # Get the original image size and the resized image size
124
+ image = image[0]
125
+ w, h = image.size
126
+
127
+ bbox_gt = [bbox_gt[0]*w, bbox_gt[1]*h, bbox_gt[2]*w, bbox_gt[3]*h]
128
+ # Extract bounding box coordinates
129
+ x_min, y_min, x_max, y_max = bbox_gt
130
+ x_min = max(0, x_min)
131
+ y_min = max(0, y_min)
132
+ x_max = min(w, x_max)
133
+ y_max = min(h, y_max)
134
+
135
+ merge_patch_size = image_processor.patch_size * image_processor.merge_size
136
+ assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}"
137
+ grid_h, grid_w = h // merge_patch_size, w // merge_patch_size
138
+
139
+ binary_mask = torch.zeros(grid_h * grid_w)
140
+ # Iterate through all patches, check if they overlap with the bounding box
141
+ for y_idx in range(grid_h):
142
+ for x_idx in range(grid_w):
143
+ # Calculate patch boundaries
144
+ patch_x_min = x_idx * merge_patch_size
145
+ patch_y_min = y_idx * merge_patch_size
146
+ patch_x_max = patch_x_min + merge_patch_size
147
+ patch_y_max = patch_y_min + merge_patch_size
148
+
149
+ # Check if patch overlaps with the bounding box
150
+ if not (patch_x_max <= x_min or patch_x_min >= x_max or
151
+ patch_y_max <= y_min or patch_y_min >= y_max):
152
+ # Calculate patch index in the flattened grid
153
+ patch_idx = y_idx * grid_w + x_idx
154
+ binary_mask[patch_idx] = 1
155
+
156
+ return binary_mask
157
+
158
+ def token_index_to_coordinates(image_processor, visual_token_index, image_width, image_height):
159
+ merge_patch_size = image_processor.patch_size * image_processor.merge_size
160
+ x_index = visual_token_index % (image_width // merge_patch_size)
161
+ y_index = visual_token_index // (image_width // merge_patch_size)
162
+ px = x_index * merge_patch_size + merge_patch_size / 2
163
+ py = y_index * merge_patch_size + merge_patch_size / 2
164
+ return px, py
165
+
166
+ class LazySupervisedDataset(Dataset):
167
+ def __init__(
168
+ self,
169
+ tokenizer: transformers.PreTrainedTokenizer,
170
+ processor: transformers.ProcessorMixin,
171
+ data_path: str,
172
+ data_args,
173
+ ):
174
+ super().__init__()
175
+ self.tokenizer = tokenizer
176
+ self.processor = processor
177
+ self.list_data_dict = []
178
+ self.list_image_path = []
179
+ self.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0]
180
+ self.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0]
181
+ self.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
182
+
183
+ # Handle multiple JSON files specified in the data_path
184
+ if "{" in data_path and "}" in data_path:
185
+ base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups()
186
+ file_names = file_pattern.split(",")
187
+ rank0_print(f"Loading {file_names} from {base_path}")
188
+ data_args.dataset_paths = []
189
+ for file_name in file_names:
190
+ data_args.dataset_paths.append(f"{base_path}{file_name}.json")
191
+ full_path = f"{base_path}{file_name}.json"
192
+ rank0_print(f"Loading {full_path}")
193
+ with open(full_path) as file:
194
+ cur_data_dict = json.load(file)
195
+ rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}")
196
+ self.list_data_dict.extend(cur_data_dict)
197
+ elif data_path.endswith(".yaml"):
198
+ with open(data_path) as file:
199
+ yaml_data = yaml.safe_load(file)
200
+ datasets = yaml_data.get("datasets")
201
+ # file should be in the format of:
202
+ # datasets:
203
+ # - json_path: xxxx1.json
204
+ # sampling_strategy: first:1000
205
+ # - json_path: xxxx2.json
206
+ # sampling_strategy: end:3000
207
+ # - json_path: xxxx3.json
208
+ # sampling_strategy: random:999
209
+ data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
210
+ for dataset in datasets:
211
+ json_path = dataset.get("json_path")
212
+ sampling_strategy = dataset.get("sampling_strategy", "all")
213
+ images_folder = dataset.get("images_folder")
214
+ sampling_number = None
215
+
216
+ rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy")
217
+
218
+ if json_path.endswith(".jsonl"):
219
+ cur_data_dict = []
220
+ with open(json_path) as json_file:
221
+ for line in json_file:
222
+ cur_data_dict.append(json.loads(line.strip()))
223
+ elif json_path.endswith(".json"):
224
+ # NOTE: we only use json_path with .json now
225
+ # Handle the images_folder in yaml
226
+ with open(json_path) as json_file:
227
+ cur_data_dict = json.load(json_file)
228
+ else:
229
+ raise ValueError(f"Unsupported file type: {json_path}")
230
+
231
+ if ":" in sampling_strategy:
232
+ sampling_strategy, sampling_number = sampling_strategy.split(":")
233
+ if "%" in sampling_number:
234
+ sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
235
+ else:
236
+ sampling_number = int(sampling_number)
237
+
238
+ # Apply the sampling strategy
239
+ if sampling_strategy == "first" and sampling_number is not None:
240
+ cur_data_dict = cur_data_dict[:sampling_number]
241
+ elif sampling_strategy == "end" and sampling_number is not None:
242
+ cur_data_dict = cur_data_dict[-sampling_number:]
243
+ elif sampling_strategy == "random" and sampling_number is not None:
244
+ random.shuffle(cur_data_dict)
245
+ cur_data_dict = cur_data_dict[:sampling_number]
246
+
247
+ rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
248
+ self.list_data_dict.extend(cur_data_dict)
249
+ self.list_image_path.extend([images_folder] * len(cur_data_dict))
250
+ else:
251
+ data_args.dataset_paths = [data_path]
252
+ rank0_print(f"Loading {data_path}")
253
+ with open(data_path) as file:
254
+ cur_data_dict = json.load(file)
255
+ rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}")
256
+ self.list_data_dict.extend(cur_data_dict)
257
+ self.list_image_path.extend([""] * len(cur_data_dict)) # NOTE: the image subfolder is empty...
258
+
259
+ rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
260
+ rank0_print("Formatting inputs...Skip in lazy mode")
261
+ self.tokenizer = tokenizer
262
+ self.data_args = data_args
263
+
264
+ def __len__(self):
265
+ return len(self.list_data_dict)
266
+
267
+ @property
268
+ def lengths(self):
269
+ length_list = []
270
+ for sample in self.list_data_dict:
271
+ img_tokens = (
272
+ 1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
273
+ )
274
+ length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
275
+ return length_list
276
+
277
+ @property
278
+ def modality_lengths(self):
279
+ length_list = []
280
+ for sample in self.list_data_dict:
281
+ cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
282
+ assert cur_len > 0, f"Conversation length is 0 for {sample}"
283
+
284
+ img_tokens = (
285
+ 1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
286
+ )
287
+
288
+ if "image" in sample or "video" in sample or self.data_args.early_mix_text:
289
+ length_list.append(cur_len + img_tokens)
290
+ else:
291
+ length_list.append(-cur_len)
292
+ return length_list
293
+
294
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
295
+ sample = self._get_item(i)
296
+ if sample is None:
297
+ new_index = random.randint(0, len(self.list_data_dict) - 1)
298
+ return self.__getitem__(new_index)
299
+ else:
300
+ return sample
301
+ try:
302
+ sample = self._get_item(i)
303
+ if sample is None:
304
+ new_index = random.randint(0, len(self.list_data_dict) - 1)
305
+ return self.__getitem__(new_index)
306
+ except Exception as e:
307
+ print(f"Failed to fetch sample {i}. Exception:", e)
308
+ new_index = random.randint(0, len(self.list_data_dict) - 1)
309
+ return self.__getitem__(new_index)
310
+ return sample
311
+
312
+ def _get_item(self, i) -> Dict[str, torch.Tensor]:
313
+ sources = self.list_data_dict[i]
314
+ image_path = os.path.join(self.data_args.image_folder, self.list_image_path[i])
315
+
316
+ if "image" in sources:
317
+ image_file = self.list_data_dict[i]["image"]
318
+ if type(image_file) is list:
319
+ image_list = [os.path.join(image_path, image_file) for image_file in image_file]
320
+ else:
321
+ image_list = [os.path.join(image_path, image_file)]
322
+
323
+ sources = copy.deepcopy(sources["conversations"])
324
+ elif "video" in sources:
325
+ raise NotImplementedError("Video is not supported for Qwen2VL")
326
+ else:
327
+ sources = copy.deepcopy(sources["conversations"])
328
+
329
+ item_id = self.list_data_dict[i].get("id", i)
330
+
331
+ data_dict = self.preprocess_qwen2vl(sources, self.tokenizer, self.processor, image_list, id=item_id)
332
+ if isinstance(i, int):
333
+ data_dict = {
334
+ "input_ids": data_dict["input_ids"][0],
335
+ "labels": data_dict["labels"][0],
336
+ "coordinates": data_dict["coordinates"][0],
337
+ "visual_token_indices_of_coordinates": data_dict["visual_token_indices_of_coordinates"][0],
338
+ "pixel_values": data_dict["pixel_values"],
339
+ "image_grid_thw": data_dict["image_grid_thw"],
340
+ "multi_patch_labels": data_dict["multi_patch_labels"][0], # add multi_patch_labels
341
+ }
342
+
343
+ data_dict["id"] = item_id
344
+
345
+ # return None if the input_ids is longer than the model_max_length
346
+ n_image_tokens = (
347
+ data_dict["image_grid_thw"][0][0] *
348
+ data_dict["image_grid_thw"][0][1] *
349
+ data_dict["image_grid_thw"][0][2] /
350
+ self.processor.image_processor.merge_size /
351
+ self.processor.image_processor.merge_size
352
+ )
353
+ if (len(data_dict["input_ids"]) + n_image_tokens) > self.tokenizer.model_max_length:
354
+ rank0_print(f"=== Removed data_dict {i} because it is longer than the model_max_length: {len(data_dict['input_ids'])} + {n_image_tokens} > {self.tokenizer.model_max_length}")
355
+ return None
356
+
357
+ return data_dict
358
+
359
+ def preprocess_qwen2vl(
360
+ self,
361
+ source, # conversations
362
+ tokenizer: transformers.PreTrainedTokenizer,
363
+ processor: transformers.ProcessorMixin,
364
+ image: list,
365
+ system_message: str = grounding_system_message,
366
+ agent_mode: bool = True,
367
+ chat_template: str = chat_template,
368
+ assistant_template: str = assistant_template,
369
+ id: int = None,
370
+ ) -> Dict:
371
+ roles = {"human": "user", "gpt": "assistant", "system": "system"}
372
+ assistant_template = assistant_template if agent_mode else chat_template
373
+ processor.tokenizer = tokenizer
374
+ assert tokenizer.additional_special_tokens == ADDITIONAL_SPECIAL_TOKENS
375
+
376
+ # Apply prompt templates
377
+ pixel_values, image_grid_thw = None, None
378
+
379
+ input_id, target = [], []
380
+ coordinates = []
381
+ visual_token_indices_of_coordinates = []
382
+ multi_patch_labels = []
383
+
384
+ image_list = []
385
+ image_index = 0
386
+
387
+ ## prepare the system message
388
+ if roles[source[0]["from"]] == "system":
389
+ system_message = source[0]["value"]
390
+ source = source[1:self.data_args.max_conv_turns]
391
+ # else: use the constant system message
392
+ system_input_id = tokenizer.apply_chat_template(
393
+ conversation=[{"role": "system", "content": [{"type": "text", "text": system_message}]}],
394
+ chat_template=chat_template,
395
+ )
396
+ input_id += system_input_id
397
+ target += [IGNORE_INDEX] * len(system_input_id)
398
+
399
+ ## prepare user-assistant conversation
400
+ for conv in source:
401
+ # regularize the conversation format
402
+ try:
403
+ role = conv["role"]
404
+ content = conv["content"]
405
+ except Exception:
406
+ role = conv["from"]
407
+ content = conv["value"]
408
+ role = roles.get(role, role)
409
+
410
+ # Count the number of <image> tokens in the content
411
+ image_count = content.count(DEFAULT_IMAGE_TOKEN)
412
+ if image_count > 0:
413
+ assert role == "user", "Images are only supported for user messages"
414
+ # include image information regarding to current conversation turn
415
+ image_placeholders = []
416
+ for _ in range(image_count):
417
+ image_placeholders.append({
418
+ "type": "image",
419
+ "image": image[image_index],
420
+ "min_pixels": self.processor.image_processor.min_pixels,
421
+ "max_pixels": self.processor.image_processor.max_pixels,
422
+ })
423
+ image_index += 1
424
+
425
+ content = content.replace(DEFAULT_IMAGE_TOKEN, "")
426
+ conv = {"role": role, "content": image_placeholders + [{"type": "text", "text": content}]}
427
+
428
+ image_inputs, _ = process_vision_info([conv]) # list of PIL.Image.Image
429
+ image_list.extend(image_inputs)
430
+
431
+ templated_conv = tokenizer.apply_chat_template(
432
+ conversation=[conv], chat_template=chat_template, tokenize=False
433
+ )
434
+ inputs = processor(text=[templated_conv], images=image_inputs, return_tensors="pt")
435
+
436
+ if pixel_values is None and image_grid_thw is None:
437
+ pixel_values = inputs["pixel_values"]
438
+ image_grid_thw = inputs["image_grid_thw"]
439
+ else:
440
+ pixel_values = torch.concat([pixel_values, inputs["pixel_values"]], dim=0)
441
+ image_grid_thw = torch.concat([image_grid_thw, inputs["image_grid_thw"]], dim=0)
442
+ else:
443
+ if role in ["user", "system"]:
444
+ conv = {"role": role, "content": [{"type": "text", "text": content}]}
445
+ else: # assistant
446
+ conv = {
447
+ "role": role,
448
+ "content": [{"type": "text", "text": content}],
449
+ "recipient": conv.get("recipient", "os"),
450
+ "end_turn": conv.get("end_turn", True),
451
+ "bbox_gt": conv.get("bbox_gt", None),
452
+ }
453
+ if conv["recipient"] == "os":
454
+ if len(image_inputs) == 0:
455
+ raise ValueError("No image found for visual grounding")
456
+ # replace the coordinates with the special tokens
457
+ text, coord = reformat_coordinates(conv["content"][0]["text"])
458
+ conv["content"][0]["text"] = text
459
+ # rank0_print(f"coord: {coord}")
460
+
461
+ # get the visual token indices of the coordinates
462
+ coordinates.extend(coord)
463
+ for (point_x, point_y) in coord:
464
+ visual_token_index = get_token_index(
465
+ processor.image_processor,
466
+ image_list,
467
+ point_x,
468
+ point_y
469
+ )
470
+ # px, py = token_index_to_coordinates(
471
+ # processor.image_processor,
472
+ # visual_token_index,
473
+ # image_list[0].size[0], # make sure the size here is after qwen2vl processing
474
+ # image_list[0].size[1]
475
+ # )
476
+ # rank0_print(f"estimated px: {px}, py: {py}")
477
+ visual_token_indices_of_coordinates.append(visual_token_index)
478
+
479
+ if conv["bbox_gt"] is not None:
480
+ patch_mask = get_multi_patch_labels(
481
+ processor.image_processor,
482
+ image_list,
483
+ conv["bbox_gt"]
484
+ )
485
+ multi_patch_labels.append(patch_mask)
486
+
487
+ templated_conv = tokenizer.apply_chat_template(
488
+ conversation=[conv],
489
+ chat_template=assistant_template,
490
+ tokenize=False,
491
+ )
492
+ inputs = processor(text=[templated_conv], return_tensors="pt")
493
+
494
+ encode_id = inputs.input_ids[0].tolist()
495
+
496
+ input_id += encode_id
497
+ if role in ["user", "system"]:
498
+ target += [IGNORE_INDEX] * len(encode_id)
499
+ else:
500
+ target += encode_id
501
+
502
+ assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
503
+
504
+ # make the labels of all pointer_end_token_id to be IGNORE_INDEX
505
+ target = [IGNORE_INDEX if token == self.pointer_end_token_id else token for token in target]
506
+
507
+ input_ids = torch.tensor([input_id], dtype=torch.long)
508
+ targets = torch.tensor([target], dtype=torch.long)
509
+ visual_token_indices_of_coordinates = torch.tensor([visual_token_indices_of_coordinates], dtype=torch.long) if len(visual_token_indices_of_coordinates) > 0 else [None]
510
+ coordinates = [coordinates] if len(coordinates) > 0 else [None]
511
+
512
+ # process multi_patch_labels
513
+ if len(multi_patch_labels) > 0:
514
+ multi_patch_labels = [torch.stack(multi_patch_labels)]
515
+ else:
516
+ multi_patch_labels = [None]
517
+
518
+ data_dict = {
519
+ "input_ids": input_ids, # tensor(bs x seq_len)
520
+ "labels": targets, # tensor(bs x seq_len)
521
+ }
522
+
523
+ if pixel_values is not None:
524
+ data_dict["pixel_values"] = pixel_values
525
+ data_dict["image_grid_thw"] = image_grid_thw
526
+
527
+ # if len(coordinates[0]) != len(visual_token_indices_of_coordinates[0]):
528
+ # raise ValueError(f"The number of coordinates ({len(coordinates[0])}) does not match the number of image token indices ({len(visual_token_indices_of_coordinates[0])})")
529
+ data_dict["coordinates"] = coordinates
530
+ data_dict["visual_token_indices_of_coordinates"] = visual_token_indices_of_coordinates
531
+ data_dict["multi_patch_labels"] = multi_patch_labels
532
+
533
+ return data_dict
gui_actor/inference.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import re
4
+ import os
5
+ from qwen_vl_utils import process_vision_info
6
+ from transformers import (
7
+ Qwen2VLForConditionalGeneration,
8
+ LogitsProcessor,
9
+ LogitsProcessorList,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer
12
+ )
13
+ from gui_actor.constants import (
14
+ DEFAULT_POINTER_END_TOKEN,
15
+ DEFAULT_POINTER_PAD_TOKEN,
16
+ chat_template
17
+ )
18
+
19
+ class ForceFollowTokensLogitsProcessor(LogitsProcessor):
20
+ """
21
+ Forces tokens B (pointer_pad_token) and C (pointer_end_token) to follow token A (pointer_start_token).
22
+ Whenever token_a_id is generated, enqueue the forced_sequence (e.g. [B, C]).
23
+ As long as forced tokens remain in the queue, force them in the output.
24
+ """
25
+ def __init__(self, token_a_id, forced_sequence=[DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN]):
26
+ super().__init__()
27
+ self.token_a_id = token_a_id
28
+ self.forced_sequence = forced_sequence # list of token IDs, e.g. [B_id, C_id]
29
+ self.force_queue = [] # holds the tokens we still need to force
30
+
31
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
32
+ """
33
+ Called at each decoding step to modify `scores`.
34
+
35
+ Args:
36
+ input_ids: shape (batch_size, seq_len). The already-decoded tokens.
37
+ scores: shape (batch_size, vocab_size). Model logits for the next token.
38
+ """
39
+ batch_size = input_ids.shape[0]
40
+ if batch_size > 1:
41
+ raise NotImplementedError("Batch size must be 1 for this logits processor.")
42
+
43
+ # We assume batch_size=1 for simplicity; if you have multiple sequences,
44
+ # you'll need to adapt the logic to handle each item in the batch.
45
+ last_token_id = input_ids[0, -1].item()
46
+
47
+ # If the last token was A, enqueue B and C
48
+ if last_token_id == self.token_a_id:
49
+ self.force_queue.extend(self.forced_sequence)
50
+
51
+ # If we have forced tokens waiting in the queue, override the distribution
52
+ if len(self.force_queue) > 0:
53
+ forced_token = self.force_queue.pop(0) # next token to force
54
+ # Create a mask of -inf for all tokens except the forced one
55
+ new_scores = torch.full_like(scores, float('-inf'))
56
+ new_scores[0, forced_token] = 0.0 # log prob = 0 => prob = 1
57
+ return new_scores
58
+
59
+ # Otherwise, return scores unmodified
60
+ return scores
61
+
62
+
63
+ def get_prediction_region_point(attn_scores, n_width, n_height, top_n=30, activation_threshold=0.3, return_all_regions=True, rect_center=False):
64
+ """
65
+ 1. Select activated patches
66
+ 2. Divide connected patches into different regions
67
+ 3. Calculate the average activation value for each region
68
+ 4. Select the region with the highest average activation value
69
+ 5. Return the center point of that region as the final prediction point
70
+ """
71
+
72
+ # Get patches with activation values greater than a certain proportion of the maximum activation value as activated patches
73
+ # Get the highest activation value and threshold
74
+ max_score = attn_scores[0].max().item()
75
+ threshold = max_score * activation_threshold
76
+ # Select all patches above the threshold
77
+ mask = attn_scores[0] > threshold
78
+ valid_indices = torch.nonzero(mask).squeeze(-1)
79
+ topk_values = attn_scores[0][valid_indices]
80
+ topk_indices = valid_indices
81
+
82
+ # Convert indices to 2D coordinates
83
+ topk_coords = []
84
+ for idx in topk_indices.tolist():
85
+ y = idx // n_width
86
+ x = idx % n_width
87
+ topk_coords.append((y, x, idx))
88
+
89
+ # Divide into connected regions
90
+ regions = []
91
+ visited = set()
92
+ for i, (y, x, idx) in enumerate(topk_coords):
93
+ if idx in visited:
94
+ continue
95
+
96
+ # Start a new region
97
+ region = [(y, x, idx, topk_values[i].item())]
98
+ visited.add(idx)
99
+ queue = [(y, x, idx, topk_values[i].item())]
100
+
101
+ # BFS to find connected points
102
+ while queue:
103
+ cy, cx, c_idx, c_val = queue.pop(0)
104
+
105
+ # Check 4 adjacent directions
106
+ for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
107
+ ny, nx = cy + dy, cx + dx
108
+ n_idx = ny * n_width + nx
109
+
110
+ # Check if this adjacent point is in the topk list
111
+ for j, (ty, tx, t_idx) in enumerate(topk_coords):
112
+ if ty == ny and tx == nx and t_idx not in visited:
113
+ visited.add(t_idx)
114
+ region.append((ny, nx, t_idx, topk_values[j].item()))
115
+ queue.append((ny, nx, t_idx, topk_values[j].item()))
116
+
117
+ regions.append(region)
118
+
119
+ # Calculate the average activation value for each region
120
+ region_scores = []
121
+ region_centers = []
122
+ region_points = []
123
+
124
+ for region in regions:
125
+ # Calculate average score for the region
126
+ avg_score = sum(item[3] for item in region) / len(region)
127
+ region_scores.append(avg_score)
128
+
129
+ # Calculate normalized center coordinates for each patch, then take the average
130
+ normalized_centers = []
131
+ weights = []
132
+ y_coords = set()
133
+ x_coords = set()
134
+
135
+ for y, x, _, score in region:
136
+ # Normalized coordinates of the center point for each patch
137
+ center_y = (y + 0.5) / n_height
138
+ center_x = (x + 0.5) / n_width
139
+ normalized_centers.append((center_x, center_y))
140
+ weights.append(score)
141
+
142
+ y_coords.add(center_y)
143
+ x_coords.add(center_x)
144
+
145
+ region_points.append(normalized_centers)
146
+
147
+ # Calculate the average of normalized coordinates as the region center
148
+ if not rect_center:
149
+ # Weighted average
150
+ total_weight = sum(weights)
151
+ weighted_x = sum(nc[0] * w for nc, w in zip(normalized_centers, weights)) / total_weight
152
+ weighted_y = sum(nc[1] * w for nc, w in zip(normalized_centers, weights)) / total_weight
153
+ avg_center_x, avg_center_y = weighted_x, weighted_y
154
+ # # Simple average
155
+ # avg_center_x = sum(nc[0] for nc in normalized_centers) / len(normalized_centers)
156
+ # avg_center_y = sum(nc[1] for nc in normalized_centers) / len(normalized_centers)
157
+ else:
158
+ avg_center_x = sum(x_coords) / len(x_coords)
159
+ avg_center_y = sum(y_coords) / len(y_coords)
160
+ region_centers.append((avg_center_x, avg_center_y))
161
+
162
+ # Select the region with the highest average activation value
163
+ sorted_indices = sorted(range(len(region_scores)), key=lambda i: region_scores[i], reverse=True)
164
+ sorted_scores = [region_scores[i] for i in sorted_indices]
165
+ sorted_centers = [region_centers[i] for i in sorted_indices]
166
+ sorted_points = [region_points[i] for i in sorted_indices]
167
+ best_point = sorted_centers[0]
168
+
169
+ if return_all_regions:
170
+ # Outputs:
171
+ # 1. best_point: the center point of the region with the highest average activation value
172
+ # 2. sorted_centers: the center points of all regions, sorted by the average activation value in descending order
173
+ # 3. sorted_scores: the average activation values of all regions, sorted in descending order
174
+ # 4. sorted_points: the normalized center coordinates of all patches, sorted by the average activation value in descending order
175
+ return best_point, sorted_centers, sorted_scores, sorted_points
176
+ else:
177
+ return best_point
178
+
179
+
180
+ def inference(conversation, model, tokenizer, data_processor, logits_processor=None, use_placeholder=False, topk=5):
181
+ """
182
+ conversation = [
183
+ {
184
+ "role": "system",
185
+ "content": [
186
+ {
187
+ "type": "text",
188
+ "text": grounding_system_message,
189
+ }
190
+ ]
191
+ },
192
+ {
193
+ "role": "user",
194
+ "content": [
195
+ {
196
+ "type": "image",
197
+ "image": example["image"], # PIL.Image.Image or str to path
198
+ # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "", will be split by "base64,"
199
+ },
200
+ {
201
+ "type": "text",
202
+ "text": example["instruction"]
203
+ },
204
+ ],
205
+ },
206
+ ]
207
+ """
208
+ if logits_processor is None:
209
+ logits_processor = ForceFollowTokensLogitsProcessor(
210
+ token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0],
211
+ forced_sequence=[
212
+ tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
213
+ ]
214
+ )
215
+
216
+ assiatant_starter = "" if not use_placeholder else "<|im_start|>assistant<|recipient|>os\npyautogui.click(<|pointer_start|><|pointer_pad|><|pointer_end|>)"
217
+
218
+ pred = {
219
+ "output_text": None, # generated text
220
+ "n_width": None, # number of patch_tokens in width dimension
221
+ "n_height": None, # number of patch_tokens in height dimension
222
+ "attn_scores": None, # attention scores over the image patches
223
+ "topk_points": None, # topk points
224
+ "topk_values": None, # topk values
225
+ "topk_points_all": None, # all points
226
+ }
227
+
228
+ # prepare text
229
+ text = data_processor.apply_chat_template(conversation,
230
+ tokenize=False,
231
+ add_generation_prompt=False,
232
+ chat_template=chat_template
233
+ )
234
+ text += assiatant_starter
235
+
236
+ # prepare inputs
237
+ image_inputs, video_inputs = process_vision_info(conversation)
238
+ inputs = data_processor(text=[text],
239
+ images=image_inputs,
240
+ videos=video_inputs,
241
+ padding=True,
242
+ return_tensors="pt"
243
+ )
244
+ inputs = inputs.to(model.device)
245
+
246
+ # generate
247
+ results = model.generate(**inputs,
248
+ max_new_tokens=2048 if not use_placeholder else 1,
249
+ logits_processor=LogitsProcessorList([logits_processor]),
250
+ return_dict_in_generate=True,
251
+ output_hidden_states=True
252
+ )
253
+
254
+
255
+ # decode the generated ids
256
+ input_ids = inputs["input_ids"][0]
257
+ generated_ids = results.sequences[0][len(input_ids):]
258
+ output_text = tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
259
+ pred["output_text"] = output_text
260
+
261
+ # check if there are <POINTER_TOKEN> is inside the input_ids or generated_ids
262
+ if use_placeholder:
263
+ pointer_pad_mask = (inputs["input_ids"][0] == model.config.pointer_pad_token_id) # n_all_input_tokens
264
+ else:
265
+ pointer_pad_mask = (generated_ids[:-1] == model.config.pointer_pad_token_id) # seq_len_generated_ids-1
266
+
267
+ # if there are no <POINTER_TOKEN> in the input_ids or generated_ids, return the pred
268
+ if len(pointer_pad_mask) == 0:
269
+ return pred
270
+
271
+ # otherwise, get the coordinate from the action head
272
+ if use_placeholder:
273
+ decoder_hidden_states = results.hidden_states[0][-1][0] # n_all_input_tokens, hidden_size
274
+ else:
275
+ decoder_hidden_states = [step_hidden_states[-1][0] for step_hidden_states in results.hidden_states[1:]]
276
+ decoder_hidden_states = torch.cat(decoder_hidden_states, dim=0) # seq_len_generated_ids-1, hidden_size
277
+ decoder_hidden_states = decoder_hidden_states[pointer_pad_mask] # n_pointer_pad_tokens, hidden_size
278
+
279
+ # get the image embeddings as encoder vectors
280
+ # image_embeds = model.visual(inputs["pixel_values"], grid_thw=inputs["image_grid_thw"]) # n_image_tokens, hidden_size
281
+ image_mask = (inputs["input_ids"][0] == tokenizer.encode("<|image_pad|>")[0])
282
+ image_embeds = results.hidden_states[0][0][0][image_mask] # n_image_tokens, hidden_size
283
+
284
+ attn_scores, _ = model.multi_patch_pointer_head(image_embeds, decoder_hidden_states)
285
+ pred["attn_scores"] = attn_scores.tolist()
286
+
287
+ _, n_height, n_width = (inputs["image_grid_thw"][0] // model.visual.spatial_merge_size).tolist()
288
+ pred["n_width"] = n_width
289
+ pred["n_height"] = n_height
290
+
291
+ # get the topk points according to the attention scores
292
+ best_point, region_points, region_scores, region_points_all = get_prediction_region_point(attn_scores, n_width, n_height, return_all_regions=True, rect_center=False)
293
+ topk_points = region_points[:topk] if len(region_points) > topk else region_points
294
+ topk_values = region_scores[:topk] if len(region_scores) > topk else region_scores
295
+ topk_points_all = region_points_all[:topk] if len(region_points_all) > topk else region_points_all
296
+ pred["topk_points"] = topk_points
297
+ pred["topk_values"] = topk_values
298
+ pred["topk_points_all"] = topk_points_all
299
+
300
+ return pred
gui_actor/modeling.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration
6
+ from gui_actor.constants import IGNORE_INDEX
7
+ from typing import List, Tuple, Union, Optional
8
+ from gui_actor.trainer import rank0_print
9
+
10
+ class QwenVLwithVisionHeadOutputWithPast(Qwen2VLCausalLMOutputWithPast):
11
+ """
12
+ Output class for Qwen2VL with pointer head, extending the base output class.
13
+
14
+ Args:
15
+ lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
16
+ Language modeling loss.
17
+ pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
18
+ Vision pointer network loss.
19
+ pointer_scores (`List[torch.FloatTensor]`, *optional*):
20
+ Attention scores from the pointer network, one tensor per batch item.
21
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
22
+ Combined loss (weighted sum of lm_loss and pointer_loss).
23
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
24
+ Prediction scores from the language modeling head.
25
+ past_key_values, hidden_states, attentions, rope_deltas:
26
+ Same as parent class.
27
+ """
28
+ def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs):
29
+ super().__init__(*args, **kwargs)
30
+ self.lm_loss = lm_loss
31
+ self.pointer_loss = pointer_loss
32
+ self.pointer_scores = pointer_scores
33
+
34
+
35
+ class VisionHead_MultiPatch(nn.Module):
36
+ def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1):
37
+ super().__init__()
38
+ self.d_model = d_model
39
+
40
+ # Note: We omit additional normalization here because Qwen2VL
41
+ # already normalizes hidden states using RMSNorm.
42
+ self.projection_enc = nn.Sequential(
43
+ nn.Linear(d_model, projection_dim),
44
+ nn.GELU(),
45
+ nn.Linear(projection_dim, d_model)
46
+ )
47
+ self.projection_dec = nn.Sequential(
48
+ nn.Linear(d_model, projection_dim),
49
+ nn.GELU(),
50
+ nn.Linear(projection_dim, d_model)
51
+ )
52
+
53
+ # Add self-attention layer for visual features
54
+ self.self_attention = nn.MultiheadAttention(
55
+ embed_dim=d_model,
56
+ num_heads=num_attention_heads,
57
+ dropout=dropout_rate,
58
+ batch_first=True
59
+ )
60
+
61
+ # Layer normalization and residual connection
62
+ self.layer_norm = nn.LayerNorm(d_model)
63
+ self.dropout = nn.Dropout(dropout_rate)
64
+
65
+ def forward(self,
66
+ hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size
67
+ hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample
68
+ labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox
69
+ do_single_patch: bool = False,
70
+ ):
71
+
72
+ enc_input = hidden_state_enc.unsqueeze(0)
73
+ attn_output, _ = self.self_attention(
74
+ query=enc_input,
75
+ key=enc_input,
76
+ value=enc_input,
77
+ # attn_mask=attention_mask,
78
+ need_weights=False
79
+ )
80
+ # Residual connection and layer normalization
81
+ hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output))
82
+ # Remove batch dimension
83
+ hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model]
84
+
85
+ # Apply the projection networks.
86
+ proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model]
87
+ proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model]
88
+
89
+ # Compute scaled dot-product attention scores.
90
+ # Scaling by sqrt(d_model) is critical regardless of variable n_enc.
91
+ scaling = self.d_model ** 0.5
92
+ patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc]
93
+
94
+ # Softmax normalization is applied along the encoder dimension.
95
+ attn_weights = F.softmax(patch_logits, dim=-1)
96
+
97
+ loss = None
98
+ if (labels is not None) and (not do_single_patch):
99
+ epsilon = 1e-8
100
+ labels_float = labels.float()
101
+ # Normalize each row to get target probability distribution
102
+ target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon)
103
+
104
+ # Apply log_softmax to logits
105
+ pred_log_probs = F.log_softmax(patch_logits, dim=-1)
106
+ # Use KL divergence as loss
107
+ loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean')
108
+
109
+ if do_single_patch and (labels is not None):
110
+ loss = F.cross_entropy(attn_scores, labels)
111
+
112
+ return attn_weights, loss
113
+
114
+
115
+ class Qwen2VLForConditionalGenerationWithPointer(Qwen2VLForConditionalGeneration):
116
+ def __init__(self, *args, **kwargs):
117
+ super().__init__(*args, **kwargs)
118
+ self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size)
119
+ self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0)
120
+ self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0)
121
+ self.post_init()
122
+
123
+ def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight):
124
+ self.pointer_loss_weight = pointer_loss_weight
125
+ self.lm_loss_weight = lm_loss_weight
126
+
127
+ def forward(self,
128
+ input_ids: torch.LongTensor = None, # (batch_size, seq_len)
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ position_ids: Optional[torch.LongTensor] = None,
131
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
132
+ inputs_embeds: Optional[torch.FloatTensor] = None,
133
+ labels: Optional[torch.LongTensor] = None,
134
+ use_cache: Optional[bool] = None,
135
+ output_attentions: Optional[bool] = None,
136
+ output_hidden_states: Optional[bool] = None,
137
+ return_dict: Optional[bool] = None,
138
+ pixel_values: Optional[torch.Tensor] = None,
139
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
140
+ image_grid_thw: Optional[torch.LongTensor] = None,
141
+ video_grid_thw: Optional[torch.LongTensor] = None,
142
+ rope_deltas: Optional[torch.LongTensor] = None,
143
+ cache_position: Optional[torch.LongTensor] = None,
144
+ # Grounding
145
+ visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token
146
+ multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox
147
+ if_multi_patch: bool = True,
148
+ coordinates: Optional[List[Tuple[float, float]]] = None,
149
+ verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]:
150
+
151
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
152
+ output_hidden_states = (
153
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
154
+ )
155
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
156
+
157
+ if verbose:
158
+ rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...")
159
+ rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...")
160
+ rank0_print(f"pixel_values: {pixel_values.shape}")
161
+ rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}")
162
+ rank0_print(f"coordinates: {coordinates}")
163
+ rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}")
164
+ rank0_print(f"return_dict: {return_dict}")
165
+
166
+ if inputs_embeds is None:
167
+ inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model)
168
+ if pixel_values is not None:
169
+ pixel_values = pixel_values.type(self.visual.dtype)
170
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
171
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
172
+ n_image_features = image_embeds.shape[0]
173
+ if n_image_tokens != n_image_features:
174
+ raise ValueError(
175
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
176
+ )
177
+ image_mask = (
178
+ (input_ids == self.config.image_token_id)
179
+ .unsqueeze(-1)
180
+ .expand_as(inputs_embeds)
181
+ .to(inputs_embeds.device)
182
+ )
183
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
184
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
185
+
186
+ if pixel_values_videos is not None:
187
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
188
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
189
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
190
+ n_video_features = video_embeds.shape[0]
191
+ if n_video_tokens != n_video_features:
192
+ raise ValueError(
193
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
194
+ )
195
+ video_mask = (
196
+ (input_ids == self.config.video_token_id)
197
+ .unsqueeze(-1)
198
+ .expand_as(inputs_embeds)
199
+ .to(inputs_embeds.device)
200
+ )
201
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
202
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
203
+
204
+ if attention_mask is not None:
205
+ attention_mask = attention_mask.to(inputs_embeds.device)
206
+
207
+ # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
208
+ if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
209
+ # calculate RoPE index once per generation in the pre-fill stage only
210
+ if (
211
+ (cache_position is not None and cache_position[0] == 0)
212
+ or self.rope_deltas is None
213
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
214
+ ):
215
+ position_ids, rope_deltas = self.get_rope_index(
216
+ input_ids, image_grid_thw, video_grid_thw, attention_mask
217
+ )
218
+ self.rope_deltas = rope_deltas
219
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
220
+ else:
221
+ batch_size, seq_length, _ = inputs_embeds.shape
222
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
223
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
224
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
225
+ if cache_position is not None: # otherwise `deltas` is an int `0`
226
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
227
+ delta = delta.to(position_ids.device)
228
+ position_ids = position_ids.add(delta)
229
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
230
+
231
+ outputs = self.model(
232
+ input_ids=None,
233
+ position_ids=position_ids,
234
+ attention_mask=attention_mask,
235
+ past_key_values=past_key_values,
236
+ inputs_embeds=inputs_embeds,
237
+ use_cache=use_cache,
238
+ output_attentions=output_attentions,
239
+ output_hidden_states=output_hidden_states,
240
+ return_dict=return_dict,
241
+ cache_position=cache_position,
242
+ )
243
+
244
+ hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model)
245
+ logits = self.lm_head(hidden_states)
246
+
247
+ lm_loss = None
248
+ if labels is not None and self.lm_loss_weight > 0:
249
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
250
+ logits = logits.float()
251
+ # Shift so that tokens < n predict n
252
+ shift_logits = logits[..., :-1, :].contiguous()
253
+ shift_labels = labels[..., 1:].contiguous()
254
+ # Flatten the tokens
255
+ loss_fct = nn.CrossEntropyLoss()
256
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
257
+ shift_labels = shift_labels.view(-1)
258
+ # Enable model parallelism
259
+ shift_labels = shift_labels.to(shift_logits.device)
260
+ lm_loss = loss_fct(shift_logits, shift_labels)
261
+
262
+
263
+ # If vision supervision is requested, process the action head.
264
+ pointer_loss = None
265
+ pointer_scores = []
266
+ if visual_token_indices_of_coordinates is not None:
267
+ batch_size = input_ids.shape[0]
268
+ pointer_losses = []
269
+
270
+ # Process each sample individually because the number of visual and target tokens may vary.
271
+ for i in range(batch_size):
272
+ dummy_target = False
273
+
274
+ # Get the token ids and corresponding hidden states for sample i.
275
+ token_ids = input_ids[i] # shape: (seq_length,)
276
+ hs = hidden_states[i] # shape: (seq_length, d_model)
277
+
278
+ # Identify visual tokens indices.
279
+ visual_mask = (token_ids == self.config.image_token_id)
280
+ visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,)
281
+
282
+ # Identify target tokens (the ones that should attend to visual features).
283
+ target_mask = (token_ids == self.config.pointer_pad_token_id)
284
+ target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
285
+
286
+ # If either visual or target tokens are missing, skip this sample.
287
+ if visual_indices.numel() == 0:
288
+ raise ValueError(f"No visual or target tokens found for sample {i}.")
289
+ if target_indices.numel() == 0:
290
+ target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token
291
+ gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth
292
+ if if_multi_patch: # task the first 4 visual tokens as the ground truth
293
+ sample_labels = torch.zeros_like(visual_indices).unsqueeze(0)
294
+ sample_labels[0][:4] = 1
295
+ dummy_target = True
296
+ else:
297
+ # For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,)
298
+ # where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token.
299
+ gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,)
300
+ if if_multi_patch:
301
+ sample_labels = multi_patch_labels[i]
302
+
303
+ # Gather the corresponding hidden state representations.
304
+ # visual_hidden = hs[visual_indices] # shape: (n_visual, d_model)
305
+ visual_embeds = inputs_embeds[i][visual_indices]
306
+ target_hidden = hs[target_indices] # shape: (n_target, d_model)
307
+
308
+ # Calculate loss for multi-patch mode
309
+ if if_multi_patch:
310
+ # Ensure the number of targets matches between sample and labels
311
+ if sample_labels.shape[0] != target_indices.shape[0]:
312
+ raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens")
313
+
314
+ # Process using VisionHead_MultiPatch
315
+ attn_scores, loss_v = self.multi_patch_pointer_head(
316
+ visual_embeds,
317
+ target_hidden,
318
+ labels=sample_labels
319
+ )
320
+
321
+ else:
322
+ # Deprecated branch - single patch mode is no longer used
323
+ # Run the action head to compute the attention (from target tokens to visual tokens) and its loss.
324
+ attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt)
325
+
326
+ pointer_scores.append(attn_scores.detach().cpu())
327
+
328
+ pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v)
329
+
330
+ pointer_loss = torch.stack(pointer_losses).mean()
331
+
332
+ # Combine the LM loss and vision loss using the provided loss weights.
333
+
334
+ if lm_loss is None:
335
+ total_loss = pointer_loss
336
+ elif pointer_loss is None:
337
+ total_loss = lm_loss
338
+ else:
339
+ total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss
340
+
341
+ if return_dict:
342
+ return QwenVLwithVisionHeadOutputWithPast(
343
+ lm_loss=lm_loss,
344
+ pointer_loss=pointer_loss,
345
+ pointer_scores=pointer_scores,
346
+ loss=total_loss,
347
+ logits=logits,
348
+ past_key_values=outputs.past_key_values,
349
+ hidden_states=outputs.hidden_states,
350
+ attentions=outputs.attentions,
351
+ rope_deltas=self.rope_deltas,
352
+ )
353
+ else:
354
+ # When labels are provided, parent's forward returns a tuple with loss as the first element.
355
+ if labels is not None:
356
+ # Replace the LM loss with the combined loss.
357
+ output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:]
358
+ print(f"returning: total_loss, logits, pointer_scores, ...")
359
+ return (total_loss,) + output if total_loss is not None else output
360
+ else:
361
+ return outputs
gui_actor/modeling_qwen25vl.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from typing import List, Tuple, Union, Optional
6
+
7
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
8
+ Qwen2_5_VLCausalLMOutputWithPast,
9
+ Qwen2_5_VLForConditionalGeneration,
10
+ )
11
+ from gui_actor.constants import IGNORE_INDEX
12
+ from gui_actor.trainer import rank0_print
13
+
14
+
15
+ def _get_token_embedding_layer(hf_model: nn.Module) -> nn.Module:
16
+ """
17
+ Robustly locate the token embedding layer across HF versions.
18
+ """
19
+ if hasattr(hf_model, "get_input_embeddings") and callable(hf_model.get_input_embeddings):
20
+ return hf_model.get_input_embeddings()
21
+ # Fallbacks (shouldn't be needed on recent transformers, but safe to keep)
22
+ lm = getattr(hf_model, "language_model", None)
23
+ if lm is not None and hasattr(lm, "embed_tokens"):
24
+ return lm.embed_tokens
25
+ raise AttributeError("Could not locate token embedding layer on model (no get_input_embeddings/embed_tokens).")
26
+
27
+
28
+ class QwenVLwithVisionHeadOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
29
+ """
30
+ Output class for Qwen2_5_VL with pointer head, extending the base output class.
31
+
32
+ Args:
33
+ lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
34
+ Language modeling loss.
35
+ pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
36
+ Vision pointer network loss.
37
+ pointer_scores (`List[torch.FloatTensor]`, *optional*):
38
+ Attention scores from the pointer network, one tensor per batch item.
39
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
40
+ Combined loss (weighted sum of lm_loss and pointer_loss).
41
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
42
+ Prediction scores from the language modeling head.
43
+ past_key_values, hidden_states, attentions, rope_deltas:
44
+ Same as parent class.
45
+ """
46
+ def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs):
47
+ super().__init__(*args, **kwargs)
48
+ self.lm_loss = lm_loss
49
+ self.pointer_loss = pointer_loss
50
+ self.pointer_scores = pointer_scores
51
+
52
+
53
+ class VisionHead_MultiPatch(nn.Module):
54
+ def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1):
55
+ super().__init__()
56
+ self.d_model = d_model
57
+
58
+ self.projection_enc = nn.Sequential(
59
+ nn.Linear(d_model, projection_dim),
60
+ nn.GELU(),
61
+ nn.Linear(projection_dim, d_model),
62
+ )
63
+ self.projection_dec = nn.Sequential(
64
+ nn.Linear(d_model, projection_dim),
65
+ nn.GELU(),
66
+ nn.Linear(projection_dim, d_model),
67
+ )
68
+
69
+ self.self_attention = nn.MultiheadAttention(
70
+ embed_dim=d_model, num_heads=num_attention_heads, dropout=dropout_rate, batch_first=True
71
+ )
72
+
73
+ self.layer_norm = nn.LayerNorm(d_model)
74
+ self.dropout = nn.Dropout(dropout_rate)
75
+
76
+ def forward(
77
+ self,
78
+ hidden_state_enc, # [n_enc, d_model]
79
+ hidden_state_dec, # [n_dec, d_model]
80
+ labels: Optional[torch.Tensor] = None, # [n_dec, n_enc] binary mask of patches in bbox
81
+ do_single_patch: bool = False,
82
+ ):
83
+ enc_input = hidden_state_enc.unsqueeze(0)
84
+ attn_output, _ = self.self_attention(query=enc_input, key=enc_input, value=enc_input, need_weights=False)
85
+ hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output)).squeeze(0) # [n_enc, d_model]
86
+
87
+ proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model]
88
+ proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model]
89
+
90
+ scaling = self.d_model ** 0.5
91
+ patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc]
92
+
93
+ attn_weights = F.softmax(patch_logits, dim=-1)
94
+
95
+ loss = None
96
+ if (labels is not None) and (not do_single_patch):
97
+ epsilon = 1e-8
98
+ labels_float = labels.float()
99
+ target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon)
100
+ pred_log_probs = F.log_softmax(patch_logits, dim=-1)
101
+ loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean')
102
+
103
+ if do_single_patch and (labels is not None):
104
+ # NOTE: if you ever enable this branch, use patch_logits for CE
105
+ loss = F.cross_entropy(patch_logits, labels)
106
+
107
+ return attn_weights, loss
108
+
109
+
110
+ class Qwen2_5_VLForConditionalGenerationWithPointer(Qwen2_5_VLForConditionalGeneration):
111
+ def __init__(self, *args, **kwargs):
112
+ super().__init__(*args, **kwargs)
113
+ self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size)
114
+ self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0)
115
+ self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0)
116
+ self.post_init()
117
+
118
+ # init rope cache slot (used in return_dict path)
119
+ self.rope_deltas = None
120
+
121
+ def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight):
122
+ self.pointer_loss_weight = pointer_loss_weight
123
+ self.lm_loss_weight = lm_loss_weight
124
+
125
+ def forward(
126
+ self,
127
+ input_ids: torch.LongTensor = None, # (batch_size, seq_len)
128
+ attention_mask: Optional[torch.Tensor] = None,
129
+ position_ids: Optional[torch.LongTensor] = None,
130
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
131
+ inputs_embeds: Optional[torch.FloatTensor] = None,
132
+ labels: Optional[torch.LongTensor] = None,
133
+ use_cache: Optional[bool] = None,
134
+ output_attentions: Optional[bool] = None,
135
+ output_hidden_states: Optional[bool] = None,
136
+ return_dict: Optional[bool] = None,
137
+ pixel_values: Optional[torch.Tensor] = None,
138
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
139
+ image_grid_thw: Optional[torch.LongTensor] = None,
140
+ video_grid_thw: Optional[torch.LongTensor] = None,
141
+ rope_deltas: Optional[torch.LongTensor] = None,
142
+ cache_position: Optional[torch.LongTensor] = None,
143
+ second_per_grid_ts: Optional[torch.Tensor] = None,
144
+ # Grounding
145
+ visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # (batch_size, n_target)
146
+ multi_patch_labels: Optional[torch.Tensor] = None, # list/packed: [(n_target, n_visual), ...]
147
+ if_multi_patch: bool = True,
148
+ coordinates: Optional[List[Tuple[float, float]]] = None,
149
+ verbose: bool = False,
150
+ ) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]:
151
+
152
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
153
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
154
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
155
+
156
+ if verbose:
157
+ rank0_print(f"input_ids: {None if input_ids is None else (input_ids.shape, input_ids[0][:5])}")
158
+ rank0_print(f"labels: {None if labels is None else (labels.shape, labels[0][:5])}")
159
+ rank0_print(f"pixel_values: {None if pixel_values is None else pixel_values.shape}")
160
+ rank0_print(f"image_grid_thw: {None if image_grid_thw is None else image_grid_thw.shape}")
161
+ rank0_print(f"coordinates: {coordinates}")
162
+ rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}")
163
+ rank0_print(f"return_dict: {return_dict}")
164
+
165
+ if inputs_embeds is None:
166
+ if input_ids is None:
167
+ raise ValueError("Either inputs_embeds or input_ids must be provided.")
168
+
169
+ # FIX: use embedding accessor instead of .embed_tokens
170
+ token_embedding = _get_token_embedding_layer(self.model)
171
+ inputs_embeds = token_embedding(input_ids) # (batch, seq_len, d_model)
172
+
173
+ if pixel_values is not None:
174
+ pixel_values = pixel_values.type(self.visual.dtype)
175
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
176
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
177
+ n_image_features = image_embeds.shape[0]
178
+ if n_image_tokens != n_image_features:
179
+ raise ValueError(
180
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features: {n_image_features}"
181
+ )
182
+ image_mask = (
183
+ (input_ids == self.config.image_token_id)
184
+ .unsqueeze(-1)
185
+ .expand_as(inputs_embeds)
186
+ .to(inputs_embeds.device)
187
+ )
188
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
189
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
190
+
191
+ if pixel_values_videos is not None:
192
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
193
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
194
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
195
+ n_video_features = video_embeds.shape[0]
196
+ if n_video_tokens != n_video_features:
197
+ raise ValueError(
198
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features: {n_video_features}"
199
+ )
200
+ video_mask = (
201
+ (input_ids == self.config.video_token_id)
202
+ .unsqueeze(-1)
203
+ .expand_as(inputs_embeds)
204
+ .to(inputs_embeds.device)
205
+ )
206
+ video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
207
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
208
+
209
+ if attention_mask is not None:
210
+ attention_mask = attention_mask.to(inputs_embeds.device)
211
+
212
+ # RoPE positions / deltas
213
+ if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
214
+ if (
215
+ (cache_position is not None and cache_position[0] == 0)
216
+ or self.rope_deltas is None
217
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
218
+ ):
219
+ position_ids, rope_deltas = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
220
+ self.rope_deltas = rope_deltas
221
+ else:
222
+ batch_size, seq_length, _ = inputs_embeds.shape
223
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
224
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
225
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
226
+ if cache_position is not None:
227
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0).to(position_ids.device)
228
+ position_ids = position_ids.add(delta)
229
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
230
+
231
+ outputs = self.model(
232
+ input_ids=None,
233
+ position_ids=position_ids,
234
+ attention_mask=attention_mask,
235
+ past_key_values=past_key_values,
236
+ inputs_embeds=inputs_embeds,
237
+ use_cache=use_cache,
238
+ output_attentions=output_attentions,
239
+ output_hidden_states=output_hidden_states,
240
+ return_dict=return_dict,
241
+ cache_position=cache_position,
242
+ )
243
+
244
+ hidden_states = outputs[0] # (batch, seq_len, d_model)
245
+ logits = self.lm_head(hidden_states)
246
+
247
+ lm_loss = None
248
+ if labels is not None and self.lm_loss_weight > 0:
249
+ logits = logits.float()
250
+ shift_logits = logits[..., :-1, :].contiguous()
251
+ shift_labels = labels[..., 1:].contiguous()
252
+ loss_fct = nn.CrossEntropyLoss()
253
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
254
+ shift_labels = shift_labels.view(-1).to(shift_logits.device)
255
+ lm_loss = loss_fct(shift_logits, shift_labels)
256
+
257
+ pointer_loss = None
258
+ pointer_scores = []
259
+ if visual_token_indices_of_coordinates is not None:
260
+ batch_size = input_ids.shape[0]
261
+ pointer_losses = []
262
+
263
+ for i in range(batch_size):
264
+ dummy_target = False
265
+
266
+ token_ids = input_ids[i] # (seq_len,)
267
+ hs = hidden_states[i] # (seq_len, d_model)
268
+
269
+ visual_mask = (token_ids == self.config.image_token_id)
270
+ visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # (n_visual,)
271
+
272
+ target_mask = (token_ids == self.config.pointer_pad_token_id)
273
+ target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
274
+
275
+ if visual_indices.numel() == 0:
276
+ raise ValueError(f"No visual tokens found for sample {i}.")
277
+
278
+ if target_indices.numel() == 0:
279
+ target_indices = torch.tensor([hs.shape[0] - 1], device=hs.device)
280
+ gt = torch.tensor([0], device=hs.device) # not used in multi-patch
281
+ if if_multi_patch:
282
+ sample_labels = torch.zeros_like(visual_indices).unsqueeze(0)
283
+ sample_labels[0][:4] = 1
284
+ dummy_target = True
285
+ else:
286
+ gt = visual_token_indices_of_coordinates[i].to(hs.device) # (n_target,)
287
+ if if_multi_patch:
288
+ sample_labels = multi_patch_labels[i]
289
+
290
+ # Use input embeddings for visual tokens (image tokens got replaced earlier)
291
+ visual_embeds = inputs_embeds[i][visual_indices] # (n_visual, d_model)
292
+ target_hidden = hs[target_indices] # (n_target, d_model)
293
+
294
+ if if_multi_patch:
295
+ if sample_labels.shape[0] != target_indices.shape[0]:
296
+ raise ValueError(
297
+ f"Sample {i} mismatched targets: {sample_labels.shape[0]} labels vs {target_indices.shape[0]} targets"
298
+ )
299
+ attn_scores, loss_v = self.multi_patch_pointer_head(
300
+ visual_embeds,
301
+ target_hidden,
302
+ labels=sample_labels,
303
+ )
304
+ else:
305
+ # Deprecated: single-patch branch
306
+ attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt)
307
+
308
+ pointer_scores.append(attn_scores.detach().cpu())
309
+ pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v)
310
+
311
+ pointer_loss = torch.stack(pointer_losses).mean()
312
+
313
+ if lm_loss is None:
314
+ total_loss = pointer_loss
315
+ elif pointer_loss is None:
316
+ total_loss = lm_loss
317
+ else:
318
+ total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss
319
+
320
+ if return_dict:
321
+ return QwenVLwithVisionHeadOutputWithPast(
322
+ lm_loss=lm_loss,
323
+ pointer_loss=pointer_loss,
324
+ pointer_scores=pointer_scores,
325
+ loss=total_loss,
326
+ logits=logits,
327
+ past_key_values=outputs.past_key_values,
328
+ hidden_states=outputs.hidden_states,
329
+ attentions=outputs.attentions,
330
+ rope_deltas=self.rope_deltas,
331
+ )
332
+ else:
333
+ if labels is not None:
334
+ output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:]
335
+ return (total_loss,) + output if total_loss is not None else output
336
+ else:
337
+ return outputs
gui_actor/trainer.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import timedelta
2
+ from functools import wraps
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ import transformers
8
+ from accelerate import Accelerator, DataLoaderConfiguration
9
+ from accelerate.utils import GradientAccumulationPlugin, InitProcessGroupKwargs
10
+ from torch.utils.data import DataLoader, RandomSampler
11
+ from transformers import Trainer
12
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
13
+ from transformers.trainer_pt_utils import get_parameter_names
14
+ from transformers.trainer_utils import has_length
15
+ from transformers.utils import (
16
+ is_accelerate_available,
17
+ is_datasets_available,
18
+ is_sagemaker_mp_enabled,
19
+ )
20
+ from transformers.trainer_pt_utils import LengthGroupedSampler as HFLengthGroupedSampler
21
+ from transformers.trainer_utils import seed_worker
22
+ from transformers.utils import logging
23
+
24
+ if is_datasets_available():
25
+ import datasets
26
+
27
+
28
+ def rank0_print(*args):
29
+ if dist.is_initialized():
30
+ if dist.get_rank() == 0:
31
+ print(f"Rank {dist.get_rank()}: ", *args)
32
+ else:
33
+ print(*args)
34
+
35
+
36
+ def maybe_zero_3(param, ignore_status=False, name=None):
37
+ from deepspeed import zero
38
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
39
+
40
+ if hasattr(param, "ds_id"):
41
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE and not ignore_status:
42
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
43
+ with zero.GatheredParameters([param]):
44
+ param = param.data.detach().cpu().clone()
45
+ else:
46
+ param = param.detach().cpu().clone()
47
+ return param
48
+
49
+
50
+ def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
51
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
52
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
53
+ return to_return
54
+
55
+
56
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
57
+ """Collects the state dict and dump to disk."""
58
+ trainer.accelerator.wait_for_everyone()
59
+ torch.cuda.synchronize()
60
+
61
+ if trainer.deepspeed:
62
+ trainer.save_model(output_dir)
63
+ return
64
+
65
+ state_dict = trainer.model.state_dict()
66
+ if trainer.args.should_save:
67
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
68
+ del state_dict
69
+ trainer._save(output_dir, state_dict=cpu_state_dict)
70
+
71
+
72
+ class AGUVISTrainer(Trainer):
73
+
74
+ def __init__(self, *args, **kwargs):
75
+ super().__init__(*args, **kwargs)
76
+
77
+ original_save = self._save
78
+ original_save_model = self.save_model
79
+
80
+ def modify_eos_token(func):
81
+ @wraps(func)
82
+ def wrapper(*args, **kwargs):
83
+ tokenizer = self.processing_class.tokenizer
84
+ old_config_id = self.model.config.eos_token_id
85
+ old_eos_token = tokenizer.eos_token
86
+ old_generation_config_eos_token_id = (
87
+ self.model.generation_config.eos_token_id if hasattr(self.model, "generation_config") else None
88
+ )
89
+
90
+ try:
91
+ new_eos_token_id = tokenizer.convert_tokens_to_ids("<|diff_marker|>")
92
+ self.model.config.eos_token_id = [new_eos_token_id]
93
+ tokenizer.eos_token = "<|diff_marker|>"
94
+ if hasattr(self.model, "generation_config"):
95
+ self.model.generation_config.eos_token_id = [new_eos_token_id]
96
+
97
+ print("Set eos token id to", new_eos_token_id)
98
+ print("Set eos token to", "<|diff_marker|>")
99
+ print("Set generation config eos token id to", [new_eos_token_id])
100
+
101
+ result = func(*args, **kwargs)
102
+ return result
103
+ finally:
104
+ self.model.config.eos_token_id = old_config_id
105
+ tokenizer.eos_token = old_eos_token
106
+ if hasattr(self.model, "generation_config") and old_generation_config_eos_token_id is not None:
107
+ self.model.generation_config.eos_token_id = old_generation_config_eos_token_id
108
+
109
+ print("Set eos token id back to", old_config_id)
110
+ print("Set eos token back to", old_eos_token)
111
+ if old_generation_config_eos_token_id is not None:
112
+ print("Set generation config eos token id back to", old_generation_config_eos_token_id)
113
+
114
+ return wrapper
115
+
116
+ self._save = modify_eos_token(original_save)
117
+ self.save_model = modify_eos_token(original_save_model)
118
+
119
+ def create_accelerator_and_postprocess(self):
120
+ grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
121
+ grad_acc_kwargs["sync_with_dataloader"] = False
122
+ gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
123
+
124
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
125
+
126
+ # create accelerator object
127
+ dispatch_batches = getattr(self.args, "dispatch_batches", None)
128
+ split_batches = getattr(self.args, "split_batches", None)
129
+ self.dataloader_config = DataLoaderConfiguration(
130
+ dispatch_batches=dispatch_batches,
131
+ split_batches=split_batches,
132
+ )
133
+ self.accelerator = Accelerator(
134
+ dataloader_config=self.dataloader_config,
135
+ deepspeed_plugin=self.args.deepspeed_plugin,
136
+ gradient_accumulation_plugin=gradient_accumulation_plugin,
137
+ kwargs_handlers=[accelerator_kwargs],
138
+ )
139
+ # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
140
+ self.gather_function = self.accelerator.gather_for_metrics
141
+
142
+ # deepspeed and accelerate flags covering both trainer args and accelerate launcher
143
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
144
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
145
+
146
+ # post accelerator creation setup
147
+ if self.is_fsdp_enabled:
148
+ fsdp_plugin = self.accelerator.state.fsdp_plugin
149
+ fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
150
+ "limit_all_gathers", fsdp_plugin.limit_all_gathers
151
+ )
152
+ if is_accelerate_available("0.23.0"):
153
+ fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
154
+ "activation_checkpointing", fsdp_plugin.activation_checkpointing
155
+ )
156
+ if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
157
+ raise ValueError(
158
+ "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
159
+ "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
160
+ "when using FSDP."
161
+ )
162
+
163
+ if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
164
+ self.propagate_args_to_deepspeed()
165
+
166
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
167
+ if self.train_dataset is None or not has_length(self.train_dataset):
168
+ return None
169
+
170
+ if self.args.group_by_length:
171
+ lengths = self.train_dataset.lengths
172
+ return HFLengthGroupedSampler(
173
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
174
+ dataset=self.train_dataset,
175
+ lengths=lengths,
176
+ )
177
+ elif self.args.group_by_modality_length:
178
+ lengths = self.train_dataset.modality_lengths
179
+ return HFLengthGroupedSampler(
180
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
181
+ dataset=self.train_dataset,
182
+ lengths=lengths,
183
+ )
184
+ else:
185
+ return RandomSampler(self.train_dataset)
186
+
187
+ def get_train_dataloader(self) -> DataLoader:
188
+ """
189
+ Returns the training [`~torch.utils.data.DataLoader`].
190
+
191
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
192
+ training if necessary) otherwise.
193
+
194
+ Subclass and override this method if you want to inject some custom behavior.
195
+ """
196
+ if self.train_dataset is None:
197
+ raise ValueError("Trainer: training requires a train_dataset.")
198
+
199
+ train_dataset = self.train_dataset
200
+ data_collator = self.data_collator
201
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
202
+ train_dataset = self._remove_unused_columns(train_dataset, description="training")
203
+ else:
204
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
205
+
206
+ dataloader_params = {
207
+ "batch_size": self._train_batch_size,
208
+ "collate_fn": data_collator,
209
+ "num_workers": self.args.dataloader_num_workers,
210
+ "pin_memory": self.args.dataloader_pin_memory,
211
+ "persistent_workers": self.args.dataloader_persistent_workers,
212
+ }
213
+
214
+ if not isinstance(train_dataset, torch.utils.data.IterableDataset):
215
+ dataloader_params["sampler"] = self._get_train_sampler()
216
+ dataloader_params["drop_last"] = self.args.dataloader_drop_last
217
+ dataloader_params["worker_init_fn"] = seed_worker
218
+ dataloader_params["prefetch_factor"] = (
219
+ self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None
220
+ )
221
+
222
+ dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
223
+
224
+ return dataloader
225
+
226
+ def create_optimizer(self):
227
+ """
228
+ Setup the optimizer.
229
+
230
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
231
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
232
+ """
233
+ if is_sagemaker_mp_enabled():
234
+ return super().create_optimizer()
235
+
236
+ opt_model = self.model
237
+
238
+ if self.optimizer is None:
239
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
240
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
241
+ optimizer_grouped_parameters = [
242
+ {
243
+ "params": [
244
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
245
+ ],
246
+ "weight_decay": self.args.weight_decay,
247
+ },
248
+ {
249
+ "params": [
250
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
251
+ ],
252
+ "weight_decay": 0.0,
253
+ },
254
+ ]
255
+
256
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
257
+
258
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
259
+
260
+ return self.optimizer
261
+
262
+ def create_optimizer_with_different_learning_rates(self):
263
+ """
264
+ Setup the optimizer.
265
+
266
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
267
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
268
+ """
269
+ if is_sagemaker_mp_enabled():
270
+ raise NotImplementedError("Sagemaker MP is not supported for separate learning rate yet")
271
+ return super().create_optimizer()
272
+
273
+ opt_model = self.model
274
+
275
+ if self.optimizer is None:
276
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
277
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
278
+
279
+ new_parameters = []
280
+ for name, param in opt_model.named_parameters():
281
+ if ("pointer_head" in name) or ("embed_tokens" in name):
282
+ new_parameters.append(name)
283
+ rank0_print(f"new_parameters: {len(new_parameters)}")
284
+
285
+ optimizer_grouped_parameters = [
286
+ {
287
+ "params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n not in new_parameters) and p.requires_grad)],
288
+ "weight_decay": self.args.weight_decay,
289
+ "lr": self.args.learning_rate,
290
+ },
291
+ {
292
+ "params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n not in new_parameters) and p.requires_grad)],
293
+ "weight_decay": 0.0,
294
+ "lr": self.args.learning_rate,
295
+ },
296
+ {
297
+ "params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n in new_parameters) and p.requires_grad)],
298
+ "weight_decay": self.args.weight_decay,
299
+ "lr": self.args.learning_rate_new_params,
300
+ },
301
+ {
302
+ "params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n in new_parameters) and p.requires_grad)],
303
+ "weight_decay": 0.0,
304
+ "lr": self.args.learning_rate_new_params,
305
+ },
306
+ ]
307
+
308
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) # {'lr': 0.0001, 'betas': (0.9, 0.999), 'eps': 1e-08}
309
+ optimizer_kwargs.pop("lr")
310
+
311
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
312
+
313
+ return self.optimizer
gui_actor/utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageColor
2
+ import json
3
+ import os
4
+
5
+ def dump_args_to_json(model_config, data_processor, model_args, data_args, training_args, output_dir):
6
+ def is_json_serializable(v):
7
+ try:
8
+ json.dumps(v)
9
+ return True
10
+ except:
11
+ return False
12
+
13
+ save_path = f"{output_dir}/args.json"
14
+ if not os.path.exists(save_path):
15
+ with open(save_path, "w") as f:
16
+ json.dump({
17
+ "model_config": {k: v for k, v in model_config.__dict__.items() if is_json_serializable(v)},
18
+ "data_processor_config": {k: v for k, v in data_processor.__dict__.items() if is_json_serializable(v)},
19
+ "image_processor_config": {k: v for k, v in data_processor.image_processor.__dict__.items() if is_json_serializable(v)},
20
+ "model_args": {k: v for k, v in model_args.__dict__.items() if is_json_serializable(v)},
21
+ "data_args": {k: v for k, v in data_args.__dict__.items() if is_json_serializable(v)},
22
+ "training_args": {k: v for k, v in training_args.__dict__.items() if is_json_serializable(v)},
23
+ }, f, indent=4)
24
+
25
+ def draw_point(image: Image.Image, point: list, color=None):
26
+ if isinstance(color, str):
27
+ try:
28
+ color = ImageColor.getrgb(color)
29
+ color = color + (128,)
30
+ except ValueError:
31
+ color = (255, 0, 0, 128)
32
+ else:
33
+ color = (255, 0, 0, 128)
34
+
35
+ overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
36
+ overlay_draw = ImageDraw.Draw(overlay)
37
+ radius = 14
38
+ x, y = point
39
+
40
+ overlay_draw.rectangle(
41
+ [x - radius, y - radius, x + radius, y + radius],
42
+ fill=color
43
+ )
44
+
45
+ center_radius = radius * 0.1
46
+ overlay_draw.ellipse(
47
+ [(x - center_radius, y - center_radius),
48
+ (x + center_radius, y + center_radius)],
49
+ fill=(0, 255, 0, 255)
50
+ )
51
+
52
+ image = image.convert('RGBA')
53
+ combined = Image.alpha_composite(image, overlay)
54
+
55
+ return combined.convert('RGB')
56
+
57
+ def draw_bbox(image: Image.Image, bbox: list, color=None):
58
+ """bbox is in the format of [x1, y1, x2, y2]"""
59
+ if isinstance(color, str):
60
+ try:
61
+ color = ImageColor.getrgb(color)
62
+ color = color + (128,)
63
+ except ValueError:
64
+ color = (255, 0, 0, 128)
65
+ else:
66
+ color = (255, 0, 0, 128)
67
+
68
+ overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
69
+ overlay_draw = ImageDraw.Draw(overlay)
70
+ overlay_draw.rectangle(bbox, fill=color)
71
+ return Image.alpha_composite(image, overlay).convert('RGB')
72
+
73
+ def do_boxes_overlap(box1, box2):
74
+ """
75
+ Check if two boxes overlap.
76
+
77
+ Each box is represented as a tuple: (x1, y1, x2, y2)
78
+ Where (x1, y1) is the top-left and (x2, y2) is the bottom-right corner.
79
+ """
80
+ # Unpack the coordinates
81
+ x1_min, y1_min, x1_max, y1_max = box1
82
+ x2_min, y2_min, x2_max, y2_max = box2
83
+
84
+ # Check for no overlap
85
+ if x1_max < x2_min or x2_max < x1_min:
86
+ return False
87
+ if y1_max < y2_min or y2_max < y1_min:
88
+ return False
89
+
90
+ return True
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ accelerate
3
+ torch
4
+ Pillow
5
+ requests
6
+ torchvision
7
+ torchaudio
8
+ gradio
9
+ gradio_client
10
+ spaces
11
+ opencv-python-headless
12
+ datasets
13
+ qwen-vl-utils
14
+ pre-commit
15
+ matplotlib
16
+ #flash-attn