kevalfst commited on
Commit
4b17c2f
Β·
verified Β·
1 Parent(s): fd63bd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -61
app.py CHANGED
@@ -8,86 +8,86 @@ from diffusers import (
8
  )
9
  from diffusers.utils import export_to_video, load_image
10
 
11
- # Set dtype and device
12
- dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
14
 
15
- # -------- Text to Image: Stable Diffusion --------
16
- txt2img_pipe = StableDiffusionPipeline.from_pretrained(
17
- "stabilityai/stable-diffusion-2-1-base", torch_dtype=dtype
18
- )
19
- txt2img_pipe.to(device)
20
-
21
- def generate_image_from_text(prompt):
22
- image = txt2img_pipe(prompt, num_inference_steps=30).images[0]
23
- return image
24
 
 
 
 
 
 
25
 
26
- # -------- Image to Image: Instruct Pix2Pix --------
27
- pix2pix_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
28
- "timbrooks/instruct-pix2pix", torch_dtype=dtype
29
- )
30
- pix2pix_pipe.to(device)
 
 
 
31
 
32
  def generate_image_from_image_and_prompt(image, prompt):
33
- result = pix2pix_pipe(prompt=prompt, image=image, num_inference_steps=10)
34
- return result.images[0]
35
-
36
-
37
- # -------- Text to Video: Wan T2V --------
38
- wan_pipe = WanPipeline.from_pretrained(
39
- "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", torch_dtype=torch.bfloat16
40
- )
41
- wan_pipe.to(device)
42
 
43
  def generate_video_from_text(prompt):
44
- frames = wan_pipe(prompt=prompt, num_frames=16).frames[0]
45
- video_path = export_to_video(frames, "wan_video.mp4", fps=8)
46
- return video_path
47
-
48
-
49
- # -------- Image to Video: Stable Video Diffusion --------
50
- svd_pipe = StableVideoDiffusionPipeline.from_pretrained(
51
- "stabilityai/stable-video-diffusion-img2vid-xt",
52
- torch_dtype=dtype,
53
- variant="fp16" if dtype == torch.float16 else None,
54
- )
55
- svd_pipe.to(device)
56
 
57
  def generate_video_from_image(image):
58
- image = image.resize((1024, 576))
59
- frames = svd_pipe(image, num_inference_steps=25).frames[0]
60
- video_path = export_to_video(frames, "svd_video.mp4", fps=8)
61
- return video_path
62
-
 
 
 
 
 
63
 
64
- # -------- Gradio Interface --------
65
  with gr.Blocks() as demo:
66
- gr.Markdown("# 🧠 Multimodal Any-to-Any AI Playground")
67
 
68
  with gr.Tab("Text β†’ Image"):
69
- prompt = gr.Textbox(label="Prompt")
70
- output_image = gr.Image()
71
- btn1 = gr.Button("Generate")
72
- btn1.click(fn=generate_image_from_text, inputs=prompt, outputs=output_image)
73
 
74
  with gr.Tab("Image β†’ Image"):
75
- in_image = gr.Image(label="Input Image")
76
- edit_prompt = gr.Textbox(label="Edit Prompt")
77
- out_image = gr.Image()
78
- btn2 = gr.Button("Generate")
79
- btn2.click(fn=generate_image_from_image_and_prompt, inputs=[in_image, edit_prompt], outputs=out_image)
80
 
81
  with gr.Tab("Text β†’ Video"):
82
- vid_prompt = gr.Textbox(label="Prompt")
83
- output_vid = gr.Video()
84
- btn3 = gr.Button("Generate")
85
- btn3.click(fn=generate_video_from_text, inputs=vid_prompt, outputs=output_vid)
86
 
87
  with gr.Tab("Image β†’ Video"):
88
- img_input = gr.Image(label="Input Image")
89
- vid_out = gr.Video()
90
- btn4 = gr.Button("Animate")
91
- btn4.click(fn=generate_video_from_image, inputs=img_input, outputs=vid_out)
92
 
93
  demo.launch()
 
8
  )
9
  from diffusers.utils import export_to_video, load_image
