File size: 6,038 Bytes
309fd4d
b9186cf
2a7f487
 
 
ba7cb71
 
 
908b63e
 
2a7f487
 
8ff4968
2a7f487
 
 
 
3113790
2a7f487
8ff4968
2a7f487
535c73d
2a7f487
535c73d
2a7f487
 
8ff4968
2a7f487
 
8ff4968
2a7f487
 
3113790
2a7f487
ba7cb71
2a7f487
535c73d
2a7f487
535c73d
2a7f487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba7cb71
 
2a7f487
ba7cb71
2a7f487
 
8ff4968
 
 
2a7f487
 
 
 
ba7cb71
2a7f487
ba7cb71
2a7f487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ff4968
2a7f487
 
 
 
ba7cb71
 
2a7f487
8ff4968
 
2a7f487
ba7cb71
 
 
 
 
 
 
 
2a7f487
 
 
 
 
 
 
 
 
ba7cb71
2a7f487
 
ba7cb71
 
 
 
2a7f487
 
 
908b63e
 
2a7f487
 
ba7cb71
2a7f487
ba7cb71
 
2a7f487
ba7cb71
2a7f487
 
 
 
 
 
 
3113790
 
2a7f487
 
 
 
 
3113790
2a7f487
 
 
 
 
ba7cb71
 
2a7f487
 
 
 
 
 
 
 
 
 
 
 
ba7cb71
 
2a7f487
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import sys
import argparse
import random
import numpy as np
import torch
from huggingface_hub import snapshot_download
from PIL import Image
import gc

# 将当前文件所在目录添加到 Python 路径中
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

# 从 'wan' 库中导入所需模块
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS
from wan.utils.utils import cache_video

# --- 1. 模型下载器 ---

def download_models():
    """
    从 Hugging Face Hub 下载并缓存所需的模型。
    """
    repo_id = "Wan-AI/Wan2.2-TI2V-5B"
    print(f"正在为 {repo_id} 下载模型检查点...")
    try:
        ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
        print(f"✅ 模型成功下载到: {ckpt_dir}")
    except Exception as e:
        print(f"❌ 下载模型时出错: {e}")
        sys.exit(1)

# --- 2. 视频生成函数 ---

def generate_video_cli(prompt: str):
    """
    使用命令行设置,根据文本提示生成视频。
    """
    print("🎬 开始视频生成流程...")

    # --- 设置 ---
    print("正在加载模型配置...")
    repo_id = "Wan-AI/Wan2.2-TI2V-5B"
    # 确保模型已下载,否则立即下载。
    try:
        # snapshot_download 会检查本地缓存,如果已存在则不会重复下载
        ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
    except Exception as e:
        print(f"❌ 无法找到或下载模型。请先运行 `python app.py --downloader`。")
        print(f"错误详情: {e}")
        sys.exit(1)

    print(f"使用来自 {ckpt_dir} 的检查点")
    
    TASK_NAME = 'ti2v-5B'
    cfg = WAN_CONFIGS[TASK_NAME]
    
    # --- 生成参数 (使用原脚本中的默认值) ---
    height = 704
    width = 1280
    duration_seconds = 2.0
    sampling_steps = 38
    guide_scale = cfg.sample_guide_scale
    shift = cfg.sample_shift
    seed = -1  # -1 代表随机种子
    image = None # 当前命令行版本不处理图像输入

    # --- 处理 ---
    if seed == -1:
        seed = random.randint(0, sys.maxsize)
    print(f"使用随机种子: {seed}")

    # 确保尺寸有效
    MOD_VALUE = 32
    target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
    target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)

    # 计算帧数
    FIXED_FPS = 24
    MIN_FRAMES_MODEL = 8
    MAX_FRAMES_MODEL = 121
    num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
    print(f"正在生成 {num_frames} 帧 ({duration_seconds}秒 @ {FIXED_FPS}fps),分辨率为 {target_w}x{target_h}。")

    # --- 初始化 Pipeline ---
    print("正在初始化 WanTI2V pipeline... (可能需要一些时间)")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    device_id = 0 if torch.cuda.is_available() else -1
    if device == "cpu":
        print("⚠️ 警告: 未检测到 GPU。在 CPU 上运行会非常慢。")

    try:
        pipeline = wan.WanTI2V(
            config=cfg,
            checkpoint_dir=ckpt_dir,
            device_id=device_id,
            rank=0,
            t5_fsdp=False,
            dit_fsdp=False,
            use_sp=False,
            t5_cpu=False,
            init_on_cpu=False,
            convert_model_dtype=True,
        )
        print("Pipeline 初始化完成。")
    except Exception as e:
        print(f"❌ 初始化 pipeline 失败: {e}")
        sys.exit(1)

    # --- 生成视频 ---
    print(f"正在为提示词生成视频: '{prompt}'")
    size_str = f"{target_h}*{target_w}"
    
    video_tensor = pipeline.generate(
        input_prompt=prompt,
        img=image,
        size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
        max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
        frame_num=num_frames,
        shift=shift,
        sample_solver='unipc',
        sampling_steps=int(sampling_steps),
        guide_scale=guide_scale,
        seed=seed,
        offload_model=True
    )

    # --- 保存视频 ---
    print("正在保存视频...")
    
    # 根据提示词生成一个安全的文件名
    safe_prompt = "".join([c for c in prompt if c.isalnum() or c==' ']).rstrip()
    safe_prompt = safe_prompt.replace(" ", "_")
    output_filename = f"{safe_prompt[:50]}_{seed}.mp4"
    output_path = os.path.join(os.getcwd(), output_filename) #保存在当前工作目录

    video_path = cache_video(
        tensor=video_tensor[None],
        save_file=output_path, # 指定保存路径
        fps=cfg.sample_fps,
        normalize=True,
        value_range=(-1, 1)
    )

    # --- 清理 ---
    del pipeline
    del video_tensor
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print(f"✅ 视频生成完成!已保存至: {video_path}")


# --- 3. 主执行模块 ---

def main():
    """
    解析命令行参数并运行相应的功能。
    """
    parser = argparse.ArgumentParser(
        description="Wan 2.2 TI2V-5B 命令行工具。用于从文本生成视频或下载模型。",
        formatter_class=argparse.RawTextHelpFormatter
    )
    
    parser.add_argument(
        '--prompt', 
        nargs='+', 
        type=str, 
        help="用于视频生成的文本提示词。\n示例: --prompt A beautiful waterfall"
    )
    
    parser.add_argument(
        '--downloader', 
        action='store_true', 
        help="如果指定此参数,将只下载所需的模型然后退出。"
    )

    args = parser.parse_args()

    if args.downloader:
        download_models()
    elif args.prompt:
        # 将单词列表合并成一个完整的提示词字符串
        # 这能正确处理 'prompt text' 和 "prompt text" 以及 prompt text
        prompt_text = " ".join(args.prompt)
        generate_video_cli(prompt_text)
    else:
        print("未指定操作。请输入 --prompt 或使用 --downloader 标志。")
        parser.print_help()

if __name__ == "__main__":
    main()