Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,456 Bytes
779c9ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import random
import traceback
from pathlib import Path
import einops
import numpy as np
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import Dataset
from tqdm import tqdm
from utils import bezier_utils
class PartsDataset(Dataset):
def __init__(
self,
dataset_dir: Path,
clip_image_size: int = 224,
image_processor=None,
max_crops=3,
use_ref: bool = True,
ref_as_grid: bool = True,
grid_size: int = 2,
sketch_prob: float = 0.0,
):
subdirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
all_paths = []
self.subdir_dict = {}
for subdir in tqdm(subdirs):
current_paths = list(subdir.glob("*.jpg"))
current_target_paths = [p for p in current_paths if len(str(p.name).split("_")) == 2]
if use_ref and len(current_target_paths) < 9:
# Skip if not enough target images
continue
all_paths.extend(current_paths)
self.subdir_dict[subdir] = current_target_paths
print(f"Percentile of valid subdirs: {len(self.subdir_dict) / len(subdirs)}")
self.target_paths = [p for p in all_paths if len(str(p.name).split("_")) == 2]
source_paths = [p for p in all_paths if len(str(p.name).split("_")) == 3]
self.source_target_mappings = {path: [] for path in self.target_paths}
for source_path in source_paths:
# Remove last part of the path
target_path = Path("_".join(str(source_path).split("_")[:-1]) + ".jpg")
if target_path in self.source_target_mappings:
self.source_target_mappings[target_path].append(source_path)
print(f"Loaded {len(self.target_paths)} target images")
self.clip_image_size = clip_image_size
self.image_processor = image_processor
self.max_crops = max_crops
self.use_ref = use_ref
self.ref_as_grid = ref_as_grid
self.grid_size = grid_size
self.sketch_prob = sketch_prob
def __len__(self):
return len(self.target_paths)
def paste_on_background(self, image, background, min_scale=0.4, max_scale=0.8):
# Calculate aspect ratio and determine resizing based on the smaller dimension of the background
aspect_ratio = image.width / image.height
scale = random.uniform(min_scale, max_scale)
new_width = int(min(background.width, background.height * aspect_ratio) * scale)
new_height = int(new_width / aspect_ratio)
# Resize image and calculate position
image = image.resize((new_width, new_height), resample=Image.LANCZOS)
pos_x = random.randint(0, background.width - new_width)
pos_y = random.randint(0, background.height - new_height)
# Paste the image using its alpha channel as mask if present
background.paste(image, (pos_x, pos_y), image if "A" in image.mode else None)
return background
def get_random_crop(self, image):
crop_percent_x = random.uniform(0.8, 1.0)
crop_percent_y = random.uniform(0.8, 1.0)
# crop_percent_y = random.uniform(0.1, 0.7)
crop_x = int(image.width * crop_percent_x)
crop_y = int(image.height * crop_percent_y)
x = random.randint(0, image.width - crop_x)
y = random.randint(0, image.height - crop_y)
return image.crop((x, y, x + crop_x, y + crop_y))
def get_empty_image(self):
empty_image = Image.new("RGB", (self.clip_image_size, self.clip_image_size), (255, 255, 255))
return self.image_processor(empty_image)["pixel_values"][0]
def __getitem__(self, i: int):
out_dict = {}
try:
target_path = self.target_paths[i]
image = Image.open(target_path).convert("RGB")
input_parts = []
source_paths = self.source_target_mappings[target_path]
n_samples = random.randint(1, len(source_paths))
n_samples = min(n_samples, self.max_crops)
source_paths = random.sample(source_paths, n_samples)
if random.random() < 0.1:
# Use empty image, but maybe still pass reference
source_paths = []
if self.use_ref:
subdir = target_path.parent
# Take something from same dir
potential_refs = list(set(self.subdir_dict[subdir]) - {target_path})
# Choose 4 refs
reference_paths = random.sample(potential_refs, self.grid_size**2)
reference_images = [
np.array(Image.open(reference_path).convert("RGB")) for reference_path in reference_paths
]
# Concat all images as grid of 2x2
reference_grid = np.stack(reference_images)
grid_image = einops.rearrange(
reference_grid,
"(h w) h1 w1 c -> (h h1) (w w1) c",
h=self.grid_size,
)
reference_image = Image.fromarray(grid_image).resize((512, 512))
# Always add the reference image
input_parts.append(reference_image)
# Sample a subset
for source_path in source_paths:
source_image = Image.open(source_path).convert("RGB")
if random.random() < 0.2:
# Instead of using the source image, use a random crop from the target
source_image = self.get_random_crop(source_image)
if random.random() < 0.2:
source_image = T.v2.RandomRotation(degrees=30, expand=True, fill=255)(source_image)
object_with_background = Image.new("RGB", image.size, (255, 255, 255))
self.paste_on_background(source_image, object_with_background, min_scale=0.8, max_scale=0.95)
if self.sketch_prob > 0 and random.random() < self.sketch_prob:
num_lines = random.randint(8, 15)
object_with_background = bezier_utils.get_sketch(
object_with_background,
total_curves=num_lines,
drop_line_prob=0.1,
)
input_parts.append(object_with_background)
# Always pad to three parts for now
actual_max_crops = self.max_crops + 1 if self.use_ref else self.max_crops
while len(input_parts) < actual_max_crops:
input_parts.append(
Image.new(
"RGB",
(self.clip_image_size, self.clip_image_size),
(255, 255, 255),
)
)
except Exception as e:
print(f"Error processing image: {e}")
traceback.print_exc()
empty_image = Image.new("RGB", (self.clip_image_size, self.clip_image_size), (255, 255, 255))
image = empty_image
actual_max_crops = self.max_crops + 1 if self.use_ref else self.max_crops
input_parts = [empty_image] * (actual_max_crops)
clip_target_image = self.image_processor(image)["pixel_values"][0]
clip_parts = [self.image_processor(part)["pixel_values"][0] for part in input_parts]
out_dict["crops"] = clip_parts
return clip_target_image, out_dict
|