OzzyGT HF Staff commited on
Commit
1fcbe69
·
1 Parent(s): 2100afd
Files changed (3) hide show
  1. README.md +0 -1
  2. app.py +77 -59
  3. requirements.txt +3 -4
README.md CHANGED
@@ -4,7 +4,6 @@ emoji: 👀
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  license: apache-2.0
app.py CHANGED
@@ -1,58 +1,66 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
4
- from diffusers import AutoencoderKL, TCDScheduler
5
- from diffusers.models.model_loading_utils import load_state_dict
6
- from gradio_imageslider import ImageSlider
7
- from huggingface_hub import hf_hub_download
8
 
9
- from controlnet_union import ControlNetModel_Union
10
- from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  MODELS = {
13
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
14
  }
15
 
16
- config_file = hf_hub_download(
17
- "xinsir/controlnet-union-sdxl-1.0",
18
- filename="config_promax.json",
19
- )
20
-
21
- config = ControlNetModel_Union.load_config(config_file)
22
- controlnet_model = ControlNetModel_Union.from_config(config)
23
- model_file = hf_hub_download(
24
- "xinsir/controlnet-union-sdxl-1.0",
25
- filename="diffusion_pytorch_model_promax.safetensors",
26
- )
27
- state_dict = load_state_dict(model_file)
28
- model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
29
- controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
30
  )
31
- model.to(device="cuda", dtype=torch.float16)
 
32
 
33
- vae = AutoencoderKL.from_pretrained(
34
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
35
- ).to("cuda")
36
-
37
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
38
  "SG161222/RealVisXL_V5.0_Lightning",
39
  torch_dtype=torch.float16,
40
  vae=vae,
41
- controlnet=model,
42
- variant="fp16",
43
  ).to("cuda")
44
 
45
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
46
 
47
 
48
  @spaces.GPU(duration=24)
49
- def fill_image(prompt, image, model_selection, paste_back):
50
  (
51
  prompt_embeds,
52
  negative_prompt_embeds,
53
  pooled_prompt_embeds,
54
  negative_pooled_prompt_embeds,
55
- ) = pipe.encode_prompt(prompt, "cuda", True)
56
 
57
  source = image["background"]
58
  mask = image["layers"][0]
@@ -62,17 +70,25 @@ def fill_image(prompt, image, model_selection, paste_back):
62
  cnet_image = source.copy()
63
  cnet_image.paste(0, (0, 0), binary_mask)
64
 
65
- for image in pipe(
66
  prompt_embeds=prompt_embeds,
67
  negative_prompt_embeds=negative_prompt_embeds,
68
  pooled_prompt_embeds=pooled_prompt_embeds,
69
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
70
- image=cnet_image,
71
- ):
72
- yield image, cnet_image
73
-
74
- print(f"{model_selection=}")
75
- print(f"{paste_back=}")
 
 
 
 
 
 
 
 
76
 
77
  if paste_back:
78
  image = image.convert("RGBA")
@@ -87,10 +103,8 @@ def clear_result():
87
  return gr.update(value=None)
88
 
89
 
90
- title = """<h1 align="center">Diffusers Fast Inpaint</h1>
91
  <div align="center">Draw the mask over the subject you want to erase or change and write what you want to inpaint it with.</div>
92
- <div align="center">This is a lighting model with almost no CFG and 12 steps, so don't expect high quality generations.</div>
93
- <div align="center">This space is a PoC made for the guide <a href='https://huggingface.co/blog/OzzyGT/diffusers-image-fill'>Diffusers Image Fill</a>.</div>
94
  """
95
 
96
  with gr.Blocks() as demo:
@@ -99,41 +113,45 @@ with gr.Blocks() as demo:
99
  with gr.Column():
100
  prompt = gr.Textbox(
101
  label="Prompt",
102
- info="Describe what to inpaint the mask with",
103
- lines=3,
104
  )
105
  with gr.Column():
106
- model_selection = gr.Dropdown(
107
- choices=list(MODELS.keys()),
108
- value="RealVisXL V5.0 Lightning",
109
- label="Model",
110
- )
111
-
112
  with gr.Row():
