dkebudi commited on
Commit
c30061b
·
verified ·
1 Parent(s): e055b3b

hard code 1080x1920

Browse files
train_dreambooth_lora_sdxl_advanced.py CHANGED
@@ -991,7 +991,7 @@ class DreamBoothDataset(Dataset):
991
  class_data_root=None,
992
  class_num=None,
993
  token_abstraction_dict=None, # token mapping for textual inversion
994
- size=1024,
995
  repeats=1,
996
  center_crop=False,
997
  ):
@@ -1070,8 +1070,8 @@ class DreamBoothDataset(Dataset):
1070
  self.original_sizes = []
1071
  self.crop_top_lefts = []
1072
  self.pixel_values = []
1073
- #train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
1074
- #train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
1075
  train_flip = transforms.RandomHorizontalFlip(p=1.0)
1076
  train_transforms = transforms.Compose(
1077
  [
@@ -1087,7 +1087,7 @@ class DreamBoothDataset(Dataset):
1087
  if not image.mode == "RGB":
1088
  image = image.convert("RGB")
1089
  self.original_sizes.append((image.height, image.width))
1090
- #image = train_resize(image)
1091
 
1092
  if not single_image and args.random_flip and random.random() < 0.5:
1093
  # flip
@@ -1123,7 +1123,7 @@ class DreamBoothDataset(Dataset):
1123
  if not image.mode == "RGB":
1124
  image = image.convert("RGB")
1125
  self.original_sizes_class_imgs.append((image.height, image.width))
1126
- # image = train_resize(image)
1127
  if args.random_flip and random.random() < 0.5:
1128
  # flip
1129
  image = train_flip(image)
@@ -1149,8 +1149,8 @@ class DreamBoothDataset(Dataset):
1149
 
1150
  self.image_transforms = transforms.Compose(
1151
  [
1152
- # transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
1153
- # transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
1154
  transforms.ToTensor(),
1155
  transforms.Normalize([0.5], [0.5]),
1156
  ]
@@ -1815,7 +1815,7 @@ def main(args):
1815
  class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1816
  token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
1817
  class_num=args.num_class_images,
1818
- size=args.resolution,
1819
  repeats=args.repeats,
1820
  center_crop=args.center_crop,
1821
  )
@@ -1835,7 +1835,7 @@ def main(args):
1835
 
1836
  def compute_time_ids(crops_coords_top_left, original_size=None):
1837
  # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1838
- target_size = (args.resolution, args.resolution)
1839
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
1840
  add_time_ids = torch.tensor([add_time_ids])
1841
  add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
 
991
  class_data_root=None,
992
  class_num=None,
993
  token_abstraction_dict=None, # token mapping for textual inversion
994
+ size=(1080,1920),
995
  repeats=1,
996
  center_crop=False,
997
  ):
 
1070
  self.original_sizes = []
1071
  self.crop_top_lefts = []
1072
  self.pixel_values = []
1073
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
1074
+ train_crop = transforms.CenterCrop(size) # if center_crop else transforms.RandomCrop(size)
1075
  train_flip = transforms.RandomHorizontalFlip(p=1.0)
1076
  train_transforms = transforms.Compose(
1077
  [
 
1087
  if not image.mode == "RGB":
1088
  image = image.convert("RGB")
1089
  self.original_sizes.append((image.height, image.width))
1090
+ image = train_resize(image)
1091
 
1092
  if not single_image and args.random_flip and random.random() < 0.5:
1093
  # flip
 
1123
  if not image.mode == "RGB":
1124
  image = image.convert("RGB")
1125
  self.original_sizes_class_imgs.append((image.height, image.width))
1126
+ image = train_resize(image)
1127
  if args.random_flip and random.random() < 0.5:
1128
  # flip
1129
  image = train_flip(image)
 
1149
 
1150
  self.image_transforms = transforms.Compose(
1151
  [
1152
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
1153
+ transforms.CenterCrop(size), #if center_crop else transforms.RandomCrop(size),
1154
  transforms.ToTensor(),
1155
  transforms.Normalize([0.5], [0.5]),
1156
  ]
 
1815
  class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1816
  token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
1817
  class_num=args.num_class_images,
1818
+ size=(1080,1920),
1819
  repeats=args.repeats,
1820
  center_crop=args.center_crop,
1821
  )
 
1835
 
1836
  def compute_time_ids(crops_coords_top_left, original_size=None):
1837
  # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1838
+ target_size = (1080, 1920)
1839
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
1840
  add_time_ids = torch.tensor([add_time_ids])
1841
  add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)