Gemini899 commited on
Commit
e05e986
·
verified ·
1 Parent(s): 8bd5dc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -201
app.py CHANGED
@@ -3,34 +3,42 @@ import gradio as gr
3
  import re
4
  from PIL import Image
5
  import os
 
 
 
 
6
  import numpy as np
7
  import torch
8
-
9
- # We'll lazy-load FluxImg2ImgPipeline
10
  from diffusers import FluxImg2ImgPipeline
11
 
12
- ###############################################################################
13
- # GLOBAL PIPELINE REFERENCE (start as None, so we only load on first inference)
14
- ###############################################################################
15
  pipe = None
16
 
17
- ###############################################################################
18
- # HELPER FUNCTIONS
19
- ###############################################################################
 
 
 
 
 
 
 
 
 
 
 
20
  def sanitize_prompt(prompt):
21
- # Allow only alphanumeric characters, spaces, and basic punctuation
22
- allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
23
- return allowed_chars.sub("", prompt)
 
24
 
25
- def convert_to_fit_size(original_width_and_height, maximum_size=512):
26
- """
27
- Resizes the image so its largest dimension = maximum_size (default 512).
28
- Lower resolution => less VRAM usage.
29
- """
30
  width, height = original_width_and_height
31
  if width <= maximum_size and height <= maximum_size:
32
  return width, height
33
-
34
  if width > height:
35
  scaling_factor = maximum_size / width
36
  else:
@@ -41,123 +49,72 @@ def convert_to_fit_size(original_width_and_height, maximum_size=512):
41
  return new_width, new_height
42
 
43
  def adjust_to_multiple_of_32(width: int, height: int):
44
- """
45
- Snap dimensions down to multiples of 32 (common for diffusion pipelines).
46
- """
47
  width = width - (width % 32)
48
  height = height - (height % 32)
49
- return max(width, 32), max(height, 32)
50
-
51
- def load_flux_pipeline():
52
- """
53
- Lazy-load the FluxImg2ImgPipeline in float16 with memory-saving features.
54
- """
55
- global pipe
56
- if pipe is not None:
57
- return pipe # Already loaded
58
-
59
- print("Loading FluxImg2ImgPipeline in float16...")
60
 
61
- # 1) Load the pipeline using float16
62
- local_pipe = FluxImg2ImgPipeline.from_pretrained(
63
- "black-forest-labs/FLUX.1-schnell",
64
- torch_dtype=torch.float16, # IMPORTANT: no bfloat16
65
- low_cpu_mem_usage=True
66
- )
67
- local_pipe.to("cuda")
68
-
69
- # 2) Enable memory-efficient attention (xFormers), if installed
70
- try:
71
- local_pipe.enable_xformers_memory_efficient_attention()
72
- print("xFormers memory efficient attention enabled.")
73
- except Exception as e:
74
- print("Could not enable xFormers:", e)
75
-
76
- # 3) CPU offload (keeps only active layers on GPU)
77
- try:
78
- local_pipe.enable_model_cpu_offload()
79
- print("CPU offload enabled.")
80
- except Exception as e:
81
- print("Could not enable model_cpu_offload:", e)
82
-
83
- # 4) VAE slicing reduces peak memory usage
84
- local_pipe.enable_vae_slicing()
85
-
86
- # 5) Optionally set max sequence length (like your original code)
87
- local_pipe.max_sequence_length = 256
88
-
89
- pipe = local_pipe
90
- print("Flux pipeline loaded successfully (float16).")
91
- return pipe
92
-
93
- ###############################################################################
94
- # MAIN INFERENCE FUNCTION
95
- ###############################################################################
96
  @spaces.GPU(duration=120)
97
- def process_images(
98
- image,
99
- prompt="a girl",
100
- strength=0.75,
101
- seed=0,
102
- inference_step=4,
103
- progress=gr.Progress(track_tqdm=True)
104
- ):
105
  progress(0, desc="Starting")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- # 1) Lazy-load the pipeline
108
- local_pipe = load_flux_pipeline()
109
-
110
- # 2) If no image provided
111
- if image is None:
112
- print("No input image provided.")
113
- return None
114
-
115
- # 3) Resize input to reduce VRAM usage
116
- fit_width, fit_height = convert_to_fit_size(image.size, maximum_size=512)
117
- width, height = adjust_to_multiple_of_32(fit_width, fit_height)
118
-
119
- # Use high-quality Lanczos resizing
120
- image = image.resize((width, height), Image.LANCZOS)
121
-
122
- # 4) Create generator for reproducibility
123
- generator = torch.Generator("cuda").manual_seed(seed)
124
-
125
- # 5) Actually run flux img2img
126
- progress(50, desc="Running flux img2img")
127
- print(f"Prompt: {prompt}, strength={strength}, steps={inference_step}")
128
-
129
- output = local_pipe(
130
- prompt=prompt,
131
- image=image,
132
- generator=generator,
133
- strength=strength,
134
- guidance_scale=0, # same as your original code
135
- num_inference_steps=inference_step,
136
- # We don't explicitly pass width & height. If you want, remove them or keep them:
137
- # width=width,
138
- # height=height,
139
- )
140
-
141
- pil_image = output.images[0]
142
-
143
- # 6) If the new image was forcibly changed shape by the model,
144
- # we can re-resize back to (fit_width, fit_height).
145
- # Usually not necessary with flux, but keep the logic if you want.
146
- new_w, new_h = pil_image.size
147
- if (new_w != fit_width) or (new_h != fit_height):
148
- pil_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
149
-
150
- progress(100, desc="Done")
151
- return pil_image
152
-
153
- ###############################################################################
154
- # GRADIO APP
155
- ###############################################################################
156
  def read_file(path: str) -> str:
