File size: 2,329 Bytes
9a1289b
73046cc
c20ce4a
fc388c8
c20ce4a
fc388c8
 
c20ce4a
 
 
 
 
9a1289b
1ac594b
c20ce4a
b1bb2b0
1ac594b
b1bb2b0
 
 
 
 
 
c20ce4a
9a1289b
c20ce4a
 
 
 
 
 
 
9a1289b
fc388c8
c20ce4a
 
fc388c8
c20ce4a
fc388c8
c20ce4a
 
1ac594b
c20ce4a
 
 
4499056
c20ce4a
fc388c8
c20ce4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a1289b
1ac594b
c20ce4a
 
9a1289b
fc388c8
c20ce4a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import torch
from diffusers import FluxImg2ImgPipeline
from PIL import Image
import sys
import spaces

# Set memory optimization flags
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

# Global pipe variable for lazy loading
pipe = None

def resize_image(image: Image.Image, max_dim: int = 512) -> Image.Image:
    """Resizes image to fit within max_dim while preserving aspect ratio"""
    w, h = image.size
    ratio = min(max_dim / w, max_dim / h)
    if ratio < 1.0:
        new_w = int(w * ratio)
        new_h = int(h * ratio)
        image = image.resize((new_w, new_h), Image.LANCZOS)
    return image

def get_pipe(model_id="black-forest-labs/FLUX.1-schnell"):
    global pipe
    if pipe is None:
        pipe = FluxImg2ImgPipeline.from_pretrained(
            model_id, 
            torch_dtype=torch.float16,
            variant="fp16"
        ).to("cuda")
    return pipe

@spaces.GPU
def process_image(image, mask_image, prompt="a person", model_id="black-forest-labs/FLUX.1-schnell", strength=0.75, seed=0, num_inference_steps=4):
    print("start process image process_image")
    if image is None:
        print("empty input image returned")
        return None
    
    # Resize image to reduce memory usage
    image = resize_image(image, max_dim=512)
    
    # Get model using lazy loading
    model = get_pipe(model_id)

    generators = []
    generator = torch.Generator("cuda").manual_seed(seed)
    generators.append(generator)
    
    # Use autocast for better memory efficiency
    with torch.cuda.amp.autocast(dtype=torch.float16):
        with torch.no_grad():
            # more parameter see https://huggingface.co/docs/diffusers/api/pipelines/flux#diffusers.FluxInpaintPipeline
            print(prompt)
            output = model(
                prompt=prompt, 
                image=image,
                generator=generator,
                strength=strength,
                guidance_scale=0,
                num_inference_steps=num_inference_steps,
                max_sequence_length=256
            )

    # TODO support mask
    return output.images[0]

if __name__ == "__main__":
    #args input-image input-mask output
    image = Image.open(sys.argv[1])
    mask = Image.open(sys.argv[2])
    output = process_image(image, mask)
    output.save(sys.argv[3])