cora / utils /args.py
armikaeili's picture
code added
79c5088
import argparse
def add_general_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--prompts_file", type=str, default="")
parser.set_defaults(fp16=False)
parser.add_argument("--fp16", action="store_true")
# parser.add_argument("--seeds", type=int, nargs='+', default=[7], help="List of seed values (e.g., --seed 22 42)")
parser.add_argument("--seed", type=int, default=7, help="Seed value for random number generation.")
parser.add_argument("--output_dir", type=str, default="output")
parser.add_argument("--eval_dataset_folder", type=str, default="dataset")
parser.add_argument("--num_of_timesteps", type=int, default=5) # 3 or 4
def add_extra_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--guidance_scale", type=float, default=0.0, help="Guidance scale value.")
parser.add_argument("--apply_dift_correction", action="store_true", help="Apply DIFT correction.")
parser.set_defaults(apply_dift_correction=False)
parser.add_argument("--w1", type=float, default=1.9, help="Weight for CTRL-X mode.")
parser.add_argument("--support_new_object", action="store_true", help="Enable support for new object detection.")
parser.add_argument("--mode", type=str, default="slerp_dift", help="Attention Type (e.g., normal, slerp, lerp, ...).")
parser.add_argument("--dift_timestep", type=int, default=400, help="DIFT timestep.")
parser.add_argument("--movement_intensifier", type=float, default=0.2, help="Movement intensifier factor.")
parser.add_argument("--structural_alignment", action="store_true", help="Enable structural alignment.")
def add_editing_arguments(parser: argparse.ArgumentParser) -> None:
parser.add_argument("--max_norm_zs", type=float, nargs="+", default=[-1, -1, -1, 15.5],
help="A list of floats for max_norm_zs.")
parser.add_argument("--noise_shift_delta", type=float, default=1)
parser.add_argument("--noise_timesteps", type=int, nargs="+", default=[799, 499, 199, 0],
help="A list of ints for noise_timesteps.")
parser.add_argument("--timesteps", type=int, nargs="+", default=[999, 799, 499, 199],
help="A list of ints for timesteps.")
parser.add_argument("--num_steps_inversion", type=int, default=5)
parser.add_argument("--step_start", type=int, default=1)
def check_args(args):
if args.num_of_timesteps not in [3, 4, 5, 10]:
raise ValueError("num_timesteps must be 3, 4, or 5 or 10")
if args.timesteps is not None:
num_steps_actual = len(args.timesteps)
else:
num_steps_actual = args.num_steps_inversion - args.step_start
if isinstance(args.max_norm_zs, (int, float)):
args.max_norm_zs = [args.max_norm_zs] * num_steps_actual
assert (
len(args.max_norm_zs) == num_steps_actual
), f"len(args.max_norm_zs) ({len(args.max_norm_zs)}) != num_steps_actual ({num_steps_actual})"
assert args.noise_timesteps is None or len(args.noise_timesteps) == (
num_steps_actual
), f"len(args.noise_timesteps) ({len(args.noise_timesteps)}) != num_steps_actual ({num_steps_actual})"
def get_args():
parser = argparse.ArgumentParser()
add_general_arguments(parser)
add_editing_arguments(parser)
add_extra_arguments(parser)
args = parser.parse_args()
check_args(args)
return args