alexanz commited on
Commit
8823f71
·
verified ·
1 Parent(s): 4ce428e

add rembg support

Browse files
Files changed (1) hide show
  1. app.py +201 -28
app.py CHANGED
@@ -1,43 +1,142 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
 
 
4
 
5
  # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
 
 
7
  import torch
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
  pipe = pipe.to(device)
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
 
23
 
24
  # @spaces.GPU #[uncomment to use ZeroGPU]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
 
 
 
 
 
 
 
 
 
 
35
  ):
 
 
 
 
 
 
 
 
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
39
  generator = torch.Generator().manual_seed(seed)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  image = pipe(
42
  prompt=prompt,
43
  negative_prompt=negative_prompt,
@@ -46,15 +145,21 @@ def infer(
46
  width=width,
47
  height=height,
48
  generator=generator,
 
 
 
49
  ).images[0]
50
 
51
- return image, seed
 
 
 
52
 
53
 
54
  examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
  ]
59
 
60
  css = """
@@ -67,6 +172,10 @@ css = """
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
  gr.Markdown(" # Text-to-Image Gradio Template")
 
 
 
 
70
 
71
  with gr.Row():
72
  prompt = gr.Text(
@@ -86,7 +195,14 @@ with gr.Blocks(css=css) as demo:
86
  label="Negative prompt",
87
  max_lines=1,
88
  placeholder="Enter a negative prompt",
89
- visible=False,
 
 
 
 
 
 
 
90
  )
91
 
92
  seed = gr.Slider(
@@ -105,7 +221,7 @@ with gr.Blocks(css=css) as demo:
105
  minimum=256,
106
  maximum=MAX_IMAGE_SIZE,
107
  step=32,
108
- value=1024, # Replace with defaults that work for your model
109
  )
110
 
111
  height = gr.Slider(
@@ -113,7 +229,7 @@ with gr.Blocks(css=css) as demo:
113
  minimum=256,
114
  maximum=MAX_IMAGE_SIZE,
115
  step=32,
116
- value=1024, # Replace with defaults that work for your model
117
  )
118
 
119
  with gr.Row():
@@ -122,7 +238,7 @@ with gr.Blocks(css=css) as demo:
122
  minimum=0.0,
123
  maximum=10.0,
124
  step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
  )
127
 
128
  num_inference_steps = gr.Slider(
@@ -130,10 +246,40 @@ with gr.Blocks(css=css) as demo:
130
  minimum=1,
131
  maximum=50,
132
  step=1,
133
- value=2, # Replace with defaults that work for your model
134
  )
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  gr.Examples(examples=examples, inputs=[prompt])
 
