File size: 4,654 Bytes
1bafe30
 
9ab45e8
 
1bafe30
9231de3
d6ceac3
d1b130d
 
9ab45e8
 
 
1bafe30
920a718
1bafe30
 
9ab45e8
 
01bf5a7
9ab45e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bafe30
920a718
9ab45e8
 
1bafe30
 
 
9ab45e8
 
 
 
1bafe30
 
 
 
 
 
 
 
 
9ab45e8
 
 
1bafe30
 
943caab
 
d1b130d
 
 
 
 
 
943caab
 
d1b130d
9ab45e8
 
 
 
 
 
 
1bafe30
9ab45e8
 
 
 
 
 
 
 
 
 
 
 
 
 
1bafe30
 
 
 
 
 
 
 
9ab45e8
 
 
 
1bafe30
 
9ab45e8
1bafe30
9ab45e8
1bafe30
9ab45e8
1bafe30
9ab45e8
1bafe30
 
 
9ab45e8
1bafe30
 
9ab45e8
9231de3
1bafe30
 
 
 
 
 
 
9ab45e8
1bafe30
 
 
 
d1b130d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import gradio as gr
import numpy as np
import spaces
import torch
import random
import os
import tempfile
from PIL import Image, ImageOps
import pillow_heif  # For HEIF/AVIF support

# Import the pipeline from diffusers
from diffusers import FluxKontextPipeline

# --- Constants ---
MAX_SEED = np.iinfo(np.int32).max

# --- Global pipeline variable ---
pipe = None

def load_model():
    """Load the model on CPU first, then move to GPU when needed"""
    global pipe
    if pipe is None:
        # Register HEIF opener with PIL for AVIF/HEIF support
        pillow_heif.register_heif_opener()
        
        # Get token from environment variable
        hf_token = os.getenv("HF_TOKEN")
        if hf_token:
            pipe = FluxKontextPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-Kontext-dev", 
                torch_dtype=torch.bfloat16,
                token=hf_token,
            )
        else:
            raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
    return pipe

# --- Core Inference Function for ChatInterface ---
@spaces.GPU(duration=120)  # Set duration based on expected inference time
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress(track_tqdm=True)):
    """
    Performs image generation or editing based on user input from the chat interface.
    """
    # Load and move model to GPU within the decorated function
    pipe = load_model()
    pipe = pipe.to("cuda")
    
    prompt = message["text"]
    files = message["files"]

    if not prompt and not files:
        raise gr.Error("Please provide a prompt and/or upload an image.")

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    generator = torch.Generator(device="cuda").manual_seed(int(seed))

    input_image = None
    if files:
        print(f"Received image: {files[0]}")
        try:
            # Try to open and convert the image
            input_image = Image.open(files[0])
            # Convert to RGB if needed (handles RGBA, P, etc.)
            if input_image.mode != "RGB":
                input_image = input_image.convert("RGB")
            # Auto-orient the image based on EXIF data
            input_image = ImageOps.exif_transpose(input_image)
        except Exception as e:
            raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).")
            
        image = pipe(
            image=input_image,
            prompt=prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            generator=generator,
        ).images[0]
    else:
        print(f"Received prompt for text-to-image: {prompt}")
        image = pipe(
            prompt=prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            generator=generator,
        ).images[0]
    
    # Move model back to CPU to free GPU memory
    pipe = pipe.to("cpu")
    torch.cuda.empty_cache()
    
    # Return the PIL Image as a Gradio Image component
    return gr.Image(value=image)

# --- UI Definition using gr.ChatInterface ---

seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5)
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)

# --- Examples without external URLs ---
# Remove examples temporarily to avoid format issues
examples = None

demo = gr.ChatInterface(
    fn=chat_fn,
    title="FLUX.1 Kontext [dev]",
    description="""<p style='text-align: center;'>
    A simple chat UI for the <b>FLUX.1 Kontext</b> model running on ZeroGPU.
    <br>
    To edit an image, upload it and type your instructions (e.g., "Add a hat").
    <br>
    To generate an image, just type a prompt (e.g., "A photo of an astronaut on a horse").
    <br>
    Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
    </p>""",
    multimodal=True,  # This is important for MultimodalTextbox to work
    textbox=gr.MultimodalTextbox(
        file_types=["image"],
        placeholder="Type a prompt and/or upload an image...",
        render=False
    ),
    additional_inputs=[
        seed_slider,
        randomize_checkbox,
        guidance_slider,
        steps_slider
    ],
    examples=examples,
    theme="soft"
)

if __name__ == "__main__":
    demo.launch()