Gemini899 commited on
Commit
8bd5dc7
·
verified ·
1 Parent(s): 1ac594b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -88
app.py CHANGED
@@ -2,30 +2,35 @@ import spaces
2
  import gradio as gr
3
  import re
4
  from PIL import Image
5
-
6
  import os
7
  import numpy as np
8
  import torch
9
- from diffusers import FluxImg2ImgPipeline
10
-
11
- dtype = torch.bfloat16
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(device)
15
 
 
 
16
 
 
 
 
 
17
 
 
 
 
18
  def sanitize_prompt(prompt):
19
- # Allow only alphanumeric characters, spaces, and basic punctuation
20
- allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
21
- sanitized_prompt = allowed_chars.sub("", prompt)
22
- return sanitized_prompt
23
 
24
- def convert_to_fit_size(original_width_and_height, maximum_size = 1024):
25
- width, height =original_width_and_height
 
 
 
 
26
  if width <= maximum_size and height <= maximum_size:
27
- return width,height
28
-
29
  if width > height:
30
  scaling_factor = maximum_size / width
31
  else:
@@ -36,52 +41,123 @@ def convert_to_fit_size(original_width_and_height, maximum_size = 1024):
36
  return new_width, new_height
37
 
38
  def adjust_to_multiple_of_32(width: int, height: int):
 
 
 
39
  width = width - (width % 32)
40
  height = height - (height % 32)
41
- return width, height
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
43
 
 
 
 
 
 
 
44
 
 
 
 
 
 
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  @spaces.GPU(duration=120)
47
- def process_images(image,prompt="a girl",strength=0.75,seed=0,inference_step=4,progress=gr.Progress(track_tqdm=True)):
48
- #print("start process_images")
 
 
 
 
 
 
49
  progress(0, desc="Starting")
50
 
 
 
51
 
52
- def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
53
- if image is None:
54
- print("empty input image returned")
55
- return None
56
- generator = torch.Generator(device).manual_seed(seed)
57
- fit_width, fit_height = convert_to_fit_size(image.size)
58
- width, height = adjust_to_multiple_of_32(fit_width, fit_height)
59
- image = image.resize((width, height), Image.LANCZOS)
60
-
61
- output = pipe(prompt=prompt, image=image, generator=generator, strength=strength, width=width, height=height,
62
- guidance_scale=0, num_inference_steps=num_inference_steps, max_sequence_length=256)
63
-
64
- pil_image = output.images[0]
65
- new_width, new_height = pil_image.size
66
-
67
- if (new_width != fit_width) or (new_height != fit_height):
68
- resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
69
- return resized_image
70
- return pil_image
71
-
72
- output = process_img2img(image, prompt, strength, seed, inference_step)
73
- return output
74
-
75
-
76
 
77
- def read_file(path: str) -> str:
78
- with open(path, 'r', encoding='utf-8') as f:
79
- content = f.read()
 
 
 
80
 
81
- return content
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- css="""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  #col-left {
86
  margin: 0 auto;
87
  max-width: 640px;
@@ -96,67 +172,100 @@ css="""
96
  justify-content: center;
97
  gap:10px
98
  }
99
-
100
  .image {
101
  width: 128px;
102
  height: 128px;
103
  object-fit: cover;
104
  }
105
-
106
  .text {
107
  font-size: 16px;
108
  }
