Spaces:
Sleeping
Sleeping
# Copyright 2024 The HuggingFace Team and The MeissonFlow Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import os | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from PIL.ImageOps import exif_transpose | |
from PIL import Image | |
import io | |
import json | |
import numpy as np | |
import pyarrow.parquet as pq | |
import random | |
import bisect | |
import pyarrow.fs as fs | |
def tokenize_prompt( | |
tokenizer, | |
prompt, | |
text_encoder_architecture='open_clip', | |
padding='max_length', | |
max_length=77, | |
max_length_t5=256, | |
): | |
if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip': | |
input_ids = tokenizer( | |
prompt, | |
truncation=True, | |
padding=padding, | |
max_length=max_length, | |
return_tensors="pt", | |
).input_ids | |
return input_ids | |
elif text_encoder_architecture == 't5_clip': # we have two tokenizers, 1st for CLIP, 2nd for T5 | |
input_ids = [] | |
input_ids.append(tokenizer[0]( | |
prompt, | |
truncation=True, | |
padding=padding, | |
max_length=max_length, | |
return_tensors="pt", | |
).input_ids) | |
input_ids.append(tokenizer[1]( | |
prompt, | |
truncation=True, | |
padding=padding, | |
max_length=max_length_t5, | |
return_tensors="pt", | |
).input_ids) | |
return input_ids | |
elif text_encoder_architecture == "gemma": | |
input_ids = [] | |
input_ids.append(tokenizer[0]( | |
prompt, | |
truncation=True, | |
padding=padding, | |
max_length=max_length, | |
return_tensors="pt", | |
).input_ids) | |
input_ids.append(tokenizer[1]( | |
prompt, | |
truncation=True, | |
padding=padding, | |
max_length=max_length_t5, | |
return_tensors="pt", | |
).input_ids) | |
return input_ids | |
else: | |
raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}") | |
def encode_prompt( | |
text_encoder, | |
input_ids, | |
text_encoder_architecture='open_clip' | |
): | |
if text_encoder_architecture == 'CLIP' or text_encoder_architecture == 'open_clip': | |
outputs = text_encoder(input_ids=input_ids, return_dict=True, output_hidden_states=True) | |
encoder_hidden_states = outputs.hidden_states[-2] | |
cond_embeds = outputs[0] | |
return encoder_hidden_states, cond_embeds | |
elif text_encoder_architecture == 't5_clip': | |
outputs_clip = text_encoder[0]( | |
input_ids=input_ids[0], | |
return_dict=True, | |
output_hidden_states=True | |
) | |
outputs_t5 = text_encoder[1]( | |
input_ids=input_ids[1], | |
return_dict=True, | |
output_hidden_states=True | |
) | |
encoder_hidden_states = outputs_t5.last_hidden_state | |
cond_embeds = outputs_clip.text_embeds | |
return encoder_hidden_states, cond_embeds | |
elif text_encoder_architecture == "gemma": | |
outputs_clip = text_encoder[0]( | |
input_ids=input_ids[0], | |
return_dict=True, | |
output_hidden_states=True | |
) | |
outputs_gemma = text_encoder[1]( | |
input_ids=input_ids[1], | |
return_dict=True, | |
output_hidden_states=True | |
) | |
encoder_hidden_states = outputs_gemma.last_hidden_state | |
cond_embeds = outputs_clip.text_embeds | |
return encoder_hidden_states, cond_embeds | |
else: | |
raise ValueError(f"Unknown text_encoder_architecture: {text_encoder_architecture}") | |
def process_image(image, size, Norm=False, hps_score=6.0): | |
image = exif_transpose(image) | |
if not image.mode == "RGB": | |
image = image.convert("RGB") | |
orig_height = image.height | |
orig_width = image.width | |
image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image) | |
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size)) | |
image = transforms.functional.crop(image, c_top, c_left, size, size) | |
image = transforms.ToTensor()(image) | |
if Norm: | |
image = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)(image) | |
micro_conds = torch.tensor( | |
[orig_width, orig_height, c_top, c_left, hps_score], | |
) | |
return {"image": image, "micro_conds": micro_conds} | |
class ImageCaptionLargeDataset(Dataset): | |
def __init__( | |
self, | |
root_dir, | |
tokenizer, | |
size, | |
text_encoder_architecture="CLIP", | |
norm=False | |
): | |
self.root_dir = root_dir | |
self.tokenizer = tokenizer | |
self.size = size | |
self.text_encoder_architecture = text_encoder_architecture | |
self.norm = norm | |
self.data_list = [] | |
for root, dirnames, filenames in os.walk(root_dir): | |
for filename in filenames: | |
if filename.endswith(".jpg") or filename.endswith(".png"): | |
base_name = os.path.splitext(filename)[0] | |
txt_file = os.path.join(root, base_name + ".txt") | |
if os.path.exists(txt_file): | |
self.data_list.append((root, base_name + ".txt", filename)) | |
def __len__(self): | |
return len(self.data_list) | |
def __getitem__(self, idx): | |
try: | |
sub_dir, txtfilename, imgfilename = self.data_list[idx] | |
img_path = os.path.join(sub_dir, imgfilename) | |
caption_path = os.path.join(sub_dir, txtfilename) | |
image = Image.open(img_path).convert("RGB") | |
ret = process_image(image, self.size, self.norm) | |
with open(caption_path, "r", encoding="utf-8") as f: | |
caption = f.read().strip() | |
ret["prompt_input_ids"] = tokenize_prompt( | |
self.tokenizer, caption, self.text_encoder_architecture | |
) | |
return ret | |
except Exception as e: | |
print("===========================================") | |
print(f"[Warning] Error at index {idx}: {img_path}") | |
print("===========================================") | |
if idx + 1 < len(self.data_list): | |
return self.__getitem__(idx + 1) | |
else: | |
return self.__getitem__(len(self.data_list) - 1) | |
class MultiSourceVLDataset(Dataset): | |
""" | |
A unified dataloader for | |
• LLaVA-Instruct-150K | |
• MMMU (multiple-choice QA) | |
• VQAv2 | |
• Local caption files under `pdd3/` | |
""" | |
def __init__( | |
self, | |
tokenizer, | |
size: int, | |
text_encoder_architecture: str = "CLIP", | |
norm: bool = False, | |
# ----- paths ----- | |
llava_json: str = None, llava_img_root: str = None, | |
mmmu_json: str = None, mmmu_img_root: str = None, | |
vqa_ann_json: str = None, vqa_img_root: str = None, | |
gqa_json: str = None, gqa_img_root: str = None, | |
coco_json: str = None, coco_img_root: str = None, | |
coco_qa_json: str = None, | |
mg_llava_json: str = None, mg_llava_root: str = None, | |
pdd3_dir: str = None, caption_dir: str = None, | |
): | |
self.tokenizer = tokenizer | |
self.size = size | |
self.arch = text_encoder_architecture | |
self.norm = norm | |
self.gen_samples = [] # [(img_path, prompt), ...] | |
self.mmu_samples = [] # [(img_path, question, answer), ...] | |
if llava_json: | |
self._load_llava(llava_json, llava_img_root) | |
if mmmu_json: | |
self._load_mmmu(mmmu_json, mmmu_img_root) | |
if vqa_ann_json: | |
self._load_vqav2(vqa_ann_json, vqa_img_root) | |
if coco_json: | |
self._load_coco2014_captions(coco_json, coco_img_root) | |
if coco_qa_json: | |
self._load_coco2014_qa(coco_qa_json, coco_img_root) | |
if gqa_json: | |
self._load_gqa(gqa_json, gqa_img_root) | |
if mg_llava_json: | |
self._load_mg_llava(mg_llava_json, mg_llava_root) | |
if caption_dir: | |
self._load_caption(caption_dir) | |
if pdd3_dir: | |
self._load_pdd3(pdd3_dir) | |
self.len_mmu = len(self.mmu_samples) | |
self.len_gen = len(self.gen_samples) | |
# ------------------------------------------------------------------ # | |
# dataset parsers # | |
# ------------------------------------------------------------------ # | |
def _load_llava(self, json_path, img_root): | |
with open(json_path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
for ex in data: | |
img_file = os.path.join(img_root, ex["image"]) | |
human_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "human") | |
gpt_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "gpt") | |
self.mmu_samples.append((img_file, human_msg.strip(), gpt_msg.strip())) | |
def _load_mmmu(self, json_path, img_root): | |
with open(json_path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
for ex in data: | |
img_file = os.path.join(img_root, ex["image"]) | |
choices = "\n".join([f"{chr(65+i)}. {c}" for i, c in enumerate(ex["choices"])]) | |
question = f"{ex['question'].strip()}\n{choices}" | |
answer = f"{ex['answer']}" | |
self.mmu_samples.append((img_file, question, answer)) | |
def _load_coco2014_qa(self, ann_jsonl, img_root): | |
with open(ann_jsonl, "r", encoding="utf-8") as file: | |
data = [json.loads(line) for line in file if line.strip()] | |
for ann in data: | |
image = ann["image"] | |
question = ann["question"] | |
answer = ann["label"] | |
image_path = os.path.join(img_root, image) | |
self.mmu_samples.append((image_path, question, answer)) | |
def _load_coco2014_captions(self, ann_json, img_root): | |
""" | |
Load COCO 2014 image-caption pairs from caption annotation file. | |
Args: | |
ann_json (str): Path to COCO-style captions JSON (e.g., captions_train2014.json) | |
img_root (str): Directory containing COCO images (should include 'train2014/' and 'val2014/' subdirs) | |
""" | |
with open(ann_json, "r") as f: | |
data = json.load(f) | |
is_train = "train" in os.path.basename(ann_json).lower() | |
img_subdir = "train2014" if is_train else "val2014" | |
prefix = "COCO_train2014_" if is_train else "COCO_val2014_" | |
for ann in data["annotations"]: | |
image_id = ann["image_id"] | |
caption = ann["caption"] | |
image_filename = f"{prefix}{image_id:012d}.jpg" | |
image_path = os.path.join(img_root, img_subdir, image_filename) | |
question = "Please describe this image concisely." | |
self.mmu_samples.append((image_path, question, caption)) | |
def _load_vqav2(self, ann_json, img_root): | |
with open(ann_json, "r") as file: | |
annos = json.load(file) | |
for ann in annos: | |
q = ann["question"] | |
answer = ann["answer"] | |
img_path = ann["image"] | |
img_file = os.path.join( | |
img_root, | |
img_path # if val, modify to val2014 | |
) | |
self.mmu_samples.append((img_file, q, answer)) | |
def _load_gqa(self, ann_json_root, img_root): | |
annos = {} | |
for jsonfile in os.listdir(ann_json_root): | |
jsonpath = os.path.join(ann_json_root, jsonfile) | |
with open(jsonpath, "r") as file: | |
anno = json.load(file) | |
annos.update(anno) | |
for ann in annos.values(): | |
q = ann["question"] | |
answer = ann["fullAnswer"] | |
img_name = ann["imageId"] + ".jpg" | |
img_path = os.path.join( | |
img_root, | |
img_name | |
) | |
self.mmu_samples.append((img_path, q, answer)) | |
def _load_mg_llava(self, json_path, img_root): | |
with open(json_path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
for ex in data: | |
image = ex.get("image", None) | |
if image is not None: | |
img_file = os.path.join(img_root, ex["image"]) | |
if os.path.exists(img_file): | |
human_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "human") | |
gpt_msg = next(m["value"] for m in ex["conversations"] if m["from"] == "gpt") | |
self.mmu_samples.append((img_file, human_msg.strip(), gpt_msg.strip())) | |
def _load_caption(self, root_dir): | |
for root, _, files in os.walk(root_dir): | |
for f in files: | |
if f.lower().endswith((".jpg", ".png")): | |
base = os.path.splitext(f)[0] | |
txt_path = os.path.join(root, base + ".txt") | |
if os.path.exists(txt_path): | |
with open(txt_path, "r") as file: | |
caption = file.read().strip() | |
q = "Please describe this image." | |
self.mmu_samples.append((os.path.join(root, f), q, caption)) | |
def _load_pdd3(self, root_dir): | |
for root, _, files in os.walk(root_dir): | |
for f in files: | |
if f.lower().endswith((".jpg", ".png")): | |
base = os.path.splitext(f)[0] | |
txt_path = os.path.join(root, base + ".txt") | |
if os.path.exists(txt_path): | |
with open(txt_path, "r") as file: | |
caption = file.read().strip() | |
self.gen_samples.append((os.path.join(root, f), caption)) | |
# ------------------------------------------------------------------ # | |
# PyTorch Dataset API # | |
# ------------------------------------------------------------------ # | |
def __len__(self): | |
return max(self.len_gen, self.len_mmu) | |
def __getitem__(self, idx): | |
get_mmu_data = False | |
get_gen_data = False | |
while not get_mmu_data: | |
try: | |
mmu_img_path, question, answer = self.mmu_samples[idx] | |
get_mmu_data = True | |
except: | |
idx = random.randint(0, self.len_mmu - 1) | |
while not get_gen_data: | |
try: | |
gen_img_path, prompt = self.gen_samples[idx] | |
get_gen_data = True | |
except: | |
idx = random.randint(0, self.len_gen - 1) | |
try: | |
# ---- image ---- | |
mmu_image = Image.open(mmu_img_path).convert("RGB") | |
mmu_ret = process_image(mmu_image, self.size, self.norm) | |
gen_image = Image.open(gen_img_path).convert("RGB") | |
gen_ret = process_image(gen_image, self.size, self.norm) | |
ret = dict( | |
gen_image=gen_ret["image"], | |
gen_micro_conds=gen_ret["micro_conds"], | |
mmu_image=mmu_ret["image"], | |
mmu_micro_conds=mmu_ret["micro_conds"] | |
) | |
# ---- text ---- | |
question = question.replace("<image>", "").replace("\n", "") | |
question_ids = tokenize_prompt( | |
self.tokenizer, | |
question, | |
self.arch, | |
padding=False, | |
) | |
question_ids = question_ids[:, :-1] | |
q_len = len(question_ids[0]) | |
if answer: | |
full_prompt = question + " " + answer | |
else: | |
full_prompt = question | |
mmu_input_ids = tokenize_prompt(self.tokenizer, full_prompt, self.arch) | |
gen_input_ids = tokenize_prompt(self.tokenizer, prompt, self.arch) | |
ret.update({ | |
"gen_input_ids": gen_input_ids, | |
"mmu_input_ids": mmu_input_ids, | |
"question_len": torch.LongTensor([q_len]) | |
}) | |
return ret | |
except: | |
print("================================================================") | |
print(f"There is something wrong with {mmu_img_path} or {gen_img_path}.") | |
print("================================================================") | |
if idx < self.len_gen - 1 or idx < self.len_mmu - 1: | |
return self.__getitem__(idx + 1) | |
else: | |
idx = random.randint(0, self.len_gen - 1) | |
return self.__getitem__(idx) |