kevalfst commited on
Commit
3455f8c
·
verified ·
1 Parent(s): 78ec26d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -7
app.py CHANGED
@@ -1,13 +1,101 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return f"Hello, {name}!"
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  with gr.Blocks() as demo:
7
- name_input = gr.Textbox(label="Enter your name")
8
- greet_button = gr.Button("Greet")
9
- output_text = gr.Textbox(label="Greeting")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- greet_button.click(fn=greet, inputs=name_input, outputs=output_text)
 
 
 
 
12
 
13
- demo.launch()
 
 
1
+ import torch
2
  import gradio as gr
3
+ from diffusers import (
4
+ StableDiffusionPipeline,
5
+ StableDiffusionInstructPix2PixPipeline,
6
+ StableVideoDiffusionPipeline,
7
+ WanPipeline,
8
+ )
9
+ from diffusers.utils import export_to_video, load_image
10
 
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ dtype = torch.float16 if device == "cuda" else torch.float32
13
 
14
+ # Pipeline factory
15
+ def make_pipe(cls, model_id, **kwargs):
16
+ pipe = cls.from_pretrained(model_id, torch_dtype=dtype, **kwargs)
17
+ pipe.enable_model_cpu_offload()
18
+ return pipe
19
+
20
+ # Global model caches
21
+ TXT2IMG_PIPE = None
22
+ IMG2IMG_PIPE = None
23
+ TXT2VID_PIPE = None
24
+ IMG2VID_PIPE = None
25
+
26
+ # Text → Image
27
+ def generate_image_from_text(prompt):
28
+ global TXT2IMG_PIPE
29
+ if TXT2IMG_PIPE is None:
30
+ TXT2IMG_PIPE = make_pipe(
31
+ StableDiffusionPipeline,
32
+ "stabilityai/stable-diffusion-2-1-base"
33
+ ).to(device)
34
+ return TXT2IMG_PIPE(prompt, num_inference_steps=20).images[0]
35
+
36
+ # Image → Image
37
+ def generate_image_from_image_and_prompt(image, prompt):
38
+ global IMG2IMG_PIPE
39
+ if IMG2IMG_PIPE is None:
40
+ IMG2IMG_PIPE = make_pipe(
41
+ StableDiffusionInstructPix2PixPipeline,
42
+ "timbrooks/instruct-pix2pix"
43
+ ).to(device)
44
+ out = IMG2IMG_PIPE(prompt=prompt, image=image, num_inference_steps=8)
45
+ return out.images[0]
46
+
47
+ # Text → Video
48
+ def generate_video_from_text(prompt):
49
+ global TXT2VID_PIPE
50
+ if TXT2VID_PIPE is None:
51
+ TXT2VID_PIPE = make_pipe(
52
+ WanPipeline,
53
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
54
+ ).to(device)
55
+ frames = TXT2VID_PIPE(prompt=prompt, num_frames=12).frames[0]
56
+ return export_to_video(frames, "/tmp/wan_video.mp4", fps=8)
57
+
58
+ # Image → Video
59
+ def generate_video_from_image(image):
60
+ global IMG2VID_PIPE
61
+ if IMG2VID_PIPE is None:
62
+ IMG2VID_PIPE = make_pipe(
63
+ StableVideoDiffusionPipeline,
64
+ "stabilityai/stable-video-diffusion-img2vid-xt",
65
+ variant="fp16" if dtype == torch.float16 else None
66
+ ).to(device)
67
+ image = load_image(image).resize((512, 288))
68
+ frames = IMG2VID_PIPE(image, num_inference_steps=16).frames[0]
69
+ return export_to_video(frames, "/tmp/svd_video.mp4", fps=8)
70
+
71
+ # Gradio Interface
72
  with gr.Blocks() as demo:
73
+ gr.Markdown("# 🧠 Lightweight Any‑to‑Any AI Playground")
74
+
75
+ with gr.Tab("Text → Image"):
76
+ text_prompt = gr.Textbox(label="Prompt")
77
+ output_image = gr.Image(label="Generated Image")
78
+ text2img_button = gr.Button("Generate")
79
+ text2img_button.click(generate_image_from_text, inputs=text_prompt, outputs=output_image)
80
+
81
+ with gr.Tab("Image → Image"):
82
+ input_image = gr.Image(label="Input Image")
83
+ edit_prompt = gr.Textbox(label="Edit Prompt")
84
+ edited_image = gr.Image(label="Edited Image")
85
+ img2img_button = gr.Button("Generate")
86
+ img2img_button.click(generate_image_from_image_and_prompt, inputs=[input_image, edit_prompt], outputs=edited_image)
87
+
88
+ with gr.Tab("Text → Video"):
89
+ video_prompt = gr.Textbox(label="Prompt")
90
+ video_output = gr.Video(label="Generated Video")
91
+ txt2vid_button = gr.Button("Generate")
92
+ txt2vid_button.click(generate_video_from_text, inputs=video_prompt, outputs=video_output)
93
 
94
+ with gr.Tab("Image Video"):
95
+ video_input_img = gr.Image(label="Input Image")
96
+ anim_video_output = gr.Video(label="Animated Video")
97
+ img2vid_button = gr.Button("Animate")
98
+ img2vid_button.click(generate_video_from_image, inputs=video_input_img, outputs=anim_video_output)
99
 
100
+ demo.queue()
101
+ demo.launch(show_error=True)