GUI-Actor / gui_actor /dataset.py
johnbridges's picture
.
9f57ecf
import copy
import json
import math
import os
import random
import re
import ast
from typing import Dict
import torch
import transformers
import yaml
from qwen_vl_utils import smart_resize, process_vision_info
from torch.utils.data import Dataset
from gui_actor.constants import (
IGNORE_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_POINTER_START_TOKEN,
DEFAULT_POINTER_PAD_TOKEN,
DEFAULT_POINTER_END_TOKEN,
ACTION_PATTENS_XY,
ADDITIONAL_SPECIAL_TOKENS,
assistant_template,
chat_template,
grounding_system_message,
)
from gui_actor.trainer import rank0_print
def reformat_coordinates(text):
"""
(1) Find all the coordinates in the text.
(2) Replace the coordinates with the special tokens.
(3) Return the new text and the coordinates as a list of (x, y), where x in [0, 1] and y in [0, 1].
"""
epsilon = 0.001
def adjust_coord(c):
"""
Adjust coordinate if it is too close to 0 or 1.
"""
if abs(c) < epsilon:
return epsilon
elif abs(c - 1) < epsilon:
return 1 - epsilon
return c
all_matches = []
for pattern in ACTION_PATTENS_XY:
matches = list(re.finditer(pattern, text))
for match in matches:
all_matches.append((match.start(), match.groups()))
if pattern == ACTION_PATTENS_XY[0]:
target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
else:
target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}, {DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}"
text = re.sub(
pattern,
target_text,
text
)
coordinates = []
all_matches.sort(key=lambda x: x[0])
# Extract coordinates in order
for _, groups in all_matches:
# When two coordinate values are found, parse them as one (x, y) pair.
if len(groups) == 2:
x_str, y_str = groups
x = adjust_coord(ast.literal_eval(x_str))
y = adjust_coord(ast.literal_eval(y_str))
coordinates.append((x, y))
# When four coordinate values are found, parse them as two pairs.
elif len(groups) == 4:
x1_str, y1_str, x2_str, y2_str = groups
x1 = adjust_coord(ast.literal_eval(x1_str))
y1 = adjust_coord(ast.literal_eval(y1_str))
x2 = adjust_coord(ast.literal_eval(x2_str))
y2 = adjust_coord(ast.literal_eval(y2_str))
coordinates.append((x1, y1))
coordinates.append((x2, y2))
return text, coordinates
def get_token_index(image_processor, image, point_x, point_y):
"""
Get the index of the visual token that contains the point (x, y).
Args:
image_processor: the image processor
image: the image in PIL format
point_x: the x coordinate of the point, in [0, 1].
point_y: the y coordinate of the point, in [0, 1].
"""
if len(image) != 1:
raise ValueError(f"Expected 1 image, got {len(image)}")
# get the original image size and the resized image size
image = image[0]
w, h = image.size
px, py = w * point_x, h * point_y
# rank0_print(f"px: {px}, py: {py}")
# get the token index
merge_patch_size = image_processor.patch_size * image_processor.merge_size
x_index = math.floor(px / merge_patch_size)
y_index = math.floor(py / merge_patch_size)
visual_token_index = y_index * (w // merge_patch_size) + x_index
# merge all above print into one line
return visual_token_index
def get_multi_patch_labels(image_processor, image, bbox_gt):
"""
Get the multi-patch labels for the bounding box.
Args:
image_processor: the image processor
image: the image in PIL format
bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1]
"""
if len(image) != 1:
raise ValueError(f"Expected 1 image, got {len(image)}")
# Get the original image size and the resized image size
image = image[0]
w, h = image.size
bbox_gt = [bbox_gt[0]*w, bbox_gt[1]*h, bbox_gt[2]*w, bbox_gt[3]*h]
# Extract bounding box coordinates
x_min, y_min, x_max, y_max = bbox_gt
x_min = max(0, x_min)
y_min = max(0, y_min)
x_max = min(w, x_max)
y_max = min(h, y_max)
merge_patch_size = image_processor.patch_size * image_processor.merge_size
assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}"
grid_h, grid_w = h // merge_patch_size, w // merge_patch_size
binary_mask = torch.zeros(grid_h * grid_w)
# Iterate through all patches, check if they overlap with the bounding box
for y_idx in range(grid_h):
for x_idx in range(grid_w):
# Calculate patch boundaries
patch_x_min = x_idx * merge_patch_size
patch_y_min = y_idx * merge_patch_size
patch_x_max = patch_x_min + merge_patch_size
patch_y_max = patch_y_min + merge_patch_size
# Check if patch overlaps with the bounding box
if not (patch_x_max <= x_min or patch_x_min >= x_max or
patch_y_max <= y_min or patch_y_min >= y_max):
# Calculate patch index in the flattened grid
patch_idx = y_idx * grid_w + x_idx
binary_mask[patch_idx] = 1
return binary_mask
def token_index_to_coordinates(image_processor, visual_token_index, image_width, image_height):
merge_patch_size = image_processor.patch_size * image_processor.merge_size
x_index = visual_token_index % (image_width // merge_patch_size)
y_index = visual_token_index // (image_width // merge_patch_size)
px = x_index * merge_patch_size + merge_patch_size / 2
py = y_index * merge_patch_size + merge_patch_size / 2
return px, py
class LazySupervisedDataset(Dataset):
def __init__(
self,
tokenizer: transformers.PreTrainedTokenizer,
processor: transformers.ProcessorMixin,
data_path: str,
data_args,
):
super().__init__()
self.tokenizer = tokenizer
self.processor = processor
self.list_data_dict = []
self.list_image_path = []
self.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0]
self.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0]
self.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
# Handle multiple JSON files specified in the data_path
if "{" in data_path and "}" in data_path:
base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups()
file_names = file_pattern.split(",")
rank0_print(f"Loading {file_names} from {base_path}")
data_args.dataset_paths = []
for file_name in file_names:
data_args.dataset_paths.append(f"{base_path}{file_name}.json")
full_path = f"{base_path}{file_name}.json"
rank0_print(f"Loading {full_path}")
with open(full_path) as file:
cur_data_dict = json.load(file)
rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}")
self.list_data_dict.extend(cur_data_dict)
elif data_path.endswith(".yaml"):
with open(data_path) as file:
yaml_data = yaml.safe_load(file)
datasets = yaml_data.get("datasets")
# file should be in the format of:
# datasets:
# - json_path: xxxx1.json
# sampling_strategy: first:1000
# - json_path: xxxx2.json
# sampling_strategy: end:3000
# - json_path: xxxx3.json
# sampling_strategy: random:999
data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets]
for dataset in datasets:
json_path = dataset.get("json_path")
sampling_strategy = dataset.get("sampling_strategy", "all")
images_folder = dataset.get("images_folder")
sampling_number = None
rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy")
if json_path.endswith(".jsonl"):
cur_data_dict = []
with open(json_path) as json_file:
for line in json_file:
cur_data_dict.append(json.loads(line.strip()))
elif json_path.endswith(".json"):
# NOTE: we only use json_path with .json now
# Handle the images_folder in yaml
with open(json_path) as json_file:
cur_data_dict = json.load(json_file)
else:
raise ValueError(f"Unsupported file type: {json_path}")
if ":" in sampling_strategy:
sampling_strategy, sampling_number = sampling_strategy.split(":")
if "%" in sampling_number:
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
else:
sampling_number = int(sampling_number)
# Apply the sampling strategy
if sampling_strategy == "first" and sampling_number is not None:
cur_data_dict = cur_data_dict[:sampling_number]
elif sampling_strategy == "end" and sampling_number is not None:
cur_data_dict = cur_data_dict[-sampling_number:]
elif sampling_strategy == "random" and sampling_number is not None:
random.shuffle(cur_data_dict)
cur_data_dict = cur_data_dict[:sampling_number]
rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
self.list_data_dict.extend(cur_data_dict)
self.list_image_path.extend([images_folder] * len(cur_data_dict))
else:
data_args.dataset_paths = [data_path]
rank0_print(f"Loading {data_path}")
with open(data_path) as file:
cur_data_dict = json.load(file)
rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}")
self.list_data_dict.extend(cur_data_dict)
self.list_image_path.extend([""] * len(cur_data_dict)) # NOTE: the image subfolder is empty...
rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}")
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.data_args = data_args
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = (
1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
)
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"])
assert cur_len > 0, f"Conversation length is 0 for {sample}"
img_tokens = (
1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0
)
if "image" in sample or "video" in sample or self.data_args.early_mix_text:
length_list.append(cur_len + img_tokens)
else:
length_list.append(-cur_len)
return length_list
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sample = self._get_item(i)
if sample is None:
new_index = random.randint(0, len(self.list_data_dict) - 1)
return self.__getitem__(new_index)
else:
return sample
try:
sample = self._get_item(i)
if sample is None:
new_index = random.randint(0, len(self.list_data_dict) - 1)
return self.__getitem__(new_index)
except Exception as e:
print(f"Failed to fetch sample {i}. Exception:", e)
new_index = random.randint(0, len(self.list_data_dict) - 1)
return self.__getitem__(new_index)
return sample
def _get_item(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
image_path = os.path.join(self.data_args.image_folder, self.list_image_path[i])
if "image" in sources:
image_file = self.list_data_dict[i]["image"]
if type(image_file) is list:
image_list = [os.path.join(image_path, image_file) for image_file in image_file]
else:
image_list = [os.path.join(image_path, image_file)]
sources = copy.deepcopy(sources["conversations"])
elif "video" in sources:
raise NotImplementedError("Video is not supported for Qwen2VL")
else:
sources = copy.deepcopy(sources["conversations"])
item_id = self.list_data_dict[i].get("id", i)
data_dict = self.preprocess_qwen2vl(sources, self.tokenizer, self.processor, image_list, id=item_id)
if isinstance(i, int):
data_dict = {
"input_ids": data_dict["input_ids"][0],
"labels": data_dict["labels"][0],
"coordinates": data_dict["coordinates"][0],
"visual_token_indices_of_coordinates": data_dict["visual_token_indices_of_coordinates"][0],
"pixel_values": data_dict["pixel_values"],
"image_grid_thw": data_dict["image_grid_thw"],
"multi_patch_labels": data_dict["multi_patch_labels"][0], # add multi_patch_labels
}
data_dict["id"] = item_id
# return None if the input_ids is longer than the model_max_length
n_image_tokens = (
data_dict["image_grid_thw"][0][0] *
data_dict["image_grid_thw"][0][1] *
data_dict["image_grid_thw"][0][2] /
self.processor.image_processor.merge_size /
self.processor.image_processor.merge_size
)
if (len(data_dict["input_ids"]) + n_image_tokens) > self.tokenizer.model_max_length:
rank0_print(f"=== Removed data_dict {i} because it is longer than the model_max_length: {len(data_dict['input_ids'])} + {n_image_tokens} > {self.tokenizer.model_max_length}")
return None
return data_dict
def preprocess_qwen2vl(
self,
source, # conversations
tokenizer: transformers.PreTrainedTokenizer,
processor: transformers.ProcessorMixin,
image: list,
system_message: str = grounding_system_message,
agent_mode: bool = True,
chat_template: str = chat_template,
assistant_template: str = assistant_template,
id: int = None,
) -> Dict:
roles = {"human": "user", "gpt": "assistant", "system": "system"}
assistant_template = assistant_template if agent_mode else chat_template
processor.tokenizer = tokenizer
assert tokenizer.additional_special_tokens == ADDITIONAL_SPECIAL_TOKENS
# Apply prompt templates
pixel_values, image_grid_thw = None, None
input_id, target = [], []
coordinates = []
visual_token_indices_of_coordinates = []
multi_patch_labels = []
image_list = []
image_index = 0
## prepare the system message
if roles[source[0]["from"]] == "system":
system_message = source[0]["value"]
source = source[1:self.data_args.max_conv_turns]
# else: use the constant system message
system_input_id = tokenizer.apply_chat_template(
conversation=[{"role": "system", "content": [{"type": "text", "text": system_message}]}],
chat_template=chat_template,
)
input_id += system_input_id
target += [IGNORE_INDEX] * len(system_input_id)
## prepare user-assistant conversation
for conv in source:
# regularize the conversation format
try:
role = conv["role"]
content = conv["content"]
except Exception:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
# Count the number of <image> tokens in the content
image_count = content.count(DEFAULT_IMAGE_TOKEN)
if image_count > 0:
assert role == "user", "Images are only supported for user messages"
# include image information regarding to current conversation turn
image_placeholders = []
for _ in range(image_count):
image_placeholders.append({
"type": "image",
"image": image[image_index],
"min_pixels": self.processor.image_processor.min_pixels,
"max_pixels": self.processor.image_processor.max_pixels,
})
image_index += 1
content = content.replace(DEFAULT_IMAGE_TOKEN, "")
conv = {"role": role, "content": image_placeholders + [{"type": "text", "text": content}]}
image_inputs, _ = process_vision_info([conv]) # list of PIL.Image.Image
image_list.extend(image_inputs)
templated_conv = tokenizer.apply_chat_template(
conversation=[conv], chat_template=chat_template, tokenize=False
)
inputs = processor(text=[templated_conv], images=image_inputs, return_tensors="pt")
if pixel_values is None and image_grid_thw is None:
pixel_values = inputs["pixel_values"]
image_grid_thw = inputs["image_grid_thw"]
else:
pixel_values = torch.concat([pixel_values, inputs["pixel_values"]], dim=0)
image_grid_thw = torch.concat([image_grid_thw, inputs["image_grid_thw"]], dim=0)
else:
if role in ["user", "system"]:
conv = {"role": role, "content": [{"type": "text", "text": content}]}
else: # assistant
conv = {
"role": role,
"content": [{"type": "text", "text": content}],
"recipient": conv.get("recipient", "os"),
"end_turn": conv.get("end_turn", True),
"bbox_gt": conv.get("bbox_gt", None),
}
if conv["recipient"] == "os":
if len(image_inputs) == 0:
raise ValueError("No image found for visual grounding")
# replace the coordinates with the special tokens
text, coord = reformat_coordinates(conv["content"][0]["text"])
conv["content"][0]["text"] = text
# rank0_print(f"coord: {coord}")
# get the visual token indices of the coordinates
coordinates.extend(coord)
for (point_x, point_y) in coord:
visual_token_index = get_token_index(
processor.image_processor,
image_list,
point_x,
point_y
)
# px, py = token_index_to_coordinates(
# processor.image_processor,
# visual_token_index,
# image_list[0].size[0], # make sure the size here is after qwen2vl processing
# image_list[0].size[1]
# )
# rank0_print(f"estimated px: {px}, py: {py}")
visual_token_indices_of_coordinates.append(visual_token_index)
if conv["bbox_gt"] is not None:
patch_mask = get_multi_patch_labels(
processor.image_processor,
image_list,
conv["bbox_gt"]
)
multi_patch_labels.append(patch_mask)
templated_conv = tokenizer.apply_chat_template(
conversation=[conv],
chat_template=assistant_template,
tokenize=False,
)
inputs = processor(text=[templated_conv], return_tensors="pt")
encode_id = inputs.input_ids[0].tolist()
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target += encode_id
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
# make the labels of all pointer_end_token_id to be IGNORE_INDEX
target = [IGNORE_INDEX if token == self.pointer_end_token_id else token for token in target]
input_ids = torch.tensor([input_id], dtype=torch.long)
targets = torch.tensor([target], dtype=torch.long)
visual_token_indices_of_coordinates = torch.tensor([visual_token_indices_of_coordinates], dtype=torch.long) if len(visual_token_indices_of_coordinates) > 0 else [None]
coordinates = [coordinates] if len(coordinates) > 0 else [None]
# process multi_patch_labels
if len(multi_patch_labels) > 0:
multi_patch_labels = [torch.stack(multi_patch_labels)]
else:
multi_patch_labels = [None]
data_dict = {
"input_ids": input_ids, # tensor(bs x seq_len)
"labels": targets, # tensor(bs x seq_len)
}
if pixel_values is not None:
data_dict["pixel_values"] = pixel_values
data_dict["image_grid_thw"] = image_grid_thw
# if len(coordinates[0]) != len(visual_token_indices_of_coordinates[0]):
# raise ValueError(f"The number of coordinates ({len(coordinates[0])}) does not match the number of image token indices ({len(visual_token_indices_of_coordinates[0])})")
data_dict["coordinates"] = coordinates
data_dict["visual_token_indices_of_coordinates"] = visual_token_indices_of_coordinates
data_dict["multi_patch_labels"] = multi_patch_labels
return data_dict