|
import copy |
|
import json |
|
import math |
|
import os |
|
import random |
|
import re |
|
import ast |
|
from typing import Dict |
|
|
|
import torch |
|
import transformers |
|
import yaml |
|
from qwen_vl_utils import smart_resize, process_vision_info |
|
from torch.utils.data import Dataset |
|
|
|
from gui_actor.constants import ( |
|
IGNORE_INDEX, |
|
DEFAULT_IMAGE_TOKEN, |
|
DEFAULT_POINTER_START_TOKEN, |
|
DEFAULT_POINTER_PAD_TOKEN, |
|
DEFAULT_POINTER_END_TOKEN, |
|
ACTION_PATTENS_XY, |
|
ADDITIONAL_SPECIAL_TOKENS, |
|
assistant_template, |
|
chat_template, |
|
grounding_system_message, |
|
) |
|
from gui_actor.trainer import rank0_print |
|
|
|
|
|
def reformat_coordinates(text): |
|
""" |
|
(1) Find all the coordinates in the text. |
|
(2) Replace the coordinates with the special tokens. |
|
(3) Return the new text and the coordinates as a list of (x, y), where x in [0, 1] and y in [0, 1]. |
|
""" |
|
epsilon = 0.001 |
|
def adjust_coord(c): |
|
""" |
|
Adjust coordinate if it is too close to 0 or 1. |
|
""" |
|
if abs(c) < epsilon: |
|
return epsilon |
|
elif abs(c - 1) < epsilon: |
|
return 1 - epsilon |
|
return c |
|
|
|
all_matches = [] |
|
for pattern in ACTION_PATTENS_XY: |
|
matches = list(re.finditer(pattern, text)) |
|
for match in matches: |
|
all_matches.append((match.start(), match.groups())) |
|
if pattern == ACTION_PATTENS_XY[0]: |
|
target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}" |
|
else: |
|
target_text = f"{DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}, {DEFAULT_POINTER_START_TOKEN}{DEFAULT_POINTER_PAD_TOKEN}{DEFAULT_POINTER_END_TOKEN}" |
|
text = re.sub( |
|
pattern, |
|
target_text, |
|
text |
|
) |
|
|
|
coordinates = [] |
|
all_matches.sort(key=lambda x: x[0]) |
|
|
|
for _, groups in all_matches: |
|
|
|
if len(groups) == 2: |
|
x_str, y_str = groups |
|
x = adjust_coord(ast.literal_eval(x_str)) |
|
y = adjust_coord(ast.literal_eval(y_str)) |
|
coordinates.append((x, y)) |
|
|
|
elif len(groups) == 4: |
|
x1_str, y1_str, x2_str, y2_str = groups |
|
x1 = adjust_coord(ast.literal_eval(x1_str)) |
|
y1 = adjust_coord(ast.literal_eval(y1_str)) |
|
x2 = adjust_coord(ast.literal_eval(x2_str)) |
|
y2 = adjust_coord(ast.literal_eval(y2_str)) |
|
coordinates.append((x1, y1)) |
|
coordinates.append((x2, y2)) |
|
|
|
return text, coordinates |
|
|
|
def get_token_index(image_processor, image, point_x, point_y): |
|
""" |
|
Get the index of the visual token that contains the point (x, y). |
|
Args: |
|
image_processor: the image processor |
|
image: the image in PIL format |
|
point_x: the x coordinate of the point, in [0, 1]. |
|
point_y: the y coordinate of the point, in [0, 1]. |
|
""" |
|
if len(image) != 1: |
|
raise ValueError(f"Expected 1 image, got {len(image)}") |
|
|
|
|
|
image = image[0] |
|
w, h = image.size |
|
px, py = w * point_x, h * point_y |
|
|
|
|
|
merge_patch_size = image_processor.patch_size * image_processor.merge_size |
|
x_index = math.floor(px / merge_patch_size) |
|
y_index = math.floor(py / merge_patch_size) |
|
|
|
visual_token_index = y_index * (w // merge_patch_size) + x_index |
|
|
|
|
|
return visual_token_index |
|
|
|
def get_multi_patch_labels(image_processor, image, bbox_gt): |
|
""" |
|
Get the multi-patch labels for the bounding box. |
|
Args: |
|
image_processor: the image processor |
|
image: the image in PIL format |
|
bbox_gt: the bounding box in the format of (x_min, y_min, x_max, y_max) [0,1] |
|
""" |
|
if len(image) != 1: |
|
raise ValueError(f"Expected 1 image, got {len(image)}") |
|
|
|
|
|
image = image[0] |
|
w, h = image.size |
|
|
|
bbox_gt = [bbox_gt[0]*w, bbox_gt[1]*h, bbox_gt[2]*w, bbox_gt[3]*h] |
|
|
|
x_min, y_min, x_max, y_max = bbox_gt |
|
x_min = max(0, x_min) |
|
y_min = max(0, y_min) |
|
x_max = min(w, x_max) |
|
y_max = min(h, y_max) |
|
|
|
merge_patch_size = image_processor.patch_size * image_processor.merge_size |
|
assert w % merge_patch_size == 0 and h % merge_patch_size == 0, f"Image size {w}x{h} is not divisible by merge_patch_size {merge_patch_size}" |
|
grid_h, grid_w = h // merge_patch_size, w // merge_patch_size |
|
|
|
binary_mask = torch.zeros(grid_h * grid_w) |
|
|
|
for y_idx in range(grid_h): |
|
for x_idx in range(grid_w): |
|
|
|
patch_x_min = x_idx * merge_patch_size |
|
patch_y_min = y_idx * merge_patch_size |
|
patch_x_max = patch_x_min + merge_patch_size |
|
patch_y_max = patch_y_min + merge_patch_size |
|
|
|
|
|
if not (patch_x_max <= x_min or patch_x_min >= x_max or |
|
patch_y_max <= y_min or patch_y_min >= y_max): |
|
|
|
patch_idx = y_idx * grid_w + x_idx |
|
binary_mask[patch_idx] = 1 |
|
|
|
return binary_mask |
|
|
|
def token_index_to_coordinates(image_processor, visual_token_index, image_width, image_height): |
|
merge_patch_size = image_processor.patch_size * image_processor.merge_size |
|
x_index = visual_token_index % (image_width // merge_patch_size) |
|
y_index = visual_token_index // (image_width // merge_patch_size) |
|
px = x_index * merge_patch_size + merge_patch_size / 2 |
|
py = y_index * merge_patch_size + merge_patch_size / 2 |
|
return px, py |
|
|
|
class LazySupervisedDataset(Dataset): |
|
def __init__( |
|
self, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
processor: transformers.ProcessorMixin, |
|
data_path: str, |
|
data_args, |
|
): |
|
super().__init__() |
|
self.tokenizer = tokenizer |
|
self.processor = processor |
|
self.list_data_dict = [] |
|
self.list_image_path = [] |
|
self.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0] |
|
self.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0] |
|
self.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0] |
|
|
|
|
|
if "{" in data_path and "}" in data_path: |
|
base_path, file_pattern = re.match(r"^(.*)\{(.*)\}\.json$", data_path).groups() |
|
file_names = file_pattern.split(",") |
|
rank0_print(f"Loading {file_names} from {base_path}") |
|
data_args.dataset_paths = [] |
|
for file_name in file_names: |
|
data_args.dataset_paths.append(f"{base_path}{file_name}.json") |
|
full_path = f"{base_path}{file_name}.json" |
|
rank0_print(f"Loading {full_path}") |
|
with open(full_path) as file: |
|
cur_data_dict = json.load(file) |
|
rank0_print(f"Loaded {len(cur_data_dict)} samples from {full_path}") |
|
self.list_data_dict.extend(cur_data_dict) |
|
elif data_path.endswith(".yaml"): |
|
with open(data_path) as file: |
|
yaml_data = yaml.safe_load(file) |
|
datasets = yaml_data.get("datasets") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_args.dataset_paths = [dataset.get("json_path") for dataset in datasets] |
|
for dataset in datasets: |
|
json_path = dataset.get("json_path") |
|
sampling_strategy = dataset.get("sampling_strategy", "all") |
|
images_folder = dataset.get("images_folder") |
|
sampling_number = None |
|
|
|
rank0_print(f"Loading {json_path} with {sampling_strategy} sampling strategy") |
|
|
|
if json_path.endswith(".jsonl"): |
|
cur_data_dict = [] |
|
with open(json_path) as json_file: |
|
for line in json_file: |
|
cur_data_dict.append(json.loads(line.strip())) |
|
elif json_path.endswith(".json"): |
|
|
|
|
|
with open(json_path) as json_file: |
|
cur_data_dict = json.load(json_file) |
|
else: |
|
raise ValueError(f"Unsupported file type: {json_path}") |
|
|
|
if ":" in sampling_strategy: |
|
sampling_strategy, sampling_number = sampling_strategy.split(":") |
|
if "%" in sampling_number: |
|
sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100) |
|
else: |
|
sampling_number = int(sampling_number) |
|
|
|
|
|
if sampling_strategy == "first" and sampling_number is not None: |
|
cur_data_dict = cur_data_dict[:sampling_number] |
|
elif sampling_strategy == "end" and sampling_number is not None: |
|
cur_data_dict = cur_data_dict[-sampling_number:] |
|
elif sampling_strategy == "random" and sampling_number is not None: |
|
random.shuffle(cur_data_dict) |
|
cur_data_dict = cur_data_dict[:sampling_number] |
|
|
|
rank0_print(f"Loaded {len(cur_data_dict)} samples from {json_path}") |
|
self.list_data_dict.extend(cur_data_dict) |
|
self.list_image_path.extend([images_folder] * len(cur_data_dict)) |
|
else: |
|
data_args.dataset_paths = [data_path] |
|
rank0_print(f"Loading {data_path}") |
|
with open(data_path) as file: |
|
cur_data_dict = json.load(file) |
|
rank0_print(f"Loaded {len(cur_data_dict)} samples from {data_path}") |
|
self.list_data_dict.extend(cur_data_dict) |
|
self.list_image_path.extend([""] * len(cur_data_dict)) |
|
|
|
rank0_print(f"Loaded {len(self.list_data_dict)} samples from {data_path}") |
|
rank0_print("Formatting inputs...Skip in lazy mode") |
|
self.tokenizer = tokenizer |
|
self.data_args = data_args |
|
|
|
def __len__(self): |
|
return len(self.list_data_dict) |
|
|
|
@property |
|
def lengths(self): |
|
length_list = [] |
|
for sample in self.list_data_dict: |
|
img_tokens = ( |
|
1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0 |
|
) |
|
length_list.append(sum(len(conv["value"].split()) for conv in sample["conversations"]) + img_tokens) |
|
return length_list |
|
|
|
@property |
|
def modality_lengths(self): |
|
length_list = [] |
|
for sample in self.list_data_dict: |
|
cur_len = sum(len(conv["value"].split()) for conv in sample["conversations"]) |
|
assert cur_len > 0, f"Conversation length is 0 for {sample}" |
|
|
|
img_tokens = ( |
|
1200 * len(sample["image"]) if isinstance(sample["image"], list) else 1200 if "image" in sample else 0 |
|
) |
|
|
|
if "image" in sample or "video" in sample or self.data_args.early_mix_text: |
|
length_list.append(cur_len + img_tokens) |
|
else: |
|
length_list.append(-cur_len) |
|
return length_list |
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
sample = self._get_item(i) |
|
if sample is None: |
|
new_index = random.randint(0, len(self.list_data_dict) - 1) |
|
return self.__getitem__(new_index) |
|
else: |
|
return sample |
|
try: |
|
sample = self._get_item(i) |
|
if sample is None: |
|
new_index = random.randint(0, len(self.list_data_dict) - 1) |
|
return self.__getitem__(new_index) |
|
except Exception as e: |
|
print(f"Failed to fetch sample {i}. Exception:", e) |
|
new_index = random.randint(0, len(self.list_data_dict) - 1) |
|
return self.__getitem__(new_index) |
|
return sample |
|
|
|
def _get_item(self, i) -> Dict[str, torch.Tensor]: |
|
sources = self.list_data_dict[i] |
|
image_path = os.path.join(self.data_args.image_folder, self.list_image_path[i]) |
|
|
|
if "image" in sources: |
|
image_file = self.list_data_dict[i]["image"] |
|
if type(image_file) is list: |
|
image_list = [os.path.join(image_path, image_file) for image_file in image_file] |
|
else: |
|
image_list = [os.path.join(image_path, image_file)] |
|
|
|
sources = copy.deepcopy(sources["conversations"]) |
|
elif "video" in sources: |
|
raise NotImplementedError("Video is not supported for Qwen2VL") |
|
else: |
|
sources = copy.deepcopy(sources["conversations"]) |
|
|
|
item_id = self.list_data_dict[i].get("id", i) |
|
|
|
data_dict = self.preprocess_qwen2vl(sources, self.tokenizer, self.processor, image_list, id=item_id) |
|
if isinstance(i, int): |
|
data_dict = { |
|
"input_ids": data_dict["input_ids"][0], |
|
"labels": data_dict["labels"][0], |
|
"coordinates": data_dict["coordinates"][0], |
|
"visual_token_indices_of_coordinates": data_dict["visual_token_indices_of_coordinates"][0], |
|
"pixel_values": data_dict["pixel_values"], |
|
"image_grid_thw": data_dict["image_grid_thw"], |
|
"multi_patch_labels": data_dict["multi_patch_labels"][0], |
|
} |
|
|
|
data_dict["id"] = item_id |
|
|
|
|
|
n_image_tokens = ( |
|
data_dict["image_grid_thw"][0][0] * |
|
data_dict["image_grid_thw"][0][1] * |
|
data_dict["image_grid_thw"][0][2] / |
|
self.processor.image_processor.merge_size / |
|
self.processor.image_processor.merge_size |
|
) |
|
if (len(data_dict["input_ids"]) + n_image_tokens) > self.tokenizer.model_max_length: |
|
rank0_print(f"=== Removed data_dict {i} because it is longer than the model_max_length: {len(data_dict['input_ids'])} + {n_image_tokens} > {self.tokenizer.model_max_length}") |
|
return None |
|
|
|
return data_dict |
|
|
|
def preprocess_qwen2vl( |
|
self, |
|
source, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
processor: transformers.ProcessorMixin, |
|
image: list, |
|
system_message: str = grounding_system_message, |
|
agent_mode: bool = True, |
|
chat_template: str = chat_template, |
|
assistant_template: str = assistant_template, |
|
id: int = None, |
|
) -> Dict: |
|
roles = {"human": "user", "gpt": "assistant", "system": "system"} |
|
assistant_template = assistant_template if agent_mode else chat_template |
|
processor.tokenizer = tokenizer |
|
assert tokenizer.additional_special_tokens == ADDITIONAL_SPECIAL_TOKENS |
|
|
|
|
|
pixel_values, image_grid_thw = None, None |
|
|
|
input_id, target = [], [] |
|
coordinates = [] |
|
visual_token_indices_of_coordinates = [] |
|
multi_patch_labels = [] |
|
|
|
image_list = [] |
|
image_index = 0 |
|
|
|
|
|
if roles[source[0]["from"]] == "system": |
|
system_message = source[0]["value"] |
|
source = source[1:self.data_args.max_conv_turns] |
|
|
|
system_input_id = tokenizer.apply_chat_template( |
|
conversation=[{"role": "system", "content": [{"type": "text", "text": system_message}]}], |
|
chat_template=chat_template, |
|
) |
|
input_id += system_input_id |
|
target += [IGNORE_INDEX] * len(system_input_id) |
|
|
|
|
|
for conv in source: |
|
|
|
try: |
|
role = conv["role"] |
|
content = conv["content"] |
|
except Exception: |
|
role = conv["from"] |
|
content = conv["value"] |
|
role = roles.get(role, role) |
|
|
|
|
|
image_count = content.count(DEFAULT_IMAGE_TOKEN) |
|
if image_count > 0: |
|
assert role == "user", "Images are only supported for user messages" |
|
|
|
image_placeholders = [] |
|
for _ in range(image_count): |
|
image_placeholders.append({ |
|
"type": "image", |
|
"image": image[image_index], |
|
"min_pixels": self.processor.image_processor.min_pixels, |
|
"max_pixels": self.processor.image_processor.max_pixels, |
|
}) |
|
image_index += 1 |
|
|
|
content = content.replace(DEFAULT_IMAGE_TOKEN, "") |
|
conv = {"role": role, "content": image_placeholders + [{"type": "text", "text": content}]} |
|
|
|
image_inputs, _ = process_vision_info([conv]) |
|
image_list.extend(image_inputs) |
|
|
|
templated_conv = tokenizer.apply_chat_template( |
|
conversation=[conv], chat_template=chat_template, tokenize=False |
|
) |
|
inputs = processor(text=[templated_conv], images=image_inputs, return_tensors="pt") |
|
|
|
if pixel_values is None and image_grid_thw is None: |
|
pixel_values = inputs["pixel_values"] |
|
image_grid_thw = inputs["image_grid_thw"] |
|
else: |
|
pixel_values = torch.concat([pixel_values, inputs["pixel_values"]], dim=0) |
|
image_grid_thw = torch.concat([image_grid_thw, inputs["image_grid_thw"]], dim=0) |
|
else: |
|
if role in ["user", "system"]: |
|
conv = {"role": role, "content": [{"type": "text", "text": content}]} |
|
else: |
|
conv = { |
|
"role": role, |
|
"content": [{"type": "text", "text": content}], |
|
"recipient": conv.get("recipient", "os"), |
|
"end_turn": conv.get("end_turn", True), |
|
"bbox_gt": conv.get("bbox_gt", None), |
|
} |
|
if conv["recipient"] == "os": |
|
if len(image_inputs) == 0: |
|
raise ValueError("No image found for visual grounding") |
|
|
|
text, coord = reformat_coordinates(conv["content"][0]["text"]) |
|
conv["content"][0]["text"] = text |
|
|
|
|
|
|
|
coordinates.extend(coord) |
|
for (point_x, point_y) in coord: |
|
visual_token_index = get_token_index( |
|
processor.image_processor, |
|
image_list, |
|
point_x, |
|
point_y |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
visual_token_indices_of_coordinates.append(visual_token_index) |
|
|
|
if conv["bbox_gt"] is not None: |
|
patch_mask = get_multi_patch_labels( |
|
processor.image_processor, |
|
image_list, |
|
conv["bbox_gt"] |
|
) |
|
multi_patch_labels.append(patch_mask) |
|
|
|
templated_conv = tokenizer.apply_chat_template( |
|
conversation=[conv], |
|
chat_template=assistant_template, |
|
tokenize=False, |
|
) |
|
inputs = processor(text=[templated_conv], return_tensors="pt") |
|
|
|
encode_id = inputs.input_ids[0].tolist() |
|
|
|
input_id += encode_id |
|
if role in ["user", "system"]: |
|
target += [IGNORE_INDEX] * len(encode_id) |
|
else: |
|
target += encode_id |
|
|
|
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
|
|
|
|
|
target = [IGNORE_INDEX if token == self.pointer_end_token_id else token for token in target] |
|
|
|
input_ids = torch.tensor([input_id], dtype=torch.long) |
|
targets = torch.tensor([target], dtype=torch.long) |
|
visual_token_indices_of_coordinates = torch.tensor([visual_token_indices_of_coordinates], dtype=torch.long) if len(visual_token_indices_of_coordinates) > 0 else [None] |
|
coordinates = [coordinates] if len(coordinates) > 0 else [None] |
|
|
|
|
|
if len(multi_patch_labels) > 0: |
|
multi_patch_labels = [torch.stack(multi_patch_labels)] |
|
else: |
|
multi_patch_labels = [None] |
|
|
|
data_dict = { |
|
"input_ids": input_ids, |
|
"labels": targets, |
|
} |
|
|
|
if pixel_values is not None: |
|
data_dict["pixel_values"] = pixel_values |
|
data_dict["image_grid_thw"] = image_grid_thw |
|
|
|
|
|
|
|
data_dict["coordinates"] = coordinates |
|
data_dict["visual_token_indices_of_coordinates"] = visual_token_indices_of_coordinates |
|
data_dict["multi_patch_labels"] = multi_patch_labels |
|
|
|
return data_dict |
|
|