Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,28 @@
|
|
1 |
import gradio as gr
|
2 |
import spaces
|
3 |
import torch
|
|
|
|
|
4 |
from diffusers import AutoencoderKL, TCDScheduler
|
5 |
from diffusers.models.model_loading_utils import load_state_dict
|
6 |
from gradio_imageslider import ImageSlider
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
MODELS = {
|
13 |
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
|
@@ -21,75 +36,190 @@ def translate_if_korean(text):
|
|
21 |
print("Translation is disabled - using original text")
|
22 |
return text
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
state_dict = load_state_dict(model_file)
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
model.to(device="cuda", dtype=torch.float16)
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
).to("cuda")
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
variant="fp16",
|
51 |
-
).to("cuda")
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
@spaces.GPU
|
56 |
def fill_image(prompt, image, model_selection):
|
57 |
-
#
|
|
|
|
|
|
|
58 |
translated_prompt = translate_if_korean(prompt)
|
59 |
|
60 |
try:
|
61 |
-
|
62 |
-
prompt_embeds,
|
63 |
-
negative_prompt_embeds,
|
64 |
-
pooled_prompt_embeds,
|
65 |
-
negative_pooled_prompt_embeds,
|
66 |
-
) = pipe.encode_prompt(translated_prompt, "cuda", True)
|
67 |
-
|
68 |
source = image["background"]
|
69 |
mask = image["layers"][0]
|
70 |
|
|
|
71 |
alpha_channel = mask.split()[3]
|
72 |
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
except Exception as e:
|
90 |
-
|
91 |
# Return the original image in case of error
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
def clear_result():
|
95 |
return gr.update(value=None)
|
|
|
1 |
import gradio as gr
|
2 |
import spaces
|
3 |
import torch
|
4 |
+
import sys
|
5 |
+
import traceback
|
6 |
from diffusers import AutoencoderKL, TCDScheduler
|
7 |
from diffusers.models.model_loading_utils import load_state_dict
|
8 |
from gradio_imageslider import ImageSlider
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
|
11 |
+
# Add better error handling
|
12 |
+
def print_error(error_message):
|
13 |
+
print("=" * 50)
|
14 |
+
print(f"ERROR: {error_message}")
|
15 |
+
print("-" * 50)
|
16 |
+
print(traceback.format_exc())
|
17 |
+
print("=" * 50)
|
18 |
+
|
19 |
+
try:
|
20 |
+
from controlnet_union import ControlNetModel_Union
|
21 |
+
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
|
22 |
+
except Exception as e:
|
23 |
+
print_error(f"Failed to import required modules: {e}")
|
24 |
+
print("Ensure the controlnet_union and pipeline_fill_sd_xl modules are available")
|
25 |
+
sys.exit(1)
|
26 |
|
27 |
MODELS = {
|
28 |
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
|
|
|
36 |
print("Translation is disabled - using original text")
|
37 |
return text
|
38 |
|
39 |
+
# Wrap with try/except to catch any model loading errors
|
40 |
+
try:
|
41 |
+
config_file = hf_hub_download(
|
42 |
+
"xinsir/controlnet-union-sdxl-1.0",
|
43 |
+
filename="config_promax.json",
|
44 |
+
)
|
45 |
+
|
46 |
+
config = ControlNetModel_Union.load_config(config_file)
|
47 |
+
controlnet_model = ControlNetModel_Union.from_config(config)
|
48 |
+
model_file = hf_hub_download(
|
49 |
+
"xinsir/controlnet-union-sdxl-1.0",
|
50 |
+
filename="diffusion_pytorch_model_promax.safetensors",
|
51 |
+
)
|
52 |
+
except Exception as e:
|
53 |
+
print_error(f"Failed to load model configuration: {e}")
|
54 |
+
print("Attempting to use direct model loading as fallback...")
|
55 |
+
# We'll set these to None to indicate failure, and handle it below
|
56 |
+
config_file = None
|
57 |
+
config = None
|
58 |
+
controlnet_model = None
|
59 |
+
model_file = None
|
60 |
state_dict = load_state_dict(model_file)
|
61 |
+
|
62 |
+
# Fix for the _load_pretrained_model method
|
63 |
+
# We need to handle the case where the method signature might have changed
|
64 |
+
try:
|
65 |
+
# Try the original approach first
|
66 |
+
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
|
67 |
+
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
|
68 |
+
)
|
69 |
+
except TypeError:
|
70 |
+
# If it fails due to missing 'loaded_keys' argument
|
71 |
+
# We'll try a more compatible approach
|
72 |
+
print("Using alternative model loading approach...")
|
73 |
+
|
74 |
+
# Try the updated method signature (includes loaded_keys)
|
75 |
+
# First get the keys from the state dict
|
76 |
+
loaded_keys = list(state_dict.keys())
|
77 |
+
|
78 |
+
try:
|
79 |
+
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
|
80 |
+
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0", loaded_keys
|
81 |
+
)
|
82 |
+
except Exception as e:
|
83 |
+
print(f"Advanced loading failed: {e}")
|
84 |
+
print("Falling back to direct loading...")
|
85 |
+
|
86 |
+
# As a last resort, try to load the model directly
|
87 |
+
try:
|
88 |
+
# Just load the model directly
|
89 |
+
controlnet_model.load_state_dict(state_dict)
|
90 |
+
model = controlnet_model
|
91 |
+
except Exception as load_err:
|
92 |
+
print(f"Direct loading failed: {load_err}")
|
93 |
+
# Final fallback: try to initialize from pretrained
|
94 |
+
model = ControlNetModel_Union.from_pretrained(
|
95 |
+
"xinsir/controlnet-union-sdxl-1.0",
|
96 |
+
torch_dtype=torch.float16
|
97 |
+
)
|
98 |
+
|
99 |
+
# Convert model to GPU with float16
|
100 |
model.to(device="cuda", dtype=torch.float16)
|
101 |
|
102 |
+
# Define flag to track if we're in fallback mode (no controlnet)
|
103 |
+
using_fallback = False
|
|
|
104 |
|
105 |
+
try:
|
106 |
+
# Try to load the VAE
|
107 |
+
vae = AutoencoderKL.from_pretrained(
|
108 |
+
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
|
109 |
+
).to("cuda")
|
|
|
|
|
110 |
|
111 |
+
# Set up the pipeline with controlnet if available
|
112 |
+
if model is not None:
|
113 |
+
pipe = StableDiffusionXLFillPipeline.from_pretrained(
|
114 |
+
"SG161222/RealVisXL_V5.0_Lightning",
|
115 |
+
torch_dtype=torch.float16,
|
116 |
+
vae=vae,
|
117 |
+
controlnet=model,
|
118 |
+
variant="fp16",
|
119 |
+
).to("cuda")
|
120 |
+
else:
|
121 |
+
# Fallback to regular StableDiffusionXLPipeline if controlnet failed
|
122 |
+
print("Loading without ControlNet as fallback")
|
123 |
+
using_fallback = True
|
124 |
+
from diffusers import StableDiffusionXLPipeline
|
125 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
126 |
+
"SG161222/RealVisXL_V5.0_Lightning",
|
127 |
+
torch_dtype=torch.float16,
|
128 |
+
vae=vae,
|
129 |
+
variant="fp16",
|
130 |
+
).to("cuda")
|
131 |
+
|
132 |
+
# Set scheduler
|
133 |
+
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
|
134 |
+
except Exception as e:
|
135 |
+
print_error(f"Failed to initialize pipeline: {e}")
|
136 |
+
# If we get here, we couldn't load even the fallback pipeline
|
137 |
+
# We'll define a dummy fill_image function below that just returns the input image
|
138 |
|
139 |
@spaces.GPU
|
140 |
def fill_image(prompt, image, model_selection):
|
141 |
+
# Check if we're in fallback mode (no ControlNet)
|
142 |
+
global using_fallback
|
143 |
+
|
144 |
+
# Get the translated prompt
|
145 |
translated_prompt = translate_if_korean(prompt)
|
146 |
|
147 |
try:
|
148 |
+
# Extract the source image and mask
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
source = image["background"]
|
150 |
mask = image["layers"][0]
|
151 |
|
152 |
+
# Create a binary mask from the alpha channel
|
153 |
alpha_channel = mask.split()[3]
|
154 |
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
|
155 |
+
|
156 |
+
# Handle based on whether we're using regular pipeline or ControlNet
|
157 |
+
if using_fallback:
|
158 |
+
# Using regular StableDiffusionXLPipeline without ControlNet
|
159 |
+
print("Using fallback pipeline without ControlNet")
|
160 |
+
|
161 |
+
# For fallback mode, we'll just use the regular pipeline
|
162 |
+
# and inpaint as best we can
|
163 |
+
try:
|
164 |
+
# Generate a new image based on the prompt
|
165 |
+
generated = pipe(
|
166 |
+
prompt=translated_prompt,
|
167 |
+
negative_prompt="low quality, worst quality, bad anatomy, bad composition, poor, low effort",
|
168 |
+
num_inference_steps=30,
|
169 |
+
guidance_scale=7.5,
|
170 |
+
).images[0]
|
171 |
+
|
172 |
+
# Composite the generated image into the masked area
|
173 |
+
result = source.copy()
|
174 |
+
result.paste(generated, (0, 0), binary_mask)
|
175 |
+
|
176 |
+
# Return both the original and the result
|
177 |
+
yield source, result
|
178 |
+
except Exception as e:
|
179 |
+
print_error(f"Fallback generation failed: {e}")
|
180 |
+
# If even this fails, just return the source image
|
181 |
+
yield source, source
|
182 |
+
else:
|
183 |
+
# Normal operation with ControlNet
|
184 |
+
# Prepare the controlnet input image
|
185 |
+
cnet_image = source.copy()
|
186 |
+
cnet_image.paste(0, (0, 0), binary_mask)
|
187 |
+
|
188 |
+
# Encode the prompt
|
189 |
+
(
|
190 |
+
prompt_embeds,
|
191 |
+
negative_prompt_embeds,
|
192 |
+
pooled_prompt_embeds,
|
193 |
+
negative_pooled_prompt_embeds,
|
194 |
+
) = pipe.encode_prompt(translated_prompt, "cuda", True)
|
195 |
+
|
196 |
+
# Generate the image
|
197 |
+
for image in pipe(
|
198 |
+
prompt_embeds=prompt_embeds,
|
199 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
200 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
201 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
202 |
+
image=cnet_image,
|
203 |
+
):
|
204 |
+
yield image, cnet_image
|
205 |
+
|
206 |
+
# Composite the final result
|
207 |
+
image = image.convert("RGBA")
|
208 |
+
cnet_image.paste(image, (0, 0), binary_mask)
|
209 |
+
|
210 |
+
yield source, cnet_image
|
211 |
+
|
212 |
except Exception as e:
|
213 |
+
print_error(f"Error during image generation: {e}")
|
214 |
# Return the original image in case of error
|
215 |
+
if 'source' in locals():
|
216 |
+
yield source, source
|
217 |
+
else:
|
218 |
+
print("Critical error: Source image not available")
|
219 |
+
# Create a blank image if we can't get the source
|
220 |
+
from PIL import Image
|
221 |
+
blank = Image.new('RGB', (512, 512), color=(255, 255, 255))
|
222 |
+
yield blank, blank
|
223 |
|
224 |
def clear_result():
|
225 |
return gr.update(value=None)
|