Gemini899 commited on
Commit
c20ce4a
·
verified ·
1 Parent(s): e05e986

Update flux1_img2img.py

Browse files
Files changed (1) hide show
  1. flux1_img2img.py +47 -151
flux1_img2img.py CHANGED
@@ -1,24 +1,18 @@
1
  import os
2
  import torch
3
- import gradio as gr
4
  from PIL import Image
5
-
6
  import spaces
7
- from diffusers import FluxImg2ImgPipeline
8
 
9
- ###############################################################################
10
- # GLOBALS
11
- ###############################################################################
12
- pipe = None # We'll load it lazily to avoid OOM during space startup
 
13
 
14
- ###############################################################################
15
- # Helper: Resize the input image
16
- ###############################################################################
17
  def resize_image(image: Image.Image, max_dim: int = 512) -> Image.Image:
18
- """
19
- Resizes 'image' so that its largest dimension <= max_dim,
20
- preserving aspect ratio. This helps reduce VRAM usage on T4.
21
- """
22
  w, h = image.size
23
  ratio = min(max_dim / w, max_dim / h)
24
  if ratio < 1.0:
@@ -27,152 +21,54 @@ def resize_image(image: Image.Image, max_dim: int = 512) -> Image.Image:
27
  image = image.resize((new_w, new_h), Image.LANCZOS)
28
  return image
29
 
30
- ###############################################################################
31
- # Lazy-load function for FLUX.1-schnell pipeline in float16
32
- ###############################################################################
33
- def load_flux_pipeline():
34
  global pipe
35
- if pipe is not None:
36
- return # Already loaded
37
-
38
- print("Loading FLUX.1-schnell with float16 on T4...")
39
-
40
- # 1) Load in float16 (NOT bfloat16)
41
- pipe_local = FluxImg2ImgPipeline.from_pretrained(
42
- "black-forest-labs/FLUX.1-schnell",
43
- torch_dtype=torch.float16, # crucial for T4
44
- low_cpu_mem_usage=True
45
- )
46
-
47
- # 2) Move to GPU
48
- pipe_local.to("cuda")
49
 
50
- # 3) Memory Efficient Attention (xFormers)
51
- try:
52
- pipe_local.enable_xformers_memory_efficient_attention()
53
- print("xFormers memory efficient attention enabled.")
54
- except Exception as e:
55
- print("Could not enable xFormers:", e)
56
-
57
- # 4) CPU offload (keeps only active layers on GPU)
58
- try:
59
- pipe_local.enable_model_cpu_offload()
60
- print("Model CPU offload enabled.")
61
- except Exception as e:
62
- print("Could not enable model_cpu_offload:", e)
63
-
64
- # 5) VAE slicing reduces peak memory usage
65
- pipe_local.enable_vae_slicing()
66
-
67
- # Save to global
68
- pipe_local.max_sequence_length = 256
69
- pipe = pipe_local
70
- print("Flux pipeline loaded successfully.")
71
-
72
- ###############################################################################
73
- # Main inference function
74
- ###############################################################################
75
  @spaces.GPU
76
- def process_image(
77
- image: Image.Image,
78
- mask_image: Image.Image,
79
- prompt="A person",
80
- strength=0.75,
81
- seed=0,
82
- num_inference_steps=4,
83
- progress=gr.Progress(track_tqdm=True)
84
- ):
85
- """
86
- Loads the pipeline if needed, resizes the input image,
87
- then runs Flux Img2Img with minimal VRAM usage strategies.
88
- """
89
- progress(0, desc="Preparing model")
90
-
91
- # 1) Ensure pipeline is loaded
92
- load_flux_pipeline()
93
-
94
- progress(20, desc="Resizing input image")
95
  if image is None:
96
- print("No input image provided.")
97
  return None
98
-
99
- # 2) Resize the input image to reduce VRAM usage
100
  image = resize_image(image, max_dim=512)
 
 
 
101
 
102
- # 3) Set up generator for reproducible results
103
  generator = torch.Generator("cuda").manual_seed(seed)