109
-
110
  """
111
 
112
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
113
- with gr.Column():
 
114
  gr.HTML(read_file("demo_header.html"))
 
 
 
115
  gr.HTML(read_file("demo_tools.html"))
 
 
 
116
  with gr.Row():
117
- with gr.Column():
118
- image = gr.Image(height=800,sources=['upload','clipboard'],image_mode='RGB', elem_id="image_upload", type="pil", label="Upload")
119
- with gr.Row(elem_id="prompt-container", equal_height=False):
120
- with gr.Row():
121
- prompt = gr.Textbox(label="Prompt",value="a women",placeholder="Your prompt (what you want in place of what is erased)", elem_id="prompt")
122
-
123
- btn = gr.Button("Img2Img", elem_id="run_button",variant="primary")
124
-
125
- with gr.Accordion(label="Advanced Settings", open=False):
126
- with gr.Row( equal_height=True):
127
- strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="strength")
128
- seed = gr.Number(value=100, minimum=0, step=1, label="seed")
129
- inference_step = gr.Number(value=4, minimum=1, step=4, label="inference_step")
130
- id_input=gr.Text(label="Name", visible=False)
131
-
132
- with gr.Column():
133
- image_out = gr.Image(height=800,sources=[],label="Output", elem_id="output-img",format="jpg")
134
-
135
-
136
-
137
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
 
139
  gr.Examples(
140
- examples=[
141
- ["examples/draw_input.jpg", "examples/draw_output.jpg","a women ,eyes closed,mouth opened"],
142
- ["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg","a women ,eyes closed,mouth opened"],
143
- ["examples/gimp_input.jpg", "examples/gimp_output.jpg","a women ,hand on neck"],
144
- ["examples/inpaint_input.jpg", "examples/inpaint_output.jpg","a women ,hand on neck"]
145
- ]
146
- ,
147
- inputs=[image,image_out,prompt],
148
- )
149
- gr.HTML(
150
- gr.HTML(read_file("demo_footer.html"))
151
  )
 
 
 
 
 
 
 
 
152
  gr.on(
153
  triggers=[btn.click, prompt.submit],
154
- fn = process_images,
155
- inputs = [image,prompt,strength,seed,inference_step],
156
- outputs = [image_out]
157
  )
158
 
159
  if __name__ == "__main__":
160
  demo.launch(share=True, show_error=True)
161
-
162
-
 
2
  import gradio as gr
3
  import re
4
  from PIL import Image
 
5
  import os
6
  import numpy as np
7
  import torch
 
 
 
 
 
 
8
 
9
+ # We'll lazy-load FluxImg2ImgPipeline
10
+ from diffusers import FluxImg2ImgPipeline
11
 
12
+ ###############################################################################
13
+ # GLOBAL PIPELINE REFERENCE (start as None, so we only load on first inference)
14
+ ###############################################################################
15
+ pipe = None
16
 
17
+ ###############################################################################
18
+ # HELPER FUNCTIONS
19
+ ###############################################################################
20
  def sanitize_prompt(prompt):
21
+ # Allow only alphanumeric characters, spaces, and basic punctuation
22
+ allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
23
+ return allowed_chars.sub("", prompt)
 
24
 
25
+ def convert_to_fit_size(original_width_and_height, maximum_size=512):
26
+ """
27
+ Resizes the image so its largest dimension = maximum_size (default 512).
28
+ Lower resolution => less VRAM usage.
29
+ """
30
+ width, height = original_width_and_height
31
  if width <= maximum_size and height <= maximum_size:
32
+ return width, height
33
+
34
  if width > height:
35
  scaling_factor = maximum_size / width
36
  else:
 
41
  return new_width, new_height
42
 
43
  def adjust_to_multiple_of_32(width: int, height: int):
44
+ """
45
+ Snap dimensions down to multiples of 32 (common for diffusion pipelines).
46
+ """
47
  width = width - (width % 32)
48
  height = height - (height % 32)
49
+ return max(width, 32), max(height, 32)
50
+
51
+ def load_flux_pipeline():
52
+ """
53
+ Lazy-load the FluxImg2ImgPipeline in float16 with memory-saving features.
54
+ """
55
+ global pipe
56
+ if pipe is not None:
57
+ return pipe # Already loaded
58
 
59
+ print("Loading FluxImg2ImgPipeline in float16...")
60
+
61
+ # 1) Load the pipeline using float16
62
+ local_pipe = FluxImg2ImgPipeline.from_pretrained(
63
+ "black-forest-labs/FLUX.1-schnell",
64
+ torch_dtype=torch.float16, # IMPORTANT: no bfloat16
65
+ low_cpu_mem_usage=True
66
+ )
67
+ local_pipe.to("cuda")
68
 
69
+ # 2) Enable memory-efficient attention (xFormers), if installed
70
+ try:
71
+ local_pipe.enable_xformers_memory_efficient_attention()
72
+ print("xFormers memory efficient attention enabled.")
73
+ except Exception as e:
74
+ print("Could not enable xFormers:", e)
75
 
76
+ # 3) CPU offload (keeps only active layers on GPU)
77
+ try:
78
+ local_pipe.enable_model_cpu_offload()
79
+ print("CPU offload enabled.")
80
+ except Exception as e:
81
+ print("Could not enable model_cpu_offload:", e)
82
 
83
+ # 4) VAE slicing reduces peak memory usage
84
+ local_pipe.enable_vae_slicing()
85
+
86
+ # 5) Optionally set max sequence length (like your original code)
87
+ local_pipe.max_sequence_length = 256
88
+
89
+ pipe = local_pipe
90
+ print("Flux pipeline loaded successfully (float16).")
91
+ return pipe
92
+
93
+ ###############################################################################
94
+ # MAIN INFERENCE FUNCTION
95
+ ###############################################################################
96
  @spaces.GPU(duration=120)
97
+ def process_images(
98
+ image,
99
+ prompt="a girl",
100
+ strength=0.75,
101
+ seed=0,
102
+ inference_step=4,
103
+ progress=gr.Progress(track_tqdm=True)
104
+ ):
105
  progress(0, desc="Starting")
106
 
107
+ # 1) Lazy-load the pipeline
108
+ local_pipe = load_flux_pipeline()
109
 
110
+ # 2) If no image provided
111
+ if image is None:
112
+ print("No input image provided.")
113
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ # 3) Resize input to reduce VRAM usage
116
+ fit_width, fit_height = convert_to_fit_size(image.size, maximum_size=512)
117
+ width, height = adjust_to_multiple_of_32(fit_width, fit_height)
118
+
119
+ # Use high-quality Lanczos resizing
120
+ image = image.resize((width, height), Image.LANCZOS)
121
 
122
+ # 4) Create generator for reproducibility
123
+ generator = torch.Generator("cuda").manual_seed(seed)
124
 
125
+ # 5) Actually run flux img2img
126
+ progress(50, desc="Running flux img2img")
127
+ print(f"Prompt: {prompt}, strength={strength}, steps={inference_step}")
128
+
129
+ output = local_pipe(
130
+ prompt=prompt,
131
+ image=image,
132
+ generator=generator,
133
+ strength=strength,
134
+ guidance_scale=0, # same as your original code
135
+ num_inference_steps=inference_step,
136
+ # We don't explicitly pass width & height. If you want, remove them or keep them:
137
+ # width=width,
138
+ # height=height,
139
+ )
140
 
141
+ pil_image = output.images[0]
142
+
143
+ # 6) If the new image was forcibly changed shape by the model,
144
+ # we can re-resize back to (fit_width, fit_height).
145
+ # Usually not necessary with flux, but keep the logic if you want.
146
+ new_w, new_h = pil_image.size
147
+ if (new_w != fit_width) or (new_h != fit_height):
148
+ pil_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
149
+
150
+ progress(100, desc="Done")
151
+ return pil_image
152
+
153
+ ###############################################################################
154
+ # GRADIO APP
155
+ ###############################################################################
156
+ def read_file(path: str) -> str:
157
+ with open(path, 'r', encoding='utf-8') as f:
158
+ return f.read()
159
+
160
+ css = """
161
  #col-left {
