|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
import os |
|
import io |
|
import json |
|
import random |
|
import torch |
|
import numpy as np |
|
from einops import rearrange |
|
try: |
|
from aoss_client.client import Client |
|
except: |
|
try: |
|
from petrel_client.client import Client |
|
except: |
|
Client = None |
|
from glob import glob |
|
from xtuner.registry import BUILDER |
|
from xtuner.dataset.utils import expand2square |
|
from src.datasets.utils import crop2square, encode_fn |
|
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX |
|
from src.datasets.understanding.caption_prompts import dense_prompts, short_prompts |
|
from typing import List, Dict, Any, Optional,Callable,Tuple |
|
|
|
|
|
@BUILDER.register_module() |
|
class CaptionDataset(Dataset): |
|
def __init__(self, |
|
data_path, |
|
local_folder, |
|
image_size, |
|
ceph_folder=None, |
|
ceph_config=None, |
|
tokenizer=None, |
|
template_map_fn=None, |
|
max_length=2048, |
|
min_image_size=80, |
|
image_length=256, |
|
pad_image=True, |
|
brief=False, |
|
cap_folder=None, |
|
cap_source='caption', |
|
): |
|
super().__init__() |
|
self.data_path = data_path |
|
self._load_data(data_path) |
|
self.local_folder = local_folder |
|
self.cap_folder = local_folder if cap_folder is None else cap_folder |
|
self.cap_source = cap_source |
|
|
|
self.image_size = image_size |
|
|
|
self.tokenizer = BUILDER.build(tokenizer) |
|
self.prompt_template = template_map_fn['template'] |
|
self.template_map_fn = BUILDER.build(template_map_fn) |
|
self.max_length = max_length |
|
self.image_length = image_length |
|
self.pad_image = pad_image |
|
self.min_image_size = min_image_size |
|
|
|
self.FILE_CLIENT = None |
|
self.ceph_folder = ceph_folder |
|
self.ceph_config = ceph_config |
|
self.use_ceph = ((Client is not None) and (ceph_folder is not None) |
|
and (ceph_config is not None) and os.path.exists(ceph_config)) |
|
|
|
self.brief = brief |
|
self.caption_prompts = short_prompts if self.brief else dense_prompts |
|
|
|
def _load_data(self, data_path: str): |
|
if data_path.endswith('.json'): |
|
with open(data_path, 'r') as f: |
|
self.data_list = json.load(f) |
|
else: |
|
json_files = glob(f"{data_path}/*.json") |
|
data_list = [] |
|
for json_file in json_files: |
|
with open(json_file, 'r') as f: |
|
data_list += json.load(f) |
|
|
|
self.data_list = data_list |
|
|
|
print(f"Load {len(self.data_list)} data samples from {data_path}", flush=True) |
|
|
|
def __len__(self): |
|
return len(self.data_list) |
|
|
|
def _read_ceph(self, ceph_path): |
|
if self.FILE_CLIENT is None: |
|
self.FILE_CLIENT = Client(self.ceph_config) |
|
data_bytes = self.FILE_CLIENT.get(ceph_path) |
|
|
|
return io.BytesIO(data_bytes) |
|
|
|
def _read_image(self, image_file): |
|
if self.use_ceph: |
|
image = Image.open( |
|
self._read_ceph( |
|
os.path.join(self.ceph_folder, image_file) |
|
) |
|
) |
|
else: |
|
image = Image.open( |
|
os.path.join(self.local_folder, image_file) |
|
) |
|
assert image.width > self.min_image_size and image.height > self.min_image_size, f"Image: {image.size}" |
|
assert image.width / image.height > 0.1, f"Image: {image.size}" |
|
assert image.width / image.height < 10, f"Image: {image.size}" |
|
return image.convert('RGB') |
|
|
|
def _read_json(self, annotation_file): |
|
if self.use_ceph: |
|
annotation = json.load( |
|
self._read_ceph( |
|
os.path.join(self.ceph_folder, annotation_file) |
|
) |
|
) |
|
else: |
|
with open(os.path.join(self.local_folder, annotation_file), 'r') as f: |
|
annotation = json.load(f) |
|
|
|
return annotation |
|
|
|
def _process_image(self, image): |
|
data = dict() |
|
if self.pad_image: |
|
image = expand2square(image, (127, 127, 127)) |
|
else: |
|
image = crop2square(image) |
|
|
|
image = image.resize(size=(self.image_size, self.image_size)) |
|
pixel_values = torch.from_numpy(np.array(image)).float() |
|
pixel_values = pixel_values / 255 |
|
pixel_values = 2 * pixel_values - 1 |
|
pixel_values = rearrange(pixel_values, 'h w c -> c h w') |
|
|
|
data.update(pixel_values=pixel_values) |
|
return data |
|
|
|
def _process_text(self, text): |
|
assert DEFAULT_IMAGE_TOKEN not in text, text |
|
data_dict = dict(conversation=[{'input': f"{DEFAULT_IMAGE_TOKEN}\n{random.choice(self.caption_prompts)}", |
|
'output': text.strip()}]) |
|
data_dict.update(self.template_map_fn(data_dict)) |
|
data_dict.update(encode_fn(data_dict, self.tokenizer, self.max_length, |
|
self.image_length, True, True)) |
|
|
|
assert (torch.tensor(data_dict['input_ids']).long() == IMAGE_TOKEN_INDEX).sum() == self.image_length, \ |
|
"Error in image format" |
|
|
|
data_dict['type'] = 'image2text' |
|
return data_dict |
|
|
|
def _retry(self): |
|
return self.__getitem__(random.choice(range(self.__len__()))) |
|
|
|
def __getitem__(self, idx): |
|
try: |
|
data_sample = self.data_list[idx] |
|
image = self._read_image(data_sample['image']).convert('RGB') |
|
data = self._process_image(image) |
|
del image |
|
with open(f"{self.cap_folder}/{data_sample['annotation']}", 'r') as f: |
|
caption = json.load(f)[self.cap_source] |
|
data.update(self._process_text(caption)) |
|
|
|
data.update(image_dir=self.local_folder, image_file=data_sample['image']) |
|
|
|
return data |
|
|
|
except Exception as e: |
|
print(f"Error when reading {self.data_path}:{data_sample['image']}: {e}", flush=True) |
|
return self._retry() |
|
|
|
|
|
@BUILDER.register_module() |
|
class VqaDataset(Dataset): |
|
"""Generic VQA / multimodal conversation dataset with robust IO & validation.""" |
|
|
|
def __init__( |
|
self, |
|
data_path: str, |
|
tokenizer, |
|
template_map_fn: Callable, |
|
img_prefix: Optional[str] = None, |
|
image_size: int = 512, |
|
max_length: int = 2048, |
|
image_length: int = 1089, |
|
pad_image: bool = True, |
|
min_image_size: int = 80, |
|
image_token_patterns: Tuple[str, ...] = ('<image>', '[image]', '<img>'), |
|
max_retry: int = 5, |
|
): |
|
super().__init__() |
|
|
|
self.img_prefix = img_prefix.rstrip("/") if img_prefix else None |
|
self.image_size = image_size |
|
self.max_length = max_length |
|
self.image_length = image_length |
|
self.pad_image = pad_image |
|
self.min_image_size = min_image_size |
|
self.image_token_patterns = list(image_token_patterns) |
|
self.max_retry = max_retry |
|
|
|
|
|
self.tokenizer = BUILDER.build(tokenizer) |
|
self.template_map_fn = BUILDER.build(template_map_fn) if template_map_fn else None |
|
|
|
|
|
self.data_list = self._load_jsonl_list(data_path) |
|
print(f"Loaded {len(self.data_list)} samples from {data_path}") |
|
|
|
|
|
@staticmethod |
|
def _load_jsonl_list(path: str) -> List[Dict[str, Any]]: |
|
data: List[Dict[str, Any]] = [] |
|
if path.endswith(".jsonl"): |
|
files = [path] |
|
else: |
|
files = sorted(glob(os.path.join(path, "**/*.jsonl"), recursive=True)) |
|
|
|
for file in files: |
|
with open(file, "r") as f: |
|
for line in f: |
|
line = line.strip() |
|
if line: |
|
data.append(json.loads(line)) |
|
return data |
|
|
|
|
|
def __len__(self) -> int: |
|
return len(self.data_list) |
|
|
|
|
|
def _get_image_path(self, img_file: str) -> str: |
|
"""保持绝对路径不变,否则加前缀""" |
|
return img_file if os.path.isabs(img_file) else os.path.join(self.img_prefix, img_file) |
|
|
|
def _read_image(self, img_file: str) -> Image.Image: |
|
img_path = self._get_image_path(img_file) |
|
try: |
|
image = Image.open(img_path).convert("RGB") |
|
except Exception as e: |
|
raise FileNotFoundError(f"Cannot open image: {img_path} ({e})") |
|
|
|
w, h = image.size |
|
if w < self.min_image_size or h < self.min_image_size: |
|
raise ValueError(f"Image too small: {img_path} ({w}x{h})") |
|
ratio = w / h |
|
if not (0.1 < ratio < 10): |
|
raise ValueError(f"Odd aspect ratio ({ratio:.3f}) for {img_path}") |
|
|
|
|
|
image = expand2square(image, (127, 127, 127)) if self.pad_image else crop2square(image) |
|
image = image.resize((self.image_size, self.image_size), resample=Image.BICUBIC) |
|
|
|
px = torch.from_numpy(np.asarray(image)).float() / 255.0 |
|
px = 2 * px - 1.0 |
|
px = rearrange(px, "h w c -> c h w") |
|
return px |
|
|
|
|
|
def _replace_image_tokens(self, txt: str) -> str: |
|
for pat in self.image_token_patterns: |
|
if pat in txt: |
|
txt = txt.replace(pat, str(self.image_token_idx)) |
|
return txt |
|
|
|
def _format_conversation(self, turns: List[Dict[str, str]]) -> Dict[str, Any]: |
|
""" |
|
将多个 human/gpt 轮次合并为若干 {'input':..., 'output':...} 对。 |
|
遵循:human → gpt 为一对;若缺失 reply,用占位符。 |
|
""" |
|
pairs = [] |
|
|
|
for i in range(0, len(turns), 2): |
|
if i + 1 < len(turns): |
|
human_turn = turns[i] |
|
gpt_turn = turns[i + 1] |
|
|
|
human_content = human_turn.get("value", "").strip() |
|
gpt_content = gpt_turn.get("value", "").strip() |
|
|
|
if not human_content.lstrip().startswith("<image>"): |
|
human_content = f"<image>\n{human_content}" |
|
|
|
if not human_content or not gpt_content: |
|
continue |
|
|
|
|
|
|
|
|
|
pairs.append({"input": human_content, "output": gpt_content}) |
|
|
|
data_dict = {"conversation": pairs} |
|
data_dict_ori = data_dict |
|
if self.template_map_fn: |
|
data_dict = self.template_map_fn(data_dict) |
|
|
|
|
|
data_dict = encode_fn( |
|
data_dict, |
|
self.tokenizer, |
|
self.max_length, |
|
self.image_length, |
|
input_ids_with_output=True, |
|
with_image_token=True, |
|
|
|
image_token_idx=self.image_token_idx |
|
) |
|
|
|
|
|
img_tokens = (torch.tensor(data_dict["input_ids"]) == self.image_token_idx).sum().item() |
|
|
|
|
|
print(f"[校验日志] input_ids长度: {len(data_dict['input_ids'])}, 图像token出现次数: {img_tokens}\n") |
|
|
|
if img_tokens != 1088: |
|
print(f"[异常对话]:{data_dict_ori}") |
|
|
|
data_dict["type"] = "image2text" |
|
return data_dict |
|
|
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, Any]: |
|
for attempt in range(self.max_retry): |
|
try: |
|
sample = self.data_list[idx] |
|
img_tensor = self._read_image(sample["image"]) |
|
text_data = self._format_conversation(sample.get("conversations", [])) |
|
return { |
|
**text_data, |
|
"pixel_values": img_tensor, |
|
"image_file": sample["image"], |
|
} |
|
except Exception as e: |
|
print(f"[Retry {attempt+1}/{self.max_retry}] idx={idx} error: {e}") |
|
idx = random.randint(0, len(self) - 1) |
|
|
|
|
|
raise RuntimeError(f"Failed to fetch valid sample after {self.max_retry} retries.") |