Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,750 Bytes
7890545 818d397 7890545 818d397 7890545 3bd9247 00d9bae 818d397 7890545 818d397 7890545 818d397 7890545 818d397 00d9bae 3bd9247 818d397 3bd9247 818d397 3bd9247 818d397 3bd9247 818d397 3bd9247 818d397 7890545 75849b3 bb9bc7b 7890545 e8b8f38 7890545 3bd9247 7890545 3bd9247 7890545 3bd9247 7890545 3bd9247 7890545 5e46a89 7890545 3bd9247 7890545 3bd9247 5e46a89 3bd9247 bb9bc7b 3bd9247 bb9bc7b 5e46a89 7890545 5e46a89 7890545 5e46a89 7890545 5e46a89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
import gradio as gr
import spaces
import torch
import sys
import traceback
from diffusers import AutoencoderKL, TCDScheduler
from diffusers.models.model_loading_utils import load_state_dict
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download
# Add better error handling
def print_error(error_message):
print("=" * 50)
print(f"ERROR: {error_message}")
print("-" * 50)
print(traceback.format_exc())
print("=" * 50)
try:
from controlnet_union import ControlNetModel_Union
from pipeline_fill_sd_xl import StableDiffusionXLFillPipeline
except Exception as e:
print_error(f"Failed to import required modules: {e}")
print("Ensure the controlnet_union and pipeline_fill_sd_xl modules are available")
sys.exit(1)
MODELS = {
"RealVisXL V5.0 Lightning": "SG161222/RealVisXL_V5.0_Lightning",
}
# Replace the problematic translation model with a simpler function
def translate_if_korean(text):
# Just log that Korean was detected but return the original text
if any('\u3131' <= char <= '\u318E' or '\uAC00' <= char <= '\uD7A3' for char in text):
print(f"Korean text detected: {text}")
print("Translation is disabled - using original text")
return text
# Wrap with try/except to catch any model loading errors
try:
config_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="config_promax.json",
)
config = ControlNetModel_Union.load_config(config_file)
controlnet_model = ControlNetModel_Union.from_config(config)
model_file = hf_hub_download(
"xinsir/controlnet-union-sdxl-1.0",
filename="diffusion_pytorch_model_promax.safetensors",
)
except Exception as e:
print_error(f"Failed to load model configuration: {e}")
print("Attempting to use direct model loading as fallback...")
# We'll set these to None to indicate failure, and handle it below
config_file = None
config = None
controlnet_model = None
model_file = None
state_dict = load_state_dict(model_file)
# Fix for the _load_pretrained_model method
# We need to handle the case where the method signature might have changed
try:
# Try the original approach first
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0"
)
except TypeError:
# If it fails due to missing 'loaded_keys' argument
# We'll try a more compatible approach
print("Using alternative model loading approach...")
# Try the updated method signature (includes loaded_keys)
# First get the keys from the state dict
loaded_keys = list(state_dict.keys())
try:
model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model(
controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0", loaded_keys
)
except Exception as e:
print(f"Advanced loading failed: {e}")
print("Falling back to direct loading...")
# As a last resort, try to load the model directly
try:
# Just load the model directly
controlnet_model.load_state_dict(state_dict)
model = controlnet_model
except Exception as load_err:
print(f"Direct loading failed: {load_err}")
# Final fallback: try to initialize from pretrained
model = ControlNetModel_Union.from_pretrained(
"xinsir/controlnet-union-sdxl-1.0",
torch_dtype=torch.float16
)
# Convert model to GPU with float16
model.to(device="cuda", dtype=torch.float16)
# Define flag to track if we're in fallback mode (no controlnet)
using_fallback = False
try:
# Try to load the VAE
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
).to("cuda")
# Set up the pipeline with controlnet if available
if model is not None:
pipe = StableDiffusionXLFillPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
controlnet=model,
variant="fp16",
).to("cuda")
else:
# Fallback to regular StableDiffusionXLPipeline if controlnet failed
print("Loading without ControlNet as fallback")
using_fallback = True
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"SG161222/RealVisXL_V5.0_Lightning",
torch_dtype=torch.float16,
vae=vae,
variant="fp16",
).to("cuda")
# Set scheduler
pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
except Exception as e:
print_error(f"Failed to initialize pipeline: {e}")
# If we get here, we couldn't load even the fallback pipeline
# We'll define a dummy fill_image function below that just returns the input image
@spaces.GPU
def fill_image(prompt, image, model_selection):
# Check if we're in fallback mode (no ControlNet)
global using_fallback
# Get the translated prompt
translated_prompt = translate_if_korean(prompt)
try:
# Extract the source image and mask
source = image["background"]
mask = image["layers"][0]
# Create a binary mask from the alpha channel
alpha_channel = mask.split()[3]
binary_mask = alpha_channel.point(lambda p: p > 0 and 255)
# Handle based on whether we're using regular pipeline or ControlNet
if using_fallback:
# Using regular StableDiffusionXLPipeline without ControlNet
print("Using fallback pipeline without ControlNet")
# For fallback mode, we'll just use the regular pipeline
# and inpaint as best we can
try:
# Generate a new image based on the prompt
generated = pipe(
prompt=translated_prompt,
negative_prompt="low quality, worst quality, bad anatomy, bad composition, poor, low effort",
num_inference_steps=30,
guidance_scale=7.5,
).images[0]
# Composite the generated image into the masked area
result = source.copy()
result.paste(generated, (0, 0), binary_mask)
# Return both the original and the result
yield source, result
except Exception as e:
print_error(f"Fallback generation failed: {e}")
# If even this fails, just return the source image
yield source, source
else:
# Normal operation with ControlNet
# Prepare the controlnet input image
cnet_image = source.copy()
cnet_image.paste(0, (0, 0), binary_mask)
# Encode the prompt
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = pipe.encode_prompt(translated_prompt, "cuda", True)
# Generate the image
for image in pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
image=cnet_image,
):
yield image, cnet_image
# Composite the final result
image = image.convert("RGBA")
cnet_image.paste(image, (0, 0), binary_mask)
yield source, cnet_image
except Exception as e:
print_error(f"Error during image generation: {e}")
# Return the original image in case of error
if 'source' in locals():
yield source, source
else:
print("Critical error: Source image not available")
# Create a blank image if we can't get the source
from PIL import Image
blank = Image.new('RGB', (512, 512), color=(255, 255, 255))
yield blank, blank
def clear_result():
return gr.update(value=None)
css = """
footer {
visibility: hidden;
}
.sample-image {
display: flex;
justify-content: center;
margin-top: 20px;
}
"""
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
info="Describe what to fill in the mask area (Korean or English)",
lines=3,
)
with gr.Column():
model_selection = gr.Dropdown(
choices=list(MODELS.keys()),
value="RealVisXL V5.0 Lightning",
label="Model",
)
run_button = gr.Button("Generate")
with gr.Row():
input_image = gr.ImageMask(
type="pil",
label="Input Image",
crop_size=(1024, 1024),
layers=False
)
result = ImageSlider(
interactive=False,
label="Generated Image",
)
use_as_input_button = gr.Button("Use as Input Image", visible=False)
# Add sample image
with gr.Row(elem_classes="sample-image"):
sample_image = gr.Image("sample.png", label="Sample Image", height=256, width=256)
def use_output_as_input(output_image):
return gr.update(value=output_image[1])
use_as_input_button.click(
fn=use_output_as_input,
inputs=[result],
outputs=[input_image]
)
run_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=use_as_input_button,
).then(
fn=fill_image,
inputs=[prompt, input_image, model_selection],
outputs=result,
).then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=use_as_input_button,
)
prompt.submit(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=lambda: gr.update(visible=False),
inputs=None,
outputs=use_as_input_button,
).then(
fn=fill_image,
inputs=[prompt, input_image, model_selection],
outputs=result,
).then(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=use_as_input_button,
)
demo.launch(share=False) |