muddit-interface / train /dataset_utils.py
QingyuShi's picture
Upload folder using huggingface_hub
7c8069d verified
# Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL.ImageOps import exif_transpose
from PIL import Image
import io
import json
import numpy as np
import pyarrow.parquet as pq
import random
import bisect
import pyarrow.fs as fs
@torch.no_grad()
def tokenize_prompt(
tokenizer,
prompt,
text_encoder_architecture='open_clip',
padding='max_length',
max_length=77,
max_length_t5=256,
):
if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip':
input_ids = tokenizer(
prompt,
truncation=True,
padding=padding,
max_length=max_length,
return_tensors="pt",
).input_ids
return input_ids
elif text_encoder_architecture == 't5_clip': # we have two tokenizers, 1st for CLIP, 2nd for T5
input_ids = []
input_ids.append(tokenizer[0](
prompt,
truncation=True,
padding=padding,
max_length=max_length,
return_tensors="pt",
).input_ids)
input_ids.append(tokenizer[1](
prompt,
truncation=True,
padding=padding,
max_length=max_length_t5,
return_tensors="pt",
).input_ids)
return input_ids
elif text_encoder_architecture == "gemma":
input_ids = []
input_ids.append(tokenizer[0](
prompt,
truncation=True,
padding=padding,
max_length=max_length,
return_tensors="pt",
).input_ids)
input_ids.append(tokenizer[1](
prompt,
truncation=True,
padding=padding,
max_length=max_length_t5,
return_tensors="pt",
).input_ids)
return input_ids
else:
raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}")
def encode_prompt(
text_encoder,
input_ids,
text_encoder_architecture='open_clip'
):
if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip':
outputs = text_encoder(input_ids=input_ids, return_dict=True, output_hidden_states=True)
encoder_hidden_states = outputs.hidden_states[-2]
cond_embeds = outputs[0]
return encoder_hidden_states, cond_embeds
elif text_encoder_architecture == 't5_clip':
outputs_clip = text_encoder[0](
input_ids=input_ids[0],
return_dict=True,
output_hidden_states=True
)
outputs_t5 = text_encoder[1](
input_ids=input_ids[1],
return_dict=True,
output_hidden_states=True
)
encoder_hidden_states = outputs_t5.last_hidden_state
cond_embeds = outputs_clip.text_embeds
return encoder_hidden_states, cond_embeds
elif text_encoder_architecture == "gemma":
outputs_clip = text_encoder[0](
input_ids=input_ids[0],
return_dict=True,
output_hidden_states=True
)
outputs_gemma = text_encoder[1](
input_ids=input_ids[1],
return_dict=True,
output_hidden_states=True
)
encoder_hidden_states = outputs_gemma.last_hidden_state
cond_embeds = outputs_clip.text_embeds
return encoder_hidden_states, cond_embeds
else:
raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}")
def process_image(image, size, Norm=False, hps_score=6.0):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
orig_height = image.height
orig_width = image.width
image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image)
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size))
image = transforms.functional.crop(image, c_top, c_left, size, size)
image = transforms.ToTensor()(image)
if Norm:
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image)
micro_conds = torch.tensor(
[orig_width, orig_height, c_top, c_left, hps_score],
)
return {"image": image, "micro_conds": micro_conds}
class ImageCaptionLargeDataset(Dataset):
def __init__(
self,
root_dir,
tokenizer,
size,
text_encoder_architecture="CLIP",
norm=False
):
self.root_dir = root_dir
self.tokenizer = tokenizer
self.size = size
self.text_encoder_architecture = text_encoder_architecture
self.norm = norm
self.data_list = []
for root, dirnames, filenames in os.walk(root_dir):
for filename in filenames:
if filename.endswith(".jpg") or filename.endswith(".png"):
base_name = os.path.splitext(filename)[0]
txt_file = os.path.join(root, base_name + ".txt")
if os.path.exists(txt_file):
self.data_list.append((root, base_name + ".txt", filename))
def __len__(self):
return len(self.data_list)
def __getitem__(self, idx):
try:
sub_dir, txtfilename, imgfilename = self.data_list[idx]
img_path = os.path.join(sub_dir, imgfilename)
caption_path = os.path.join(sub_dir, txtfilename)
image = Image.open(img_path).convert("RGB")
ret = process_image(image, self.size, self.norm)
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
ret["prompt_input_ids"] = tokenize_prompt(
self.tokenizer, caption, self.text_encoder_architecture
)
return ret
except Exception as e:
print("===========================================")
print(f"[Warning] Error at index {idx}: {img_path}")
print("===========================================")
if idx + 1 < len(self.data_list):
return self.__getitem__(idx + 1)
else:
return self.__getitem__(len(self.data_list) - 1)
class MultiSourceVLDataset(Dataset):
"""
A unified dataloader for
• LLaVA-Instruct-150K
• MMMU (multiple-choice QA)
• VQAv2
• Local caption files under `pdd3/`
"""
def __init__(
self,
tokenizer,
size: int,
text_encoder_architecture: str = "CLIP",
norm: bool = False,
# ----- paths -----
llava_json: str = None, llava_img_root: str = None,
mmmu_json: str = None, mmmu_img_root: str = None,
vqa_ann_json: str = None, vqa_img_root: str = None,
gqa_json: str = None, gqa_img_root: str = None,
coco_json: str = None, coco_img_root: str = None,
coco_qa_json: str = None,
mg_llava_json: str = None, mg_llava_root: str = None,
pdd3_dir: str = None, caption_dir: str = None,
):
self.tokenizer = tokenizer
self.size = size
self.arch = text_encoder_architecture
self.norm = norm
self.gen_samples = [] # [(img_path, prompt), ...]
self.mmu_samples = [] # [(img_path, question, answer), ...]
if llava_json:
self._load_llava(llava_json, llava_img_root)
if mmmu_json:
self._load_mmmu(mmmu_json, mmmu_img_root)
if vqa_ann_json:
self._load_vqav2(vqa_ann_json, vqa_img_root)
if coco_json:
self._load_coco2014_captions(coco_json, coco_img_root)
if coco_qa_json:
self._load_coco2014_qa(coco_qa_json, coco_img_root)
if gqa_json:
self._load_gqa(gqa_json, gqa_img_root)
if mg_llava_json:
self._load_mg_llava(mg_llava_json, mg_llava_root)
if caption_dir:
self._load_caption(caption_dir)
if pdd3_dir:
self._load_pdd3(pdd3_dir)
self.len_mmu = len(self.mmu_samples)
self.len_gen = len(self.gen_samples)
# ------------------------------------------------------------------ #
# dataset parsers #
# ------------------------------------------------------------------ #
def _load_llava(self, json_path, img_root):
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
for ex in data:
img_file = os.path.join(img_root, ex["image"])
human_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "human")
gpt_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "gpt")
self.mmu_samples.append((img_file, human_msg.strip(), gpt_msg.strip()))
def _load_mmmu(self, json_path, img_root):
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
for ex in data:
img_file = os.path.join(img_root, ex["image"])
choices = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(ex["choices"])])
question = f"{ex['question'].strip()}\n{choices}"
answer = f"{ex['answer']}"
self.mmu_samples.append((img_file, question, answer))
def _load_coco2014_qa(self, ann_jsonl, img_root):
with open(ann_jsonl, "r", encoding="utf-8") as file:
data = [json.loads(line) for line in file if line.strip()]
for ann in data:
image = ann["image"]
question = ann["question"]
answer = ann["label"]
image_path = os.path.join(img_root, image)
self.mmu_samples.append((image_path, question, answer))
def _load_coco2014_captions(self, ann_json, img_root):
"""
Load COCO 2014 image-caption pairs from caption annotation file.
Args:
ann_json (str): Path to COCO-style captions JSON (e.g., captions_train2014.json)
img_root (str): Directory containing COCO images (should include 'train2014/' and 'val2014/' subdirs)
"""
with open(ann_json, "r") as f:
data = json.load(f)
is_train = "train" in os.path.basename(ann_json).lower()
img_subdir = "train2014" if is_train else "val2014"
prefix = "COCO_train2014_" if is_train else "COCO_val2014_"
for ann in data["annotations"]:
image_id = ann["image_id"]
caption = ann["caption"]
image_filename = f"{prefix}{image_id:012d}.jpg"
image_path = os.path.join(img_root, img_subdir, image_filename)
question = "Please describe this image concisely."
self.mmu_samples.append((image_path, question, caption))
def _load_vqav2(self, ann_json, img_root):
with open(ann_json, "r") as file:
annos = json.load(file)
for ann in annos:
q = ann["question"]
answer = ann["answer"]
img_path = ann["image"]
img_file = os.path.join(
img_root,
img_path # if val, modify to val2014
)
self.mmu_samples.append((img_file, q, answer))
def _load_gqa(self, ann_json_root, img_root):
annos = {}
for jsonfile in os.listdir(ann_json_root):
jsonpath = os.path.join(ann_json_root, jsonfile)
with open(jsonpath, "r") as file:
anno = json.load(file)
annos.update(anno)
for ann in annos.values():
q = ann["question"]
answer = ann["fullAnswer"]
img_name = ann["imageId"] + ".jpg"
img_path = os.path.join(
img_root,
img_name
)
self.mmu_samples.append((img_path, q, answer))
def _load_mg_llava(self, json_path, img_root):
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
for ex in data:
image = ex.get("image", None)
if image is not None:
img_file = os.path.join(img_root, ex["image"])
if os.path.exists(img_file):
human_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "human")
gpt_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "gpt")
self.mmu_samples.append((img_file, human_msg.strip(), gpt_msg.strip()))
def _load_caption(self, root_dir):
for root, _, files in os.walk(root_dir):
for f in files:
if f.lower().endswith((".jpg", ".png")):
base = os.path.splitext(f)[0]
txt_path = os.path.join(root, base + ".txt")
if os.path.exists(txt_path):
with open(txt_path, "r") as file:
caption = file.read().strip()
q = "Please describe this image."
self.mmu_samples.append((os.path.join(root, f), q, caption))
def _load_pdd3(self, root_dir):
for root, _, files in os.walk(root_dir):
for f in files:
if f.lower().endswith((".jpg", ".png")):
base = os.path.splitext(f)[0]
txt_path = os.path.join(root, base + ".txt")
if os.path.exists(txt_path):
with open(txt_path, "r") as file:
caption = file.read().strip()
self.gen_samples.append((os.path.join(root, f), caption))
# ------------------------------------------------------------------ #
# PyTorch Dataset API #
# ------------------------------------------------------------------ #
def __len__(self):
return max(self.len_gen, self.len_mmu)
def __getitem__(self, idx):
get_mmu_data = False
get_gen_data = False
while not get_mmu_data:
try:
mmu_img_path, question, answer = self.mmu_samples[idx]
get_mmu_data = True
except:
idx = random.randint(0, self.len_mmu - 1)
while not get_gen_data:
try:
gen_img_path, prompt = self.gen_samples[idx]
get_gen_data = True
except:
idx = random.randint(0, self.len_gen - 1)
try:
# ---- image ----
mmu_image = Image.open(mmu_img_path).convert("RGB")
mmu_ret = process_image(mmu_image, self.size, self.norm)
gen_image = Image.open(gen_img_path).convert("RGB")
gen_ret = process_image(gen_image, self.size, self.norm)
ret = dict(
gen_image=gen_ret["image"],
gen_micro_conds=gen_ret["micro_conds"],
mmu_image=mmu_ret["image"],
mmu_micro_conds=mmu_ret["micro_conds"]
)
# ---- text ----
question = question.replace("<image>", "").replace("\n", "")
question_ids = tokenize_prompt(
self.tokenizer,
question,
self.arch,
padding=False,
)
question_ids = question_ids[:, :-1]
q_len = len(question_ids[0])
if answer:
full_prompt = question + " " + answer
else:
full_prompt = question
mmu_input_ids = tokenize_prompt(self.tokenizer, full_prompt, self.arch)
gen_input_ids = tokenize_prompt(self.tokenizer, prompt, self.arch)
ret.update({
"gen_input_ids": gen_input_ids,
"mmu_input_ids": mmu_input_ids,
"question_len": torch.LongTensor([q_len])
})
return ret
except:
print("================================================================")
print(f"There is something wrong with {mmu_img_path} or {gen_img_path}.")
print("================================================================")
if idx < self.len_gen - 1 or idx < self.len_mmu - 1:
return self.__getitem__(idx + 1)
else:
idx = random.randint(0, self.len_gen - 1)
return self.__getitem__(idx)