Gemini899 commited on
Commit
83d1db7
·
verified ·
1 Parent(s): 193cf00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -191
app.py CHANGED
@@ -1,118 +1,108 @@
1
- import spaces
2
  import gradio as gr
3
  import re
 
4
  from PIL import Image
5
 
6
- import os
7
- import numpy as np
8
- import torch
9
- from diffusers import StableDiffusionImg2ImgPipeline
10
 
11
- # Choose a higher-quality or specialized model.
12
- model_id = "SG161222/Realistic_Vision_V2.0" # e.g. "runwayml/stable-diffusion-v1-5"
 
 
 
13
 
14
- # Typically use float16 to reduce memory usage if on GPU
15
  dtype = torch.float16
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
- pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=dtype).to(device)
 
19
 
20
- def sanitize_prompt(prompt):
21
- # Allow only alphanumeric characters, spaces, and basic punctuation
 
 
 
22
  allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
23
- sanitized_prompt = allowed_chars.sub("", prompt)
24
- return sanitized_prompt
25
-
26
- def convert_to_fit_size(original_width_and_height, maximum_size=2048):
27
- width, height = original_width_and_height
28
- # If within maximum size on both sides, no need to shrink
29
- if width <= maximum_size and height <= maximum_size:
30
- return width, height
31
-
32
- # Otherwise, scale down so the largest dimension = maximum_size
33
- if width > height:
34
- scaling_factor = maximum_size / width
35
- else:
36
- scaling_factor = maximum_size / height
37
-
38
- new_width = int(width * scaling_factor)
39
- new_height = int(height * scaling_factor)
40
- return new_width, new_height
41
-
42
- def adjust_to_multiple_of_32(width: int, height: int):
43
- # Stable Diffusion pipelines typically work best with dims multiple-of-32
44
- width = width - (width % 32)
45
- height = height - (height % 32)
46
- # Ensure not to drop to zero
47
- width = max(width, 32)
48
- height = max(height, 32)
49
- return width, height
50
-
51
- @spaces.GPU(duration=120)
52
- def process_images(
53
- image,
54
- prompt="a girl",
55
- strength=0.75,
56
- seed=0,
57
- inference_step=30,
58
- progress=gr.Progress(track_tqdm=True)
59
  ):
60
- # Provide feedback in the Gradio UI
61
- progress(0, desc="Starting")
 
 
 
 
62
 
63
- def process_img2img(img, prompt="a person", strength=0.75, seed=0, num_inference_steps=30):
64
- if img is None:
65
- print("empty input image returned")
66
- return None
67
-
68
- # Make results reproducible
69
- generator = torch.Generator(device).manual_seed(seed)
70
-
71
- # 1) Resize the input image to fit within a maximum dimension
72
- fit_width, fit_height = convert_to_fit_size(img.size)
73
- # 2) Adjust final dimensions to multiples of 32
74
- width, height = adjust_to_multiple_of_32(fit_width, fit_height)
75
-
76
- # Use high-quality Lanczos downsampling
77
- img = img.resize((width, height), Image.LANCZOS)
78
-
79
- # For better quality, let's set guidance_scale ~7 and steps ~30
80
- output = pipe(
81
- prompt=prompt,
82
- image=img,
83
- generator=generator,
84
- strength=strength,
85
- guidance_scale=7.0, # typical, can tune to 5-10
86
- num_inference_steps=num_inference_steps,
87
- )
88
-
89
- pil_image = output.images[0]
90
-
91
- # If we forcibly down/up scaled to multiple-of-32, let's restore to the "fit" size
92
- # (not strictly necessary, but can preserve original aspect ratio exactly)
93
- new_width, new_height = pil_image.size
94
- if (new_width != fit_width) or (new_height != fit_height):
95
- resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
96
- return resized_image
97
-
98
- return pil_image
99
 
100
- # Actually run the process
101
- output = process_img2img(
102
- img=image,
 
 
 
 
 
 
103
  prompt=prompt,
 
104
  strength=strength,
105
- seed=seed,
106
- num_inference_steps=inference_step
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  )
 
108
 
109
- return output
110
 
111
- def read_file(path: str) -> str:
112
- with open(path, 'r', encoding='utf-8') as f:
113
- content = f.read()
114
- return content
115
 
 
 
 
116
  css = """
117
  #col-left {
118
  margin: 0 auto;
@@ -122,111 +112,39 @@ css = """
122
  margin: 0 auto;
123
  max-width: 640px;
124
  }
