Spaces:
Runtime error
Runtime error
File size: 4,500 Bytes
9c5d561 327dc47 83d1db7 327dc47 48b8b8d 83d1db7 9c5d561 83d1db7 48b8b8d fe92109 48b8b8d 83d1db7 327dc47 83d1db7 fe92109 83d1db7 fe92109 83d1db7 48b8b8d 83d1db7 f53e43a 83d1db7 fe92109 83d1db7 fe92109 83d1db7 fe92109 83d1db7 fe92109 83d1db7 48b8b8d 83d1db7 fe92109 327dc47 9c5d561 83d1db7 f30b01b 327dc47 fe92109 83d1db7 fe92109 83d1db7 fe92109 83d1db7 fe92109 83d1db7 48b8b8d 83d1db7 9c5d561 83d1db7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import re
import torch
from PIL import Image
import spaces
from diffusers import StableDiffusionXLImg2ImgPipeline
#
# Load the two SDXL pipelines (base + refiner) globally, so they only load once.
#
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
REFINER_MODEL_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"
dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe_base = StableDiffusionXLImg2ImgPipeline.from_pretrained(BASE_MODEL_ID, torch_dtype=dtype).to(device)
pipe_refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(REFINER_MODEL_ID, torch_dtype=dtype).to(device)
#
# Helper functions
#
def sanitize_prompt(prompt: str) -> str:
# Simple sanitation: remove suspicious characters
allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
return allowed_chars.sub("", prompt)
def resize_to_multiple_of_64(image: Image.Image, max_dim: int = 1024):
"""
Resizes the image so that both width/height <= max_dim,
and each dimension is a multiple of 64.
(SDXL often uses 1024x1024. You can do multiples of 128 if you prefer.)
"""
w, h = image.size
# If image is bigger than max_dim in any dimension, scale it down
ratio = min(max_dim / w, max_dim / h, 1.0)
new_w = int(w * ratio)
new_h = int(h * ratio)
# Round down to multiples of 64 for best results in SDXL
new_w = new_w - (new_w % 64)
new_h = new_h - (new_h % 64)
new_w = max(new_w, 64)
new_h = max(new_h, 64)
return image.resize((new_w, new_h), Image.LANCZOS)
@spaces.GPU(duration=240) # Increase time if needed (SDXL can be slow)
def run_img2img_sdxl(
init_image,
prompt: str,
strength: float,
seed: int,
steps_base: int,
steps_refiner: int,
):
"""
Runs a two-step SDXL (base + refiner) pass for high-quality img2img.
"""
if init_image is None:
print("No input image provided.")
return None
# Clean up prompt
prompt = sanitize_prompt(prompt)
# Ensure reproducibility
generator = torch.Generator(device).manual_seed(seed)
# Possibly resize the input to a smaller multiple-of-64 dimension
# (1024x1024 or smaller is typical for SDXL)
init_image = resize_to_multiple_of_64(init_image, max_dim=1024)
# 1) Base pass
base_output = pipe_base(
prompt=prompt,
image=init_image,
strength=strength,
guidance_scale=8.0, # Adjust if you want more or less adherence to prompt
num_inference_steps=steps_base,
generator=generator
)
base_image = base_output.images[0]
# 2) Refiner pass
# Typically set strength=0.0 for the refiner to do final detailing,
# and possibly a slightly higher guidance scale.
refiner_output = pipe_refiner(
prompt=prompt,
image=base_image,
strength=0.0, # strictly refine
guidance_scale=9.0,
num_inference_steps=steps_refiner,
generator=generator
)
final_image = refiner_output.images[0]
return final_image
#
# Gradio UI
#
css = """
#col-left {
margin: 0 auto;
max-width: 640px;
}
#col-right {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("## SDXL Img2Img (Base + Refiner) — High Quality Demo")
with gr.Row():
with gr.Column():
init_image = gr.Image(
label="Init Image (Img2Img)",
type="pil",
image_mode="RGB",
height=512
)
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe what you want to see"
)
run_button = gr.Button("Generate")
with gr.Accordion("Advanced Options", open=False):
strength = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Strength (img2img)")
seed = gr.Number(value=42, label="Seed", precision=0)
steps_base = gr.Slider(1, 100, value=50, step=1, label="Steps (Base)")
steps_refiner = gr.Slider(1, 100, value=30, step=1, label="Steps (Refiner)")
with gr.Column():
result_image = gr.Image(label="Result", height=512)
# Link the button to our function
run_button.click(
fn=run_img2img_sdxl,
inputs=[init_image, prompt, strength, seed, steps_base, steps_refiner],
outputs=[result_image]
)
if __name__ == "__main__":
demo.launch(share=True)
|