Gemini899 commited on
Commit
1ac594b
·
verified ·
1 Parent(s): 9a1289b

Update flux1_img2img.py

Browse files
Files changed (1) hide show
  1. flux1_img2img.py +64 -80
flux1_img2img.py CHANGED
@@ -1,6 +1,4 @@
1
  import os
2
- import re
3
- import sys
4
  import torch
5
  import gradio as gr
6
  from PIL import Image
@@ -9,20 +7,20 @@ import spaces
9
  from diffusers import FluxImg2ImgPipeline
10
 
11
  ###############################################################################
12
- # GLOBAL PIPE VARIABLE (lazy-loaded so the Space can start without OOM)
13
  ###############################################################################
14
- pipe = None # We will load this when the user triggers an inference
15
 
16
  ###############################################################################
17
- # OPTIONAL: Resize Helper for Lower VRAM Usage
18
  ###############################################################################
19
- def resize_image(image, max_size=512):
20
  """
21
- Resizes the image so that the max dimension is 'max_size',
22
- which helps reduce GPU memory usage on a T4.
23
  """
24
  w, h = image.size
25
- ratio = min(max_size / w, max_size / h)
26
  if ratio < 1.0:
27
  new_w = int(w * ratio)
28
  new_h = int(h * ratio)
@@ -30,96 +28,93 @@ def resize_image(image, max_size=512):
30
  return image
31
 
32
  ###############################################################################
33
- # PIPELINE LOADER: Loads FLUX.1-schnell with memory-saving features
34
  ###############################################################################
35
  def load_flux_pipeline():
36
- """
37
- Lazily loads the FluxImg2ImgPipeline with float16,
38
- CPU offload, xFormers (if installed), etc.
39
- """
40
  global pipe
41
  if pipe is not None:
42
  return # Already loaded
43
-
44
- print("Loading FluxImg2ImgPipeline in float16 mode ...")
45
- # Use float16 for T4
 
46
  pipe_local = FluxImg2ImgPipeline.from_pretrained(
47
  "black-forest-labs/FLUX.1-schnell",
48
- torch_dtype=torch.float16,
49
  low_cpu_mem_usage=True
50
  )
51
 
52
- # Move to GPU
53
  pipe_local.to("cuda")
54
 
55
- # Try enabling xFormers for memory-efficient attention
56
  try:
57
  pipe_local.enable_xformers_memory_efficient_attention()
58
- print("Enabled xFormers memory efficient attention.")
59
  except Exception as e:
60
  print("Could not enable xFormers:", e)
61
 
62
- # Offload model chunks to CPU if VRAM is tight
63
  try:
64
  pipe_local.enable_model_cpu_offload()
65
- print("Enabled model CPU offload.")
66
  except Exception as e:
67
  print("Could not enable model_cpu_offload:", e)
68
 
69
- # VAE slicing can reduce peak memory usage
70
  pipe_local.enable_vae_slicing()
71
 
72
- pipe_local.max_sequence_length = 256 # same as your original code suggestion
73
- print("Flux pipeline loaded successfully.")
74
  pipe = pipe_local
 
75
 
76
  ###############################################################################
77
- # MAIN INFERENCE FUNCTION
78
  ###############################################################################
79
  @spaces.GPU
80
  def process_image(
81
- image,
82
- mask_image,
83
- prompt="a person",
84
- model_id="black-forest-labs/FLUX.1-schnell",
85
  strength=0.75,
86
  seed=0,
87
  num_inference_steps=4,
88
  progress=gr.Progress(track_tqdm=True)
89
  ):
90
  """
91
- Runs Flux Img2Img with memory-optimized loading.
92
- 'mask_image' is not currently used.
93
  """
 
94
 
95
- # Let Gradio show progress
96
- progress(0, desc="Starting Inference")
97
 
 
98
  if image is None:
99
  print("No input image provided.")
100
  return None
101
 
102
- # 1) Load pipeline if not loaded
103
- load_flux_pipeline()
104
-
105
- # 2) Resize input to reduce VRAM usage
106
- image = resize_image(image, max_size=512)
107
 
108
- # 3) Prepare generator for reproducible results
109
  generator = torch.Generator("cuda").manual_seed(seed)
