Spaces:
Runtime error
Runtime error
hard code 1080x1920
Browse files
train_dreambooth_lora_sdxl_advanced.py
CHANGED
@@ -991,7 +991,7 @@ class DreamBoothDataset(Dataset):
|
|
991 |
class_data_root=None,
|
992 |
class_num=None,
|
993 |
token_abstraction_dict=None, # token mapping for textual inversion
|
994 |
-
size=
|
995 |
repeats=1,
|
996 |
center_crop=False,
|
997 |
):
|
@@ -1070,8 +1070,8 @@ class DreamBoothDataset(Dataset):
|
|
1070 |
self.original_sizes = []
|
1071 |
self.crop_top_lefts = []
|
1072 |
self.pixel_values = []
|
1073 |
-
|
1074 |
-
|
1075 |
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
1076 |
train_transforms = transforms.Compose(
|
1077 |
[
|
@@ -1087,7 +1087,7 @@ class DreamBoothDataset(Dataset):
|
|
1087 |
if not image.mode == "RGB":
|
1088 |
image = image.convert("RGB")
|
1089 |
self.original_sizes.append((image.height, image.width))
|
1090 |
-
|
1091 |
|
1092 |
if not single_image and args.random_flip and random.random() < 0.5:
|
1093 |
# flip
|
@@ -1123,7 +1123,7 @@ class DreamBoothDataset(Dataset):
|
|
1123 |
if not image.mode == "RGB":
|
1124 |
image = image.convert("RGB")
|
1125 |
self.original_sizes_class_imgs.append((image.height, image.width))
|
1126 |
-
|
1127 |
if args.random_flip and random.random() < 0.5:
|
1128 |
# flip
|
1129 |
image = train_flip(image)
|
@@ -1149,8 +1149,8 @@ class DreamBoothDataset(Dataset):
|
|
1149 |
|
1150 |
self.image_transforms = transforms.Compose(
|
1151 |
[
|
1152 |
-
|
1153 |
-
|
1154 |
transforms.ToTensor(),
|
1155 |
transforms.Normalize([0.5], [0.5]),
|
1156 |
]
|
@@ -1815,7 +1815,7 @@ def main(args):
|
|
1815 |
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
1816 |
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
|
1817 |
class_num=args.num_class_images,
|
1818 |
-
size=
|
1819 |
repeats=args.repeats,
|
1820 |
center_crop=args.center_crop,
|
1821 |
)
|
@@ -1835,7 +1835,7 @@ def main(args):
|
|
1835 |
|
1836 |
def compute_time_ids(crops_coords_top_left, original_size=None):
|
1837 |
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
1838 |
-
target_size = (
|
1839 |
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
1840 |
add_time_ids = torch.tensor([add_time_ids])
|
1841 |
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
|
|
991 |
class_data_root=None,
|
992 |
class_num=None,
|
993 |
token_abstraction_dict=None, # token mapping for textual inversion
|
994 |
+
size=(1080,1920),
|
995 |
repeats=1,
|
996 |
center_crop=False,
|
997 |
):
|
|
|
1070 |
self.original_sizes = []
|
1071 |
self.crop_top_lefts = []
|
1072 |
self.pixel_values = []
|
1073 |
+
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
1074 |
+
train_crop = transforms.CenterCrop(size) # if center_crop else transforms.RandomCrop(size)
|
1075 |
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
1076 |
train_transforms = transforms.Compose(
|
1077 |
[
|
|
|
1087 |
if not image.mode == "RGB":
|
1088 |
image = image.convert("RGB")
|
1089 |
self.original_sizes.append((image.height, image.width))
|
1090 |
+
image = train_resize(image)
|
1091 |
|
1092 |
if not single_image and args.random_flip and random.random() < 0.5:
|
1093 |
# flip
|
|
|
1123 |
if not image.mode == "RGB":
|
1124 |
image = image.convert("RGB")
|
1125 |
self.original_sizes_class_imgs.append((image.height, image.width))
|
1126 |
+
image = train_resize(image)
|
1127 |
if args.random_flip and random.random() < 0.5:
|
1128 |
# flip
|
1129 |
image = train_flip(image)
|
|
|
1149 |
|
1150 |
self.image_transforms = transforms.Compose(
|
1151 |
[
|
1152 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
1153 |
+
transforms.CenterCrop(size), #if center_crop else transforms.RandomCrop(size),
|
1154 |
transforms.ToTensor(),
|
1155 |
transforms.Normalize([0.5], [0.5]),
|
1156 |
]
|
|
|
1815 |
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
1816 |
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
|
1817 |
class_num=args.num_class_images,
|
1818 |
+
size=(1080,1920),
|
1819 |
repeats=args.repeats,
|
1820 |
center_crop=args.center_crop,
|
1821 |
)
|
|
|
1835 |
|
1836 |
def compute_time_ids(crops_coords_top_left, original_size=None):
|
1837 |
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
1838 |
+
target_size = (1080, 1920)
|
1839 |
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
1840 |
add_time_ids = torch.tensor([add_time_ids])
|
1841 |
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|