dkebudi commited on
Commit
cfc0981
·
verified ·
1 Parent(s): f4612c7

generate a folder on the fly for prior preservation

Browse files
train_dreambooth_lora_sdxl_advanced.py CHANGED
@@ -22,6 +22,7 @@ import math
22
  import os
23
  import random
24
  import re
 
25
  import shutil
26
  import warnings
27
  from contextlib import nullcontext
@@ -827,17 +828,17 @@ def parse_args(input_args=None):
827
  if env_local_rank != -1 and env_local_rank != args.local_rank:
828
  args.local_rank = env_local_rank
829
 
830
- # if args.with_prior_preservation:
831
- # if args.class_data_dir is None:
832
- # raise ValueError("You must specify a data directory for class images.")
833
- # if args.class_prompt is None:
834
- # raise ValueError("You must specify prompt for class images.")
835
- # else:
836
- # # logger is not available yet
837
- # if args.class_data_dir is not None:
838
- # warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
839
- # if args.class_prompt is not None:
840
- # warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
841
 
842
  return args
843
 
@@ -1321,6 +1322,9 @@ def main(args):
1321
 
1322
  # Generate class images if prior preservation is enabled.
1323
  if args.with_prior_preservation:
 
 
 
1324
  class_images_dir = Path(args.class_data_dir)
1325
  if not class_images_dir.exists():
1326
  class_images_dir.mkdir(parents=True)
 
22
  import os
23
  import random
24
  import re
25
+ import uuid
26
  import shutil
27
  import warnings
28
  from contextlib import nullcontext
 
828
  if env_local_rank != -1 and env_local_rank != args.local_rank:
829
  args.local_rank = env_local_rank
830
 
831
+ if args.with_prior_preservation:
832
+ if args.class_data_dir is None:
833
+ raise ValueError("You must specify a data directory for class images.")
834
+ if args.class_prompt is None:
835
+ raise ValueError("You must specify prompt for class images.")
836
+ else:
837
+ # logger is not available yet
838
+ if args.class_data_dir is not None:
839
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
840
+ if args.class_prompt is not None:
841
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
842
 
843
  return args
844
 
 
1322
 
1323
  # Generate class images if prior preservation is enabled.
1324
  if args.with_prior_preservation:
1325
+ if arg.class_data_dir == None:
1326
+ class_folder = str(uuid.uuid4())
1327
+ args.class_data_dir = os.path.join("ariadne", class_folder)
1328
  class_images_dir = Path(args.class_data_dir)
1329
  if not class_images_dir.exists():
1330
  class_images_dir.mkdir(parents=True)