Spaces:
Runtime error
Runtime error
| import argparse | |
| import shlex | |
| import os | |
| import pickle | |
| import swapae.util as util | |
| import swapae.models as models | |
| import swapae.models.networks as networks | |
| import swapae.data as data | |
| import swapae.evaluation as evaluation | |
| import swapae.optimizers as optimizers | |
| from swapae.util import IterationCounter | |
| from swapae.util import Visualizer | |
| class BaseOptions(): | |
| def initialize(self, parser): | |
| # experiment specifics | |
| parser.add_argument('--name', type=str, default="ffhq512_pretrained", help='name of the experiment. It decides where to store samples and models') | |
| parser.add_argument('--easy_label', type=str, default="") | |
| parser.add_argument('--num_gpus', type=int, default=1, help='#GPUs to use. 0 means CPU mode') | |
| parser.add_argument('--checkpoints_dir', type=str, default='/home/xtli/Documents/GITHUB/swapping-autoencoder-pytorch/checkpoints/', help='models are saved here') | |
| parser.add_argument('--model', type=str, default='swapping_autoencoder', help='which model to use') | |
| parser.add_argument('--optimizer', type=str, default='swapping_autoencoder', help='which model to use') | |
| parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') | |
| parser.add_argument('--resume_iter', type=str, default="latest", | |
| help="# iterations (in thousands) to resume") | |
| parser.add_argument('--num_classes', type=int, default=0) | |
| # input/output sizes | |
| parser.add_argument('--batch_size', type=int, default=1, help='input batch size') | |
| parser.add_argument('--preprocess', type=str, default='resize', help='scaling and cropping of images at load time.') | |
| parser.add_argument('--load_size', type=int, default=512, help='Scale images to this size. The final image will be cropped to --crop_size.') | |
| parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') | |
| parser.add_argument('--preprocess_crop_padding', type=int, default=None, help='padding parameter of transforms.RandomCrop(). It is not used if --preprocess does not contain crop option.') | |
| parser.add_argument('--no_flip', action='store_true') | |
| parser.add_argument('--shuffle_dataset', type=str, default=None, choices=('true', 'false')) | |
| # for setting inputs | |
| parser.add_argument('--dataroot', type=str, default="/home/xtli/Dropbox/swapping-autoencoder-pytorch/testphotos/ffhq512/fig9/") | |
| parser.add_argument('--dataset_mode', type=str, default='imagefolder') | |
| parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') | |
| # networks | |
| parser.add_argument("--netG", default="StyleGAN2Resnet") | |
| parser.add_argument("--netD", default="StyleGAN2") | |
| parser.add_argument("--netE", default="StyleGAN2Resnet") | |
| parser.add_argument("--netPatchD", default="StyleGAN2") | |
| parser.add_argument("--use_antialias", type=util.str2bool, default=True) | |
| parser.add_argument("-f", "--config_file", type=str, default='models/swap/json/sem_cons.json', help='json files including all arguments') | |
| parser.add_argument("--local_rank", type=int) | |
| return parser | |
| def gather_options(self, command=None): | |
| parser = AugmentedArgumentParser() | |
| parser.custom_command = command | |
| # get basic options | |
| parser = self.initialize(parser) | |
| # get the basic options | |
| opt, unknown = parser.parse_known_args() | |
| # modify model-related parser options | |
| model_name = opt.model | |
| model_option_setter = models.get_option_setter(model_name) | |
| parser = model_option_setter(parser, self.isTrain) | |
| # modify network-related parser options | |
| parser = networks.modify_commandline_options(parser, self.isTrain) | |
| # modify optimizer-related parser options | |
| optimizer_name = opt.optimizer | |
| optimizer_option_setter = optimizers.get_option_setter(optimizer_name) | |
| parser = optimizer_option_setter(parser, self.isTrain) | |
| # modify dataset-related parser options | |
| dataset_mode = opt.dataset_mode | |
| dataset_option_setter = data.get_option_setter(dataset_mode) | |
| parser = dataset_option_setter(parser, self.isTrain) | |
| # modify parser options related to iteration_counting | |
| parser = Visualizer.modify_commandline_options(parser, self.isTrain) | |
| # modify parser options related to iteration_counting | |
| parser = IterationCounter.modify_commandline_options(parser, self.isTrain) | |
| # modify evaluation-related parser options | |
| evaluation_option_setter = evaluation.get_option_setter() | |
| parser = evaluation_option_setter(parser, self.isTrain) | |
| opt, unknown = parser.parse_known_args() | |
| opt = parser.parse_args() | |
| self.parser = parser | |
| return opt | |
| def print_options(self, opt): | |
| """Print and save options | |
| It will print both current options and default values(if different). | |
| It will save options into a text file / [checkpoints_dir] / opt.txt | |
| """ | |
| message = '' | |
| message += '----------------- Options ---------------\n' | |
| for k, v in sorted(vars(opt).items()): | |
| comment = '' | |
| default = self.parser.get_default(k) | |
| if v != default: | |
| comment = '\t[default: %s]' % str(default) | |
| message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) | |
| message += '----------------- End -------------------' | |
| print(message) | |
| def option_file_path(self, opt, makedir=False): | |
| expr_dir = os.path.join(opt.checkpoints_dir, opt.name) | |
| if makedir: | |
| util.mkdirs(expr_dir) | |
| file_name = os.path.join(expr_dir, 'opt') | |
| return file_name | |
| def save_options(self, opt): | |
| file_name = self.option_file_path(opt, makedir=True) | |
| with open(file_name + '.txt', 'wt') as opt_file: | |
| for k, v in sorted(vars(opt).items()): | |
| comment = '' | |
| default = self.parser.get_default(k) | |
| if v != default: | |
| comment = '\t[default: %s]' % str(default) | |
| opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) | |
| with open(file_name + '.pkl', 'wb') as opt_file: | |
| pickle.dump(opt, opt_file) | |
| def parse(self, save=False, command=None): | |
| opt = self.gather_options(command) | |
| opt.isTrain = self.isTrain # train or test | |
| self.print_options(opt) | |
| if opt.isTrain: | |
| self.save_options(opt) | |
| opt.dataroot = os.path.expanduser(opt.dataroot) | |
| assert opt.num_gpus <= opt.batch_size, "Batch size must not be smaller than num_gpus" | |
| return opt | |
| class TrainOptions(BaseOptions): | |
| def __init__(self): | |
| super().__init__() | |
| self.isTrain = True | |
| def initialize(self, parser): | |
| super().initialize(parser) | |
| parser.add_argument('--continue_train', type=util.str2bool, default=False, help="resume training from last checkpoint") | |
| parser.add_argument('--pretrained_name', type=str, default=None, | |
| help="Load weights from the checkpoint of another experiment") | |
| return parser | |
| class TestOptions(BaseOptions): | |
| def __init__(self): | |
| super().__init__() | |
| self.isTrain = False | |
| def initialize(self, parser): | |
| super().initialize(parser) | |
| parser.add_argument("--result_dir", type=str, default="results") | |
| return parser | |
| class AugmentedArgumentParser(argparse.ArgumentParser): | |
| def parse_args(self, args=None, namespace=None): | |
| """ Enables passing bash commands as arguments to the class. | |
| """ | |
| print("parsing args...") | |
| if args is None and hasattr(self, 'custom_command') and self.custom_command is not None: | |
| print('using custom command') | |
| print(self.custom_command) | |
| args = shlex.split(self.custom_command)[2:] | |
| return super().parse_args(args, namespace) | |
| def parse_known_args(self, args=None, namespace=None): | |
| if args is None and hasattr(self, 'custom_command') and self.custom_command is not None: | |
| args = shlex.split(self.custom_command)[2:] | |
| return super().parse_known_args(args, namespace) | |
| def add_argument(self, *args, **kwargs): | |
| """ Support for providing a new argument type called "str2bool" | |
| Example: | |
| parser.add_argument("--my_option", type=util.str2bool, default=|bool|) | |
| 1. "python train.py" sets my_option to be |bool| | |
| 2. "python train.py --my_option" sets my_option to be True | |
| 3. "python train.py --my_option False" sets my_option to be False | |
| 4. "python train.py --my_option True" sets my_options to be True | |
| https://stackoverflow.com/a/43357954 | |
| """ | |
| if 'type' in kwargs and kwargs['type'] == util.str2bool: | |
| if 'nargs' not in kwargs: | |
| kwargs['nargs'] = "?" | |
| if 'const' not in kwargs: | |
| kwargs['const'] = True | |
| super().add_argument(*args, **kwargs) | |