Spaces:
Build error
Build error
File size: 9,490 Bytes
357c94c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import argparse
from hymm_sp.constants import *
import re
import collections.abc
def as_tuple(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
if x is None or isinstance(x, (int, float, str)):
return (x,)
else:
raise ValueError(f"Unknown type {type(x)}")
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="Hunyuan Multimodal training/inference script")
parser = add_extra_args(parser)
args = parser.parse_args(namespace=namespace)
args = sanity_check_args(args)
return args
def add_extra_args(parser: argparse.ArgumentParser):
parser = add_network_args(parser)
parser = add_extra_models_args(parser)
parser = add_denoise_schedule_args(parser)
parser = add_evaluation_args(parser)
return parser
def add_network_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Network")
group.add_argument("--model", type=str, default="HYVideo-T/2",
help="Model architecture to use. It it also used to determine the experiment directory.")
group.add_argument("--latent-channels", type=str, default=None,
help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
"it still needs to match the latent channels of the VAE model.")
group.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
return parser
def add_extra_models_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Extra Models (VAE, Text Encoder, Tokenizer)")
# VAE
group.add_argument("--vae", type=str, default="884-16c-hy0801", help="Name of the VAE model.")
group.add_argument("--vae-precision", type=str, default="fp16",
help="Precision mode for the VAE model.")
group.add_argument("--vae-tiling", action="store_true", default=True, help="Enable tiling for the VAE model.")
group.add_argument("--text-encoder", type=str, default="llava-llama-3-8b", choices=list(TEXT_ENCODER_PATH),
help="Name of the text encoder model.")
group.add_argument("--text-encoder-precision", type=str, default="fp16", choices=PRECISIONS,
help="Precision mode for the text encoder model.")
group.add_argument("--text-states-dim", type=int, default=4096, help="Dimension of the text encoder hidden states.")
group.add_argument("--text-len", type=int, default=256, help="Maximum length of the text input.")
group.add_argument("--tokenizer", type=str, default="llava-llama-3-8b", choices=list(TOKENIZER_PATH),
help="Name of the tokenizer model.")
group.add_argument("--text-encoder-infer-mode", type=str, default="encoder", choices=["encoder", "decoder"],
help="Inference mode for the text encoder model. It should match the text encoder type. T5 and "
"CLIP can only work in 'encoder' mode, while Llava/GLM can work in both modes.")
group.add_argument("--prompt-template-video", type=str, default='li-dit-encode-video', choices=PROMPT_TEMPLATE,
help="Video prompt template for the decoder-only text encoder model.")
group.add_argument("--hidden-state-skip-layer", type=int, default=2,
help="Skip layer for hidden states.")
group.add_argument("--apply-final-norm", action="store_true",
help="Apply final normalization to the used text encoder hidden states.")
# - CLIP
group.add_argument("--text-encoder-2", type=str, default='clipL', choices=list(TEXT_ENCODER_PATH),
help="Name of the second text encoder model.")
group.add_argument("--text-encoder-precision-2", type=str, default="fp16", choices=PRECISIONS,
help="Precision mode for the second text encoder model.")
group.add_argument("--text-states-dim-2", type=int, default=768,
help="Dimension of the second text encoder hidden states.")
group.add_argument("--tokenizer-2", type=str, default='clipL', choices=list(TOKENIZER_PATH),
help="Name of the second tokenizer model.")
group.add_argument("--text-len-2", type=int, default=77, help="Maximum length of the second text input.")
group.set_defaults(use_attention_mask=True)
group.add_argument("--text-projection", type=str, default="single_refiner", choices=TEXT_PROJECTION,
help="A projection layer for bridging the text encoder hidden states and the diffusion model "
"conditions.")
return parser
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Denoise schedule")
group.add_argument("--flow-shift-eval-video", type=float, default=None, help="Shift factor for flow matching schedulers when using video data.")
group.add_argument("--flow-reverse", action="store_true", default=True, help="If reverse, learning/sampling from t=1 -> t=0.")
group.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.")
group.add_argument("--use-linear-quadratic-schedule", action="store_true", help="Use linear quadratic schedule for flow matching."
"Follow MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)")
group.add_argument("--linear-schedule-end", type=int, default=25, help="End step for linear quadratic schedule for flow matching.")
return parser
def add_evaluation_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Validation Loss Evaluation")
parser.add_argument("--precision", type=str, default="bf16", choices=PRECISIONS,
help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.")
parser.add_argument("--reproduce", action="store_true",
help="Enable reproducibility by setting random seeds and deterministic algorithms.")
parser.add_argument("--ckpt", type=str, help="Path to the checkpoint to evaluate.")
parser.add_argument("--load-key", type=str, default="module", choices=["module", "ema"],
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.")
parser.add_argument("--cpu-offload", action="store_true", help="Use CPU offload for the model load.")
parser.add_argument("--infer-min", action="store_true", help="infer 5s.")
group.add_argument( "--use-fp8", action="store_true", help="Enable use fp8 for inference acceleration.")
group.add_argument("--video-size", type=int, nargs='+', default=512,
help="Video size for training. If a single value is provided, it will be used for both width "
"and height. If two values are provided, they will be used for width and height "
"respectively.")
group.add_argument("--sample-n-frames", type=int, default=1,
help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1")
group.add_argument("--infer-steps", type=int, default=100, help="Number of denoising steps for inference.")
group.add_argument("--val-disable-autocast", action="store_true",
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.")
group.add_argument("--num-images", type=int, default=1, help="Number of images to generate for each prompt.")
group.add_argument("--seed", type=int, default=1024, help="Seed for evaluation.")
group.add_argument("--save-path-suffix", type=str, default="", help="Suffix for the directory of saved samples.")
group.add_argument("--pos-prompt", type=str, default='', help="Prompt for sampling during evaluation.")
group.add_argument("--neg-prompt", type=str, default='', help="Negative prompt for sampling during evaluation.")
group.add_argument("--image-size", type=int, default=704)
group.add_argument("--pad-face-size", type=float, default=0.7, help="Pad bbox for face align.")
group.add_argument("--image-path", type=str, default="", help="")
group.add_argument("--save-path", type=str, default=None, help="Path to save the generated samples.")
group.add_argument("--input", type=str, default=None, help="test data.")
group.add_argument("--item-name", type=str, default=None, help="")
group.add_argument("--cfg-scale", type=float, default=7.5, help="Classifier free guidance scale.")
group.add_argument("--ip-cfg-scale", type=float, default=0, help="Classifier free guidance scale.")
group.add_argument("--use-deepcache", type=int, default=1)
return parser
def sanity_check_args(args):
# VAE channels
vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
if not re.match(vae_pattern, args.vae):
raise ValueError(
f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
)
vae_channels = int(args.vae.split("-")[1][:-1])
if args.latent_channels is None:
args.latent_channels = vae_channels
if vae_channels != args.latent_channels:
raise ValueError(
f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
)
return args
|