Spaces:
Runtime error
Runtime error
Update flux1_img2img.py
Browse files- flux1_img2img.py +31 -15
flux1_img2img.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from diffusers import
|
3 |
from PIL import Image
|
4 |
import sys
|
5 |
import spaces
|
@@ -7,43 +7,59 @@ import spaces
|
|
7 |
@spaces.GPU
|
8 |
def process_image(
|
9 |
image,
|
10 |
-
mask_image,
|
11 |
prompt="a person",
|
12 |
-
model_id="
|
13 |
strength=0.75,
|
14 |
seed=0,
|
15 |
-
num_inference_steps=
|
16 |
):
|
|
|
17 |
if image is None:
|
18 |
-
print("
|
19 |
return None
|
20 |
|
21 |
-
#
|
22 |
-
#
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
)
|
|
|
|
|
26 |
pipe.to("cuda")
|
27 |
|
28 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
generator = torch.Generator("cuda").manual_seed(seed)
|
30 |
|
31 |
-
|
32 |
output = pipe(
|
33 |
prompt=prompt,
|
34 |
image=image,
|
|
|
35 |
strength=strength,
|
36 |
-
guidance_scale=
|
37 |
num_inference_steps=num_inference_steps,
|
38 |
-
|
39 |
)
|
40 |
|
|
|
41 |
return output.images[0]
|
42 |
|
43 |
if __name__ == "__main__":
|
44 |
-
# Usage: python
|
45 |
image = Image.open(sys.argv[1]).convert("RGB")
|
46 |
-
mask = Image.open(sys.argv[2]).convert("RGB") #
|
47 |
result = process_image(image, mask)
|
48 |
if result is not None:
|
49 |
result.save(sys.argv[3])
|
|
|
1 |
import torch
|
2 |
+
from diffusers import FluxImg2ImgPipeline
|
3 |
from PIL import Image
|
4 |
import sys
|
5 |
import spaces
|
|
|
7 |
@spaces.GPU
|
8 |
def process_image(
|
9 |
image,
|
10 |
+
mask_image,
|
11 |
prompt="a person",
|
12 |
+
model_id="black-forest-labs/FLUX.1-schnell",
|
13 |
strength=0.75,
|
14 |
seed=0,
|
15 |
+
num_inference_steps=4
|
16 |
):
|
17 |
+
print("start process image process_image")
|
18 |
if image is None:
|
19 |
+
print("empty input image returned")
|
20 |
return None
|
21 |
|
22 |
+
# 1) Use float16 (T4 doesn't have native bf16 support)
|
23 |
+
# 2) low_cpu_mem_usage=True for more efficient loading
|
24 |
+
# 3) Optionally enable xFormers
|
25 |
+
pipe = FluxImg2ImgPipeline.from_pretrained(
|
26 |
+
model_id,
|
27 |
+
torch_dtype=torch.float16,
|
28 |
+
revision="fp16", # sometimes needed if the repo has an FP16 branch
|
29 |
+
low_cpu_mem_usage=True
|
30 |
)
|
31 |
+
|
32 |
+
# Move to GPU
|
33 |
pipe.to("cuda")
|
34 |
|
35 |
+
# If you have xFormers installed (pip install xformers):
|
36 |
+
try:
|
37 |
+
pipe.enable_xformers_memory_efficient_attention()
|
38 |
+
print("Enabled xFormers memory efficient attention.")
|
39 |
+
except Exception as e:
|
40 |
+
print("xFormers not available:", e)
|
41 |
+
|
42 |
+
# Create a reproducible generator
|
43 |
generator = torch.Generator("cuda").manual_seed(seed)
|
44 |
|
45 |
+
print(f"Prompt: {prompt}")
|
46 |
output = pipe(
|
47 |
prompt=prompt,
|
48 |
image=image,
|
49 |
+
generator=generator,
|
50 |
strength=strength,
|
51 |
+
guidance_scale=0, # same as your original code
|
52 |
num_inference_steps=num_inference_steps,
|
53 |
+
max_sequence_length=256
|
54 |
)
|
55 |
|
56 |
+
# TODO: support mask if needed
|
57 |
return output.images[0]
|
58 |
|
59 |
if __name__ == "__main__":
|
60 |
+
# Usage: python img2img.py input_image.png input_mask.png output.png
|
61 |
image = Image.open(sys.argv[1]).convert("RGB")
|
62 |
+
mask = Image.open(sys.argv[2]).convert("RGB") # currently unused
|
63 |
result = process_image(image, mask)
|
64 |
if result is not None:
|
65 |
result.save(sys.argv[3])
|