| |
| import argparse |
| import os.path as osp |
| import os |
| import sys |
| import warnings |
|
|
| import gradio as gr |
|
|
| warnings.filterwarnings('ignore') |
|
|
| |
| sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) |
| import wan |
| from wan.configs import WAN_CONFIGS |
| from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander |
| from wan.utils.utils import cache_video |
|
|
| |
| prompt_expander = None |
| wan_t2v = None |
|
|
|
|
| |
| def prompt_enc(prompt, tar_lang): |
| global prompt_expander |
| prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower()) |
| if prompt_output.status == False: |
| return prompt |
| else: |
| return prompt_output.prompt |
|
|
|
|
| def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale, |
| shift_scale, seed, n_prompt): |
| global wan_t2v |
| |
|
|
| W = int(resolution.split("*")[0]) |
| H = int(resolution.split("*")[1]) |
| video = wan_t2v.generate( |
| txt2vid_prompt, |
| size=(W, H), |
| shift=shift_scale, |
| sampling_steps=sd_steps, |
| guide_scale=guide_scale, |
| n_prompt=n_prompt, |
| seed=seed, |
| offload_model=True) |
|
|
| cache_video( |
| tensor=video[None], |
| save_file="example.mp4", |
| fps=16, |
| nrow=1, |
| normalize=True, |
| value_range=(-1, 1)) |
|
|
| return "example.mp4" |
|
|
|
|
| |
| def gradio_interface(): |
| with gr.Blocks() as demo: |
| gr.Markdown(""" |
| <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> |
| Wan2.1 (T2V-14B) |
| </div> |
| <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;"> |
| Wan: Open and Advanced Large-Scale Video Generative Models. |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| txt2vid_prompt = gr.Textbox( |
| label="Prompt", |
| placeholder="Describe the video you want to generate", |
| ) |
| tar_lang = gr.Radio( |
| choices=["CH", "EN"], |
| label="Target language of prompt enhance", |
| value="CH") |
| run_p_button = gr.Button(value="Prompt Enhance") |
|
|
| with gr.Accordion("Advanced Options", open=True): |
| resolution = gr.Dropdown( |
| label='Resolution(Width*Height)', |
| choices=[ |
| '720*1280', '1280*720', '960*960', '1088*832', |
| '832*1088', '480*832', '832*480', '624*624', |
| '704*544', '544*704' |
| ], |
| value='720*1280') |
|
|
| with gr.Row(): |
| sd_steps = gr.Slider( |
| label="Diffusion steps", |
| minimum=1, |
| maximum=1000, |
| value=50, |
| step=1) |
| guide_scale = gr.Slider( |
| label="Guide scale", |
| minimum=0, |
| maximum=20, |
| value=5.0, |
| step=1) |
| with gr.Row(): |
| shift_scale = gr.Slider( |
| label="Shift scale", |
| minimum=0, |
| maximum=10, |
| value=5.0, |
| step=1) |
| seed = gr.Slider( |
| label="Seed", |
| minimum=-1, |
| maximum=2147483647, |
| step=1, |
| value=-1) |
| n_prompt = gr.Textbox( |
| label="Negative Prompt", |
| placeholder="Describe the negative prompt you want to add" |
| ) |
|
|
| run_t2v_button = gr.Button("Generate Video") |
|
|
| with gr.Column(): |
| result_gallery = gr.Video( |
| label='Generated Video', interactive=False, height=600) |
|
|
| run_p_button.click( |
| fn=prompt_enc, |
| inputs=[txt2vid_prompt, tar_lang], |
| outputs=[txt2vid_prompt]) |
|
|
| run_t2v_button.click( |
| fn=t2v_generation, |
| inputs=[ |
| txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale, |
| seed, n_prompt |
| ], |
| outputs=[result_gallery], |
| ) |
|
|
| return demo |
|
|
|
|
| |
| def _parse_args(): |
| parser = argparse.ArgumentParser( |
| description="Generate a video from a text prompt or image using Gradio") |
| parser.add_argument( |
| "--ckpt_dir", |
| type=str, |
| default="cache", |
| help="The path to the checkpoint directory.") |
| parser.add_argument( |
| "--prompt_extend_method", |
| type=str, |
| default="local_qwen", |
| choices=["dashscope", "local_qwen"], |
| help="The prompt extend method to use.") |
| parser.add_argument( |
| "--prompt_extend_model", |
| type=str, |
| default=None, |
| help="The prompt extend model to use.") |
|
|
| args = parser.parse_args() |
|
|
| return args |
|
|
|
|
| if __name__ == '__main__': |
| args = _parse_args() |
|
|
| print("Step1: Init prompt_expander...", end='', flush=True) |
| if args.prompt_extend_method == "dashscope": |
| prompt_expander = DashScopePromptExpander( |
| model_name=args.prompt_extend_model, is_vl=False) |
| elif args.prompt_extend_method == "local_qwen": |
| prompt_expander = QwenPromptExpander( |
| model_name=args.prompt_extend_model, is_vl=False, device=0) |
| else: |
| raise NotImplementedError( |
| f"Unsupport prompt_extend_method: {args.prompt_extend_method}") |
| print("done", flush=True) |
|
|
| print("Step2: Init 14B t2v model...", end='', flush=True) |
| cfg = WAN_CONFIGS['t2v-14B'] |
| wan_t2v = wan.WanT2V( |
| config=cfg, |
| checkpoint_dir=args.ckpt_dir, |
| device_id=0, |
| rank=0, |
| t5_fsdp=False, |
| dit_fsdp=False, |
| use_usp=False, |
| ) |
| print("done", flush=True) |
|
|
| demo = gradio_interface() |
| demo.launch(server_name="0.0.0.0", share=False, server_port=7860) |
|
|