|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
import os |
|
import json |
|
import random |
|
import torch |
|
import numpy as np |
|
from einops import rearrange |
|
from xtuner.registry import BUILDER |
|
from xtuner.dataset.utils import expand2square |
|
from src.datasets.utils import crop2square, encode_fn, load_jsonl |
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN |
|
from transformers import AutoImageProcessor |
|
|
|
|
|
class VLMDataset(Dataset): |
|
def __init__( |
|
self, |
|
data_path, |
|
image_size, |
|
tokenizer=None, |
|
template_map_fn=None, |
|
max_length=2048, |
|
min_image_size=80, |
|
pad_image=True, |
|
local_folder="", |
|
key_value="conversations", |
|
): |
|
super().__init__() |
|
self.data_path = data_path |
|
self._load_data(data_path) |
|
self.image_size = image_size |
|
|
|
self.tokenizer = BUILDER.build(tokenizer) |
|
self.prompt_template = template_map_fn["template"] |
|
self.template_map_fn = BUILDER.build(template_map_fn) |
|
self.max_length = max_length |
|
self.pad_image = pad_image |
|
self.min_image_size = min_image_size |
|
self.key_value = key_value |
|
self.processor = AutoImageProcessor.from_pretrained( |
|
"checkpoint/siglip2-so400m-patch16-512" |
|
) |
|
self.metainfo = {'task' :'unified'} |
|
self.DEFAULT_IMAGE_TOKEN = DEFAULT_IMAGE_TOKEN |
|
m = n = self.image_size // 16 |
|
self.image_token_repeat = m * n + 64 |
|
|
|
self.tokenizer.add_tokens(["<image>"], special_tokens=True) |
|
self.image_token_idx = self.tokenizer.convert_tokens_to_ids("<image>") |
|
print(f"Registered <image> token at index {self.image_token_idx}") |
|
|
|
def _load_data( |
|
self, data_path: str |
|
): |
|
self.data_list = load_jsonl(data_path) |
|
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True) |
|
|
|
def full_init(self): |
|
"""Dummy full_init to be compatible with MMEngine ConcatDataset.""" |
|
return |
|
def __len__(self): |
|
return len(self.data_list) |
|
|
|
def _read_image(self, image_file): |
|
image = Image.open(image_file) |
|
assert ( |
|
image.width > self.min_image_size and image.height > self.min_image_size |
|
), f"Image: {image.size}" |
|
assert image.width / image.height > 0.1, f"Image: {image.size}" |
|
assert image.width / image.height < 10, f"Image: {image.size}" |
|
return image.convert("RGB") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _process_image(self, image: Image.Image): |
|
|
|
if self.pad_image: |
|
image = crop2square(image) |
|
|
|
image = image.resize((self.image_size, self.image_size)) |
|
|
|
arr = np.array(image).astype(np.float32) / 255.0 |
|
arr = 2 * arr - 1 |
|
tensor = torch.from_numpy(arr) |
|
tensor = rearrange(tensor, "h w c -> c h w") |
|
return {"pixel_values": tensor} |
|
def _process_text(self, question, answer): |
|
data_dict = dict( |
|
conversation=[ |
|
{ |
|
"input": f"{self.DEFAULT_IMAGE_TOKEN}\n{question}", |
|
"output": answer, |
|
} |
|
] |
|
) |
|
data_dict.update(self.template_map_fn(data_dict)) |
|
data_dict.update( |
|
encode_fn( |
|
example=data_dict, |
|
tokenizer=self.tokenizer, |
|
max_length=self.max_length, |
|
image_length=self.image_token_repeat, |
|
input_ids_with_output=True, |
|
with_image_token=True, |
|
truncation='right', |
|
image_token_idx=self.image_token_idx, |
|
image_token_str=self.DEFAULT_IMAGE_TOKEN, |
|
) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
data_dict["type"] = "image2text" |
|
return data_dict |
|
|
|
def _retry(self): |
|
return self.__getitem__(random.choice(range(self.__len__()))) |
|
|
|
def __getitem__(self, idx): |
|
try: |
|
data_sample = self.data_list[idx] |
|
image = self._read_image(data_sample["image"]).convert("RGB") |
|
data = self._process_image(image) |
|
del image |
|
question = ( |
|
data_sample[self.key_value][0]["value"] |
|
.replace("<image>", "") |
|
.strip() |
|
) |
|
answer = ( |
|
data_sample[self.key_value][1]["value"] |
|
.replace("<image>", "") |
|
.strip() |
|
) |
|
|
|
data.update(self._process_text(question, answer)) |
|
|
|
data.update(image_file=data_sample["image"]) |
|
|
|
return data |
|
|
|
except Exception as e: |
|
print( |
|
f"Error when reading data_sample:{data_sample},{self.data_path}:{data_sample['image']}: {e}", |
|
flush=True, |
|
) |
|
return self._retry() |
|
|