110
-
111
- # 4) Actually run the pipeline
112
- print(f"Running Flux with prompt: '{prompt}' (strength={strength}, steps={num_inference_steps})")
 
113
  output = pipe(
114
  prompt=prompt,
115
  image=image,
116
  generator=generator,
117
  strength=strength,
118
- guidance_scale=0, # same as your original code
119
  num_inference_steps=num_inference_steps
120
  )
121
- progress(100, desc="Done")
122
 
 
123
  return output.images[0]
124
 
125
  ###############################################################################
@@ -137,57 +132,46 @@ css = """
137
  """
138
 
139
  with gr.Blocks(css=css) as demo:
140
- gr.Markdown("## Flux Img2Img - Memory-Optimized for T4")
 
141
 
142
  with gr.Row():
143
  with gr.Column():
144
- image_input = gr.Image(
 
145
  label="Input Image (Img2Img)",
146
  type="pil",
147
  image_mode="RGB",
148
  height=512
149
  )
150
- # The mask is not used in your original code, but we keep it in signature
151
- mask_input = gr.Image(
 
152
  label="Mask (unused)",
153
  type="pil",
154
  image_mode="RGB",
155
- height=512
156
- )
157
- prompt_input = gr.Textbox(label="Prompt", value="a person")
158
- strength_slider = gr.Slider(
159
- minimum=0.0,
160
- maximum=1.0,
161
- value=0.75,
162
- step=0.05,
163
- label="Strength"
164
- )
165
- seed_box = gr.Number(label="Seed", value=0)
166
- steps_box = gr.Slider(
167
- minimum=1,
168
- maximum=50,
169
- value=4,
170
- step=1,
171
- label="Inference Steps"
172
  )
173
- run_button = gr.Button("Run Flux Img2Img")
 
 
 
 
 
 
174
 
175
  with gr.Column():
176
- output_image = gr.Image(label="Output", height=512)
 
 
 
 
177
 
178
- # Connect button -> process_image
179
  run_button.click(
180
  fn=process_image,
181
- inputs=[
182
- image_input,
183
- mask_input,
184
- prompt_input,
185
- # model_id is default, so we won't pass it from UI
186
- strength_slider,
187
- seed_box,
188
- steps_box
189
- ],
190
- outputs=[output_image]
191
  )
192
 
193
  if __name__ == "__main__":
 
1
  import os
 
 
2
  import torch
3
  import gradio as gr
4
  from PIL import Image
 
7
  from diffusers import FluxImg2ImgPipeline
8
 
9
  ###############################################################################
10
+ # GLOBALS
11
  ###############################################################################
12
+ pipe = None # We'll load it lazily to avoid OOM during space startup
13
 
14
  ###############################################################################
15
+ # Helper: Resize the input image
16
  ###############################################################################
17
+ def resize_image(image: Image.Image, max_dim: int = 512) -> Image.Image:
18
  """
19
+ Resizes 'image' so that its largest dimension <= max_dim,
20
+ preserving aspect ratio. This helps reduce VRAM usage on T4.
21
  """
22
  w, h = image.size
23
+ ratio = min(max_dim / w, max_dim / h)
24
  if ratio < 1.0:
25
  new_w = int(w * ratio)
26
  new_h = int(h * ratio)
 
28
  return image
29
 
30
  ###############################################################################
31
+ # Lazy-load function for FLUX.1-schnell pipeline in float16
32
  ###############################################################################
33
  def load_flux_pipeline():
 
 
 
 
34
  global pipe
35
  if pipe is not None:
36
  return # Already loaded
37
+
38
+ print("Loading FLUX.1-schnell with float16 on T4...")
39
+
40
+ # 1) Load in float16 (NOT bfloat16)
41
  pipe_local = FluxImg2ImgPipeline.from_pretrained(
42
  "black-forest-labs/FLUX.1-schnell",
43
+ torch_dtype=torch.float16, # crucial for T4
44
  low_cpu_mem_usage=True
45
  )
46
 
47
+ # 2) Move to GPU
48
  pipe_local.to("cuda")
49
 
50
+ # 3) Memory Efficient Attention (xFormers)
51
  try:
