img2img_test / flux1_img2img.py
Gemini899's picture
Update flux1_img2img.py
c20ce4a verified
raw
history blame
2.33 kB
import os
import torch
from diffusers import FluxImg2ImgPipeline
from PIL import Image
import sys
import spaces
# Set memory optimization flags
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Global pipe variable for lazy loading
pipe = None
def resize_image(image: Image.Image, max_dim: int = 512) -> Image.Image:
"""Resizes image to fit within max_dim while preserving aspect ratio"""
w, h = image.size
ratio = min(max_dim / w, max_dim / h)
if ratio < 1.0:
new_w = int(w * ratio)
new_h = int(h * ratio)
image = image.resize((new_w, new_h), Image.LANCZOS)
return image
def get_pipe(model_id="black-forest-labs/FLUX.1-schnell"):
global pipe
if pipe is None:
pipe = FluxImg2ImgPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
return pipe
@spaces.GPU
def process_image(image, mask_image, prompt="a person", model_id="black-forest-labs/FLUX.1-schnell", strength=0.75, seed=0, num_inference_steps=4):
print("start process image process_image")
if image is None:
print("empty input image returned")
return None
# Resize image to reduce memory usage
image = resize_image(image, max_dim=512)
# Get model using lazy loading
model = get_pipe(model_id)
generators = []
generator = torch.Generator("cuda").manual_seed(seed)
generators.append(generator)
# Use autocast for better memory efficiency
with torch.cuda.amp.autocast(dtype=torch.float16):
with torch.no_grad():
# more parameter see https://huggingface.co/docs/diffusers/api/pipelines/flux#diffusers.FluxInpaintPipeline
print(prompt)
output = model(
prompt=prompt,
image=image,
generator=generator,
strength=strength,
guidance_scale=0,
num_inference_steps=num_inference_steps,
max_sequence_length=256
)
# TODO support mask
return output.images[0]
if __name__ == "__main__":
#args input-image input-mask output
image = Image.open(sys.argv[1])
mask = Image.open(sys.argv[2])
output = process_image(image, mask)
output.save(sys.argv[3])