Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import argparse | |
import random | |
import numpy as np | |
import torch | |
from huggingface_hub import snapshot_download | |
from PIL import Image | |
import gc | |
# 将当前文件所在目录添加到 Python 路径中 | |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
# 从 'wan' 库中导入所需模块 | |
import wan | |
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS | |
from wan.utils.utils import cache_video | |
# --- 1. 模型下载器 --- | |
def download_models(): | |
""" | |
从 Hugging Face Hub 下载并缓存所需的模型。 | |
""" | |
repo_id = "Wan-AI/Wan2.2-TI2V-5B" | |
print(f"正在为 {repo_id} 下载模型检查点...") | |
try: | |
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False) | |
print(f"✅ 模型成功下载到: {ckpt_dir}") | |
except Exception as e: | |
print(f"❌ 下载模型时出错: {e}") | |
sys.exit(1) | |
# --- 2. 视频生成函数 --- | |
def generate_video_cli(prompt: str): | |
""" | |
使用命令行设置,根据文本提示生成视频。 | |
""" | |
print("🎬 开始视频生成流程...") | |
# --- 设置 --- | |
print("正在加载模型配置...") | |
repo_id = "Wan-AI/Wan2.2-TI2V-5B" | |
# 确保模型已下载,否则立即下载。 | |
try: | |
# snapshot_download 会检查本地缓存,如果已存在则不会重复下载 | |
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False) | |
except Exception as e: | |
print(f"❌ 无法找到或下载模型。请先运行 `python app.py --downloader`。") | |
print(f"错误详情: {e}") | |
sys.exit(1) | |
print(f"使用来自 {ckpt_dir} 的检查点") | |
TASK_NAME = 'ti2v-5B' | |
cfg = WAN_CONFIGS[TASK_NAME] | |
# --- 生成参数 (使用原脚本中的默认值) --- | |
height = 704 | |
width = 1280 | |
duration_seconds = 2.0 | |
sampling_steps = 38 | |
guide_scale = cfg.sample_guide_scale | |
shift = cfg.sample_shift | |
seed = -1 # -1 代表随机种子 | |
image = None # 当前命令行版本不处理图像输入 | |
# --- 处理 --- | |
if seed == -1: | |
seed = random.randint(0, sys.maxsize) | |
print(f"使用随机种子: {seed}") | |
# 确保尺寸有效 | |
MOD_VALUE = 32 | |
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) | |
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) | |
# 计算帧数 | |
FIXED_FPS = 24 | |
MIN_FRAMES_MODEL = 8 | |
MAX_FRAMES_MODEL = 121 | |
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
print(f"正在生成 {num_frames} 帧 ({duration_seconds}秒 @ {FIXED_FPS}fps),分辨率为 {target_w}x{target_h}。") | |
# --- 初始化 Pipeline --- | |
print("正在初始化 WanTI2V pipeline... (可能需要一些时间)") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
device_id = 0 if torch.cuda.is_available() else -1 | |
if device == "cpu": | |
print("⚠️ 警告: 未检测到 GPU。在 CPU 上运行会非常慢。") | |
try: | |
pipeline = wan.WanTI2V( | |
config=cfg, | |
checkpoint_dir=ckpt_dir, | |
device_id=device_id, | |
rank=0, | |
t5_fsdp=False, | |
dit_fsdp=False, | |
use_sp=False, | |
t5_cpu=False, | |
init_on_cpu=False, | |
convert_model_dtype=True, | |
) | |
print("Pipeline 初始化完成。") | |
except Exception as e: | |
print(f"❌ 初始化 pipeline 失败: {e}") | |
sys.exit(1) | |
# --- 生成视频 --- | |
print(f"正在为提示词生成视频: '{prompt}'") | |
size_str = f"{target_h}*{target_w}" | |
video_tensor = pipeline.generate( | |
input_prompt=prompt, | |
img=image, | |
size=SIZE_CONFIGS.get(size_str, (target_h, target_w)), | |
max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w), | |
frame_num=num_frames, | |
shift=shift, | |
sample_solver='unipc', | |
sampling_steps=int(sampling_steps), | |
guide_scale=guide_scale, | |
seed=seed, | |
offload_model=True | |
) | |
# --- 保存视频 --- | |
print("正在保存视频...") | |
# 根据提示词生成一个安全的文件名 | |
safe_prompt = "".join([c for c in prompt if c.isalnum() or c==' ']).rstrip() | |
safe_prompt = safe_prompt.replace(" ", "_") | |
output_filename = f"{safe_prompt[:50]}_{seed}.mp4" | |
output_path = os.path.join(os.getcwd(), output_filename) #保存在当前工作目录 | |
video_path = cache_video( | |
tensor=video_tensor[None], | |
save_file=output_path, # 指定保存路径 | |
fps=cfg.sample_fps, | |
normalize=True, | |
value_range=(-1, 1) | |
) | |
# --- 清理 --- | |
del pipeline | |
del video_tensor | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
print(f"✅ 视频生成完成!已保存至: {video_path}") | |
# --- 3. 主执行模块 --- | |
def main(): | |
""" | |
解析命令行参数并运行相应的功能。 | |
""" | |
parser = argparse.ArgumentParser( | |
description="Wan 2.2 TI2V-5B 命令行工具。用于从文本生成视频或下载模型。", | |
formatter_class=argparse.RawTextHelpFormatter | |
) | |
parser.add_argument( | |
'--prompt', | |
nargs='+', | |
type=str, | |
help="用于视频生成的文本提示词。\n示例: --prompt A beautiful waterfall" | |
) | |
parser.add_argument( | |
'--downloader', | |
action='store_true', | |
help="如果指定此参数,将只下载所需的模型然后退出。" | |
) | |
args = parser.parse_args() | |
if args.downloader: | |
download_models() | |
elif args.prompt: | |
# 将单词列表合并成一个完整的提示词字符串 | |
# 这能正确处理 'prompt text' 和 "prompt text" 以及 prompt text | |
prompt_text = " ".join(args.prompt) | |
generate_video_cli(prompt_text) | |
else: | |
print("未指定操作。请输入 --prompt 或使用 --downloader 标志。") | |
parser.print_help() | |
if __name__ == "__main__": | |
main() |