Spaces:
Running
on
T4
Running
on
T4
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageOps | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from diffusionsfm.dataset.co3d_v2 import square_bbox | |
class CustomDataset(Dataset): | |
def __init__( | |
self, | |
image_list, | |
): | |
self.images = [] | |
for image_path in sorted(image_list): | |
img = Image.open(image_path) | |
img = ImageOps.exif_transpose(img).convert("RGB") # Apply EXIF rotation | |
self.images.append(img) | |
self.n = len(self.images) | |
self.jitter_scale = [1, 1] | |
self.jitter_trans = [0, 0] | |
self.transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Resize(224), | |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
self.transform_for_vis = transforms.Compose( | |
[ | |
transforms.Resize(224), | |
] | |
) | |
def __len__(self): | |
return 1 | |
def _crop_image(self, image, bbox, white_bg=False): | |
if white_bg: | |
# Only support PIL Images | |
image_crop = Image.new( | |
"RGB", (bbox[2] - bbox[0], bbox[3] - bbox[1]), (255, 255, 255) | |
) | |
image_crop.paste(image, (-bbox[0], -bbox[1])) | |
else: | |
image_crop = transforms.functional.crop( | |
image, | |
top=bbox[1], | |
left=bbox[0], | |
height=bbox[3] - bbox[1], | |
width=bbox[2] - bbox[0], | |
) | |
return image_crop | |
def __getitem__(self): | |
return self.get_data() | |
def get_data(self): | |
cmap = plt.get_cmap("hsv") | |
ids = [i for i in range(len(self.images))] | |
images = [self.images[i] for i in ids] | |
images_transformed = [] | |
images_for_vis = [] | |
crop_parameters = [] | |
for i, image in enumerate(images): | |
bbox = np.array([0, 0, image.width, image.height]) | |
bbox = square_bbox(bbox, tight=True) | |
bbox = np.around(bbox).astype(int) | |
image = self._crop_image(image, bbox) | |
images_transformed.append(self.transform(image)) | |
image_for_vis = self.transform_for_vis(image) | |
color_float = cmap(i / len(images)) | |
color_rgb = tuple(int(255 * c) for c in color_float[:3]) | |
image_for_vis = ImageOps.expand(image_for_vis, border=3, fill=color_rgb) | |
images_for_vis.append(image_for_vis) | |
width, height = image.size | |
length = max(width, height) | |
s = length / min(width, height) | |
crop_center = (bbox[:2] + bbox[2:]) / 2 | |
crop_center = crop_center + (length - np.array([width, height])) / 2 | |
# convert to NDC | |
cc = s - 2 * s * crop_center / length | |
crop_width = 2 * s * (bbox[2] - bbox[0]) / length | |
crop_params = torch.tensor([-cc[0], -cc[1], crop_width, s]) | |
crop_parameters.append(crop_params) | |
images = images_transformed | |
batch = {} | |
batch["image"] = torch.stack(images) | |
batch["image_for_vis"] = images_for_vis | |
batch["n"] = len(images) | |
batch["ind"] = torch.tensor(ids), | |
batch["crop_parameters"] = torch.stack(crop_parameters) | |
batch["distortion_parameters"] = torch.zeros(4) | |
return batch | |