akhaliq's picture
akhaliq HF Staff
Create app.py
1bafe30 verified
raw
history blame
4.97 kB
import gradio as gr
import numpy as np
import spaces
import torch
import random
from PIL import Image
# Import the pipeline from diffusers
from diffusers import FluxKontextPipeline
# --- Constants and Model Loading ---
MAX_SEED = np.iinfo(np.int32).max
# Load the pretrained model
# Note: This requires a CUDA-enabled GPU. Error handling is added for environments without it.
try:
pipe = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16
).to("cuda")
except Exception as e:
pipe = None
print(f"Warning: Could not load the model on CUDA. GPU is required. Error: {e}")
# --- Core Inference Function for ChatInterface ---
@spaces.GPU
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress(track_tqdm=True)):
"""
Performs image generation or editing based on user input from the chat interface.
Args:
message (dict): A dictionary from gr.MultimodalTextbox, containing:
- "text" (str): The user's text prompt.
- "files" (list): A list of paths to uploaded files.
chat_history (list): The history of the conversation (managed by ChatInterface).
seed (int): The random seed for generation.
randomize_seed (bool): If True, a random seed is used.
guidance_scale (float): Controls adherence to the prompt.
steps (int): Number of inference steps.
progress (gr.Progress): Gradio progress tracker.
Returns:
PIL.Image.Image: The generated or edited image to be displayed in the chat.
"""
if pipe is None:
raise gr.Error("Model could not be loaded. A CUDA-enabled GPU is required to run this application.")
prompt = message["text"]
files = message["files"]
# Input validation
if not prompt and not files:
raise gr.Error("Please provide a prompt and/or upload an image.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Set up a PyTorch generator for reproducible results
generator = torch.Generator(device="cuda").manual_seed(seed)
input_image = None
if files:
# User has uploaded an image for editing (image-to-image)
print(f"Received image: {files[0]}")
input_image = Image.open(files[0]).convert("RGB")
image = pipe(
image=input_image,
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=generator,
).images[0]
else:
# No image uploaded, perform text-to-image generation
print(f"Received prompt for text-to-image: {prompt}")
image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=generator,
).images[0]
# To also inform the user of the seed, you could optionally return a tuple,
# but for a clean image output, we just return the image.
# For example: return (image, f"Seed: {seed}")
return image
# --- UI Definition using gr.ChatInterface ---
# Define the components for "Advanced Settings" that will be passed to `additional_inputs`
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5)
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)
# Create the ChatInterface
demo = gr.ChatInterface(
fn=chat_fn,
title="FLUX.1 Kontext [dev]",
description="""<p style='text-align: center;'>
A simple chat UI for the <b>FLUX.1 Kontext</b> model.
<br>
To edit an image, upload it and type your instructions (e.g., "Add a hat").
<br>
To generate an image, just type a prompt (e.g., "A photo of an astronaut on a horse").
<br>
Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
</p>""",
# Use a multimodal textbox to allow both text and image uploads
textbox=gr.MultimodalTextbox(
file_types=["image"],
placeholder="Type a prompt and/or upload an image...",
render=False # Important: Let ChatInterface render the textbox
),
additional_inputs=[
seed_slider,
randomize_checkbox,
guidance_slider,
steps_slider
],
examples=[
{"text": "A cute robot reading a book", "files": []},
{"text": "change his shirt to a hawaiian shirt", "files": ["https://gradio-builds.s3.amazonaws.com/demo-files/chewbacca.png"]},
{"text": "make it a wooden house", "files": ["https://gradio-builds.s3.amazonaws.com/demo-files/house.png"]},
],
theme="soft"
)
# Launch the application
if __name__ == "__main__":
demo.launch()