Spaces:
Paused
Paused
Upload uno.py
Browse files- uno/dataset/uno.py +132 -0
uno/dataset/uno.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
|
| 2 |
+
# Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
|
| 3 |
+
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
import torchvision.transforms.functional as TVF
|
| 22 |
+
from torch.utils.data import DataLoader, Dataset
|
| 23 |
+
from torchvision.transforms import Compose, Normalize, ToTensor
|
| 24 |
+
|
| 25 |
+
def bucket_images(images: list[torch.Tensor], resolution: int = 512):
|
| 26 |
+
bucket_override=[
|
| 27 |
+
# h w
|
| 28 |
+
(256, 768),
|
| 29 |
+
(320, 768),
|
| 30 |
+
(320, 704),
|
| 31 |
+
(384, 640),
|
| 32 |
+
(448, 576),
|
| 33 |
+
(512, 512),
|
| 34 |
+
(576, 448),
|
| 35 |
+
(640, 384),
|
| 36 |
+
(704, 320),
|
| 37 |
+
(768, 320),
|
| 38 |
+
(768, 256)
|
| 39 |
+
]
|
| 40 |
+
bucket_override = [(int(h / 512 * resolution), int(w / 512 * resolution)) for h, w in bucket_override]
|
| 41 |
+
bucket_override = [(h // 16 * 16, w // 16 * 16) for h, w in bucket_override]
|
| 42 |
+
|
| 43 |
+
aspect_ratios = [image.shape[-2] / image.shape[-1] for image in images]
|
| 44 |
+
mean_aspect_ratio = np.mean(aspect_ratios)
|
| 45 |
+
|
| 46 |
+
new_h, new_w = bucket_override[0]
|
| 47 |
+
min_aspect_diff = np.abs(new_h / new_w - mean_aspect_ratio)
|
| 48 |
+
for h, w in bucket_override:
|
| 49 |
+
aspect_diff = np.abs(h / w - mean_aspect_ratio)
|
| 50 |
+
if aspect_diff < min_aspect_diff:
|
| 51 |
+
min_aspect_diff = aspect_diff
|
| 52 |
+
new_h, new_w = h, w
|
| 53 |
+
|
| 54 |
+
images = [TVF.resize(image, (new_h, new_w)) for image in images]
|
| 55 |
+
images = torch.stack(images, dim=0)
|
| 56 |
+
return images
|
| 57 |
+
|
| 58 |
+
class FluxPairedDatasetV2(Dataset):
|
| 59 |
+
def __init__(self, json_file: str, resolution: int, resolution_ref: int | None = None):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.json_file = json_file
|
| 62 |
+
self.resolution = resolution
|
| 63 |
+
self.resolution_ref = resolution_ref if resolution_ref is not None else resolution
|
| 64 |
+
self.image_root = os.path.dirname(json_file)
|
| 65 |
+
|
| 66 |
+
with open(self.json_file, "rt") as f:
|
| 67 |
+
self.data_dicts = json.load(f)
|
| 68 |
+
|
| 69 |
+
self.transform = Compose([
|
| 70 |
+
ToTensor(),
|
| 71 |
+
Normalize([0.5], [0.5]),
|
| 72 |
+
])
|
| 73 |
+
|
| 74 |
+
def __getitem__(self, idx):
|
| 75 |
+
data_dict = self.data_dicts[idx]
|
| 76 |
+
image_paths = [data_dict["image_path"]] if "image_path" in data_dict else data_dict["image_paths"]
|
| 77 |
+
txt = data_dict["prompt"]
|
| 78 |
+
image_tgt_path = data_dict.get("image_tgt_path", None)
|
| 79 |
+
ref_imgs = [
|
| 80 |
+
Image.open(os.path.join(self.image_root, path)).convert("RGB")
|
| 81 |
+
for path in image_paths
|
| 82 |
+
]
|
| 83 |
+
ref_imgs = [self.transform(img) for img in ref_imgs]
|
| 84 |
+
img = None
|
| 85 |
+
if image_tgt_path is not None:
|
| 86 |
+
img = Image.open(os.path.join(self.image_root, image_tgt_path)).convert("RGB")
|
| 87 |
+
img = self.transform(img)
|
| 88 |
+
|
| 89 |
+
return {
|
| 90 |
+
"img": img,
|
| 91 |
+
"txt": txt,
|
| 92 |
+
"ref_imgs": ref_imgs,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def __len__(self):
|
| 96 |
+
return len(self.data_dicts)
|
| 97 |
+
|
| 98 |
+
def collate_fn(self, batch):
|
| 99 |
+
img = [data["img"] for data in batch]
|
| 100 |
+
txt = [data["txt"] for data in batch]
|
| 101 |
+
ref_imgs = [data["ref_imgs"] for data in batch]
|
| 102 |
+
assert all([len(ref_imgs[0]) == len(ref_imgs[i]) for i in range(len(ref_imgs))])
|
| 103 |
+
|
| 104 |
+
n_ref = len(ref_imgs[0])
|
| 105 |
+
|
| 106 |
+
img = bucket_images(img, self.resolution)
|
| 107 |
+
ref_imgs_new = []
|
| 108 |
+
for i in range(n_ref):
|
| 109 |
+
ref_imgs_i = [refs[i] for refs in ref_imgs]
|
| 110 |
+
ref_imgs_i = bucket_images(ref_imgs_i, self.resolution_ref)
|
| 111 |
+
ref_imgs_new.append(ref_imgs_i)
|
| 112 |
+
|
| 113 |
+
return {
|
| 114 |
+
"txt": txt,
|
| 115 |
+
"img": img,
|
| 116 |
+
"ref_imgs": ref_imgs_new,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
if __name__ == '__main__':
|
| 120 |
+
import argparse
|
| 121 |
+
from pprint import pprint
|
| 122 |
+
parser = argparse.ArgumentParser()
|
| 123 |
+
# parser.add_argument("--json_file", type=str, required=True)
|
| 124 |
+
parser.add_argument("--json_file", type=str, default="datasets/fake_train_data.json")
|
| 125 |
+
args = parser.parse_args()
|
| 126 |
+
dataset = FluxPairedDatasetV2(args.json_file, 512)
|
| 127 |
+
dataloder = DataLoader(dataset, batch_size=4, collate_fn=dataset.collate_fn)
|
| 128 |
+
|
| 129 |
+
for i, data_dict in enumerate(dataloder):
|
| 130 |
+
pprint(i)
|
| 131 |
+
pprint(data_dict)
|
| 132 |
+
breakpoint()
|