10
 
11
+ # Detect device & dtype
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ dtype = torch.float16 if device == "cuda" else torch.float32
14
 
15
+ # Factory to load & offload a pipeline
16
+ def make_pipe(cls, model_id, **kwargs):
17
+ pipe = cls.from_pretrained(model_id, torch_dtype=dtype, **kwargs)
18
+ # Enables CPU offload of model parts not in use
19
+ pipe.enable_model_cpu_offload()
20
+ return pipe
 
 
 
21
 
22
+ # Hold pipelines in globals but don’t load yet
23
+ TXT2IMG_PIPE = None
24
+ IMG2IMG_PIPE = None
25
+ TXT2VID_PIPE = None
26
+ IMG2VID_PIPE = None
27
 
28
+ def generate_image_from_text(prompt):
29
+ global TXT2IMG_PIPE
30
+ if TXT2IMG_PIPE is None:
31
+ TXT2IMG_PIPE = make_pipe(
32
+ StableDiffusionPipeline,
33
+ "stabilityai/stable-diffusion-2-1-base"
34
+ ).to(device)
35
+ return TXT2IMG_PIPE(prompt, num_inference_steps=20).images[0]
36
 
37
  def generate_image_from_image_and_prompt(image, prompt):
38
+ global IMG2IMG_PIPE
39
+ if IMG2IMG_PIPE is None:
40
+ IMG2IMG_PIPE = make_pipe(
41
+ StableDiffusionInstructPix2PixPipeline,
42
+ "timbrooks/instruct-pix2pix"
43
+ ).to(device)
44
+ out = IMG2IMG_PIPE(prompt=prompt, image=image, num_inference_steps=8)
45
+ return out.images[0]
 
46
 
47
  def generate_video_from_text(prompt):
48
+ global TXT2VID_PIPE
49
+ if TXT2VID_PIPE is None:
50
+ TXT2VID_PIPE = make_pipe(
51
+ WanPipeline,
52
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
53
+ ).to(device)
54
+ frames = TXT2VID_PIPE(prompt=prompt, num_frames=12).frames[0]
55
+ return export_to_video(frames, "wan_video.mp4", fps=8)
 
 
 
 
56
 
57
  def generate_video_from_image(image):
58
+ global IMG2VID_PIPE
59
+ if IMG2VID_PIPE is None:
60
+ IMG2VID_PIPE = make_pipe(
61
+ StableVideoDiffusionPipeline,
62
+ "stabilityai/stable-video-diffusion-img2vid-xt",
63
+ variant="fp16" if dtype==torch.float16 else None
64
+ ).to(device)
65
+ image = load_image(image).resize((512, 288))
66
+ frames = IMG2VID_PIPE(image, num_inference_steps=16).frames[0]
67
+ return export_to_video(frames, "svd_video.mp4", fps=8)
68
 
 
69
  with gr.Blocks() as demo:
70
+ gr.Markdown("# 🧠 Lightweight Any‑to‑Any AI Playground")
71
 
72
  with gr.Tab("Text β†’ Image"):
73
+ inp = gr.Textbox(label="Prompt")
74
+ out = gr.Image()
75
+ gr.Button("Generate").click(generate_image_from_text, inp, out)
 
76
 
77
  with gr.Tab("Image β†’ Image"):
78
+ img = gr.Image(label="Input Image")
79
+ prm = gr.Textbox(label="Edit Prompt")
80
+ out2 = gr.Image()
81
+ gr.Button("Generate").click(generate_image_from_image_and_prompt, [img, prm], out2)
 
82
 
83
  with gr.Tab("Text β†’ Video"):
84
+ inp2 = gr.Textbox(label="Prompt")
85
+ out_vid = gr.Video()
86
+ gr.Button("Generate").click(generate_video_from_text, inp2, out_vid)
 
87
 
88
  with gr.Tab("Image β†’ Video"):
89
+ img2 = gr.Image(label="Input Image")
90
+ out_vid2 = gr.Video()
91
+ gr.Button("Animate").click(generate_video_from_image, img2, out_vid2)
 
92
 
93
  demo.launch()