| import os |
| |
| os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces') |
|
|
| |
| import spaces |
| import torch |
| from diffusers import FlowMatchEulerDiscreteScheduler |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline |
| from diffusers.models.transformers.transformer_wan import WanTransformer3DModel |
| from diffusers.utils.export_utils import export_to_video |
| import gradio as gr |
| import tempfile |
| import numpy as np |
| from PIL import Image |
| import random |
| import gc |
|
|
| |
| from optimization import optimize_pipeline_ |
|
|
| |
| MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" |
|
|
| |
| MAX_DIMENSION = 720 |
| MIN_DIMENSION = 480 |
| DIMENSION_MULTIPLE = 16 |
| SQUARE_SIZE = 480 |
|
|
| MAX_SEED = np.iinfo(np.int32).max |
|
|
| FIXED_FPS = 16 |
| MIN_FRAMES_MODEL = 8 |
| MAX_FRAMES_MODEL = 81 |
|
|
| MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1) |
| MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1) |
|
|
| default_negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝," |
|
|
| print("Loading models into memory. This may take a few minutes...") |
|
|
| pipe = WanImageToVideoPipeline.from_pretrained( |
| MODEL_ID, |
| transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', |
| subfolder='transformer', |
| torch_dtype=torch.bfloat16, |
| device_map='cuda', |
| ), |
| transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', |
| subfolder='transformer_2', |
| torch_dtype=torch.bfloat16, |
| device_map='cuda', |
| ), |
| torch_dtype=torch.bfloat16, |
| ) |
| pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, shift=8.0) |
| pipe.to('cuda') |
|
|
|
|
|
|
| print("Optimizing pipeline...") |
| for i in range(3): |
| gc.collect() |
| torch.cuda.synchronize() |
| torch.cuda.empty_cache() |
|
|
| |
| optimize_pipeline_(pipe, |
| image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)), |
| prompt='prompt', |
| height=MIN_DIMENSION, |
| width=MAX_DIMENSION, |
| num_frames=MAX_FRAMES_MODEL, |
| ) |
| print("All models loaded and optimized. Gradio app is ready.") |
|
|
|
|
| |
|
|
| def process_image_for_video(image: Image.Image) -> Image.Image: |
| """ |
| Resizes an image based on the following rules for video generation: |
| 1. The longest side will be scaled down to MAX_DIMENSION if it's larger. |
| 2. The shortest side will be scaled up to MIN_DIMENSION if it's smaller. |
| 3. The final dimensions will be rounded to the nearest multiple of DIMENSION_MULTIPLE. |
| 4. Square images are resized to a fixed SQUARE_SIZE. |
| The aspect ratio is preserved as closely as possible. |
| """ |
| width, height = image.size |
|
|
| |
| if width == height: |
| return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS) |
|
|
| |
| aspect_ratio = width / height |
| new_width, new_height = width, height |
|
|
| |
| if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION: |
| if aspect_ratio > 1: |
| scale = MAX_DIMENSION / new_width |
| else: |
| scale = MAX_DIMENSION / new_height |
| new_width *= scale |
| new_height *= scale |
|
|
| |
| if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION: |
| if aspect_ratio > 1: |
| scale = MIN_DIMENSION / new_height |
| else: |
| scale = MIN_DIMENSION / new_width |
| new_width *= scale |
| new_height *= scale |
|
|
| |
| final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) |
| final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) |
| |
| |
| final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE) |
| final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE) |
|
|
|
|
| return image.resize((final_width, final_height), Image.Resampling.LANCZOS) |
|
|
| def resize_and_crop_to_match(target_image, reference_image): |
| """Resizes and center-crops the target image to match the reference image's dimensions.""" |
| ref_width, ref_height = reference_image.size |
| target_width, target_height = target_image.size |
| scale = max(ref_width / target_width, ref_height / target_height) |
| new_width, new_height = int(target_width * scale), int(target_height * scale) |
| resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
| left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2 |
| return resized.crop((left, top, left + ref_width, top + ref_height)) |
|
|
| @spaces.GPU(duration=75) |
| def generate_video( |
| start_image_pil, |
| end_image_pil, |
| prompt, |
| negative_prompt=default_negative_prompt, |
| duration_seconds=2.5, |
| steps=5, |
| guidance_scale=1, |
| guidance_scale_2=1, |
| seed=42, |
| randomize_seed=False, |
| progress=gr.Progress(track_tqdm=True) |
| ): |
| """ |
| Generates a video by interpolating between a start and end image, guided by a text prompt, |
| using the diffusers Wan2.2 pipeline. |
| """ |
| if start_image_pil is None or end_image_pil is None: |
| raise gr.Error("Please upload both a start and an end image.") |
|
|
| progress(0.1, desc="Preprocessing images...") |
|
|
| |
| processed_start_image = process_image_for_video(start_image_pil) |
| |
| |
| processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image) |
| |
| target_height, target_width = processed_start_image.height, processed_start_image.width |
|
|
| |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) |
| num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) |
|
|
| progress(0.2, desc=f"Generating {num_frames} frames at {target_width}x{target_height} (seed: {current_seed})...") |
|
|
| output_frames_list = pipe( |
| image=processed_start_image, |
| last_image=processed_end_image, |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| height=target_height, |
| width=target_width, |
| num_frames=num_frames, |
| guidance_scale=float(guidance_scale), |
| guidance_scale_2=float(guidance_scale_2), |
| num_inference_steps=int(steps), |
| generator=torch.Generator(device="cuda").manual_seed(current_seed), |
| ).frames[0] |
|
|
| progress(0.9, desc="Encoding and saving video...") |
|
|
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: |
| video_path = tmpfile.name |
|
|
| export_to_video(output_frames_list, video_path, fps=FIXED_FPS) |
|
|
| progress(1.0, desc="Done!") |
| return video_path, current_seed |
|
|
|
|
| |
|
|
| css = ''' |
| .fillable{max-width: 1100px !important} |
| .dark .progress-text {color: white} |
| ''' |
| with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: |
| gr.Markdown("# Wan 2.2 First/Last Frame Video Fast") |
| gr.Markdown("Based on the [Wan 2.2 First/Last Frame workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/), applied to 🧨 Diffusers + [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| with gr.Group(): |
| with gr.Row(): |
| start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"]) |
| end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"]) |
|
|
| prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images") |
|
|
| with gr.Accordion("Advanced Settings", open=False): |
| duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=4, label="Video Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.") |
| negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3) |
| steps_slider = gr.Slider(minimum=4, maximum=30, step=1, value=8, label="Inference Steps") |
| guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - high noise") |
| guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - low noise") |
| with gr.Row(): |
| seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) |
| randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True) |
|
|
| generate_button = gr.Button("Generate Video", variant="primary") |
|
|
| with gr.Column(): |
| output_video = gr.Video(label="Generated Video", autoplay=True) |
|
|
| |
| ui_inputs = [ |
| start_image, |
| end_image, |
| prompt, |
| negative_prompt_input, |
| duration_seconds_input, |
| steps_slider, |
| guidance_scale_input, |
| guidance_scale_2_input, |
| seed_input, |
| randomize_seed_checkbox |
| ] |
| |
| ui_outputs = [output_video, seed_input] |
|
|
| generate_button.click( |
| fn=generate_video, |
| inputs=ui_inputs, |
| outputs=ui_outputs |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| app.launch(share=True) |