aiqtech commited on
Commit
818d397
·
verified ·
1 Parent(s): 3bd9247

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -53
app.py CHANGED
@@ -1,13 +1,28 @@
1
  import gradio as gr
2
  import spaces
3
  import torch
 
 
4
  from diffusers import AutoencoderKL, TCDScheduler
5
  from diffusers.models.model_loading_utils import load_state_dict
6
  from gradio_imageslider import ImageSlider
7
  from huggingface_hub import hf_hub_download
8
 
9
- from controlnet_union import ControlNetModel_Union
10
- from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  MODELS = {
13
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
@@ -21,75 +36,190 @@ def translate_if_korean(text):
21
  print("Translation is disabled - using original text")
22
  return text
23
 
24
- config_file = hf_hub_download(
25
- "xinsir/controlnet-union-sdxl-1.0",
26
- filename="config_promax.json",
27
- )
28
-
29
- config = ControlNetModel_Union.load_config(config_file)
30
- controlnet_model = ControlNetModel_Union.from_config(config)
31
- model_file = hf_hub_download(
32
- "xinsir/controlnet-union-sdxl-1.0",
33
- filename="diffusion_pytorch_model_promax.safetensors",
34
- )
 
 
 
 
 
 
 
 
 
 
35
  state_dict = load_state_dict(model_file)
36
- model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
37
- controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  model.to(device="cuda", dtype=torch.float16)
40
 
41
- vae = AutoencoderKL.from_pretrained(
42
- "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
43
- ).to("cuda")
44
 
45
- pipe = StableDiffusionXLFillPipeline.from_pretrained(
46
- "SG161222/RealVisXL_V5.0_Lightning",
47
- torch_dtype=torch.float16,
48
- vae=vae,
49
- controlnet=model,
50
- variant="fp16",
51
- ).to("cuda")
52
 
53
- pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  @spaces.GPU
56
  def fill_image(prompt, image, model_selection):
57
- # Translate prompt if needed
 
 
 
58
  translated_prompt = translate_if_korean(prompt)
59
 
60
  try:
61
- (
62
- prompt_embeds,
63
- negative_prompt_embeds,
64
- pooled_prompt_embeds,
65
- negative_pooled_prompt_embeds,
66
- ) = pipe.encode_prompt(translated_prompt, "cuda", True)
67
-
68
  source = image["background"]
69
  mask = image["layers"][0]
70
 
 
71
  alpha_channel = mask.split()[3]
72
  binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
73
- cnet_image = source.copy()
74
- cnet_image.paste(0, (0, 0), binary_mask)
75
-
76
- for image in pipe(
77
- prompt_embeds=prompt_embeds,
78
- negative_prompt_embeds=negative_prompt_embeds,
79
- pooled_prompt_embeds=pooled_prompt_embeds,
80
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
81
- image=cnet_image,
82
- ):
83
- yield image, cnet_image
84
-
85
- image = image.convert("RGBA")
86
- cnet_image.paste(image, (0, 0), binary_mask)
87
-
88
- yield source, cnet_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  except Exception as e:
90
- print(f"Error during image generation: {e}")
91
  # Return the original image in case of error
92
- return source, source
 
 
 
 
 
 
 
93
 
94
  def clear_result():
95
  return gr.update(value=None)
 
1
  import gradio as gr
2
  import spaces
3
  import torch
4
+ import sys
5
+ import traceback
6
  from diffusers import AutoencoderKL, TCDScheduler
7
  from diffusers.models.model_loading_utils import load_state_dict
8
  from gradio_imageslider import ImageSlider
9
  from huggingface_hub import hf_hub_download
10
 
11
+ # Add better error handling
12
+ def print_error(error_message):
13
+ print("=" * 50)
14
+ print(f"ERROR: {error_message}")
15
+ print("-" * 50)
16
+ print(traceback.format_exc())
17
+ print("=" * 50)
18
+
19
+ try:
20
+ from controlnet_union import ControlNetModel_Union
21
+ from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
22
+ except Exception as e:
23
+ print_error(f"Failed to import required modules: {e}")
24
+ print("Ensure the controlnet_union and pipeline_fill_sd_xl modules are available")
25
+ sys.exit(1)
26
 
27
  MODELS = {
28
  "RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
 
36
  print("Translation is disabled - using original text")
37
  return text
38
 
39
+ # Wrap with try/except to catch any model loading errors
40
+ try:
41
+ config_file = hf_hub_download(
42
+ "xinsir/controlnet-union-sdxl-1.0",
43
+ filename="config_promax.json",
44
+ )
45
+
46
+ config = ControlNetModel_Union.load_config(config_file)
47
+ controlnet_model = ControlNetModel_Union.from_config(config)
48
+ model_file = hf_hub_download(
49
+ "xinsir/controlnet-union-sdxl-1.0",
50
+ filename="diffusion_pytorch_model_promax.safetensors",
51
+ )
52
+ except Exception as e:
53
+ print_error(f"Failed to load model configuration: {e}")
54
+ print("Attempting to use direct model loading as fallback...")
55
+ # We'll set these to None to indicate failure, and handle it below
56
+ config_file = None
57
+ config = None
58
+ controlnet_model = None
59
+ model_file = None
60
  state_dict = load_state_dict(model_file)
61
+
62
+ # Fix for the _load_pretrained_model method
63
+ # We need to handle the case where the method signature might have changed
64
+ try:
65
+ # Try the original approach first
66
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
67
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
68
+ )
69
+ except TypeError:
70
+ # If it fails due to missing 'loaded_keys' argument
71
+ # We'll try a more compatible approach
72
+ print("Using alternative model loading approach...")
73
+
74
+ # Try the updated method signature (includes loaded_keys)
75
+ # First get the keys from the state dict
76
+ loaded_keys = list(state_dict.keys())
77
+
78
+ try:
79
+ model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
80
+ controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0", loaded_keys
81
+ )
82
+ except Exception as e:
83
+ print(f"Advanced loading failed: {e}")
84
+ print("Falling back to direct loading...")
85
+
86
+ # As a last resort, try to load the model directly
87
+ try:
88
+ # Just load the model directly
89
+ controlnet_model.load_state_dict(state_dict)
90
+ model = controlnet_model
91
+ except Exception as load_err:
92
+ print(f"Direct loading failed: {load_err}")
93
+ # Final fallback: try to initialize from pretrained
94
+ model = ControlNetModel_Union.from_pretrained(
95
+ "xinsir/controlnet-union-sdxl-1.0",
96
+ torch_dtype=torch.float16
97
+ )
98
+
99
+ # Convert model to GPU with float16
100
  model.to(device="cuda", dtype=torch.float16)
