Spaces:
Runtime error
Runtime error
File size: 6,038 Bytes
309fd4d b9186cf 2a7f487 ba7cb71 908b63e 2a7f487 8ff4968 2a7f487 3113790 2a7f487 8ff4968 2a7f487 535c73d 2a7f487 535c73d 2a7f487 8ff4968 2a7f487 8ff4968 2a7f487 3113790 2a7f487 ba7cb71 2a7f487 535c73d 2a7f487 535c73d 2a7f487 ba7cb71 2a7f487 ba7cb71 2a7f487 8ff4968 2a7f487 ba7cb71 2a7f487 ba7cb71 2a7f487 8ff4968 2a7f487 ba7cb71 2a7f487 8ff4968 2a7f487 ba7cb71 2a7f487 ba7cb71 2a7f487 ba7cb71 2a7f487 908b63e 2a7f487 ba7cb71 2a7f487 ba7cb71 2a7f487 ba7cb71 2a7f487 3113790 2a7f487 3113790 2a7f487 ba7cb71 2a7f487 ba7cb71 2a7f487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import os
import sys
import argparse
import random
import numpy as np
import torch
from huggingface_hub import snapshot_download
from PIL import Image
import gc
# 将当前文件所在目录添加到 Python 路径中
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# 从 'wan' 库中导入所需模块
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS
from wan.utils.utils import cache_video
# --- 1. 模型下载器 ---
def download_models():
"""
从 Hugging Face Hub 下载并缓存所需的模型。
"""
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
print(f"正在为 {repo_id} 下载模型检查点...")
try:
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
print(f"✅ 模型成功下载到: {ckpt_dir}")
except Exception as e:
print(f"❌ 下载模型时出错: {e}")
sys.exit(1)
# --- 2. 视频生成函数 ---
def generate_video_cli(prompt: str):
"""
使用命令行设置,根据文本提示生成视频。
"""
print("🎬 开始视频生成流程...")
# --- 设置 ---
print("正在加载模型配置...")
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
# 确保模型已下载,否则立即下载。
try:
# snapshot_download 会检查本地缓存,如果已存在则不会重复下载
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
except Exception as e:
print(f"❌ 无法找到或下载模型。请先运行 `python app.py --downloader`。")
print(f"错误详情: {e}")
sys.exit(1)
print(f"使用来自 {ckpt_dir} 的检查点")
TASK_NAME = 'ti2v-5B'
cfg = WAN_CONFIGS[TASK_NAME]
# --- 生成参数 (使用原脚本中的默认值) ---
height = 704
width = 1280
duration_seconds = 2.0
sampling_steps = 38
guide_scale = cfg.sample_guide_scale
shift = cfg.sample_shift
seed = -1 # -1 代表随机种子
image = None # 当前命令行版本不处理图像输入
# --- 处理 ---
if seed == -1:
seed = random.randint(0, sys.maxsize)
print(f"使用随机种子: {seed}")
# 确保尺寸有效
MOD_VALUE = 32
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
# 计算帧数
FIXED_FPS = 24
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 121
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
print(f"正在生成 {num_frames} 帧 ({duration_seconds}秒 @ {FIXED_FPS}fps),分辨率为 {target_w}x{target_h}。")
# --- 初始化 Pipeline ---
print("正在初始化 WanTI2V pipeline... (可能需要一些时间)")
device = "cuda" if torch.cuda.is_available() else "cpu"
device_id = 0 if torch.cuda.is_available() else -1
if device == "cpu":
print("⚠️ 警告: 未检测到 GPU。在 CPU 上运行会非常慢。")
try:
pipeline = wan.WanTI2V(
config=cfg,
checkpoint_dir=ckpt_dir,
device_id=device_id,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=False,
convert_model_dtype=True,
)
print("Pipeline 初始化完成。")
except Exception as e:
print(f"❌ 初始化 pipeline 失败: {e}")
sys.exit(1)
# --- 生成视频 ---
print(f"正在为提示词生成视频: '{prompt}'")
size_str = f"{target_h}*{target_w}"
video_tensor = pipeline.generate(
input_prompt=prompt,
img=image,
size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
frame_num=num_frames,
shift=shift,
sample_solver='unipc',
sampling_steps=int(sampling_steps),
guide_scale=guide_scale,
seed=seed,
offload_model=True
)
# --- 保存视频 ---
print("正在保存视频...")
# 根据提示词生成一个安全的文件名
safe_prompt = "".join([c for c in prompt if c.isalnum() or c==' ']).rstrip()
safe_prompt = safe_prompt.replace(" ", "_")
output_filename = f"{safe_prompt[:50]}_{seed}.mp4"
output_path = os.path.join(os.getcwd(), output_filename) #保存在当前工作目录
video_path = cache_video(
tensor=video_tensor[None],
save_file=output_path, # 指定保存路径
fps=cfg.sample_fps,
normalize=True,
value_range=(-1, 1)
)
# --- 清理 ---
del pipeline
del video_tensor
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"✅ 视频生成完成!已保存至: {video_path}")
# --- 3. 主执行模块 ---
def main():
"""
解析命令行参数并运行相应的功能。
"""
parser = argparse.ArgumentParser(
description="Wan 2.2 TI2V-5B 命令行工具。用于从文本生成视频或下载模型。",
formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
'--prompt',
nargs='+',
type=str,
help="用于视频生成的文本提示词。\n示例: --prompt A beautiful waterfall"
)
parser.add_argument(
'--downloader',
action='store_true',
help="如果指定此参数,将只下载所需的模型然后退出。"
)
args = parser.parse_args()
if args.downloader:
download_models()
elif args.prompt:
# 将单词列表合并成一个完整的提示词字符串
# 这能正确处理 'prompt text' 和 "prompt text" 以及 prompt text
prompt_text = " ".join(args.prompt)
generate_video_cli(prompt_text)
else:
print("未指定操作。请输入 --prompt 或使用 --downloader 标志。")
parser.print_help()
if __name__ == "__main__":
main() |