File size: 4,968 Bytes
1bafe30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import numpy as np
import spaces
import torch
import random
from PIL import Image

# Import the pipeline from diffusers
from diffusers import FluxKontextPipeline

# --- Constants and Model Loading ---
MAX_SEED = np.iinfo(np.int32).max

# Load the pretrained model
# Note: This requires a CUDA-enabled GPU. Error handling is added for environments without it.
try:
    pipe = FluxKontextPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-Kontext-dev", 
        torch_dtype=torch.bfloat16
    ).to("cuda")
except Exception as e:
    pipe = None
    print(f"Warning: Could not load the model on CUDA. GPU is required. Error: {e}")

# --- Core Inference Function for ChatInterface ---

@spaces.GPU
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress(track_tqdm=True)):
    """
    Performs image generation or editing based on user input from the chat interface.

    Args:
        message (dict): A dictionary from gr.MultimodalTextbox, containing:
                        - "text" (str): The user's text prompt.
                        - "files" (list): A list of paths to uploaded files.
        chat_history (list): The history of the conversation (managed by ChatInterface).
        seed (int): The random seed for generation.
        randomize_seed (bool): If True, a random seed is used.
        guidance_scale (float): Controls adherence to the prompt.
        steps (int): Number of inference steps.
        progress (gr.Progress): Gradio progress tracker.

    Returns:
        PIL.Image.Image: The generated or edited image to be displayed in the chat.
    """
    if pipe is None:
        raise gr.Error("Model could not be loaded. A CUDA-enabled GPU is required to run this application.")

    prompt = message["text"]
    files = message["files"]

    # Input validation
    if not prompt and not files:
        raise gr.Error("Please provide a prompt and/or upload an image.")

    if randomize_seed:
        seed = random.randint(0, MAX_SEED)

    # Set up a PyTorch generator for reproducible results
    generator = torch.Generator(device="cuda").manual_seed(seed)

    input_image = None
    if files:
        # User has uploaded an image for editing (image-to-image)
        print(f"Received image: {files[0]}")
        input_image = Image.open(files[0]).convert("RGB")
        image = pipe(
            image=input_image,
            prompt=prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            generator=generator,
        ).images[0]
    else:
        # No image uploaded, perform text-to-image generation
        print(f"Received prompt for text-to-image: {prompt}")
        image = pipe(
            prompt=prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=steps,
            generator=generator,
        ).images[0]
        
    # To also inform the user of the seed, you could optionally return a tuple,
    # but for a clean image output, we just return the image.
    # For example: return (image, f"Seed: {seed}")
    return image

# --- UI Definition using gr.ChatInterface ---

# Define the components for "Advanced Settings" that will be passed to `additional_inputs`
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5)
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)

# Create the ChatInterface
demo = gr.ChatInterface(
    fn=chat_fn,
    title="FLUX.1 Kontext [dev]",
    description="""<p style='text-align: center;'>
    A simple chat UI for the <b>FLUX.1 Kontext</b> model.
    <br>
    To edit an image, upload it and type your instructions (e.g., "Add a hat").
    <br>
    To generate an image, just type a prompt (e.g., "A photo of an astronaut on a horse").
    <br>
    Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
    </p>""",
    # Use a multimodal textbox to allow both text and image uploads
    textbox=gr.MultimodalTextbox(
        file_types=["image"],
        placeholder="Type a prompt and/or upload an image...",
        render=False  # Important: Let ChatInterface render the textbox
    ),
    additional_inputs=[
        seed_slider,
        randomize_checkbox,
        guidance_slider,
        steps_slider
    ],
    examples=[
        {"text": "A cute robot reading a book", "files": []},
        {"text": "change his shirt to a hawaiian shirt", "files": ["https://gradio-builds.s3.amazonaws.com/demo-files/chewbacca.png"]},
        {"text": "make it a wooden house", "files": ["https://gradio-builds.s3.amazonaws.com/demo-files/house.png"]},
    ],
    theme="soft"
)

# Launch the application
if __name__ == "__main__":
    demo.launch()