wan-api / app.py
eggman-poff's picture
Update app.py
783eee8 verified
import gradio as gr
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
import tempfile
# --- Load Model ---
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
)
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe.to(device)
pipe.enable_model_cpu_offload()
# --- Helper Functions ---
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
def center_crop_resize(image, height, width):
import torchvision.transforms.functional as TF
resize_ratio = max(width / image.width, height / image.height)
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)
return image, height, width
# --- Gradio Inference Function ---
def infer(first_image, last_image, prompt, guidance=5.5, frames=25):
# Convert to PIL
if not isinstance(first_image, Image.Image):
first_image = Image.fromarray(first_image)
if not isinstance(last_image, Image.Image):
last_image = Image.fromarray(last_image)
# Resize/crop as needed
first_image, height, width = aspect_ratio_resize(first_image, pipe)
if last_image.size != first_image.size:
last_image, _, _ = center_crop_resize(last_image, height, width)
# Run pipeline
output = pipe(
image=[first_image, last_image],
prompt=prompt,
height=height,
width=width,
guidance_scale=guidance,
num_frames=frames,
).frames
# Export to video
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
export_to_video(output, tmp.name, fps=16)
return tmp.name
# --- Gradio Interface ---
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(type="pil", label="Start Frame"),
gr.Image(type="pil", label="End Frame"),
gr.Textbox(placeholder="Prompt (optional)", label="Prompt"),
gr.Slider(3, 12, value=5.5, step=0.1, label="Guidance Scale"),
gr.Slider(8, 48, value=25, step=1, label="Num Frames"),
],
outputs=gr.Video(label="Generated Video"),
title="WAN Two-Frame Video Interpolation",
description="Upload two images and (optionally) a prompt to create a smooth video transition."
)
if __name__ == "__main__":
demo.launch(show_api=True, delete_cache=(60, 60))