File size: 3,625 Bytes
f07e098
 
8b742a3
 
 
f07e098
2e87863
f07e098
 
8b742a3
 
 
 
 
 
 
 
f07e098
8b742a3
 
 
f07e098
8b742a3
 
 
 
 
 
 
 
 
 
 
 
f07e098
 
 
8b742a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f07e098
8b742a3
 
 
 
 
 
 
 
 
 
 
f07e098
8b742a3
 
f07e098
 
 
8b742a3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import torch
from diffusers.utils import export_to_video, load_image
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from transformers import CLIPVisionModel
import numpy as np
import os


# Install necessary libraries (using a more robust approach)
try:
    import diffusers
    print("diffusers is already installed.")
except ImportError:
    print("Installing diffusers...")
    os.system("pip install git+https://github.com/huggingface/diffusers.git transformers accelerate") # install required packages
    import diffusers # try importing again after installation.

# Download necessary model (check and load)
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
lora_weights = "Remade/Squish"

def load_models():
    try:
        image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
        vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
        pipe = WanImageToVideoPipeline.from_pretrained(model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16)
        pipe.to("cuda")
        pipe.load_lora_weights(lora_weights)
        pipe.enable_model_cpu_offload()  # For low-VRAM
        return pipe
    except Exception as e:
        print(f"Error loading models: {e}")
        return None



pipe = load_models()  # Load models outside the function, so they are loaded only once

def generate_video(image_url, prompt, num_frames, guidance_scale, num_inference_steps, progress=gr.Progress()):
    if pipe is None:
        return "Error: Model failed to load.  Check server logs for details.", None

    if not image_url or not prompt:
        return "Error: Please provide both an image URL and a prompt.", None

    try:
        image = load_image(image_url)

        max_area = 480 * 832
        aspect_ratio = image.height / image.width
        mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
        height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
        width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
        image = image.resize((width, height))


        output = pipe(
            image=image,
            prompt=prompt,
            height=height,
            width=width,
            num_frames=int(num_frames),
            guidance_scale=guidance_scale,
            num_inference_steps=int(num_inference_steps)
        ).frames[0]

        export_to_video(output, "output.mp4", fps=16)  # save locally first
        return "output.mp4", "output.mp4" # Return both file path and Gradio's video component path


    except Exception as e:
        return f"An error occurred: {e}", None


# Gradio Interface
iface = gr.Interface(
    fn=generate_video,
    inputs=[
        gr.Image(type="filepath", label="Input Image URL (or upload)"), # allow local files
        gr.Textbox(label="Prompt"),
        gr.Slider(minimum=10, maximum=100, step=1, value=81, label="Number of Frames"),
        gr.Slider(minimum=1, maximum=10, step=0.1, value=5.0, label="Guidance Scale"),
        gr.Slider(minimum=10, maximum=50, step=1, value=28, label="Inference Steps"),
    ],
    outputs=[
       gr.Textbox(label="Status/Error Message"),
        gr.Video(label="Generated Video"),  # Display the generated video
    ],
    title="Wan Image-to-Video Generator",
    description="Generate videos from an image and a text prompt using the Wan Image-to-Video model.",
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860) # make accessible on the network