File size: 3,112 Bytes
556d852
cce6ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556d852
cce6ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556d852
cce6ff1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707450c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import gradio as gr
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionInstructPix2PixPipeline,
    StableVideoDiffusionPipeline,
    WanPipeline,
)
from diffusers.utils import export_to_video, load_image

# Set dtype and device
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"

# -------- Text to Image: Stable Diffusion --------
txt2img_pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base", torch_dtype=dtype
)
txt2img_pipe.to(device)

def generate_image_from_text(prompt):
    image = txt2img_pipe(prompt, num_inference_steps=30).images[0]
    return image


# -------- Image to Image: Instruct Pix2Pix --------
pix2pix_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    "timbrooks/instruct-pix2pix", torch_dtype=dtype
)
pix2pix_pipe.to(device)

def generate_image_from_image_and_prompt(image, prompt):
    result = pix2pix_pipe(prompt=prompt, image=image, num_inference_steps=10)
    return result.images[0]


# -------- Text to Video: Wan T2V --------
wan_pipe = WanPipeline.from_pretrained(
    "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
)
wan_pipe.to(device)

def generate_video_from_text(prompt):
    frames = wan_pipe(prompt=prompt, num_frames=16).frames[0]
    video_path = export_to_video(frames, "wan_video.mp4", fps=8)
    return video_path


# -------- Image to Video: Stable Video Diffusion --------
svd_pipe = StableVideoDiffusionPipeline.from_pretrained(
    "stabilityai/stable-video-diffusion-img2vid-xt",
    torch_dtype=dtype,
    variant="fp16" if dtype == torch.float16 else None,
)
svd_pipe.to(device)

def generate_video_from_image(image):
    image = image.resize((1024, 576))
    frames = svd_pipe(image, num_inference_steps=25).frames[0]
    video_path = export_to_video(frames, "svd_video.mp4", fps=8)
    return video_path


# -------- Gradio Interface --------
with gr.Blocks() as demo:
    gr.Markdown("# 🧠 Multimodal Any-to-Any AI Playground")

    with gr.Tab("Text β†’ Image"):
        prompt = gr.Textbox(label="Prompt")
        output_image = gr.Image()
        btn1 = gr.Button("Generate")
        btn1.click(fn=generate_image_from_text, inputs=prompt, outputs=output_image)

    with gr.Tab("Image β†’ Image"):
        in_image = gr.Image(label="Input Image")
        edit_prompt = gr.Textbox(label="Edit Prompt")
        out_image = gr.Image()
        btn2 = gr.Button("Generate")
        btn2.click(fn=generate_image_from_image_and_prompt, inputs=[in_image, edit_prompt], outputs=out_image)

    with gr.Tab("Text β†’ Video"):
        vid_prompt = gr.Textbox(label="Prompt")
        output_vid = gr.Video()
        btn3 = gr.Button("Generate")
        btn3.click(fn=generate_video_from_text, inputs=vid_prompt, outputs=output_vid)

    with gr.Tab("Image β†’ Video"):
        img_input = gr.Image(label="Input Image")
        vid_out = gr.Video()
        btn4 = gr.Button("Animate")
        btn4.click(fn=generate_video_from_image, inputs=img_input, outputs=vid_out)

demo.launch()