aiqtech commited on
Commit
5849473
·
verified ·
1 Parent(s): 3b75af7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
 
2
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
3
  import torch
4
 
 
5
  controlnet = ControlNetModel.from_pretrained(
6
  "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
7
  )
@@ -12,20 +14,22 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
12
  torch_dtype=torch.float16,
13
  safety_checker=None
14
  )
15
-
16
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
17
- pipe.to("cuda" if torch.cuda.is_available() else "cpu")
18
 
 
 
19
  def generate(image, prompt="a person posing"):
 
20
  result = pipe(prompt=prompt, image=image, num_inference_steps=20).images[0]
21
  return result
22
 
23
  demo = gr.Interface(
24
  fn=generate,
25
- inputs=[gr.Image(type="pil"), gr.Textbox(label="Prompt")],
26
  outputs="image",
27
  title="Pose Generator",
28
  description="Upload an image and enter a prompt to generate a ControlNet-based pose output."
29
  )
30
 
31
- demo.launch()
 
 
1
  import gradio as gr
2
+ import spaces
3
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
4
  import torch
5
 
6
+ # Initialize models outside the GPU function
7
  controlnet = ControlNetModel.from_pretrained(
8
  "lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
9
  )
 
14
  torch_dtype=torch.float16,
15
  safety_checker=None
16
  )
 
17
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
 
18
 
19
+ # Move model to GPU inside the decorated function
20
+ @spaces.GPU(duration=60) # Request GPU for 60 seconds per call
21
  def generate(image, prompt="a person posing"):
22
+ pipe.to("cuda")
23
  result = pipe(prompt=prompt, image=image, num_inference_steps=20).images[0]
24
  return result
25
 
26
  demo = gr.Interface(
27
  fn=generate,
28
+ inputs=[gr.Image(type="pil"), gr.Textbox(label="Prompt", value="a person posing")],
29
  outputs="image",
30
  title="Pose Generator",
31
  description="Upload an image and enter a prompt to generate a ControlNet-based pose output."
32
  )
33
 
34
+ if __name__ == "__main__":
35
+ demo.launch()