125
- .grid-container {
126
- display: flex;
127
- align-items: center;
128
- justify-content: center;
129
- gap:10px
130
- }
131
-
132
- .image {
133
- width: 128px;
134
- height: 128px;
135
- object-fit: cover;
136
- }
137
-
138
- .text {
139
- font-size: 16px;
140
- }
141
  """
142
 
143
- with gr.Blocks(css=css, elem_id="demo-container") as demo:
144
- with gr.Column():
145
- # Replace "demo_header.html" and "demo_tools.html" with your actual files or remove if not needed
146
- try:
147
- gr.HTML(read_file("demo_header.html"))
148
- except:
149
- pass
150
- try:
151
- gr.HTML(read_file("demo_tools.html"))
152
- except:
153
- pass
154
 
155
  with gr.Row():
156
  with gr.Column():
157
- image = gr.Image(
158
- height=800,
159
- sources=['upload','clipboard'],
160
- image_mode='RGB',
161
- elem_id="image_upload",
162
  type="pil",
163
- label="Upload"
 
164
  )
165
- with gr.Row(elem_id="prompt-container", equal_height=False):
166
- with gr.Row():
167
- prompt = gr.Textbox(
168
- label="Prompt",
169
- value="a portrait of a beautiful woman",
170
- placeholder="Your prompt",
171
- elem_id="prompt"
172
- )
173
- btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
174
- with gr.Accordion(label="Advanced Settings", open=False):
175
- with gr.Row(equal_height=True):
176
- strength = gr.Slider(
177
- value=0.75,
178
- minimum=0.0,
179
- maximum=1.0,
180
- step=0.01,
181
- label="strength"
182
- )
183
- seed = gr.Number(
184
- value=100,
185
- minimum=0,
186
- step=1,
187
- label="seed"
188
- )
189
- inference_step = gr.Number(
190
- value=30,
191
- minimum=1,
192
- step=1,
193
- label="num_inference_steps"
194
- )
195
- id_input = gr.Text(label="Name", visible=False)
196
-
197
- with gr.Column():
198
- image_out = gr.Image(
199
- height=800,
200
- sources=[],
201
- label="Output",
202
- elem_id="output-img",
203
- format="jpg"
204
  )
 
 
 
 
 
 
 
 
 
205
 
206
- # Optional examples. Replace with your own images or remove.
207
- gr.Examples(
208
- examples=[
209
- ["examples/draw_input.jpg", None, "a woman, eyes closed, mouth opened"],
210
- ["examples/gimp_input.jpg", None, "a woman, hand on neck"]
211
- ],
212
- inputs=[image, image_out, prompt]
213
- )
214
-
215
- # Maybe a footer file or custom HTML. If not present, remove.
216
- try:
217
- gr.HTML(gr.HTML(read_file("demo_footer.html")))
218
- except:
219
- pass
220
-
221
- # When the "Img2Img" button is clicked or the prompt is submitted, run `process_images`.
222
- gr.on(
223
- triggers=[btn.click, prompt.submit],
224
- fn=process_images,
225
- inputs=[image, prompt, strength, seed, inference_step],
226
- outputs=[image_out]
227
  )
228
 
229
  if __name__ == "__main__":
230
- # Launch the Gradio app.
231
- # If you set share=True, you'll get a public link.
232
- demo.launch(share=True, show_error=True)
 
 
1
  import gradio as gr
2
  import re
3
+ import torch
4
  from PIL import Image
5
 
6
+ import spaces
7
+ from diffusers import StableDiffusionXLImg2ImgPipeline
 
 
8
 
9
+ #
10
+ # Load the two SDXL pipelines (base + refiner) globally, so they only load once.
11
+ #
12
+ BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"
13
+ REFINER_MODEL_ID = "stabilityai/stable-diffusion-xl-refiner-1.0"
14
 
 
15
  dtype = torch.float16
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
+ pipe_base = StableDiffusionXLImg2ImgPipeline.from_pretrained(BASE_MODEL_ID, torch_dtype=dtype).to(device)
19
+ pipe_refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(REFINER_MODEL_ID, torch_dtype=dtype).to(device)
20
 
21
+ #
22
+ # Helper functions
23
+ #
24
+ def sanitize_prompt(prompt: str) -> str:
25
+ # Simple sanitation: remove suspicious characters
26
  allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
27
+ return allowed_chars.sub("", prompt)
28
+
29
+ def resize_to_multiple_of_64(image: Image.Image, max_dim: int = 1024):
30
+ """
31
+ Resizes the image so that both width/height <= max_dim,
32
+ and each dimension is a multiple of 64.
33
+ (SDXL often uses 1024x1024. You can do multiples of 128 if you prefer.)
34
+ """
35
+ w, h = image.size
36
+
37
+ # If image is bigger than max_dim in any dimension, scale it down
38
+ ratio = min(max_dim / w, max_dim / h, 1.0)
39
+ new_w = int(w * ratio)
40
+ new_h = int(h * ratio)
41
+
42
+ # Round down to multiples of 64 for best results in SDXL
43
+ new_w = new_w - (new_w % 64)
44
+ new_h = new_h - (new_h % 64)
45
+
46
+ new_w = max(new_w, 64)
47
+ new_h = max(new_h, 64)
48
+ return image.resize((new_w, new_h), Image.LANCZOS)
49
+
50
+ @spaces.GPU(duration=240) # Increase time if needed (SDXL can be slow)
51
+ def run_img2img_sdxl(
52
+ init_image,
53
+ prompt: str,
54
+ strength: float,
55
+ seed: int,
56
+ steps_base: int,
57
+ steps_refiner: int,
 
 
 
 
 
58
  ):
59
+ """
60
+ Runs a two-step SDXL (base + refiner) pass for high-quality img2img.
61
+ """
62
+ if init_image is None:
63
+ print("No input image provided.")
64
+ return None
65
 
66
+ # Clean up prompt
67
+ prompt = sanitize_prompt(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
+ # Ensure reproducibility
70
+ generator = torch.Generator(device).manual_seed(seed)
71
+
72
+ # Possibly resize the input to a smaller multiple-of-64 dimension
73
+ # (1024x1024 or smaller is typical for SDXL)
74
+ init_image = resize_to_multiple_of_64(init_image, max_dim=1024)
75
+
76
+ # 1) Base pass
77
+ base_output = pipe_base(
78
  prompt=prompt,
79
+ image=init_image,
80
  strength=strength,
81
+ guidance_scale=8.0, # Adjust if you want more or less adherence to prompt
82
+ num_inference_steps=steps_base,
83
+ generator=generator
84
+ )
85
+ base_image = base_output.images[0]
86
+
87
+ # 2) Refiner pass
88
+ # Typically set strength=0.0 for the refiner to do final detailing,
89
+ # and possibly a slightly higher guidance scale.
90
+ refiner_output = pipe_refiner(
91
+ prompt=prompt,
92
+ image=base_image,
93
+ strength=0.0, # strictly refine
94
+ guidance_scale=9.0,
95
+ num_inference_steps=steps_refiner,
96
+ generator=generator
97
  )
98
+ final_image = refiner_output.images[0]
99
 
100
+ return final_image
101
 
 
 
 
 
102
 
103
+ #
104
+ # Gradio UI
105
+ #
106
  css = """
