wan-api / app.py
eggman-poff's picture
update py
caf4acf verified
raw
history blame
1.39 kB
import gradio as gr
import torch
import tempfile
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import export_to_video
# Use the official SVD-XT img2vid-xt model
MODEL = "stabilityai/stable-video-diffusion-img2vid-xt"
# Load pipeline in half-precision on GPU
pipe = StableVideoDiffusionPipeline.from_pretrained(
MODEL, torch_dtype=torch.float16
).to("cuda")
def infer(first_image, last_image, prompt, guidance=7.5, frames=25):
# Generate the in-between frames
video = pipe(
image=first_image,
last_image=last_image,
prompt=prompt,
guidance_scale=guidance,
num_frames=frames
).frames
# Export as MP4 to a temp file
mp4_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
export_to_video(video, mp4_path, fps=15)
return mp4_path # Gradio will auto-encode this to base64 for the API
# Build a minimal Gradio interface
demo = gr.Interface(
fn=infer,
inputs=[
gr.Image(type="pil", label="Start frame"),
gr.Image(type="pil", label="End frame"),
gr.Textbox(placeholder="Prompt (optional)"),
gr.Slider(0, 12, 7.5, label="Guidance scale"),
gr.Slider(8, 48, 25, step=1, label="Num frames"),
],
outputs="video",
title="Eggman – 2-Frame SVD API"
)
# Enable the REST API
demo.queue(concurrency_count=1).launch(show_api=True)