Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from diffusers import FluxImg2ImgPipeline | |
| from PIL import Image | |
| import sys | |
| import spaces | |
| # Set memory optimization flags | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
| # Global pipe variable for lazy loading | |
| pipe = None | |
| def resize_image(image: Image.Image, max_dim: int = 512) -> Image.Image: | |
| """Resizes image to fit within max_dim while preserving aspect ratio""" | |
| w, h = image.size | |
| ratio = min(max_dim / w, max_dim / h) | |
| if ratio < 1.0: | |
| new_w = int(w * ratio) | |
| new_h = int(h * ratio) | |
| image = image.resize((new_w, new_h), Image.LANCZOS) | |
| return image | |
| def get_pipe(model_id="black-forest-labs/FLUX.1-schnell"): | |
| global pipe | |
| if pipe is None: | |
| pipe = FluxImg2ImgPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| variant="fp16" | |
| ).to("cuda") | |
| return pipe | |
| def process_image(image, mask_image, prompt="a person", model_id="black-forest-labs/FLUX.1-schnell", strength=0.75, seed=0, num_inference_steps=4): | |
| print("start process image process_image") | |
| if image is None: | |
| print("empty input image returned") | |
| return None | |
| # Resize image to reduce memory usage | |
| image = resize_image(image, max_dim=512) | |
| # Get model using lazy loading | |
| model = get_pipe(model_id) | |
| generators = [] | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| generators.append(generator) | |
| # Use autocast for better memory efficiency | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| with torch.no_grad(): | |
| # more parameter see https://huggingface.co/docs/diffusers/api/pipelines/flux#diffusers.FluxInpaintPipeline | |
| print(prompt) | |
| output = model( | |
| prompt=prompt, | |
| image=image, | |
| generator=generator, | |
| strength=strength, | |
| guidance_scale=0, | |
| num_inference_steps=num_inference_steps, | |
| max_sequence_length=256 | |
| ) | |
| # TODO support mask | |
| return output.images[0] | |
| if __name__ == "__main__": | |
| #args input-image input-mask output | |
| image = Image.open(sys.argv[1]) | |
| mask = Image.open(sys.argv[2]) | |
| output = process_image(image, mask) | |
| output.save(sys.argv[3]) |