104
-
105
- # 4) Run the pipeline
106
- progress(50, desc="Running Flux Inference")
107
- print(f"Prompt: {prompt} | Strength: {strength} | Steps: {num_inference_steps}")
108
- output = pipe(
109
- prompt=prompt,
110
- image=image,
111
- generator=generator,
112
- strength=strength,
113
- guidance_scale=0, # matches your original code
114
- num_inference_steps=num_inference_steps
115
- )
116
-
117
- progress(100, desc="Done")
118
- return output.images[0]
119
-
120
- ###############################################################################
121
- # BUILD THE GRADIO UI
122
- ###############################################################################
123
- css = """
124
- #col-left {
125
- margin: 0 auto;
126
- max-width: 640px;
127
- }
128
- #col-right {
129
- margin: 0 auto;
130
- max-width: 640px;
131
- }
132
- """
133
-
134
- with gr.Blocks(css=css) as demo:
135
- gr.Markdown("## FLUX Img2Img — Memory-Optimized for T4\n"
136
- "Using float16, CPU offload, xFormers, and image resizing to reduce VRAM usage.")
137
-
138
- with gr.Row():
139
- with gr.Column():
140
- # The main input image
141
- input_image = gr.Image(
142
- label="Input Image (Img2Img)",
143
- type="pil",
144
- image_mode="RGB",
145
- height=512
146
  )
147
 
148
- # Mask is not used in your code, but we keep it to match your function signature
149
- mask_image = gr.Image(
150
- label="Mask (unused)",
151
- type="pil",
152
- image_mode="RGB",
153
- height=200
154
- )
155
-
156
- prompt = gr.Textbox(label="Prompt", value="A person")
157
- strength_slider = gr.Slider(0.0, 1.0, value=0.75, step=0.05, label="Strength")
158
- seed_box = gr.Number(value=0, label="Seed", precision=0)
159
- steps_box = gr.Slider(1, 50, value=4, step=1, label="Inference Steps")
160
-
161
- run_button = gr.Button("Generate")
162
-
163
- with gr.Column():
164
- result_image = gr.Image(
165
- label="Output",
166
- type="pil",
167
- height=512
168
- )
169
-
170
- # Tie the button to our inference function
171
- run_button.click(
172
- fn=process_image,
173
- inputs=[input_image, mask_image, prompt, strength_slider, seed_box, steps_box],
174
- outputs=result_image
175
- )
176
 
177
  if __name__ == "__main__":
178
- demo.launch(share=True)
 
 
 
 
 
1
  import os
2
  import torch
3
+ from diffusers import FluxImg2ImgPipeline
4
  from PIL import Image
5
+ import sys
6
  import spaces
 
7
 
8
+ # Set memory optimization flags
9
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
10
+
11
+ # Global pipe variable for lazy loading
12
+ pipe = None
13
 
 
 
 
14
  def resize_image(image: Image.Image, max_dim: int = 512) -> Image.Image:
15
+ """Resizes image to fit within max_dim while preserving aspect ratio"""
 
 
 
16
  w, h = image.size
17
  ratio = min(max_dim / w, max_dim / h)
18
  if ratio < 1.0:
 
21
  image = image.resize((new_w, new_h), Image.LANCZOS)
22
  return image
23
 
24
+ def get_pipe(model_id="black-forest-labs/FLUX.1-schnell"):
 
 
 
25
  global pipe
26
+ if pipe is None:
27
+ pipe = FluxImg2ImgPipeline.from_pretrained(
28
+ model_id,
29
+ torch_dtype=torch.float16,
30
+ variant="fp16"
31
+ ).to("cuda")
32
+ return pipe
 
 
 
 
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @spaces.GPU
35
+ def process_image(image, mask_image, prompt="a person", model_id="black-forest-labs/FLUX.1-schnell", strength=0.75, seed=0, num_inference_steps=4):
36
+ print("start process image process_image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  if image is None:
38
+ print("empty input image returned")
39
  return None
40
+
41
+ # Resize image to reduce memory usage
42
  image = resize_image(image, max_dim=512)
43
+
44
+ # Get model using lazy loading
45
+ model = get_pipe(model_id)
46
 
47
+ generators = []
48
  generator = torch.Generator("cuda").manual_seed(seed)
49
+ generators.append(generator)
50
+
51
+ # Use autocast for better memory efficiency
52
+ with torch.cuda.amp.autocast(dtype=torch.float16):
53
+ with torch.no_grad():
54
+ # more parameter see https://huggingface.co/docs/diffusers/api/pipelines/flux#diffusers.FluxInpaintPipeline
55
+ print(prompt)
56
+ output = model(
57
+ prompt=prompt,
58
+ image=image,
59
+ generator=generator,
60
+ strength=strength,
61
+ guidance_scale=0,
62
+ num_inference_steps=num_inference_steps,
63
+ max_sequence_length=256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  )
65
 
66
+ # TODO support mask
67
+ return output.images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
+ #args input-image input-mask output
71
+ image = Image.open(sys.argv[1])
72
+ mask = Image.open(sys.argv[2])
73
+ output = process_image(image, mask)
74
+ output.save(sys.argv[3])