wan2.1 / app.py
alexanderbaikal
fix indentation
cee1dc4
raw
history blame
2.38 kB
import gradio as gr
import torch
import numpy as np
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
def generate_video(first_frame_url, last_frame_url, prompt):
model_id = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(low_cpu_mem_usage=True, model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16)
pipe.to("cuda")
first_frame = load_image(first_frame_url)
last_frame = load_image(last_frame_url)
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
return image, height, width
def center_crop_resize(image, height, width):
resize_ratio = max(width / image.width, height / image.height)
width = round(image.width * resize_ratio)
height = round(image.height * resize_ratio)
size = [width, height]
image = TF.center_crop(image, size)
return image, height, width
first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
if last_frame.size != first_frame.size:
last_frame, _, _ = center_crop_resize(last_frame, height, width)
output = pipe(
image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.5
).frames[0]
video_path = "wan_output.mp4"
export_to_video(output, video_path, fps=16)
return video_path
iface = gr.Interface(
fn=generate_video,
inputs=[
gr.Textbox(label="First Frame URL"),
gr.Textbox(label="Last Frame URL"),
gr.Textbox(label="Prompt")
],
outputs=gr.Video(label="Generated Video"),
title="Wan2.1 FLF2V Video Generator"
)
iface.launch()