import os import sys import argparse import random import numpy as np import torch from huggingface_hub import snapshot_download from PIL import Image import gc # 将当前文件所在目录添加到 Python 路径中 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) # 从 'wan' 库中导入所需模块 import wan from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS from wan.utils.utils import cache_video # --- 1. 模型下载器 --- def download_models(): """ 从 Hugging Face Hub 下载并缓存所需的模型。 """ repo_id = "Wan-AI/Wan2.2-TI2V-5B" print(f"正在为 {repo_id} 下载模型检查点...") try: ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False) print(f"✅ 模型成功下载到: {ckpt_dir}") except Exception as e: print(f"❌ 下载模型时出错: {e}") sys.exit(1) # --- 2. 视频生成函数 --- def generate_video_cli(prompt: str): """ 使用命令行设置,根据文本提示生成视频。 """ print("🎬 开始视频生成流程...") # --- 设置 --- print("正在加载模型配置...") repo_id = "Wan-AI/Wan2.2-TI2V-5B" # 确保模型已下载,否则立即下载。 try: # snapshot_download 会检查本地缓存,如果已存在则不会重复下载 ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False) except Exception as e: print(f"❌ 无法找到或下载模型。请先运行 `python app.py --downloader`。") print(f"错误详情: {e}") sys.exit(1) print(f"使用来自 {ckpt_dir} 的检查点") TASK_NAME = 'ti2v-5B' cfg = WAN_CONFIGS[TASK_NAME] # --- 生成参数 (使用原脚本中的默认值) --- height = 704 width = 1280 duration_seconds = 2.0 sampling_steps = 38 guide_scale = cfg.sample_guide_scale shift = cfg.sample_shift seed = -1 # -1 代表随机种子 image = None # 当前命令行版本不处理图像输入 # --- 处理 --- if seed == -1: seed = random.randint(0, sys.maxsize) print(f"使用随机种子: {seed}") # 确保尺寸有效 MOD_VALUE = 32 target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) # 计算帧数 FIXED_FPS = 24 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 121 num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) print(f"正在生成 {num_frames} 帧 ({duration_seconds}秒 @ {FIXED_FPS}fps),分辨率为 {target_w}x{target_h}。") # --- 初始化 Pipeline --- print("正在初始化 WanTI2V pipeline... (可能需要一些时间)") device = "cuda" if torch.cuda.is_available() else "cpu" device_id = 0 if torch.cuda.is_available() else -1 if device == "cpu": print("⚠️ 警告: 未检测到 GPU。在 CPU 上运行会非常慢。") try: pipeline = wan.WanTI2V( config=cfg, checkpoint_dir=ckpt_dir, device_id=device_id, rank=0, t5_fsdp=False, dit_fsdp=False, use_sp=False, t5_cpu=False, init_on_cpu=False, convert_model_dtype=True, ) print("Pipeline 初始化完成。") except Exception as e: print(f"❌ 初始化 pipeline 失败: {e}") sys.exit(1) # --- 生成视频 --- print(f"正在为提示词生成视频: '{prompt}'") size_str = f"{target_h}*{target_w}" video_tensor = pipeline.generate( input_prompt=prompt, img=image, size=SIZE_CONFIGS.get(size_str, (target_h, target_w)), max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w), frame_num=num_frames, shift=shift, sample_solver='unipc', sampling_steps=int(sampling_steps), guide_scale=guide_scale, seed=seed, offload_model=True ) # --- 保存视频 --- print("正在保存视频...") # 根据提示词生成一个安全的文件名 safe_prompt = "".join([c for c in prompt if c.isalnum() or c==' ']).rstrip() safe_prompt = safe_prompt.replace(" ", "_") output_filename = f"{safe_prompt[:50]}_{seed}.mp4" output_path = os.path.join(os.getcwd(), output_filename) #保存在当前工作目录 video_path = cache_video( tensor=video_tensor[None], save_file=output_path, # 指定保存路径 fps=cfg.sample_fps, normalize=True, value_range=(-1, 1) ) # --- 清理 --- del pipeline del video_tensor gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"✅ 视频生成完成!已保存至: {video_path}") # --- 3. 主执行模块 --- def main(): """ 解析命令行参数并运行相应的功能。 """ parser = argparse.ArgumentParser( description="Wan 2.2 TI2V-5B 命令行工具。用于从文本生成视频或下载模型。", formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( '--prompt', nargs='+', type=str, help="用于视频生成的文本提示词。\n示例: --prompt A beautiful waterfall" ) parser.add_argument( '--downloader', action='store_true', help="如果指定此参数,将只下载所需的模型然后退出。" ) args = parser.parse_args() if args.downloader: download_models() elif args.prompt: # 将单词列表合并成一个完整的提示词字符串 # 这能正确处理 'prompt text' 和 "prompt text" 以及 prompt text prompt_text = " ".join(args.prompt) generate_video_cli(prompt_text) else: print("未指定操作。请输入 --prompt 或使用 --downloader 标志。") parser.print_help() if __name__ == "__main__": main()