101
 
102
+ # Define flag to track if we're in fallback mode (no controlnet)
103
+ using_fallback = False
 
104
 
105
+ try:
106
+ # Try to load the VAE
107
+ vae = AutoencoderKL.from_pretrained(
108
+ "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
109
+ ).to("cuda")
 
 
110
 
111
+ # Set up the pipeline with controlnet if available
112
+ if model is not None:
113
+ pipe = StableDiffusionXLFillPipeline.from_pretrained(
114
+ "SG161222/RealVisXL_V5.0_Lightning",
115
+ torch_dtype=torch.float16,
116
+ vae=vae,
117
+ controlnet=model,
118
+ variant="fp16",
119
+ ).to("cuda")
120
+ else:
121
+ # Fallback to regular StableDiffusionXLPipeline if controlnet failed
122
+ print("Loading without ControlNet as fallback")
123
+ using_fallback = True
124
+ from diffusers import StableDiffusionXLPipeline
125
+ pipe = StableDiffusionXLPipeline.from_pretrained(
126
+ "SG161222/RealVisXL_V5.0_Lightning",
127
+ torch_dtype=torch.float16,
128
+ vae=vae,
129
+ variant="fp16",
130
+ ).to("cuda")
131
+
132
+ # Set scheduler
133
+ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
134
+ except Exception as e:
135
+ print_error(f"Failed to initialize pipeline: {e}")
136
+ # If we get here, we couldn't load even the fallback pipeline
137
+ # We'll define a dummy fill_image function below that just returns the input image
138
 
139
  @spaces.GPU
140
  def fill_image(prompt, image, model_selection):
141
+ # Check if we're in fallback mode (no ControlNet)
142
+ global using_fallback
143
+
144
+ # Get the translated prompt
145
  translated_prompt = translate_if_korean(prompt)
146
 
147
  try:
148
+ # Extract the source image and mask
 
 
 
 
 
 
149
  source = image["background"]
150
  mask = image["layers"][0]
151
 
152
+ # Create a binary mask from the alpha channel
153
  alpha_channel = mask.split()[3]
154
  binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
155
+
156
+ # Handle based on whether we're using regular pipeline or ControlNet
157
+ if using_fallback:
158
+ # Using regular StableDiffusionXLPipeline without ControlNet
159
+ print("Using fallback pipeline without ControlNet")
160
+
161
+ # For fallback mode, we'll just use the regular pipeline
162
+ # and inpaint as best we can
163
+ try:
164
+ # Generate a new image based on the prompt
165
+ generated = pipe(
166
+ prompt=translated_prompt,
167
+ negative_prompt="low quality, worst quality, bad anatomy, bad composition, poor, low effort",
168
+ num_inference_steps=30,
169
+ guidance_scale=7.5,
170
+ ).images[0]
171
+
172
+ # Composite the generated image into the masked area
173
+ result = source.copy()
174
+ result.paste(generated, (0, 0), binary_mask)
175
+
176
+ # Return both the original and the result
177
+ yield source, result
178
+ except Exception as e:
179
+ print_error(f"Fallback generation failed: {e}")
180
+ # If even this fails, just return the source image
181
+ yield source, source
182
+ else:
183
+ # Normal operation with ControlNet
184
+ # Prepare the controlnet input image
185
+ cnet_image = source.copy()
186
+ cnet_image.paste(0, (0, 0), binary_mask)
187
+
188
+ # Encode the prompt
189
+ (
190
+ prompt_embeds,
191
+ negative_prompt_embeds,
192
+ pooled_prompt_embeds,
193
+ negative_pooled_prompt_embeds,
194
+ ) = pipe.encode_prompt(translated_prompt, "cuda", True)
195
+
196
+ # Generate the image
197
+ for image in pipe(
198
+ prompt_embeds=prompt_embeds,
199
+ negative_prompt_embeds=negative_prompt_embeds,
200
+ pooled_prompt_embeds=pooled_prompt_embeds,
201
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
202
+ image=cnet_image,
203
+ ):
204
+ yield image, cnet_image
205
+
206
+ # Composite the final result
207
+ image = image.convert("RGBA")
208
+ cnet_image.paste(image, (0, 0), binary_mask)
209
+
210
+ yield source, cnet_image
211
+
212
  except Exception as e:
213
+ print_error(f"Error during image generation: {e}")
214
  # Return the original image in case of error
215
+ if 'source' in locals():
216
+ yield source, source
217
+ else:
218
+ print("Critical error: Source image not available")
219
+ # Create a blank image if we can't get the source
220
+ from PIL import Image
221
+ blank = Image.new('RGB', (512, 512), color=(255, 255, 255))
222
+ yield blank, blank
223
 
224
  def clear_result():
225
  return gr.update(value=None)