Spaces:
Runtime error
Runtime error
generate a folder on the fly for prior preservation
Browse files
train_dreambooth_lora_sdxl_advanced.py
CHANGED
@@ -22,6 +22,7 @@ import math
|
|
22 |
import os
|
23 |
import random
|
24 |
import re
|
|
|
25 |
import shutil
|
26 |
import warnings
|
27 |
from contextlib import nullcontext
|
@@ -827,17 +828,17 @@ def parse_args(input_args=None):
|
|
827 |
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
828 |
args.local_rank = env_local_rank
|
829 |
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
|
842 |
return args
|
843 |
|
@@ -1321,6 +1322,9 @@ def main(args):
|
|
1321 |
|
1322 |
# Generate class images if prior preservation is enabled.
|
1323 |
if args.with_prior_preservation:
|
|
|
|
|
|
|
1324 |
class_images_dir = Path(args.class_data_dir)
|
1325 |
if not class_images_dir.exists():
|
1326 |
class_images_dir.mkdir(parents=True)
|
|
|
22 |
import os
|
23 |
import random
|
24 |
import re
|
25 |
+
import uuid
|
26 |
import shutil
|
27 |
import warnings
|
28 |
from contextlib import nullcontext
|
|
|
828 |
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
829 |
args.local_rank = env_local_rank
|
830 |
|
831 |
+
if args.with_prior_preservation:
|
832 |
+
if args.class_data_dir is None:
|
833 |
+
raise ValueError("You must specify a data directory for class images.")
|
834 |
+
if args.class_prompt is None:
|
835 |
+
raise ValueError("You must specify prompt for class images.")
|
836 |
+
else:
|
837 |
+
# logger is not available yet
|
838 |
+
if args.class_data_dir is not None:
|
839 |
+
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
|
840 |
+
if args.class_prompt is not None:
|
841 |
+
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
842 |
|
843 |
return args
|
844 |
|
|
|
1322 |
|
1323 |
# Generate class images if prior preservation is enabled.
|
1324 |
if args.with_prior_preservation:
|
1325 |
+
if arg.class_data_dir == None:
|
1326 |
+
class_folder = str(uuid.uuid4())
|
1327 |
+
args.class_data_dir = os.path.join("ariadne", class_folder)
|
1328 |
class_images_dir = Path(args.class_data_dir)
|
1329 |
if not class_images_dir.exists():
|
1330 |
class_images_dir.mkdir(parents=True)
|