File size: 5,342 Bytes
3455f8c
78ec26d
3455f8c
 
 
 
 
 
 
f7bfc02
 
78ec26d
3455f8c
 
f7bfc02
78ec26d
f7bfc02
3455f8c
 
 
 
 
f7bfc02
 
 
 
 
 
 
3455f8c
 
f7bfc02
 
 
 
 
 
3455f8c
f7bfc02
3455f8c
 
f7bfc02
 
 
 
 
 
3455f8c
f7bfc02
3455f8c
 
f7bfc02
 
 
 
 
 
3455f8c
f7bfc02
3455f8c
 
f7bfc02
 
 
 
3455f8c
f7bfc02
 
3455f8c
f7bfc02
 
 
3455f8c
f7bfc02
 
 
 
 
 
 
 
 
 
 
 
 
 
3455f8c
f7bfc02
 
 
 
 
 
 
 
 
 
 
 
 
 
3455f8c
f7bfc02
 
 
 
 
 
 
 
 
 
 
 
 
78ec26d
f7bfc02
 
 
 
 
 
 
 
 
 
 
 
 
78ec26d
3455f8c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import gradio as gr
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionInstructPix2PixPipeline,
    StableVideoDiffusionPipeline,
    WanPipeline,
)
from diffusers.utils import export_to_video, load_image
import random
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
MAX_SEED = np.iinfo(np.int32).max

# Model cache
TXT2IMG_PIPE = None
IMG2IMG_PIPE = None
TXT2VID_PIPE = None
IMG2VID_PIPE = None

def make_pipe(cls, model_id, **kwargs):
    pipe = cls.from_pretrained(model_id, torch_dtype=dtype, **kwargs)
    pipe.enable_model_cpu_offload()
    return pipe

# Functions
def generate_image_from_text(prompt, seed, randomize_seed):
    global TXT2IMG_PIPE
    if TXT2IMG_PIPE is None:
        TXT2IMG_PIPE = make_pipe(StableDiffusionPipeline, "stabilityai/stable-diffusion-2-1-base").to(device)
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.manual_seed(seed)
    image = TXT2IMG_PIPE(prompt=prompt, num_inference_steps=20, generator=generator).images[0]
    return image, seed

def generate_image_from_image_and_prompt(image, prompt, seed, randomize_seed):
    global IMG2IMG_PIPE
    if IMG2IMG_PIPE is None:
        IMG2IMG_PIPE = make_pipe(StableDiffusionInstructPix2PixPipeline, "timbrooks/instruct-pix2pix").to(device)
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.manual_seed(seed)
    out = IMG2IMG_PIPE(prompt=prompt, image=image, num_inference_steps=8, generator=generator)
    return out.images[0], seed

def generate_video_from_text(prompt, seed, randomize_seed):
    global TXT2VID_PIPE
    if TXT2VID_PIPE is None:
        TXT2VID_PIPE = make_pipe(WanPipeline, "Wan-AI/Wan2.1-T2V-1.3B-Diffusers").to(device)
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.manual_seed(seed)
    frames = TXT2VID_PIPE(prompt=prompt, num_frames=12, generator=generator).frames[0]
    return export_to_video(frames, "/tmp/wan_video.mp4", fps=8), seed

def generate_video_from_image(image, seed, randomize_seed):
    global IMG2VID_PIPE
    if IMG2VID_PIPE is None:
        IMG2VID_PIPE = make_pipe(StableVideoDiffusionPipeline, "stabilityai/stable-video-diffusion-img2vid-xt", variant="fp16" if dtype == torch.float16 else None).to(device)
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.manual_seed(seed)
    image = load_image(image).resize((512, 288))
    frames = IMG2VID_PIPE(image=image, num_inference_steps=16, generator=generator).frames[0]
    return export_to_video(frames, "/tmp/svd_video.mp4", fps=8), seed

# UI
with gr.Blocks(css="footer {display:none !important}") as demo:
    gr.Markdown("# 🧠 AI Playground – Multi-Mode Generator")

    with gr.Tabs():
        # Text β†’ Image
        with gr.Tab("Text β†’ Image"):
            with gr.Row():
                prompt_txt = gr.Textbox(label="Prompt")
                generate_btn = gr.Button("Generate")
            result_img = gr.Image()
            seed_txt = gr.Slider(0, MAX_SEED, value=42, label="Seed")
            rand_seed_txt = gr.Checkbox(label="Randomize seed", value=True)
            generate_btn.click(
                fn=generate_image_from_text,
                inputs=[prompt_txt, seed_txt, rand_seed_txt],
                outputs=[result_img, seed_txt]
            )

        # Image β†’ Image
        with gr.Tab("Image β†’ Image"):
            with gr.Row():
                image_in = gr.Image(label="Input Image")
                prompt_img = gr.Textbox(label="Edit Prompt")
                generate_btn2 = gr.Button("Generate")
            result_img2 = gr.Image()
            seed_img = gr.Slider(0, MAX_SEED, value=123, label="Seed")
            rand_seed_img = gr.Checkbox(label="Randomize seed", value=True)
            generate_btn2.click(
                fn=generate_image_from_image_and_prompt,
                inputs=[image_in, prompt_img, seed_img, rand_seed_img],
                outputs=[result_img2, seed_img]
            )

        # Text β†’ Video
        with gr.Tab("Text β†’ Video"):
            with gr.Row():
                prompt_vid = gr.Textbox(label="Prompt")
                generate_btn3 = gr.Button("Generate")
            result_vid = gr.Video()
            seed_vid = gr.Slider(0, MAX_SEED, value=555, label="Seed")
            rand_seed_vid = gr.Checkbox(label="Randomize seed", value=True)
            generate_btn3.click(
                fn=generate_video_from_text,
                inputs=[prompt_vid, seed_vid, rand_seed_vid],
                outputs=[result_vid, seed_vid]
            )

        # Image β†’ Video
        with gr.Tab("Image β†’ Video"):
            with gr.Row():
                image_in_vid = gr.Image(label="Input Image")
                generate_btn4 = gr.Button("Animate")
            result_vid2 = gr.Video()
            seed_vid2 = gr.Slider(0, MAX_SEED, value=999, label="Seed")
            rand_seed_vid2 = gr.Checkbox(label="Randomize seed", value=True)
            generate_btn4.click(
                fn=generate_video_from_image,
                inputs=[image_in_vid, seed_vid2, rand_seed_vid2],
                outputs=[result_vid2, seed_vid2]
            )

demo.queue()
demo.launch(show_error=True)