Gemini899 commited on
Commit
9a1289b
·
verified ·
1 Parent(s): b1bb2b0

Update flux1_img2img.py

Browse files
Files changed (1) hide show
  1. flux1_img2img.py +158 -41
flux1_img2img.py CHANGED
@@ -1,18 +1,81 @@
 
 
 
1
  import torch
2
- from diffusers import FluxImg2ImgPipeline
3
  from PIL import Image
4
- import sys
5
  import spaces
 
6
 
7
- def resize_image(image, max_res=512):
 
 
 
 
 
 
 
 
 
 
 
 
8
  w, h = image.size
9
- ratio = min(max_res / w, max_res / h)
10
  if ratio < 1.0:
11
  new_w = int(w * ratio)
12
  new_h = int(h * ratio)
13
  image = image.resize((new_w, new_h), Image.LANCZOS)
14
  return image
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @spaces.GPU
17
  def process_image(
18
  image,
@@ -21,57 +84,111 @@ def process_image(
21
  model_id="black-forest-labs/FLUX.1-schnell",
22
  strength=0.75,
23
  seed=0,
24
- num_inference_steps=4
 
25
  ):
26
- print("start process image process_image")
 
 
 
 
 
 
 
27
  if image is None:
28
- print("empty input image returned")
29
  return None
30
 
31
- # Try resizing input to reduce VRAM usage
32
- image = resize_image(image, 512)
33
-
34
- # Load with float16
35
- pipe = FluxImg2ImgPipeline.from_pretrained(
36
- model_id,
37
- torch_dtype=torch.float16
38
- ).to("cuda")
39
 
40
- # If xFormers installed, enable memory efficient attention
41
- try:
42
- pipe.enable_xformers_memory_efficient_attention()
43
- print("Enabled xFormers memory efficient attention.")
44
- except Exception as e:
45
- print("Could not enable xFormers:", e)
46
-
47
- # Enable CPU offload to reduce VRAM usage
48
- # (Pick either model_cpu_offload or sequential_cpu_offload)
49
- try:
50
- pipe.enable_model_cpu_offload()
51
- except Exception as e:
52
- print("Could not enable model_cpu_offload:", e)
53
-
54
- # Optional: enable VAE slicing
55
- pipe.enable_vae_slicing()
56
 
 
57
  generator = torch.Generator("cuda").manual_seed(seed)
58
-
59
- print(f"Prompt: {prompt}")
 
60
  output = pipe(
61
  prompt=prompt,
62
  image=image,
63
  generator=generator,
64
  strength=strength,
65
- guidance_scale=0,
66
- num_inference_steps=num_inference_steps,
67
- max_sequence_length=256
68
  )
 
69
 
70
  return output.images[0]
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  if __name__ == "__main__":
73
- image = Image.open(sys.argv[1]).convert("RGB")
74
- mask = Image.open(sys.argv[2]).convert("RGB") # unused
75
- result = process_image(image, mask)
76
- if result:
77
- result.save(sys.argv[3])
 
1
+ import os
2
+ import re
3
+ import sys
4
  import torch
5
+ import gradio as gr
6
  from PIL import Image
7
+
8
  import spaces
9
+ from diffusers import FluxImg2ImgPipeline
10
 
11
+ ###############################################################################
12
+ # GLOBAL PIPE VARIABLE (lazy-loaded so the Space can start without OOM)
13
+ ###############################################################################
14
+ pipe = None # We will load this when the user triggers an inference
15
+
16
+ ###############################################################################
17
+ # OPTIONAL: Resize Helper for Lower VRAM Usage
18
+ ###############################################################################
19
+ def resize_image(image, max_size=512):
20
+ """
21
+ Resizes the image so that the max dimension is 'max_size',
22
+ which helps reduce GPU memory usage on a T4.
23
+ """
24
  w, h = image.size
25
+ ratio = min(max_size / w, max_size / h)
26
  if ratio < 1.0:
27
  new_w = int(w * ratio)
28
  new_h = int(h * ratio)
29
  image = image.resize((new_w, new_h), Image.LANCZOS)
30
  return image
31
 
32
+ ###############################################################################
33
+ # PIPELINE LOADER: Loads FLUX.1-schnell with memory-saving features
34
+ ###############################################################################
35
+ def load_flux_pipeline():
36
+ """
37
+ Lazily loads the FluxImg2ImgPipeline with float16,
38
+ CPU offload, xFormers (if installed), etc.
39
+ """
40
+ global pipe
41
+ if pipe is not None:
42
+ return # Already loaded
43
+
44
+ print("Loading FluxImg2ImgPipeline in float16 mode ...")
45
+ # Use float16 for T4
46
+ pipe_local = FluxImg2ImgPipeline.from_pretrained(
47
+ "black-forest-labs/FLUX.1-schnell",
48
+ torch_dtype=torch.float16,
49
+ low_cpu_mem_usage=True
50
+ )
51
+
52
+ # Move to GPU
53
+ pipe_local.to("cuda")
54
+
55
+ # Try enabling xFormers for memory-efficient attention
56
+ try:
57
+ pipe_local.enable_xformers_memory_efficient_attention()
58
+ print("Enabled xFormers memory efficient attention.")
59
+ except Exception as e:
60
+ print("Could not enable xFormers:", e)
61
+
62
+ # Offload model chunks to CPU if VRAM is tight
63
+ try:
64
+ pipe_local.enable_model_cpu_offload()
65
+ print("Enabled model CPU offload.")
66
+ except Exception as e:
67
+ print("Could not enable model_cpu_offload:", e)
68
+
69
+ # VAE slicing can reduce peak memory usage
70
+ pipe_local.enable_vae_slicing()
71
+
72
+ pipe_local.max_sequence_length = 256 # same as your original code suggestion
73
+ print("Flux pipeline loaded successfully.")
74
+ pipe = pipe_local
75
+
76
+ ###############################################################################
77
+ # MAIN INFERENCE FUNCTION
78
+ ###############################################################################
79
  @spaces.GPU
80
  def process_image(
81
  image,
 
84
  model_id="black-forest-labs/FLUX.1-schnell",
85
  strength=0.75,
86
  seed=0,
87
+ num_inference_steps=4,
88
+ progress=gr.Progress(track_tqdm=True)
89
  ):
90
+ """
91
+ Runs Flux Img2Img with memory-optimized loading.
92
+ 'mask_image' is not currently used.
93
+ """
94
+
95
+ # Let Gradio show progress
96
+ progress(0, desc="Starting Inference")
97
+
98
  if image is None:
99
+ print("No input image provided.")
100
  return None
101
 
102
+ # 1) Load pipeline if not loaded
103
+ load_flux_pipeline()
 
 
 
 
 
 
104
 
105
+ # 2) Resize input to reduce VRAM usage
106
+ image = resize_image(image, max_size=512)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # 3) Prepare generator for reproducible results
109
  generator = torch.Generator("cuda").manual_seed(seed)
110
+
111
+ # 4) Actually run the pipeline
112
+ print(f"Running Flux with prompt: '{prompt}' (strength={strength}, steps={num_inference_steps})")
113
  output = pipe(
114
  prompt=prompt,
115
  image=image,
116
  generator=generator,
117
  strength=strength,
118
+ guidance_scale=0, # same as your original code
119
+ num_inference_steps=num_inference_steps
 
120
  )
121
+ progress(100, desc="Done")
122
 
123
  return output.images[0]
124
 
125
+ ###############################################################################
126
+ # BUILD THE GRADIO UI
127
+ ###############################################################################
128
+ css = """
129
+ #col-left {
130
+ margin: 0 auto;
131
+ max-width: 640px;
132
+ }
133
+ #col-right {
134
+ margin: 0 auto;
135
+ max-width: 640px;
136
+ }
137
+ """
138
+
139
+ with gr.Blocks(css=css) as demo:
140
+ gr.Markdown("## Flux Img2Img - Memory-Optimized for T4")
141
+
142
+ with gr.Row():
143
+ with gr.Column():
144
+ image_input = gr.Image(
145
+ label="Input Image (Img2Img)",
146
+ type="pil",
147
+ image_mode="RGB",
148
+ height=512
149
+ )
150
+ # The mask is not used in your original code, but we keep it in signature
151
+ mask_input = gr.Image(
152
+ label="Mask (unused)",
153
+ type="pil",
154
+ image_mode="RGB",
155
+ height=512
156
+ )
157
+ prompt_input = gr.Textbox(label="Prompt", value="a person")
158
+ strength_slider = gr.Slider(
159
+ minimum=0.0,
160
+ maximum=1.0,
161
+ value=0.75,
162
+ step=0.05,
163
+ label="Strength"
164
+ )
165
+ seed_box = gr.Number(label="Seed", value=0)
166
+ steps_box = gr.Slider(
167
+ minimum=1,
168
+ maximum=50,
169
+ value=4,
170
+ step=1,
171
+ label="Inference Steps"
172
+ )
173
+ run_button = gr.Button("Run Flux Img2Img")
174
+
175
+ with gr.Column():
176
+ output_image = gr.Image(label="Output", height=512)
177
+
178
+ # Connect button -> process_image
179
+ run_button.click(
180
+ fn=process_image,
181
+ inputs=[
182
+ image_input,
183
+ mask_input,
184
+ prompt_input,
185
+ # model_id is default, so we won't pass it from UI
186
+ strength_slider,
187
+ seed_box,
188
+ steps_box
189
+ ],
190
+ outputs=[output_image]
191
+ )
192
+
193
  if __name__ == "__main__":
194
+ demo.launch(share=True)