Spaces:
Runtime error
Runtime error
| import os | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import Dataset | |
| from torchvision.transforms import Compose, Resize, ToTensor | |
| import imageio | |
| from tqdm import tqdm | |
| class pix2pixDataset(Dataset): | |
| def __init__(self, dataset="maps", data_dir="/projects/ml4science/datasets_pix2pix/", split="train", normalize=True, transforms=None, preload=False, image_size=256, direction="BtoA"): | |
| self.datadir = os.path.join(data_dir, dataset) | |
| self.img_name_list_path = os.path.join(data_dir, dataset, split) | |
| if not os.path.exists(self.datadir): | |
| print(f'Dataset directory {self.datadir} does not exists') | |
| self.normalize=normalize | |
| self.image_name_list = os.listdir(self.img_name_list_path) | |
| self.preload = preload | |
| self.direction = direction | |
| if transforms is None: | |
| self.transforms = Compose([ | |
| ToTensor(), # Convert to torch tensor | |
| Resize((image_size, image_size), antialias=False), # Resize to 256x256 | |
| ]) | |
| else: | |
| self.transforms = transforms | |
| if self.preload: | |
| self.x_list, self.y_list= (), () | |
| for name in tqdm(self.image_name_list): | |
| x, y = self.load_every(name) | |
| self.x_list = self.x_list + (x,) | |
| self.y_list = self.y_list + (y,) | |
| self.x_list = torch.stack(self.x_list, 0) | |
| self.y_list = torch.stack(self.y_list, 0) | |
| print(f"{split} dataset preloaded!") | |
| def load_every(self, name): | |
| img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name))) | |
| img_H, img_W = img_array.shape[0], img_array.shape[1] | |
| if self.normalize: | |
| img_array = self.normalize_fn(img_array) | |
| x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :] | |
| x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform | |
| return x_img.float(), y_img.float() | |
| def normalize_fn(self, x): | |
| return (x/255. -0.5)*2 | |
| def unnormalize_fn(self, x): | |
| return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images | |
| def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation) | |
| class_cond = None | |
| if self.preload: | |
| x_img, y_img = self.x_list[index], self.y_list[index] | |
| else: | |
| name = self.image_name_list[index] | |
| x_img, y_img = self.load_every(name) | |
| # if self.direction == "BtoA": | |
| # return x_img, y_img, class_cond | |
| # elif self.direction == "AtoB": | |
| # return y_img, x_img, class_cond | |
| batch ={ | |
| "image1":x_img, | |
| "image2":y_img, | |
| } | |
| return batch | |
| def __len__(self): | |
| return len(self.image_name_list) | |
| class FishDataset(Dataset): | |
| def __init__(self, data_dir="/projects/ml4science/FishDiffusion/", split="train", normalize=True, transforms=None, preload=False, image_size=128): | |
| self.datadir = os.path.join(data_dir) | |
| self.img_name_list_path = os.path.join(data_dir, split) | |
| if not os.path.exists(self.datadir): | |
| print(f'Dataset directory {self.datadir} does not exists') | |
| self.normalize=normalize | |
| self.image_name_list = os.listdir(self.img_name_list_path) | |
| self.preload = preload | |
| if transforms is None: | |
| # self.transforms = Compose([ | |
| # ToTensor(), # Convert to torch tensor | |
| # Resize((image_size, image_size), antialias=False), # Resize to 256x256 | |
| # ]) | |
| self.transforms = Compose([ | |
| ToTensor(), # Convert to torch tensor | |
| ]) | |
| else: | |
| self.transforms = transforms | |
| if self.preload: | |
| self.x_list, self.y_list, self.class_id = (), (), [] | |
| for name in tqdm(self.image_name_list): | |
| x, y = self.load_every(name) | |
| cls_id = int(name.split("_")[-1][:-4]) | |
| self.x_list = self.x_list + (x,) | |
| self.y_list = self.y_list + (y,) | |
| self.class_id.append(cls_id) | |
| self.x_list = torch.stack(self.x_list, 0) | |
| self.y_list = torch.stack(self.y_list, 0) | |
| self.class_id = torch.tensor(self.class_id) | |
| print(f"{split} dataset preloaded!") | |
| def load_every(self, name): | |
| img_array = np.asarray(imageio.imread(os.path.join(self.img_name_list_path, name))) | |
| img_H, img_W = img_array.shape[0], img_array.shape[1] | |
| if self.normalize: | |
| img_array = self.normalize_fn(img_array) | |
| x_img, y_img = img_array[:,:img_W//2, :], img_array[:, img_W//2:, :] | |
| x_img, y_img = self.transforms(x_img), self.transforms(y_img) # Apply the resize transform | |
| return x_img.float(), y_img.float() | |
| def normalize_fn(self, x): | |
| return (x/255. -0.5)*2 | |
| def unnormalize_fn(self, x): | |
| return ((x/2 + 0.5) * 255).int().clamp(0, 255) #since these are images | |
| def __getitem__(self, index): # getitem should return x0, x1, y (where y is the class label for class conditional generation) | |
| if self.preload: | |
| x_img, y_img, class_id = self.x_list[index], self.y_list[index], self.class_id[index] | |
| else: | |
| name = self.image_name_list[index] | |
| class_id = torch.tensor(int(name.split("_")[-1][:-4])) | |
| x_img, y_img = self.load_every(name) | |
| return x_img, y_img, class_id | |
| def __len__(self): | |
| return len(self.image_name_list) |