File size: 2,805 Bytes
c1f7300
d491e94
 
093967b
035f115
62066bc
c1f7300
035f115
 
 
093967b
035f115
 
b74bc4e
035f115
 
b74bc4e
035f115
 
 
c1f7300
035f115
 
 
d491e94
035f115
 
 
edc3608
035f115
62066bc
 
035f115
 
62066bc
edc3608
035f115
 
d491e94
035f115
 
 
fedb718
d491e94
035f115
41394ac
035f115
 
 
41394ac
 
 
 
035f115
 
41394ac
 
 
 
 
 
 
035f115
41394ac
 
 
 
 
 
 
035f115
41394ac
 
 
035f115
 
41394ac
 
035f115
41394ac
035f115
 
 
41394ac
035f115
41394ac
 
 
035f115
 
 
41394ac
035f115
41394ac
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import sys
import subprocess
import time
from pathlib import Path
from huggingface_hub import snapshot_download

# --------------------
# CONFIGURATION
# --------------------
MODEL_REPO = "tencent/HunyuanVideo-Avatar"
HF_CACHE_DIR = Path("/home/user/.cache/huggingface/hf_cache/hunyuan_avatar")
HF_CACHE_DIR.mkdir(parents=True, exist_ok=True)

CHECKPOINT_FILE = HF_CACHE_DIR / "ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt"
CHECKPOINT_FP8_FILE = HF_CACHE_DIR / "ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states_fp8.pt"

ASSETS_CSV = "assets/test.csv"
OUTPUT_DIR = Path("results-poor")
OUTPUT_DIR.mkdir(exist_ok=True)

# --------------------
# Download the model (if needed)
# --------------------
def download_model():
    if CHECKPOINT_FILE.exists() and CHECKPOINT_FP8_FILE.exists():
        print("βœ… Model checkpoint already exists. Skipping download.")
        return

    print("⬇️ Downloading model into HF Space cache...")
    snapshot_download(
        repo_id=MODEL_REPO,
        local_dir=HF_CACHE_DIR,
        local_dir_use_symlinks=False
    )

    if not CHECKPOINT_FILE.exists():
        print(f"❌ Missing checkpoint: {CHECKPOINT_FILE}")
        sys.exit(1)

    if not CHECKPOINT_FP8_FILE.exists():
        print(f"❌ Missing FP8 checkpoint: {CHECKPOINT_FP8_FILE}")
        sys.exit(1)

    print("βœ… Model download complete.")

# --------------------
# Run sample_gpu_poor.py
# --------------------
def run_sample_gpu_poor():
    print("🎬 Running sample_gpu_poor.py...")
    cmd = [
        "python3", "hymm_sp/sample_gpu_poor.py",
        "--input", ASSETS_CSV,
        "--ckpt", str(CHECKPOINT_FP8_FILE),
        "--sample-n-frames", "129",
        "--seed", "128",
        "--image-size", "704",
        "--cfg-scale", "7.5",
        "--infer-steps", "50",
        "--use-deepcache", "1",
        "--flow-shift-eval-video", "5.0",
        "--save-path", str(OUTPUT_DIR),
        "--use-fp8",
        "--cpu-offload",
        "--infer-min"
    ]

    env = os.environ.copy()
    env["PYTHONPATH"] = "./"
    env["MODEL_BASE"] = str(HF_CACHE_DIR)
    env["CPU_OFFLOAD"] = "1"
    env["CUDA_VISIBLE_DEVICES"] = "0"

    proc = subprocess.run(cmd, env=env)
    if proc.returncode != 0:
        print("❌ sample_gpu_poor.py failed.")
        sys.exit(1)
    print("βœ… sample_gpu_poor.py completed successfully.")

# --------------------
# Optional: Start UI
# --------------------
def run_gradio_ui():
    print("🟒 Launching Gradio interface...")
    cmd = ["python3", "hymm_gradio/gradio_audio.py"]
    subprocess.Popen(cmd)

# --------------------
# Entry point
# --------------------
def main():
    download_model()
    run_sample_gpu_poor()
    time.sleep(5)
    run_gradio_ui()

if __name__ == "__main__":
    main()