Wan-2.2-5B / app.py
dangthr's picture
Update app.py
2a7f487 verified
import os
import sys
import argparse
import random
import numpy as np
import torch
from huggingface_hub import snapshot_download
from PIL import Image
import gc
# 将当前文件所在目录添加到 Python 路径中
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# 从 'wan' 库中导入所需模块
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS
from wan.utils.utils import cache_video
# --- 1. 模型下载器 ---
def download_models():
"""
从 Hugging Face Hub 下载并缓存所需的模型。
"""
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
print(f"正在为 {repo_id} 下载模型检查点...")
try:
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
print(f"✅ 模型成功下载到: {ckpt_dir}")
except Exception as e:
print(f"❌ 下载模型时出错: {e}")
sys.exit(1)
# --- 2. 视频生成函数 ---
def generate_video_cli(prompt: str):
"""
使用命令行设置,根据文本提示生成视频。
"""
print("🎬 开始视频生成流程...")
# --- 设置 ---
print("正在加载模型配置...")
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
# 确保模型已下载,否则立即下载。
try:
# snapshot_download 会检查本地缓存,如果已存在则不会重复下载
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
except Exception as e:
print(f"❌ 无法找到或下载模型。请先运行 `python app.py --downloader`。")
print(f"错误详情: {e}")
sys.exit(1)
print(f"使用来自 {ckpt_dir} 的检查点")
TASK_NAME = 'ti2v-5B'
cfg = WAN_CONFIGS[TASK_NAME]
# --- 生成参数 (使用原脚本中的默认值) ---
height = 704
width = 1280
duration_seconds = 2.0
sampling_steps = 38
guide_scale = cfg.sample_guide_scale
shift = cfg.sample_shift
seed = -1 # -1 代表随机种子
image = None # 当前命令行版本不处理图像输入
# --- 处理 ---
if seed == -1:
seed = random.randint(0, sys.maxsize)
print(f"使用随机种子: {seed}")
# 确保尺寸有效
MOD_VALUE = 32
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
# 计算帧数
FIXED_FPS = 24
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 121
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
print(f"正在生成 {num_frames} 帧 ({duration_seconds}秒 @ {FIXED_FPS}fps),分辨率为 {target_w}x{target_h}。")
# --- 初始化 Pipeline ---
print("正在初始化 WanTI2V pipeline... (可能需要一些时间)")
device = "cuda" if torch.cuda.is_available() else "cpu"
device_id = 0 if torch.cuda.is_available() else -1
if device == "cpu":
print("⚠️ 警告: 未检测到 GPU。在 CPU 上运行会非常慢。")
try:
pipeline = wan.WanTI2V(
config=cfg,
checkpoint_dir=ckpt_dir,
device_id=device_id,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=False,
convert_model_dtype=True,
)
print("Pipeline 初始化完成。")
except Exception as e:
print(f"❌ 初始化 pipeline 失败: {e}")
sys.exit(1)
# --- 生成视频 ---
print(f"正在为提示词生成视频: '{prompt}'")
size_str = f"{target_h}*{target_w}"
video_tensor = pipeline.generate(
input_prompt=prompt,
img=image,
size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
frame_num=num_frames,
shift=shift,
sample_solver='unipc',
sampling_steps=int(sampling_steps),
guide_scale=guide_scale,
seed=seed,
offload_model=True
)
# --- 保存视频 ---
print("正在保存视频...")
# 根据提示词生成一个安全的文件名
safe_prompt = "".join([c for c in prompt if c.isalnum() or c==' ']).rstrip()
safe_prompt = safe_prompt.replace(" ", "_")
output_filename = f"{safe_prompt[:50]}_{seed}.mp4"
output_path = os.path.join(os.getcwd(), output_filename) #保存在当前工作目录
video_path = cache_video(
tensor=video_tensor[None],
save_file=output_path, # 指定保存路径
fps=cfg.sample_fps,
normalize=True,
value_range=(-1, 1)
)
# --- 清理 ---
del pipeline
del video_tensor
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"✅ 视频生成完成!已保存至: {video_path}")
# --- 3. 主执行模块 ---
def main():
"""
解析命令行参数并运行相应的功能。
"""
parser = argparse.ArgumentParser(
description="Wan 2.2 TI2V-5B 命令行工具。用于从文本生成视频或下载模型。",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
'--prompt',
nargs='+',
type=str,
help="用于视频生成的文本提示词。\n示例: --prompt A beautiful waterfall"
)
parser.add_argument(
'--downloader',
action='store_true',
help="如果指定此参数,将只下载所需的模型然后退出。"
)
args = parser.parse_args()
if args.downloader:
download_models()
elif args.prompt:
# 将单词列表合并成一个完整的提示词字符串
# 这能正确处理 'prompt text' 和 "prompt text" 以及 prompt text
prompt_text = " ".join(args.prompt)
generate_video_cli(prompt_text)
else:
print("未指定操作。请输入 --prompt 或使用 --downloader 标志。")
parser.print_help()
if __name__ == "__main__":
main()