107
  #col-left {
108
  margin: 0 auto;
 
112
  margin: 0 auto;
113
  max-width: 640px;
114
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  """
116
 
117
+ with gr.Blocks(css=css) as demo:
118
+ gr.Markdown("## SDXL Img2Img (Base + Refiner) — High Quality Demo")
 
 
 
 
 
 
 
 
 
119
 
120
  with gr.Row():
121
  with gr.Column():
122
+ init_image = gr.Image(
123
+ label="Init Image (Img2Img)",
 
 
 
124
  type="pil",
125
+ image_mode="RGB",
126
+ height=512
127
  )
128
+ prompt = gr.Textbox(
129
+ label="Prompt",
130
+ placeholder="Describe what you want to see"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  )
132
+ run_button = gr.Button("Generate")
133
+ with gr.Accordion("Advanced Options", open=False):
134
+ strength = gr.Slider(0.0, 1.0, value=0.7, step=0.05, label="Strength (img2img)")
135
+ seed = gr.Number(value=42, label="Seed", precision=0)
136
+ steps_base = gr.Slider(1, 100, value=50, step=1, label="Steps (Base)")
137
+ steps_refiner = gr.Slider(1, 100, value=30, step=1, label="Steps (Refiner)")
138
+
139
+ with gr.Column():
140
+ result_image = gr.Image(label="Result", height=512)
141
 
142
+ # Link the button to our function
143
+ run_button.click(
144
+ fn=run_img2img_sdxl,
145
+ inputs=[init_image, prompt, strength, seed, steps_base, steps_refiner],
146
+ outputs=[result_image]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  )
148
 
149
  if __name__ == "__main__":
150
+ demo.launch(share=True)