52
  pipe_local.enable_xformers_memory_efficient_attention()
53
+ print("xFormers memory efficient attention enabled.")
54
  except Exception as e:
55
  print("Could not enable xFormers:", e)
56
 
57
+ # 4) CPU offload (keeps only active layers on GPU)
58
  try:
59
  pipe_local.enable_model_cpu_offload()
60
+ print("Model CPU offload enabled.")
61
  except Exception as e:
62
  print("Could not enable model_cpu_offload:", e)
63
 
64
+ # 5) VAE slicing reduces peak memory usage
65
  pipe_local.enable_vae_slicing()
66
 
67
+ # Save to global
68
+ pipe_local.max_sequence_length = 256
69
  pipe = pipe_local
70
+ print("Flux pipeline loaded successfully.")
71
 
72
  ###############################################################################
73
+ # Main inference function
74
  ###############################################################################
75
  @spaces.GPU
76
  def process_image(
77
+ image: Image.Image,
78
+ mask_image: Image.Image,
79
+ prompt="A person",
 
80
  strength=0.75,
81
  seed=0,
82
  num_inference_steps=4,
83
  progress=gr.Progress(track_tqdm=True)
84
  ):
85
  """
86
+ Loads the pipeline if needed, resizes the input image,
87
+ then runs Flux Img2Img with minimal VRAM usage strategies.
88
  """
89
+ progress(0, desc="Preparing model")
90
 
91
+ # 1) Ensure pipeline is loaded
92
+ load_flux_pipeline()
93
 
94
+ progress(20, desc="Resizing input image")
95
  if image is None:
96
  print("No input image provided.")
97
  return None
98
 
99
+ # 2) Resize the input image to reduce VRAM usage
100
+ image = resize_image(image, max_dim=512)
 
 
 
101
 
102
+ # 3) Set up generator for reproducible results
103
  generator = torch.Generator("cuda").manual_seed(seed)
104
+
105
+ # 4) Run the pipeline
106
+ progress(50, desc="Running Flux Inference")
107
+ print(f"Prompt: {prompt} | Strength: {strength} | Steps: {num_inference_steps}")
108
  output = pipe(
109
  prompt=prompt,
110
  image=image,
111
  generator=generator,
112
  strength=strength,
113
+ guidance_scale=0, # matches your original code
114
  num_inference_steps=num_inference_steps
115
  )
 
116
 
117
+ progress(100, desc="Done")
118
  return output.images[0]
119
 
120
  ###############################################################################
 
132
  """
133
 
134
  with gr.Blocks(css=css) as demo:
135
+ gr.Markdown("## FLUX Img2Img Memory-Optimized for T4\n"
136
+ "Using float16, CPU offload, xFormers, and image resizing to reduce VRAM usage.")
137
 
138
  with gr.Row():
139
  with gr.Column():
140
+ # The main input image
141
+ input_image = gr.Image(
142
  label="Input Image (Img2Img)",
143
  type="pil",
144
  image_mode="RGB",
145
  height=512
146
  )
147
+
148
+ # Mask is not used in your code, but we keep it to match your function signature
149
+ mask_image = gr.Image(
150
  label="Mask (unused)",
151
  type="pil",
152
  image_mode="RGB",
153
+ height=200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  )
155
+
156
+ prompt = gr.Textbox(label="Prompt", value="A person")
157
+ strength_slider = gr.Slider(0.0, 1.0, value=0.75, step=0.05, label="Strength")
158
+ seed_box = gr.Number(value=0, label="Seed", precision=0)
159
+ steps_box = gr.Slider(1, 50, value=4, step=1, label="Inference Steps")
160
+
161
+ run_button = gr.Button("Generate")
162
 
163
  with gr.Column():
164
+ result_image = gr.Image(
165
+ label="Output",
166
+ type="pil",
167
+ height=512
168
+ )
169
 
170
+ # Tie the button to our inference function
171
  run_button.click(
172
  fn=process_image,
173
+ inputs=[input_image, mask_image, prompt, strength_slider, seed_box, steps_box],
174
+ outputs=result_image
 
 
 
 
 
 
 
 
175
  )
176
 
177
  if __name__ == "__main__":