Tg / app.py
Athspi's picture
Update app.py
8b742a3 verified
raw
history blame
3.63 kB
import gradio as gr
import torch
from diffusers.utils import export_to_video, load_image
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from transformers import CLIPVisionModel
import numpy as np
import os
# Install necessary libraries (using a more robust approach)
try:
import diffusers
print("diffusers is already installed.")
except ImportError:
print("Installing diffusers...")
os.system("pip install git+https://github.com/huggingface/diffusers.git transformers accelerate") # install required packages
import diffusers # try importing again after installation.
# Download necessary model (check and load)
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
lora_weights = "Remade/Squish"
def load_models():
try:
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe.load_lora_weights(lora_weights)
pipe.enable_model_cpu_offload() # For low-VRAM
return pipe
except Exception as e:
print(f"Error loading models: {e}")
return None
pipe = load_models() # Load models outside the function, so they are loaded only once
def generate_video(image_url, prompt, num_frames, guidance_scale, num_inference_steps, progress=gr.Progress()):
if pipe is None:
return "Error: Model failed to load. Check server logs for details.", None
if not image_url or not prompt:
return "Error: Please provide both an image URL and a prompt.", None
try:
image = load_image(image_url)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
output = pipe(
image=image,
prompt=prompt,
height=height,
width=width,
num_frames=int(num_frames),
guidance_scale=guidance_scale,
num_inference_steps=int(num_inference_steps)
).frames[0]
export_to_video(output, "output.mp4", fps=16) # save locally first
return "output.mp4", "output.mp4" # Return both file path and Gradio's video component path
except Exception as e:
return f"An error occurred: {e}", None
# Gradio Interface
iface = gr.Interface(
fn=generate_video,
inputs=[
gr.Image(type="filepath", label="Input Image URL (or upload)"), # allow local files
gr.Textbox(label="Prompt"),
gr.Slider(minimum=10, maximum=100, step=1, value=81, label="Number of Frames"),
gr.Slider(minimum=1, maximum=10, step=0.1, value=5.0, label="Guidance Scale"),
gr.Slider(minimum=10, maximum=50, step=1, value=28, label="Inference Steps"),
],
outputs=[
gr.Textbox(label="Status/Error Message"),
gr.Video(label="Generated Video"), # Display the generated video
],
title="Wan Image-to-Video Generator",
description="Generate videos from an image and a text prompt using the Wan Image-to-Video model.",
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860) # make accessible on the network