137
  gr.on(
138
  triggers=[run_button.click, prompt.submit],
139
  fn=infer,
@@ -144,11 +290,38 @@ with gr.Blocks(css=css) as demo:
144
  randomize_seed,
145
  width,
146
  height,
 
147
  guidance_scale,
148
  num_inference_steps,
 
 
 
 
 
 
 
 
 
149
  ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ from PIL import Image
5
+ from rembg import remove
6
 
7
  # import spaces #[uncomment to use ZeroGPU]
8
+ from peft import PeftModel
9
+ from diffusers import DiffusionPipeline, StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
10
+ from diffusers.utils import load_image
11
  import torch
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ model_repo_id = "CompVis/stable-diffusion-v1-4" # Replace to the model you would like to use
15
 
16
+ torch_dtype = torch.float16
 
 
 
17
 
18
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
19
  pipe = pipe.to(device)
20
+ # pipe.unet = PeftModel.from_pretrained(pipe.unet, "alexanz/SD14_lora_pusheen")
21
+ pipe.safety_checker = None
22
+ pipe.requires_safety_checker = False
23
 
24
  MAX_SEED = np.iinfo(np.int32).max
25
+ MAX_IMAGE_SIZE = 512
26
 
27
 
28
  # @spaces.GPU #[uncomment to use ZeroGPU]
29
+ def load_model(model_id, lora_strength, use_controlnet=False, control_mode="edge_detection", use_ip_adapter=False, control_strength_ip=0.0):
30
+ global pipe
31
+ if pipe is not None:
32
+ del pipe
33
+ torch.cuda.empty_cache()
34
+ try:
35
+ if control_mode == "edge_detection" and (model_id == "CompVis/stable-diffusion-v1-4" or model_id == "alexanz/SD14_lora_pusheen"):
36
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype)
37
+ elif control_mode == "pose_estimation"and (model_id == "CompVis/stable-diffusion-v1-4" or model_id == "alexanz/SD14_lora_pusheen"):
38
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose", torch_dtype=torch_dtype)
39
+ if control_mode == "edge_detection" and (model_id == "alexanz/SD15_lora_pusheen"):
40
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", torch_dtype=torch_dtype)
41
+ elif control_mode == "pose_estimation"and (model_id == "alexanz/SD15_lora_pusheen"):
42
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch_dtype)
43
+
44
+ if model_id == "CompVis/stable-diffusion-v1-4":
45
+ if use_controlnet:
46
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
47
+ model_id,
48
+ safety_checker=None,
49
+ controlnet=controlnet,
50
+ torch_dtype=torch_dtype
51
+ )
52
+ else:
53
+ pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype)
54
+
55
+ elif model_id == "alexanz/SD14_lora_pusheen":
56
+ if use_controlnet:
57
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
58
+ "CompVis/stable-diffusion-v1-4",
59
+ safety_checker=None,
60
+ controlnet=controlnet,
61
+ torch_dtype=torch_dtype
62
+ )
63
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, model_id, scaling=lora_strength, torch_dtype=torch_dtype)
64
+ else:
65
+ pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch_dtype)
66
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, model_id, scaling=lora_strength)
67
+
68
+ elif model_id == "alexanz/SD15_lora_pusheen":
69
+ if use_controlnet:
70
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
71
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
72
+ safety_checker=None,
73
+ controlnet=controlnet,
74
+ torch_dtype=torch_dtype
75
+ )
76
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, model_id, scaling=lora_strength, torch_dtype=torch_dtype)
77
+ else:
78
+ pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch_dtype)
79
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, model_id, scaling=lora_strength)
80
+
81
+ if use_ip_adapter:
82
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
83
+ pipe.set_ip_adapter_scale(control_strength_ip)
84
+
85
+ pipe = pipe.to(device)
86
+ pipe.safety_checker = None
87
+ pipe.requires_safety_checker = False
88
+ pipe.enable_model_cpu_offload()
89
+ return f"Model {model_id} loaded with ControlNet: {use_controlnet}, mode: {control_mode}"
90
+ except Exception as e:
91
+ return f"Error: {str(e)}"
92
+
93
+
94
  def infer(
95
+ prompt,
96
+ negative_prompt,
97
+ seed,
98
+ randomize_seed,
99
+ width,
100
+ height,
101
+ lora_strength,
102
+ guidance_scale,
103
+ num_inference_steps,
104
+ use_controlnet,
105
+ control_image_cont,
106
+ control_strength_cont,
107
+ model_dropdown,
108
+ control_mode,
109
+ use_ip_adapter,
110
+ control_strength_ip,
111
+ control_image_ip,
112
+ use_rmbg,
113
+ progress=gr.Progress(track_tqdm=True),
114
  ):
115
+ load_status = load_model(
116
+ model_dropdown,
117
+ lora_strength,
118
+ use_controlnet,
119
+ control_mode,
120
+ use_ip_adapter,
121
+ control_strength_ip
122
+ )
123
  if randomize_seed:
124
  seed = random.randint(0, MAX_SEED)
125
 
126
  generator = torch.Generator().manual_seed(seed)
127
 
128
+ if use_controlnet and control_image_cont is None:
129
+ return None, seed, "⚠️ ControlNet need control_image!"
130
+
131
+ if use_ip_adapter and control_image_ip is None:
132
+ return None, seed, "⚠️ IP-adapter need control_image!"
133
+
134
+ if use_controlnet:
135
+ control_image_cont= Image.fromarray(control_image_cont)
136
+ control_strength_cont = float(control_strength_cont)
137
+ if use_ip_adapter:
138
+ control_image_ip = Image.fromarray(control_image_ip)
139
+
140
  image = pipe(
141
  prompt=prompt,
142
  negative_prompt=negative_prompt,
 
145
  width=width,
146
  height=height,
147
  generator=generator,
148
+ image=control_image_cont if use_controlnet else None,
149
+ controlnet_conditioning_scale=control_strength_cont if use_controlnet else None,
150
+ ip_adapter_image=control_image_ip if use_ip_adapter else None
151
  ).images[0]
152
 
153
+ if use_rmbg:
154
+ image = remove(image)
155
+
156
+ return image, seed, "Model ready"
157
 
158
 
159
  examples = [
160
+ "Sticker of Pusheen. Cartoon image of a gray cat with cap of tea.",
161
+ "Sticker of Pusheen. Gray cat holding a guitar, sitting under a disco ball, with colorful lights and a happy face.",
162
+ "Sticker of Pusheen. A cute cartoon fluffy cat.",
163
  ]