157
  with open(path, 'r', encoding='utf-8') as f:
158
- return f.read()
 
159
 
160
- css = """
161
  #col-left {
162
  margin: 0 auto;
163
  max-width: 640px;
@@ -172,100 +129,60 @@ css = """
172
  justify-content: center;
173
  gap:10px
174
  }
 
175
  .image {
176
  width: 128px;
177
  height: 128px;
178
  object-fit: cover;
179
  }
 
180
  .text {
181
  font-size: 16px;
182
  }
183
  """
184
 
185
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
186
- # Optionally load some HTML from files
187
- try:
188
  gr.HTML(read_file("demo_header.html"))
189
- except:
190
- pass
191
- try:
192
  gr.HTML(read_file("demo_tools.html"))
193
- except:
194
- pass
195
-
196
  with gr.Row():
197
- with gr.Column():
198
- image = gr.Image(
199
- height=800,
200
- sources=['upload','clipboard'],
201
- image_mode='RGB',
202
- elem_id="image_upload",
203
- type="pil",
204
- label="Upload"
205
- )
206
- with gr.Row(elem_id="prompt-container", equal_height=False):
207
- prompt = gr.Textbox(
208
- label="Prompt",
209
- value="a woman",
210
- placeholder="Enter your prompt here",
211
- elem_id="prompt"
212
- )
213
- btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
214
-
215
- with gr.Accordion(label="Advanced Settings", open=False):
216
- with gr.Row(equal_height=True):
217
- strength = gr.Number(
218
- value=0.75,
219
- minimum=0,
220
- maximum=0.75,
221
- step=0.01,
222
- label="strength"
223
- )
224
- seed = gr.Number(
225
- value=100,
226
- minimum=0,
227
- step=1,
228
- label="seed"
229
- )
230
- inference_step = gr.Number(
231
- value=4,
232
- minimum=1,
233
- step=1,
234
- label="inference_step"
235
- )
236
- id_input = gr.Text(label="Name", visible=False)
237
-
238
- with gr.Column():
239
- image_out = gr.Image(
240
- height=800,
241
- sources=[],
242
- label="Output",
243
- elem_id="output-img",
244
- format="jpg"
245
- )
246
-
247
- # Provide example inputs if desired
248
  gr.Examples(
249
- examples=[
250
- ["examples/draw_input.jpg", None, "a woman, eyes closed, mouth opened"],
251
- ["examples/gimp_input.jpg", None, "a woman, hand on neck"]
252
- ],
253
- inputs=[image, image_out, prompt],
 
 
 
 
 
 
254
  )
255
-
256
- # Possibly load a footer HTML
257
- try:
258
- gr.HTML(read_file("demo_footer.html"))
259
- except:
260
- pass
261
-
262
- # Link UI events to process_images
263
  gr.on(
264
  triggers=[btn.click, prompt.submit],
265
- fn=process_images,
266
- inputs=[image, prompt, strength, seed, inference_step],
267
- outputs=[image_out]
268
  )
269
 
270
  if __name__ == "__main__":
271
- demo.launch(share=True, show_error=True)
 
3
  import re
4
  from PIL import Image
5
  import os
6
+
7
+ # Set memory optimization flags
8
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
9
+
10
  import numpy as np
11
  import torch
 
 
12
  from diffusers import FluxImg2ImgPipeline
13
 
14
+ # Global pipe variable for lazy loading
 
 
15
  pipe = None
16
 
17
+ # Use float16 instead of bfloat16 for T4 compatibility
18
+ dtype = torch.float16
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ def get_pipe():
22
+ global pipe
23
+ if pipe is None:
24
+ pipe = FluxImg2ImgPipeline.from_pretrained(
25
+ "black-forest-labs/FLUX.1-schnell",
26
+ torch_dtype=torch.float16,
27
+ variant="fp16"
28
+ ).to(device)
29
+ return pipe
30
+
31
  def sanitize_prompt(prompt):
32
+ # Allow only alphanumeric characters, spaces, and basic punctuation
33
+ allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
34
+ sanitized_prompt = allowed_chars.sub("", prompt)
35
+ return sanitized_prompt
36
 
37
+ def convert_to_fit_size(original_width_and_height, maximum_size = 1024):
 
 
 
 
38
  width, height = original_width_and_height
39
  if width <= maximum_size and height <= maximum_size:
40
  return width, height
41
+
42
  if width > height:
43
  scaling_factor = maximum_size / width
44
  else:
 
49
  return new_width, new_height
50
 
51
  def adjust_to_multiple_of_32(width: int, height: int):
 
 
 
52
  width = width - (width % 32)
53
  height = height - (height % 32)
54
+ return width, height
55
+
56
+ def resize_image(image: Image.Image, max_dim: int = 512) -> Image.Image:
57
+ """Resizes image to fit within max_dim while preserving aspect ratio"""
58
+ w, h = image.size
59
+ ratio = min(max_dim / w, max_dim / h)
60
+ if ratio < 1.0:
61
+ new_w = int(w * ratio)
62
+ new_h = int(h * ratio)
63
+ image = image.resize((new_w, new_h), Image.LANCZOS)
64
+ return image
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  @spaces.GPU(duration=120)
67
+ def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
68
  progress(0, desc="Starting")
69
+
70
+ # Get the model using lazy loading
71
+ model = get_pipe()
72
+
73
+ def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
74
+ if image is None:
75
+ print("empty input image returned")
76
+ return None
77
+
78
+ # Resize image to reduce memory usage
79
+ image = resize_image(image, max_dim=512)
80
+
81
+ generator = torch.Generator(device).manual_seed(seed)
82
+ fit_width, fit_height = convert_to_fit_size(image.size, maximum_size=512)
83
+ width, height = adjust_to_multiple_of_32(fit_width, fit_height)
84
+ image = image.resize((width, height), Image.LANCZOS)
85
+
86
+ # Use autocast for better memory efficiency
87
+ with torch.cuda.amp.autocast(dtype=torch.float16):
88
+ with torch.no_grad():
89
+ output = model(
90
+ prompt=prompt,
91
+ image=image,
92
+ generator=generator,
93
+ strength=strength,
94
+ width=width,
95
+ height=height,
96
+ guidance_scale=0,
97
+ num_inference_steps=num_inference_steps,
98
+ max_sequence_length=256
99
+ )
100
+
101
+ pil_image = output.images[0]
102
+ new_width, new_height = pil_image.size
103
+
104
+ if (new_width != fit_width) or (new_height != fit_height):
105
+ resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
106
+ return resized_image
107
+ return pil_image
108
+
109
+ output = process_img2img(image, prompt, strength, seed, inference_step)
110
+ return output
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def read_file(path: str) -> str:
113
  with open(path, 'r', encoding='utf-8') as f:
114
+ content = f.read()
115
+ return content
116
 
117
+ css="""
118
  #col-left {
