Spaces:
Runtime error
Runtime error
import gradio as gr | |
import re | |
import torch | |
from PIL import Image | |
import spaces | |
from diffusers import StableDiffusionXLImg2ImgPipeline | |
# | |
# Load the two SDXL pipelines (base + refiner) globally, so they only load once. | |
# | |
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0" | |
REFINER_MODEL_ID = "stabilityai/stable-diffusion-xl-refiner-1.0" | |
dtype = torch.float16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe_base = StableDiffusionXLImg2ImgPipeline.from_pretrained(BASE_MODEL_ID, torch_dtype=dtype).to(device) | |
pipe_refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(REFINER_MODEL_ID, torch_dtype=dtype).to(device) | |
# | |
# Helper functions | |
# | |
def sanitize_prompt(prompt: str) -> str: | |
# Simple sanitation: remove suspicious characters | |
allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]") | |
return allowed_chars.sub("", prompt) | |
def resize_to_multiple_of_64(image: Image.Image, max_dim: int = 1024): | |
""" | |
Resizes the image so that both width/height <= max_dim, | |
and each dimension is a multiple of 64. | |
(SDXL often uses 1024x1024. You can do multiples of 128 if you prefer.) | |
""" | |
w, h = image.size | |
# If image is bigger than max_dim in any dimension, scale it down | |
ratio = min(max_dim / w, max_dim / h, 1.0) | |
new_w = int(w * ratio) | |
new_h = int(h * ratio) | |
# Round down to multiples of 64 for best results in SDXL | |
new_w = new_w - (new_w % 64) | |
new_h = new_h - (new_h % 64) | |
new_w = max(new_w, 64) | |
new_h = max(new_h, 64) | |
return image.resize((new_w, new_h), Image.LANCZOS) | |
# Increase time if needed (SDXL can be slow) | |
def run_img2img_sdxl( | |
init_image, | |
prompt: str, | |
strength: float, | |
seed: int, | |
steps_base: int, | |
steps_refiner: int, | |
): | |
""" | |
Runs a two-step SDXL (base + refiner) pass for high-quality img2img. | |
""" | |
if init_image is None: | |
print("No input image provided.") | |
return None | |
# Clean up prompt | |
prompt = sanitize_prompt(prompt) | |
# Ensure reproducibility | |
generator = torch.Generator(device).manual_seed(seed) | |
# Possibly resize the input to a smaller multiple-of-64 dimension | |
# (1024x1024 or smaller is typical for SDXL) | |
init_image = resize_to_multiple_of_64(init_image, max_dim=1024) | |
# 1) Base pass | |
base_output = pipe_base( | |
prompt=prompt, | |
image=init_image, | |
strength=strength, | |
guidance_scale=8.0, # Adjust if you want more or less adherence to prompt | |
num_inference_steps=steps_base, | |
generator=generator | |
) | |
base_image = base_output.images[0] | |
# 2) Refiner pass | |
# Typically set strength=0.0 for the refiner to do final detailing, | |
# and possibly a slightly higher guidance scale. | |
refiner_output = pipe_refiner( | |
prompt=prompt, | |
image=base_image, | |
strength=0.0, # strictly refine | |
guidance_scale=9.0, | |
num_inference_steps=steps_refiner, | |
generator=generator | |
) | |
final_image = refiner_output.images[0] | |
return final_image | |
# | |
# Gradio UI | |
# | |
css = """ | |
#col-left { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
#col-right { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown("## SDXL Img2Img (Base + Refiner) — High Quality Demo") | |
with gr.Row(): | |
with gr.Column(): | |
init_image = gr.Image( | |
label="Init Image (Img2Img)", | |
type="pil", | |
image_mode="RGB", | |
height=512 | |
) | |
prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe what you want to see" | |
) | |
run_button = gr.Button("Generate") | |
with gr.Accordion("Advanced Options", open=False): | |
strength = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Strength (img2img)") | |
seed = gr.Number(value=42, label="Seed", precision=0) | |
steps_base = gr.Slider(1, 100, value=50, step=1, label="Steps (Base)") | |
steps_refiner = gr.Slider(1, 100, value=30, step=1, label="Steps (Refiner)") | |
with gr.Column(): | |
result_image = gr.Image(label="Result", height=512) | |
# Link the button to our function | |
run_button.click( | |
fn=run_img2img_sdxl, | |
inputs=[init_image, prompt, strength, seed, steps_base, steps_refiner], | |
outputs=[result_image] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |