File size: 3,338 Bytes
f7aef4e
c1f7300
47783c0
 
f7aef4e
 
47783c0
 
 
 
 
 
f7aef4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5a7706
 
 
 
 
 
f7aef4e
e31f052
 
 
 
f7aef4e
 
bb8fcdf
f7aef4e
 
 
 
abcba3a
f7aef4e
a98af42
47783c0
f7aef4e
47783c0
 
 
f7aef4e
47783c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7aef4e
 
 
 
 
abcba3a
f7aef4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abcba3a
f7aef4e
 
 
 
abcba3a
 
4fe3949
f7aef4e
 
 
47783c0
abcba3a
47783c0
67c0d60
41394ac
 
f7aef4e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

import os
import sys
import subprocess
import time
from huggingface_hub import snapshot_download

MODEL_REPO = "tencent/HunyuanVideo-Avatar"
BASE_DIR = os.getcwd()
WEIGHTS_DIR = os.path.join(BASE_DIR, "weights")
OUTPUT_BASEPATH = os.path.join(BASE_DIR, "results-poor")

# Specific checkpoint to use in the poor sampling run
CHECKPOINT_FILE = os.path.join(
    WEIGHTS_DIR,
    "ckpts",
    "hunyuan-video-t2v-720p",
    "transformers",
    "mp_rank_00_model_states.pt"
)
CHECKPOINT_FP8_FILE = os.path.join(
    WEIGHTS_DIR,
    "ckpts",
    "hunyuan-video-t2v-720p",
    "transformers",
    "mp_rank_00_model_states_fp8.pt"
)

def download_model():
    print("⬇️  Model not found. Downloading with snapshot_download into weights directory...")
    os.makedirs(WEIGHTS_DIR, exist_ok=True)

    snapshot_download(
    repo_id=MODEL_REPO,
    local_dir=WEIGHTS_DIR,
    local_dir_use_symlinks=False,
    ignore_patterns=["ckpts/**"]  # Ignore everything inside the 'ckpts' directory
    )
    snapshot_download(
    repo_id="tencent/HunyuanVideo-Avatar",  # Just the repo ID
    local_dir=WEIGHTS_DIR,
    local_dir_use_symlinks=False,
    allow_patterns=["ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt"]
    )

    
    
    if not os.path.isfile(CHECKPOINT_FP8_FILE):
        print(f"❌ FP8 checkpoint file not found at {CHECKPOINT_FP8_FILE}. Cannot proceed with sample_gpu_poor.py.")
        sys.exit(1)

    print("βœ… Model downloaded successfully.")

def run_sample_gpu_poor():
    print("🎬 Running sample_gpu_poor.py...")
    cmd = [
        "python3", "hymm_sp/sample_gpu_poor.py",
        "--input", "assets/test.csv",
        "--ckpt", CHECKPOINT_FP8_FILE,
        "--sample-n-frames", "129",
        "--seed", "128",
        "--image-size", "704",
        "--cfg-scale", "7.5",
        "--infer-steps", "50",
        "--use-deepcache", "1",
        "--flow-shift-eval-video", "5.0",
        "--save-path", OUTPUT_BASEPATH,
        "--use-fp8",
        "--cpu-offload",
        "--infer-min"
    ]

    env = os.environ.copy()
    env["PYTHONPATH"] = "./"
    env["MODEL_BASE"] = WEIGHTS_DIR
    env["CPU_OFFLOAD"] = "1"
    env["CUDA_VISIBLE_DEVICES"] = "0"

    proc = subprocess.run(cmd, env=env)
    if proc.returncode != 0:
        print("❌ sample_gpu_poor.py failed.")
        sys.exit(1)
    print("βœ… sample_gpu_poor.py completed successfully.")

def run_flask_audio():
    print("πŸš€ Starting flask_audio.py...")
    cmd = [
        "torchrun",
        "--nnodes=1",
        "--nproc_per_node=8",
        "--master_port=29605",
        "hymm_gradio/flask_audio.py",
        "--input", "assets/test.csv",
        "--ckpt", CHECKPOINT_FILE,
        "--sample-n-frames", "129",
        "--seed", "128",
        "--image-size", "704",
        "--cfg-scale", "7.5",
        "--infer-steps", "50",
        "--use-deepcache", "1",
        "--flow-shift-eval-video", "5.0"
    ]
    subprocess.Popen(cmd)

def run_gradio_ui():
    print("🟒 Starting gradio_audio.py UI...")
    cmd = ["python3", "hymm_gradio/gradio_audio.py"]
    subprocess.Popen(cmd)

def main():
    if os.path.isfile(CHECKPOINT_FP8_FILE):
        print("βœ… Model checkpoint already exists. Skipping download.")
    else:
        download_model()

    run_sample_gpu_poor()

   

if __name__ == "__main__":
    main()