lisa-on-cuda / utils /dataset.py
x-lai
fix bug: add sample rate to control different sampling probability for each type of dataset
e53daa9
raw
history blame
15.6 kB
import glob
import os
import random
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from pycocotools import mask
from transformers import CLIPImageProcessor
from model.segment_anything.utils.transforms import ResizeLongestSide
from .conversation import get_default_conv_template
from .data_processing import get_mask_from_json
from .reason_seg_dataset import ReasonSegDataset
from .refer import REFER
from .refer_seg_dataset import ReferSegDataset
from .sem_seg_dataset import SemSegDataset
from .utils import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN,
DEFAULT_IMAGE_TOKEN,
)
from .vqa_dataset import VQADataset
def collate_fn(batch, tokenizer=None):
image_path_list = []
images_list = []
images_clip_list = []
conversation_list = []
masks_list = []
label_list = []
resize_list = []
questions_list = []
sampled_classes_list = []
offset_list = [0]
cnt = 0
inferences = []
for (
image_path,
images,
images_clip,
conversations,
masks,
label,
resize,
questions,
sampled_classes,
inference,
) in batch:
image_path_list.append(image_path)
images_list.append(images)
images_clip_list.append(images_clip)
conversation_list.extend(conversations)
label_list.append(label)
masks_list.append(masks.float())
resize_list.append(resize)
questions_list.append(questions)
sampled_classes_list.append(sampled_classes)
cnt += len(conversations)
offset_list.append(cnt)
inferences.append(inference)
tokenize_data = tokenizer(
conversation_list,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
input_ids = tokenize_data.input_ids
attention_masks = tokenize_data.attention_mask
IGNORE_TOKEN_ID = -100
conv = get_default_conv_template("vicuna").copy()
targets = input_ids.clone()
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversation_list, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep)
# if len(parts) != 2:
# break
assert len(parts) == 2, (len(parts), rou)
parts[0] += sep
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += round_len
target[cur_len:] = IGNORE_TOKEN_ID
if False:
# if True:
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
# rank0_print(tokenizer.decode(z))
print(
"conversation: ",
conversation,
"tokenizer.decode(z): ",
tokenizer.decode(z),
)
if cur_len < tokenizer.model_max_length:
assert cur_len == total_len
return {
"image_paths": image_path_list,
"images": torch.stack(images_list, dim=0),
"images_clip": torch.stack(images_clip_list, dim=0),
"input_ids": input_ids,
"labels": targets,
"attention_masks": attention_masks,
"masks_list": masks_list,
"label_list": label_list,
"resize_list": resize_list,
"offset": torch.LongTensor(offset_list),
"questions_list": questions_list,
"sampled_classes_list": sampled_classes_list,
"inference": inferences[0],
"conversation_list": conversation_list,
}
class HybridDataset(torch.utils.data.Dataset):
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
img_size = 1024
ignore_label = 255
def __init__(
self,
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch=500 * 8 * 2 * 10,
precision: str = "fp32",
image_size: int = 224,
num_classes_per_sample: int = 3,
exclude_val=False,
dataset="sem_seg||refer_seg||vqa||reason_seg",
sample_rate=[9, 3, 3, 1],
sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
refer_seg_data="refclef||refcoco||refcoco+||refcocog",
vqa_data="llava_instruct_150k",
reason_seg_data="ReasonSeg|train",
explanatory=0.1,
):
self.exclude_val = exclude_val
self.dataset = dataset
self.samples_per_epoch = samples_per_epoch
self.explanatory = explanatory
self.num_classes_per_sample = num_classes_per_sample
sample_rate = np.array(sample_rate)
self.sample_rate = sample_rate / sample_rate.sum()
self.base_image_dir = base_image_dir
self.image_size = image_size
self.tokenizer = tokenizer
self.precision = precision
self.datasets = dataset.split("||")
self.all_datasets = []
for dataset in self.datasets:
if dataset == "sem_seg":
self.all_datasets.append(
SemSegDataset(
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
image_size,
num_classes_per_sample,
exclude_val,
sem_seg_data,
)
)
elif dataset == "refer_seg":
self.all_datasets.append(
ReferSegDataset(
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
image_size,
num_classes_per_sample,
exclude_val,
refer_seg_data,
)
)
elif dataset == "vqa":
self.all_datasets.append(
VQADataset(
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
image_size,
num_classes_per_sample,
exclude_val,
vqa_data,
)
)
elif dataset == "reason_seg":
self.all_datasets.append(
ReasonSegDataset(
base_image_dir,
tokenizer,
vision_tower,
samples_per_epoch,
precision,
image_size,
num_classes_per_sample,
exclude_val,
reason_seg_data,
explanatory,
)
)
def __len__(self):
return self.samples_per_epoch
def __getitem__(self, idx):
ind = np.random.choice(list(range(len(self.datasets))), p=self.sample_rate)
data = self.all_datasets[ind]
inference = False
return *data[0], inference
class ValDataset(torch.utils.data.Dataset):
pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
img_size = 1024
ignore_label = 255
def __init__(
self,
base_image_dir,
tokenizer,
vision_tower,
val_dataset,
image_size=1024,
):
self.base_image_dir = base_image_dir
splits = val_dataset.split("|")
if len(splits) == 2:
ds, split = splits
images = glob.glob(
os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg")
)
self.images = images
self.data_type = "reason_seg"
elif len(splits) == 3:
ds, splitBy, split = splits
refer_api = REFER(self.base_image_dir, ds, splitBy)
ref_ids_val = refer_api.getRefIds(split=split)
images_ids_val = refer_api.getImgIds(ref_ids=ref_ids_val)
refs_val = refer_api.loadRefs(ref_ids=ref_ids_val)
refer_seg_ds = {}
refer_seg_ds["images"] = []
loaded_images = refer_api.loadImgs(image_ids=images_ids_val)
for item in loaded_images:
item = item.copy()
if ds == "refclef":
item["file_name"] = os.path.join(
base_image_dir, "images/saiapr_tc-12", item["file_name"]
)
elif ds in ["refcoco", "refcoco+", "refcocog", "grefcoco"]:
item["file_name"] = os.path.join(
base_image_dir,
"images/mscoco/images/train2014",
item["file_name"],
)
refer_seg_ds["images"].append(item)
refer_seg_ds["annotations"] = refer_api.Anns # anns_val
img2refs = {}
for ref in refs_val:
image_id = ref["image_id"]
img2refs[image_id] = img2refs.get(image_id, []) + [
ref,
]
refer_seg_ds["img2refs"] = img2refs
self.refer_seg_ds = refer_seg_ds
self.data_type = "refer_seg"
self.ds = ds
self.image_size = image_size
self.tokenizer = tokenizer
self.transform = ResizeLongestSide(image_size)
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
def __len__(self):
if self.data_type == "refer_seg":
return len(self.refer_seg_ds["images"])
else:
return len(self.images)
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.img_size - h
padw = self.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def __getitem__(self, idx):
if self.data_type == "refer_seg":
refer_seg_ds = self.refer_seg_ds
images = refer_seg_ds["images"]
annotations = refer_seg_ds["annotations"]
img2refs = refer_seg_ds["img2refs"]
image = images[idx]
image_path = image["file_name"]
image_id = image["id"]
refs = img2refs[image_id]
if len(refs) == 0:
raise ValueError("image {} has no refs".format(image_id))
sents = []
ann_ids = []
for ref in refs:
for sent in ref["sentences"]:
sents.append(sent["sent"].strip().lower())
ann_ids.append(ref["ann_id"])
sampled_sents = sents
sampled_ann_ids = ann_ids
img = cv2.imread(image_path)
images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
is_sentence = False
else:
image_path = self.images[idx]
img = cv2.imread(image_path)
images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
json_path = image_path.replace(".jpg", ".json")
mask_json, sampled_sents, is_sentence = get_mask_from_json(json_path, img)
sampled_sents = [sampled_sents[0]]
conversations = []
conv = get_default_conv_template("vicuna").copy()
i = 0
while i < len(sampled_sents):
conv.messages = []
text = sampled_sents[i].strip()
if is_sentence:
conv.append_message(
conv.roles[0],
DEFAULT_IMAGE_TOKEN
+ " {} Please output segmentation mask.".format(text),
)
conv.append_message(conv.roles[1], "[SEG].")
else:
conv.append_message(
conv.roles[0],
DEFAULT_IMAGE_TOKEN
+ " What is {} in this image? Please output segmentation mask.".format(
text
),
)
conv.append_message(conv.roles[1], "[SEG].")
conversations.append(conv.get_prompt())
i += 1
# replace <image> token
image_token_len = 256
for i in range(len(conversations)):
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
replace_token = (
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
)
conversations[i] = conversations[i].replace(
DEFAULT_IMAGE_TOKEN, replace_token
)
# preprocess images for clip
images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
"pixel_values"
][0]
image_token_len = (images_clip.shape[1] // 14) * (
images_clip.shape[2] // 14
) # FIXME: 14 is hardcoded patch size
# preprocess images for sam
images = self.transform.apply_image(images)
resize = images.shape[:2]
images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
if self.data_type == "refer_seg":
masks = []
for i, ann_id in enumerate(sampled_ann_ids):
ann = annotations[ann_id]
if len(ann["segmentation"]) == 0 and sampled_sents[i] != "":
m = np.zeros((image["height"], image["width"], 1))
else:
if type(ann["segmentation"][0]) == list: # polygon
rle = mask.frPyObjects(
ann["segmentation"], image["height"], image["width"]
)
else:
rle = ann["segmentation"]
for i in range(len(rle)):
if not isinstance(rle[i]["counts"], bytes):
rle[i]["counts"] = rle[i]["counts"].encode()
m = mask.decode(rle)
m = np.sum(
m, axis=2
) # sometimes there are multiple binary map (corresponding to multiple segs)
m = m.astype(np.uint8) # convert to np.uint8
masks.append(m)
else:
masks = [mask_json]
masks = np.stack(masks, axis=0)
masks = torch.from_numpy(masks)
labels = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
inference = True
return (
image_path,
images,
images_clip,
conversations,
masks,
labels,
resize,
None,
None,
inference,
)