164
 
165
  css = """
 
172
  with gr.Blocks(css=css) as demo:
173
  with gr.Column(elem_id="col-container"):
174
  gr.Markdown(" # Text-to-Image Gradio Template")
175
+ model_dropdown = gr.Dropdown(label="Model ID",
176
+ choices=["alexanz/SD14_lora_pusheen", "CompVis/stable-diffusion-v1-4", "alexanz/SD15_lora_pusheen"],
177
+ value="CompVis/stable-diffusion-v1-4")
178
+ model_status = gr.Textbox(label="Model Status", interactive=False)
179
 
180
  with gr.Row():
181
  prompt = gr.Text(
 
195
  label="Negative prompt",
196
  max_lines=1,
197
  placeholder="Enter a negative prompt",
198
+ )
199
+
200
+ lora_strength = gr.Slider(
201
+ label="Lora strength",
202
+ minimum=0.0,
203
+ maximum=1.0,
204
+ step=0.1,
205
+ value=1.0,
206
  )
207
 
208
  seed = gr.Slider(
 
221
  minimum=256,
222
  maximum=MAX_IMAGE_SIZE,
223
  step=32,
224
+ value=512, # Replace with defaults that work for your model
225
  )
226
 
227
  height = gr.Slider(
 
229
  minimum=256,
230
  maximum=MAX_IMAGE_SIZE,
231
  step=32,
232
+ value=512, # Replace with defaults that work for your model
233
  )
234
 
235
  with gr.Row():
 
238
  minimum=0.0,
239
  maximum=10.0,
240
  step=0.1,
241
+ value=7.5, # Replace with defaults that work for your model
242
  )
243
 
244
  num_inference_steps = gr.Slider(
 
246
  minimum=1,
247
  maximum=50,
248
  step=1,
249
+ value=20, # Replace with defaults that work for your model
250
  )
251
 
252
+ use_controlnet = gr.Checkbox(label="Use ControlNet", value=False)
253
+ with gr.Accordion("ControlNet Settings", open=True, visible=False) as controlnet_settings:
254
+ control_mode = gr.Dropdown(
255
+ label="ControlNet Mode",
256
+ choices=["edge_detection", "pose_estimation"],
257
+ value="edge_detection"
258
+ )
259
+ control_strength_cont = gr.Slider(
260
+ label="Control Strength",
261
+ minimum=0.0,
262
+ maximum=2.0,
263
+ step=0.1,
264
+ value=1.0
265
+ )
266
+ control_image_cont = gr.Image(label="Control Image", type="numpy")
267
+
268
+ use_ip_adapter = gr.Checkbox(label="Use IP-adapter", value=False)
269
+ with gr.Accordion("IP-adapter Settings", open=True, visible=False) as ip_adapter_settings:
270
+ control_strength_ip = gr.Slider(
271
+ label="Control Strength",
272
+ minimum=0.0,
273
+ maximum=2.0,
274
+ step=0.1,
275
+ value=1.0
276
+ )
277
+ control_image_ip = gr.Image(label="Control Image (IP-adapter)", type="numpy")
278
+
279
+ use_rmbg = gr.Checkbox(label="Delete background?", value=False)
280
+
281
  gr.Examples(examples=examples, inputs=[prompt])
282
+
283
  gr.on(
284
  triggers=[run_button.click, prompt.submit],
285
  fn=infer,
 
290
  randomize_seed,
291
  width,
292
  height,
293
+ lora_strength,
294
  guidance_scale,
295
  num_inference_steps,
296
+ use_controlnet,
297
+ control_image_cont,
298
+ control_strength_cont,
299
+ model_dropdown,
300
+ control_mode,
301
+ use_ip_adapter,
302
+ control_strength_ip,
303
+ control_image_ip,
304
+ use_rmbg
305
  ],
306
+ outputs=[result, seed, model_status],
307
+ )
308
+
309
+ use_controlnet.change(
310
+ fn=lambda x: gr.update(visible=x, value=None),
311
+ inputs=[use_controlnet],
312
+ outputs=[controlnet_settings]
313
+ )
314
+
315
+ use_ip_adapter.change(
316
+ fn=lambda x: gr.update(visible=x, value=None),
317
+ inputs=[use_ip_adapter],
318
+ outputs=[ip_adapter_settings]
319
+ )
320
+
321
+ use_rmbg.change(
322
+ fn=lambda x: gr.update(visible=x, value=None),
323
+ inputs=[use_rmbg]
324
  )
325
 
326
  if __name__ == "__main__":
327
+ demo.launch()