File size: 4,500 Bytes
9c5d561
327dc47
83d1db7
327dc47
48b8b8d
83d1db7
 
9c5d561
83d1db7
 
 
 
 
48b8b8d
fe92109
 
48b8b8d
83d1db7
 
327dc47
83d1db7
 
 
 
 
fe92109
83d1db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe92109
83d1db7
 
 
 
 
 
48b8b8d
83d1db7
 
f53e43a
83d1db7
 
 
 
 
 
 
 
 
fe92109
83d1db7
fe92109
83d1db7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe92109
83d1db7
fe92109
83d1db7
48b8b8d
 
83d1db7
 
 
fe92109
327dc47
 
 
 
 
9c5d561
 
 
 
 
83d1db7
 
f30b01b
327dc47
fe92109
83d1db7
 
fe92109
83d1db7
 
fe92109
83d1db7
 
 
fe92109
83d1db7
 
 
 
 
 
 
 
 
48b8b8d
83d1db7
 
 
 
 
9c5d561
 
 
83d1db7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import re
import torch
from PIL import Image

import spaces
from diffusers import StableDiffusionXLImg2ImgPipeline

#
# Load the two SDXL pipelines (base + refiner) globally, so they only load once.
#
BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
REFINER_MODEL_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"

dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe_base = StableDiffusionXLImg2ImgPipeline.from_pretrained(BASE_MODEL_ID, torch_dtype=dtype).to(device)
pipe_refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(REFINER_MODEL_ID, torch_dtype=dtype).to(device)

#
# Helper functions
#
def sanitize_prompt(prompt: str) -> str:
    # Simple sanitation: remove suspicious characters
    allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
    return allowed_chars.sub("", prompt)

def resize_to_multiple_of_64(image: Image.Image, max_dim: int = 1024):
    """
    Resizes the image so that both width/height <= max_dim, 
    and each dimension is a multiple of 64. 
    (SDXL often uses 1024x1024. You can do multiples of 128 if you prefer.)
    """
    w, h = image.size

    # If image is bigger than max_dim in any dimension, scale it down
    ratio = min(max_dim / w, max_dim / h, 1.0)
    new_w = int(w * ratio)
    new_h = int(h * ratio)

    # Round down to multiples of 64 for best results in SDXL
    new_w = new_w - (new_w % 64)
    new_h = new_h - (new_h % 64)

    new_w = max(new_w, 64)
    new_h = max(new_h, 64)
    return image.resize((new_w, new_h), Image.LANCZOS)

@spaces.GPU(duration=240)  # Increase time if needed (SDXL can be slow)
def run_img2img_sdxl(
    init_image,
    prompt: str,
    strength: float,
    seed: int,
    steps_base: int,
    steps_refiner: int,
):
    """
    Runs a two-step SDXL (base + refiner) pass for high-quality img2img.
    """
    if init_image is None:
        print("No input image provided.")
        return None

    # Clean up prompt
    prompt = sanitize_prompt(prompt)

    # Ensure reproducibility
    generator = torch.Generator(device).manual_seed(seed)

    # Possibly resize the input to a smaller multiple-of-64 dimension 
    # (1024x1024 or smaller is typical for SDXL)
    init_image = resize_to_multiple_of_64(init_image, max_dim=1024)

    # 1) Base pass
    base_output = pipe_base(
        prompt=prompt,
        image=init_image,
        strength=strength,
        guidance_scale=8.0,      # Adjust if you want more or less adherence to prompt
        num_inference_steps=steps_base,
        generator=generator
    )
    base_image = base_output.images[0]

    # 2) Refiner pass
    # Typically set strength=0.0 for the refiner to do final detailing, 
    # and possibly a slightly higher guidance scale.
    refiner_output = pipe_refiner(
        prompt=prompt,
        image=base_image,
        strength=0.0,  # strictly refine
        guidance_scale=9.0,
        num_inference_steps=steps_refiner,
        generator=generator
    )
    final_image = refiner_output.images[0]

    return final_image


#
# Gradio UI
#
css = """
#col-left {
    margin: 0 auto;
    max-width: 640px;
}
#col-right {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("## SDXL Img2Img (Base + Refiner) — High Quality Demo")

    with gr.Row():
        with gr.Column():
            init_image = gr.Image(
                label="Init Image (Img2Img)",
                type="pil",
                image_mode="RGB",
                height=512
            )
            prompt = gr.Textbox(
                label="Prompt", 
                placeholder="Describe what you want to see"
            )
            run_button = gr.Button("Generate")
            with gr.Accordion("Advanced Options", open=False):
                strength = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Strength (img2img)")
                seed = gr.Number(value=42, label="Seed", precision=0)
                steps_base = gr.Slider(1, 100, value=50, step=1, label="Steps (Base)")
                steps_refiner = gr.Slider(1, 100, value=30, step=1, label="Steps (Refiner)")
        
        with gr.Column():
            result_image = gr.Image(label="Result", height=512)

    # Link the button to our function
    run_button.click(
        fn=run_img2img_sdxl,
        inputs=[init_image, prompt, strength, seed, steps_base, steps_refiner],
        outputs=[result_image]
    )

if __name__ == "__main__":
    demo.launch(share=True)