| 
							 | 
						 | 
					
					
						
						| 
							 | 
						import argparse | 
					
					
						
						| 
							 | 
						import gc | 
					
					
						
						| 
							 | 
						import os.path as osp | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						import sys | 
					
					
						
						| 
							 | 
						import warnings | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import gradio as gr | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						warnings.filterwarnings('ignore') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						sys.path.insert(0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) | 
					
					
						
						| 
							 | 
						import wan | 
					
					
						
						| 
							 | 
						from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS | 
					
					
						
						| 
							 | 
						from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander | 
					
					
						
						| 
							 | 
						from wan.utils.utils import cache_video | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						prompt_expander = None | 
					
					
						
						| 
							 | 
						wan_i2v_480P = None | 
					
					
						
						| 
							 | 
						wan_i2v_720P = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def load_model(value): | 
					
					
						
						| 
							 | 
						    global wan_i2v_480P, wan_i2v_720P | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if value == '------': | 
					
					
						
						| 
							 | 
						        print("No model loaded") | 
					
					
						
						| 
							 | 
						        return '------' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if value == '720P': | 
					
					
						
						| 
							 | 
						        if args.ckpt_dir_720p is None: | 
					
					
						
						| 
							 | 
						            print("Please specify the checkpoint directory for 720P model") | 
					
					
						
						| 
							 | 
						            return '------' | 
					
					
						
						| 
							 | 
						        if wan_i2v_720P is not None: | 
					
					
						
						| 
							 | 
						            pass | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            del wan_i2v_480P | 
					
					
						
						| 
							 | 
						            gc.collect() | 
					
					
						
						| 
							 | 
						            wan_i2v_480P = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            print("load 14B-720P i2v model...", end='', flush=True) | 
					
					
						
						| 
							 | 
						            cfg = WAN_CONFIGS['i2v-14B'] | 
					
					
						
						| 
							 | 
						            wan_i2v_720P = wan.WanI2V( | 
					
					
						
						| 
							 | 
						                config=cfg, | 
					
					
						
						| 
							 | 
						                checkpoint_dir=args.ckpt_dir_720p, | 
					
					
						
						| 
							 | 
						                device_id=0, | 
					
					
						
						| 
							 | 
						                rank=0, | 
					
					
						
						| 
							 | 
						                t5_fsdp=False, | 
					
					
						
						| 
							 | 
						                dit_fsdp=False, | 
					
					
						
						| 
							 | 
						                use_usp=False, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            print("done", flush=True) | 
					
					
						
						| 
							 | 
						            return '720P' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if value == '480P': | 
					
					
						
						| 
							 | 
						        if args.ckpt_dir_480p is None: | 
					
					
						
						| 
							 | 
						            print("Please specify the checkpoint directory for 480P model") | 
					
					
						
						| 
							 | 
						            return '------' | 
					
					
						
						| 
							 | 
						        if wan_i2v_480P is not None: | 
					
					
						
						| 
							 | 
						            pass | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            del wan_i2v_720P | 
					
					
						
						| 
							 | 
						            gc.collect() | 
					
					
						
						| 
							 | 
						            wan_i2v_720P = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            print("load 14B-480P i2v model...", end='', flush=True) | 
					
					
						
						| 
							 | 
						            cfg = WAN_CONFIGS['i2v-14B'] | 
					
					
						
						| 
							 | 
						            wan_i2v_480P = wan.WanI2V( | 
					
					
						
						| 
							 | 
						                config=cfg, | 
					
					
						
						| 
							 | 
						                checkpoint_dir=args.ckpt_dir_480p, | 
					
					
						
						| 
							 | 
						                device_id=0, | 
					
					
						
						| 
							 | 
						                rank=0, | 
					
					
						
						| 
							 | 
						                t5_fsdp=False, | 
					
					
						
						| 
							 | 
						                dit_fsdp=False, | 
					
					
						
						| 
							 | 
						                use_usp=False, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            print("done", flush=True) | 
					
					
						
						| 
							 | 
						            return '480P' | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def prompt_enc(prompt, img, tar_lang): | 
					
					
						
						| 
							 | 
						    print('prompt extend...') | 
					
					
						
						| 
							 | 
						    if img is None: | 
					
					
						
						| 
							 | 
						        print('Please upload an image') | 
					
					
						
						| 
							 | 
						        return prompt | 
					
					
						
						| 
							 | 
						    global prompt_expander | 
					
					
						
						| 
							 | 
						    prompt_output = prompt_expander( | 
					
					
						
						| 
							 | 
						        prompt, image=img, tar_lang=tar_lang.lower()) | 
					
					
						
						| 
							 | 
						    if prompt_output.status == False: | 
					
					
						
						| 
							 | 
						        return prompt | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        return prompt_output.prompt | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps, | 
					
					
						
						| 
							 | 
						                   guide_scale, shift_scale, seed, n_prompt): | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    if resolution == '------': | 
					
					
						
						| 
							 | 
						        print( | 
					
					
						
						| 
							 | 
						            'Please specify at least one resolution ckpt dir or specify the resolution' | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        return None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        if resolution == '720P': | 
					
					
						
						| 
							 | 
						            global wan_i2v_720P | 
					
					
						
						| 
							 | 
						            video = wan_i2v_720P.generate( | 
					
					
						
						| 
							 | 
						                img2vid_prompt, | 
					
					
						
						| 
							 | 
						                img2vid_image, | 
					
					
						
						| 
							 | 
						                max_area=MAX_AREA_CONFIGS['720*1280'], | 
					
					
						
						| 
							 | 
						                shift=shift_scale, | 
					
					
						
						| 
							 | 
						                sampling_steps=sd_steps, | 
					
					
						
						| 
							 | 
						                guide_scale=guide_scale, | 
					
					
						
						| 
							 | 
						                n_prompt=n_prompt, | 
					
					
						
						| 
							 | 
						                seed=seed, | 
					
					
						
						| 
							 | 
						                offload_model=True) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            global wan_i2v_480P | 
					
					
						
						| 
							 | 
						            video = wan_i2v_480P.generate( | 
					
					
						
						| 
							 | 
						                img2vid_prompt, | 
					
					
						
						| 
							 | 
						                img2vid_image, | 
					
					
						
						| 
							 | 
						                max_area=MAX_AREA_CONFIGS['480*832'], | 
					
					
						
						| 
							 | 
						                shift=shift_scale, | 
					
					
						
						| 
							 | 
						                sampling_steps=sd_steps, | 
					
					
						
						| 
							 | 
						                guide_scale=guide_scale, | 
					
					
						
						| 
							 | 
						                n_prompt=n_prompt, | 
					
					
						
						| 
							 | 
						                seed=seed, | 
					
					
						
						| 
							 | 
						                offload_model=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        cache_video( | 
					
					
						
						| 
							 | 
						            tensor=video[None], | 
					
					
						
						| 
							 | 
						            save_file="example.mp4", | 
					
					
						
						| 
							 | 
						            fps=16, | 
					
					
						
						| 
							 | 
						            nrow=1, | 
					
					
						
						| 
							 | 
						            normalize=True, | 
					
					
						
						| 
							 | 
						            value_range=(-1, 1)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return "example.mp4" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def gradio_interface(): | 
					
					
						
						| 
							 | 
						    with gr.Blocks() as demo: | 
					
					
						
						| 
							 | 
						        gr.Markdown(""" | 
					
					
						
						| 
							 | 
						                    <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> | 
					
					
						
						| 
							 | 
						                        Wan2.1 (I2V-14B) | 
					
					
						
						| 
							 | 
						                    </div> | 
					
					
						
						| 
							 | 
						                    <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;"> | 
					
					
						
						| 
							 | 
						                        Wan: Open and Advanced Large-Scale Video Generative Models. | 
					
					
						
						| 
							 | 
						                    </div> | 
					
					
						
						| 
							 | 
						                    """) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        with gr.Row(): | 
					
					
						
						| 
							 | 
						            with gr.Column(): | 
					
					
						
						| 
							 | 
						                resolution = gr.Dropdown( | 
					
					
						
						| 
							 | 
						                    label='Resolution', | 
					
					
						
						| 
							 | 
						                    choices=['------', '720P', '480P'], | 
					
					
						
						| 
							 | 
						                    value='------') | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                img2vid_image = gr.Image( | 
					
					
						
						| 
							 | 
						                    type="pil", | 
					
					
						
						| 
							 | 
						                    label="Upload Input Image", | 
					
					
						
						| 
							 | 
						                    elem_id="image_upload", | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						                img2vid_prompt = gr.Textbox( | 
					
					
						
						| 
							 | 
						                    label="Prompt", | 
					
					
						
						| 
							 | 
						                    placeholder="Describe the video you want to generate", | 
					
					
						
						| 
							 | 
						                ) | 
					
					
						
						| 
							 | 
						                tar_lang = gr.Radio( | 
					
					
						
						| 
							 | 
						                    choices=["CH", "EN"], | 
					
					
						
						| 
							 | 
						                    label="Target language of prompt enhance", | 
					
					
						
						| 
							 | 
						                    value="CH") | 
					
					
						
						| 
							 | 
						                run_p_button = gr.Button(value="Prompt Enhance") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                with gr.Accordion("Advanced Options", open=True): | 
					
					
						
						| 
							 | 
						                    with gr.Row(): | 
					
					
						
						| 
							 | 
						                        sd_steps = gr.Slider( | 
					
					
						
						| 
							 | 
						                            label="Diffusion steps", | 
					
					
						
						| 
							 | 
						                            minimum=1, | 
					
					
						
						| 
							 | 
						                            maximum=1000, | 
					
					
						
						| 
							 | 
						                            value=50, | 
					
					
						
						| 
							 | 
						                            step=1) | 
					
					
						
						| 
							 | 
						                        guide_scale = gr.Slider( | 
					
					
						
						| 
							 | 
						                            label="Guide scale", | 
					
					
						
						| 
							 | 
						                            minimum=0, | 
					
					
						
						| 
							 | 
						                            maximum=20, | 
					
					
						
						| 
							 | 
						                            value=5.0, | 
					
					
						
						| 
							 | 
						                            step=1) | 
					
					
						
						| 
							 | 
						                    with gr.Row(): | 
					
					
						
						| 
							 | 
						                        shift_scale = gr.Slider( | 
					
					
						
						| 
							 | 
						                            label="Shift scale", | 
					
					
						
						| 
							 | 
						                            minimum=0, | 
					
					
						
						| 
							 | 
						                            maximum=10, | 
					
					
						
						| 
							 | 
						                            value=5.0, | 
					
					
						
						| 
							 | 
						                            step=1) | 
					
					
						
						| 
							 | 
						                        seed = gr.Slider( | 
					
					
						
						| 
							 | 
						                            label="Seed", | 
					
					
						
						| 
							 | 
						                            minimum=-1, | 
					
					
						
						| 
							 | 
						                            maximum=2147483647, | 
					
					
						
						| 
							 | 
						                            step=1, | 
					
					
						
						| 
							 | 
						                            value=-1) | 
					
					
						
						| 
							 | 
						                    n_prompt = gr.Textbox( | 
					
					
						
						| 
							 | 
						                        label="Negative Prompt", | 
					
					
						
						| 
							 | 
						                        placeholder="Describe the negative prompt you want to add" | 
					
					
						
						| 
							 | 
						                    ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                run_i2v_button = gr.Button("Generate Video") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            with gr.Column(): | 
					
					
						
						| 
							 | 
						                result_gallery = gr.Video( | 
					
					
						
						| 
							 | 
						                    label='Generated Video', interactive=False, height=600) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        resolution.input( | 
					
					
						
						| 
							 | 
						            fn=load_model, inputs=[resolution], outputs=[resolution]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        run_p_button.click( | 
					
					
						
						| 
							 | 
						            fn=prompt_enc, | 
					
					
						
						| 
							 | 
						            inputs=[img2vid_prompt, img2vid_image, tar_lang], | 
					
					
						
						| 
							 | 
						            outputs=[img2vid_prompt]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        run_i2v_button.click( | 
					
					
						
						| 
							 | 
						            fn=i2v_generation, | 
					
					
						
						| 
							 | 
						            inputs=[ | 
					
					
						
						| 
							 | 
						                img2vid_prompt, img2vid_image, resolution, sd_steps, | 
					
					
						
						| 
							 | 
						                guide_scale, shift_scale, seed, n_prompt | 
					
					
						
						| 
							 | 
						            ], | 
					
					
						
						| 
							 | 
						            outputs=[result_gallery], | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return demo | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						def _parse_args(): | 
					
					
						
						| 
							 | 
						    parser = argparse.ArgumentParser( | 
					
					
						
						| 
							 | 
						        description="Generate a video from a text prompt or image using Gradio") | 
					
					
						
						| 
							 | 
						    parser.add_argument( | 
					
					
						
						| 
							 | 
						        "--ckpt_dir_720p", | 
					
					
						
						| 
							 | 
						        type=str, | 
					
					
						
						| 
							 | 
						        default=None, | 
					
					
						
						| 
							 | 
						        help="The path to the checkpoint directory.") | 
					
					
						
						| 
							 | 
						    parser.add_argument( | 
					
					
						
						| 
							 | 
						        "--ckpt_dir_480p", | 
					
					
						
						| 
							 | 
						        type=str, | 
					
					
						
						| 
							 | 
						        default=None, | 
					
					
						
						| 
							 | 
						        help="The path to the checkpoint directory.") | 
					
					
						
						| 
							 | 
						    parser.add_argument( | 
					
					
						
						| 
							 | 
						        "--prompt_extend_method", | 
					
					
						
						| 
							 | 
						        type=str, | 
					
					
						
						| 
							 | 
						        default="local_qwen", | 
					
					
						
						| 
							 | 
						        choices=["dashscope", "local_qwen"], | 
					
					
						
						| 
							 | 
						        help="The prompt extend method to use.") | 
					
					
						
						| 
							 | 
						    parser.add_argument( | 
					
					
						
						| 
							 | 
						        "--prompt_extend_model", | 
					
					
						
						| 
							 | 
						        type=str, | 
					
					
						
						| 
							 | 
						        default=None, | 
					
					
						
						| 
							 | 
						        help="The prompt extend model to use.") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    args = parser.parse_args() | 
					
					
						
						| 
							 | 
						    assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory." | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    return args | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == '__main__': | 
					
					
						
						| 
							 | 
						    args = _parse_args() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    print("Step1: Init prompt_expander...", end='', flush=True) | 
					
					
						
						| 
							 | 
						    if args.prompt_extend_method == "dashscope": | 
					
					
						
						| 
							 | 
						        prompt_expander = DashScopePromptExpander( | 
					
					
						
						| 
							 | 
						            model_name=args.prompt_extend_model, is_vl=True) | 
					
					
						
						| 
							 | 
						    elif args.prompt_extend_method == "local_qwen": | 
					
					
						
						| 
							 | 
						        prompt_expander = QwenPromptExpander( | 
					
					
						
						| 
							 | 
						            model_name=args.prompt_extend_model, is_vl=True, device=0) | 
					
					
						
						| 
							 | 
						    else: | 
					
					
						
						| 
							 | 
						        raise NotImplementedError( | 
					
					
						
						| 
							 | 
						            f"Unsupport prompt_extend_method: {args.prompt_extend_method}") | 
					
					
						
						| 
							 | 
						    print("done", flush=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    demo = gradio_interface() | 
					
					
						
						| 
							 | 
						    demo.launch(server_name="0.0.0.0", share=False, server_port=7860) | 
					
					
						
						| 
							 | 
						
 |