162
  margin: 0 auto;
163
  max-width: 640px;
 
172
  justify-content: center;
173
  gap:10px
174
  }
 
175
  .image {
176
  width: 128px;
177
  height: 128px;
178
  object-fit: cover;
179
  }
 
180
  .text {
181
  font-size: 16px;
182
  }
 
183
  """
184
 
185
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
186
+ # Optionally load some HTML from files
187
+ try:
188
  gr.HTML(read_file("demo_header.html"))
189
+ except:
190
+ pass
191
+ try:
192
  gr.HTML(read_file("demo_tools.html"))
193
+ except:
194
+ pass
195
+
196
  with gr.Row():
197
+ with gr.Column():
198
+ image = gr.Image(
199
+ height=800,
200
+ sources=['upload','clipboard'],
201
+ image_mode='RGB',
202
+ elem_id="image_upload",
203
+ type="pil",
204
+ label="Upload"
205
+ )
206
+ with gr.Row(elem_id="prompt-container", equal_height=False):
207
+ prompt = gr.Textbox(
208
+ label="Prompt",
209
+ value="a woman",
210
+ placeholder="Enter your prompt here",
211
+ elem_id="prompt"
212
+ )
213
+ btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
214
+
215
+ with gr.Accordion(label="Advanced Settings", open=False):
216
+ with gr.Row(equal_height=True):
217
+ strength = gr.Number(
218
+ value=0.75,
219
+ minimum=0,
220
+ maximum=0.75,
221
+ step=0.01,
222
+ label="strength"
223
+ )
224
+ seed = gr.Number(
225
+ value=100,
226
+ minimum=0,
227
+ step=1,
228
+ label="seed"
229
+ )
230
+ inference_step = gr.Number(
231
+ value=4,
232
+ minimum=1,
233
+ step=1,
234
+ label="inference_step"
235
+ )
236
+ id_input = gr.Text(label="Name", visible=False)
237
+
238
+ with gr.Column():
239
+ image_out = gr.Image(
240
+ height=800,
241
+ sources=[],
242
+ label="Output",
243
+ elem_id="output-img",
244
+ format="jpg"
245
+ )
246
 
247
+ # Provide example inputs if desired
248
  gr.Examples(
249
+ examples=[
250
+ ["examples/draw_input.jpg", None, "a woman, eyes closed, mouth opened"],
251
+ ["examples/gimp_input.jpg", None, "a woman, hand on neck"]
252
+ ],
253
+ inputs=[image, image_out, prompt],
 
 
 
 
 
 
254
  )
255
+
256
+ # Possibly load a footer HTML
257
+ try:
258
+ gr.HTML(read_file("demo_footer.html"))
259
+ except:
260
+ pass
261
+
262
+ # Link UI events to process_images
263
  gr.on(
264
  triggers=[btn.click, prompt.submit],
265
+ fn=process_images,
266
+ inputs=[image, prompt, strength, seed, inference_step],
267
+ outputs=[image_out]
268
  )
269
 
270
  if __name__ == "__main__":
271
  demo.launch(share=True, show_error=True)