113
- with gr.Column():
114
- run_button = gr.Button("Generate")
 
 
115
 
116
- with gr.Column():
117
- paste_back = gr.Checkbox(True, label="Paste back original")
 
 
 
 
118
 
119
  with gr.Row():
120
  input_image = gr.ImageMask(
121
- type="pil", label="Input Image", crop_size=(1024, 1024), layers=False
 
 
 
 
 
122
  )
123
 
124
- result = ImageSlider(
125
  interactive=False,
126
  label="Generated Image",
127
  )
128
 
129
  use_as_input_button = gr.Button("Use as Input Image", visible=False)
130
 
 
 
131
  def use_output_as_input(output_image):
132
  return gr.update(value=output_image[1])
133
 
134
- use_as_input_button.click(
135
- fn=use_output_as_input, inputs=[result], outputs=[input_image]
136
- )
137
 
138
  run_button.click(
139
  fn=clear_result,
@@ -145,7 +163,7 @@ with gr.Blocks() as demo:
145
  outputs=use_as_input_button,
146
  ).then(
147
  fn=fill_image,
148
- inputs=[prompt, input_image, model_selection, paste_back],
149
  outputs=result,
150
  ).then(
151
  fn=lambda: gr.update(visible=True),
@@ -163,7 +181,7 @@ with gr.Blocks() as demo:
163
  outputs=use_as_input_button,
164
  ).then(
165
  fn=fill_image,
166
- inputs=[prompt, input_image, model_selection, paste_back],
167
  outputs=result,
168
  ).then(
169
  fn=lambda: gr.update(visible=True),
 
1
  import gradio as gr
2
  import spaces
3
  import torch
 
 
 
 
4
 
5
+ from diffusers import AutoencoderKL, ControlNetUnionModel, DiffusionPipeline, TCDScheduler
6
+
7
+
8
+ def callback_cfg_cutoff(pipeline, step_index, timestep, callback_kwargs):
9
+ if step_index == int(pipeline.num_timesteps * 0.2):
10
+ prompt_embeds = callback_kwargs["prompt_embeds"]
11
+ prompt_embeds = prompt_embeds[-1:]
12
+
13
+ add_text_embeds = callback_kwargs["add_text_embeds"]
14
+ add_text_embeds = add_text_embeds[-1:]
15
+
16
+ add_time_ids = callback_kwargs["add_time_ids"]
17
+ add_time_ids = add_time_ids[-1:]
18
+
19
+ control_image = callback_kwargs["control_image"]
20
+ control_image[0] = control_image[0][-1:]
21
+
22
+ control_type = callback_kwargs["control_type"]
23
+ control_type = control_type[-1:]
24
+
25
+ pipeline._guidance_scale = 0.0
26
+ callback_kwargs["prompt_embeds"] = prompt_embeds
27
+ callback_kwargs["add_text_embeds"] = add_text_embeds
28
+ callback_kwargs["add_time_ids"] = add_time_ids
29
+ callback_kwargs["control_image"] = control_image
30
+ callback_kwargs["control_type"] = control_type
31
+
32
+ return callback_kwargs
33
+
34
 
35
  MODELS = {
36
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
37
  }
38
 
39
+ controlnet_model = ControlNetUnionModel.from_pretrained(
40
+ "OzzyGT/controlnet-union-promax-sdxl-1.0", variant="fp16", torch_dtype=torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
+ controlnet_model.to(device="cuda", dtype=torch.float16)
43
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
44
 
45
+ pipe = DiffusionPipeline.from_pretrained(
 
 
 
 
46
  "SG161222/RealVisXL_V5.0_Lightning",
47
  torch_dtype=torch.float16,
48
  vae=vae,
49
+ controlnet=controlnet_model,
50
+ custom_pipeline="OzzyGT/custom_sdxl_cnet_union",
51
  ).to("cuda")
52
 
53
  pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
54
 
55
 
56
  @spaces.GPU(duration=24)
57
+ def fill_image(prompt, negative_prompt, image, model_selection, paste_back):
58
  (
59
  prompt_embeds,
60
  negative_prompt_embeds,
61
  pooled_prompt_embeds,
62
  negative_pooled_prompt_embeds,
63
+ ) = pipe.encode_prompt(prompt, device="cuda", negative_prompt=negative_prompt)
64
 
65
  source = image["background"]
66
  mask = image["layers"][0]
 
70
  cnet_image = source.copy()
71
  cnet_image.paste(0, (0, 0), binary_mask)
72
 
73
+ image = pipe(
74
  prompt_embeds=prompt_embeds,
75
  negative_prompt_embeds=negative_prompt_embeds,
76
  pooled_prompt_embeds=pooled_prompt_embeds,
77
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
78
+ control_image=[cnet_image],
79
+ controlnet_conditioning_scale=[1.0],
80
+ control_mode=[7],
81
+ num_inference_steps=8,
82
+ guidance_scale=1.5,
83
+ callback_on_step_end=callback_cfg_cutoff,
84
+ callback_on_step_end_tensor_inputs=[
85
+ "prompt_embeds",
86
+ "add_text_embeds",
87
+ "add_time_ids",
88
+ "control_image",
89
+ "control_type",
90
+ ],
91
+ ).images[0]
92
 
93
  if paste_back:
94
  image = image.convert("RGBA")
 
103
  return gr.update(value=None)
104
 
105
 
106
+ title = """<h2 align="center">Diffusers Fast Inpaint</h2>
107
  <div align="center">Draw the mask over the subject you want to erase or change and write what you want to inpaint it with.</div>
 
 
108
  """
109
 
110
  with gr.Blocks() as demo:
 
113
  with gr.Column():
114
  prompt = gr.Textbox(
115
  label="Prompt",
116
+ lines=1,
 
117
  )
118
  with gr.Column():
 
 
 
 
 
 
119
  with gr.Row():
120
+ negative_prompt = gr.Textbox(
121
+ label="Negative Prompt",
122
+ lines=1,
123
+ )
124
 
125
+ with gr.Row():
126
+ with gr.Column():
127
+ run_button = gr.Button("Generate")
128
+
129
+ with gr.Column():
130
+ paste_back = gr.Checkbox(True, label="Paste back original")
131
 
132
  with gr.Row():
133
  input_image = gr.ImageMask(
134
+ type="pil",
135
+ label="Input Image",
136
+ crop_size=(1024, 1024),
137
+ canvas_size=(1024, 1024),
138
+ layers=False,
139
+ height=512,
140
  )
141
 
142
+ result = gr.ImageSlider(
143
  interactive=False,
144
  label="Generated Image",
145
  )
146
 
147
  use_as_input_button = gr.Button("Use as Input Image", visible=False)
148
 
149
+ model_selection = gr.Dropdown(choices=list(MODELS.keys()), value="RealVisXL V5.0 Lightning", label="Model")
150
+
151
  def use_output_as_input(output_image):
152
  return gr.update(value=output_image[1])
153
 
154
+ use_as_input_button.click(fn=use_output_as_input, inputs=[result], outputs=[input_image])
 
 
155
 
156
  run_button.click(
157
  fn=clear_result,
 
163
  outputs=use_as_input_button,
164
  ).then(
165
  fn=fill_image,
166
+ inputs=[prompt, negative_prompt, input_image, model_selection, paste_back],
167
  outputs=result,
168
  ).then(
169
  fn=lambda: gr.update(visible=True),
 
181
  outputs=use_as_input_button,
182
  ).then(
183
  fn=fill_image,
184
+ inputs=[prompt, negative_prompt, input_image, model_selection, paste_back],
185
  outputs=result,
186
  ).then(
187
  fn=lambda: gr.update(visible=True),
requirements.txt CHANGED
@@ -1,10 +1,9 @@
1
  torch
2
  spaces
3
- gradio==4.42.0
4
- gradio-imageslider
5
- numpy==1.26.4
6
  transformers
7
  accelerate
8
  diffusers
9
- fastapi<0.113.0
10
  opencv-python
 
1
  torch
2
  spaces
3
+ gradio
4
+ numpy
 
5
  transformers
6
  accelerate
7
  diffusers
8
+ fastapi
9
  opencv-python