119
  margin: 0 auto;
120
  max-width: 640px;
 
129
  justify-content: center;
130
  gap:10px
131
  }
132
+
133
  .image {
134
  width: 128px;
135
  height: 128px;
136
  object-fit: cover;
137
  }
138
+
139
  .text {
140
  font-size: 16px;
141
  }
142
  """
143
 
144
  with gr.Blocks(css=css, elem_id="demo-container") as demo:
145
+ with gr.Column():
 
146
  gr.HTML(read_file("demo_header.html"))
 
 
 
147
  gr.HTML(read_file("demo_tools.html"))
 
 
 
148
  with gr.Row():
149
+ with gr.Column():
150
+ image = gr.Image(height=800,sources=['upload','clipboard'],image_mode='RGB', elem_id="image_upload", type="pil", label="Upload")
151
+ with gr.Row(elem_id="prompt-container", equal_height=False):
152
+ with gr.Row():
153
+ prompt = gr.Textbox(label="Prompt",value="a women",placeholder="Your prompt (what you want in place of what is erased)", elem_id="prompt")
154
+
155
+ btn = gr.Button("Img2Img", elem_id="run_button",variant="primary")
156
+
157
+ with gr.Accordion(label="Advanced Settings", open=False):
158
+ with gr.Row( equal_height=True):
159
+ strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="strength")
160
+ seed = gr.Number(value=100, minimum=0, step=1, label="seed")
161
+ inference_step = gr.Number(value=4, minimum=1, step=4, label="inference_step")
162
+ id_input=gr.Text(label="Name", visible=False)
163
+
164
+ with gr.Column():
165
+ image_out = gr.Image(height=800,sources=[],label="Output", elem_id="output-img",format="jpg")
166
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  gr.Examples(
168
+ examples=[
169
+ ["examples/draw_input.jpg", "examples/draw_output.jpg","a women ,eyes closed,mouth opened"],
170
+ ["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg","a women ,eyes closed,mouth opened"],
171
+ ["examples/gimp_input.jpg", "examples/gimp_output.jpg","a women ,hand on neck"],
172
+ ["examples/inpaint_input.jpg", "examples/inpaint_output.jpg","a women ,hand on neck"]
173
+ ]
174
+ ,
175
+ inputs=[image,image_out,prompt],
176
+ )
177
+ gr.HTML(
178
+ gr.HTML(read_file("demo_footer.html"))
179
  )
 
 
 
 
 
 
 
 
180
  gr.on(
181
  triggers=[btn.click, prompt.submit],
182
+ fn = process_images,
183
+ inputs = [image,prompt,strength,seed,inference_step],
184
+ outputs = [image_out]
185
  )
186
 
187
  if __name__ == "__main__":
188
+ demo.launch(share=True, show_error=True)