Spaces:
Running
Running
import torch | |
import gradio as gr | |
from diffusers import ( | |
StableDiffusionPipeline, | |
StableDiffusionInstructPix2PixPipeline, | |
StableVideoDiffusionPipeline, | |
WanPipeline, | |
) | |
from diffusers.utils import export_to_video, load_image | |
# Set dtype and device | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# -------- Text to Image: Stable Diffusion -------- | |
txt2img_pipe = StableDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1-base", torch_dtype=dtype | |
) | |
txt2img_pipe.to(device) | |
def generate_image_from_text(prompt): | |
image = txt2img_pipe(prompt, num_inference_steps=30).images[0] | |
return image | |
# -------- Image to Image: Instruct Pix2Pix -------- | |
pix2pix_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
"timbrooks/instruct-pix2pix", torch_dtype=dtype | |
) | |
pix2pix_pipe.to(device) | |
def generate_image_from_image_and_prompt(image, prompt): | |
result = pix2pix_pipe(prompt=prompt, image=image, num_inference_steps=10) | |
return result.images[0] | |
# -------- Text to Video: Wan T2V -------- | |
wan_pipe = WanPipeline.from_pretrained( | |
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16 | |
) | |
wan_pipe.to(device) | |
def generate_video_from_text(prompt): | |
frames = wan_pipe(prompt=prompt, num_frames=16).frames[0] | |
video_path = export_to_video(frames, "wan_video.mp4", fps=8) | |
return video_path | |
# -------- Image to Video: Stable Video Diffusion -------- | |
svd_pipe = StableVideoDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-video-diffusion-img2vid-xt", | |
torch_dtype=dtype, | |
variant="fp16" if dtype == torch.float16 else None, | |
) | |
svd_pipe.to(device) | |
def generate_video_from_image(image): | |
image = image.resize((1024, 576)) | |
frames = svd_pipe(image, num_inference_steps=25).frames[0] | |
video_path = export_to_video(frames, "svd_video.mp4", fps=8) | |
return video_path | |
# -------- Gradio Interface -------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# π§ Multimodal Any-to-Any AI Playground") | |
with gr.Tab("Text β Image"): | |
prompt = gr.Textbox(label="Prompt") | |
output_image = gr.Image() | |
btn1 = gr.Button("Generate") | |
btn1.click(fn=generate_image_from_text, inputs=prompt, outputs=output_image) | |
with gr.Tab("Image β Image"): | |
in_image = gr.Image(label="Input Image") | |
edit_prompt = gr.Textbox(label="Edit Prompt") | |
out_image = gr.Image() | |
btn2 = gr.Button("Generate") | |
btn2.click(fn=generate_image_from_image_and_prompt, inputs=[in_image, edit_prompt], outputs=out_image) | |
with gr.Tab("Text β Video"): | |
vid_prompt = gr.Textbox(label="Prompt") | |
output_vid = gr.Video() | |
btn3 = gr.Button("Generate") | |
btn3.click(fn=generate_video_from_text, inputs=vid_prompt, outputs=output_vid) | |
with gr.Tab("Image β Video"): | |
img_input = gr.Image(label="Input Image") | |
vid_out = gr.Video() | |
btn4 = gr.Button("Animate") | |
btn4.click(fn=generate_video_from_image, inputs=img_input, outputs=vid_out) | |
demo.launch() | |