import gradio as gr import numpy as np import spaces import torch import random import os import tempfile from PIL import Image, ImageOps import pillow_heif # For HEIF/AVIF support # Import the pipeline from diffusers from diffusers import FluxKontextPipeline # --- Constants --- MAX_SEED = np.iinfo(np.int32).max # --- Global pipeline variable --- pipe = None def load_model(): """Load the model on CPU first, then move to GPU when needed""" global pipe if pipe is None: # Register HEIF opener with PIL for AVIF/HEIF support pillow_heif.register_heif_opener() # Get token from environment variable hf_token = os.getenv("HF_TOKEN") if hf_token: pipe = FluxKontextPipeline.from_pretrained( "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16, token=hf_token, ) else: raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.") return pipe # --- Core Inference Function for ChatInterface --- @spaces.GPU(duration=120) # Set duration based on expected inference time def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress(track_tqdm=True)): """ Performs image generation or editing based on user input from the chat interface. """ # Load and move model to GPU within the decorated function pipe = load_model() pipe = pipe.to("cuda") prompt = message["text"] files = message["files"] if not prompt and not files: raise gr.Error("Please provide a prompt and/or upload an image.") if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device="cuda").manual_seed(int(seed)) input_image = None if files: print(f"Received image: {files[0]}") try: # Try to open and convert the image input_image = Image.open(files[0]) # Convert to RGB if needed (handles RGBA, P, etc.) if input_image.mode != "RGB": input_image = input_image.convert("RGB") # Auto-orient the image based on EXIF data input_image = ImageOps.exif_transpose(input_image) except Exception as e: raise gr.Error(f"Could not process the uploaded image: {str(e)}. Please try uploading a different image format (JPEG, PNG, WebP).") image = pipe( image=input_image, prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=steps, generator=generator, ).images[0] else: print(f"Received prompt for text-to-image: {prompt}") image = pipe( prompt=prompt, guidance_scale=guidance_scale, num_inference_steps=steps, generator=generator, ).images[0] # Move model back to CPU to free GPU memory pipe = pipe.to("cpu") torch.cuda.empty_cache() # Return the PIL Image as a Gradio Image component return gr.Image(value=image) # --- UI Definition using gr.ChatInterface --- seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42) randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False) guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5) steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1) # --- Examples without external URLs --- # Remove examples temporarily to avoid format issues examples = None demo = gr.ChatInterface( fn=chat_fn, title="FLUX.1 Kontext [dev]", description="""
A simple chat UI for the FLUX.1 Kontext model running on ZeroGPU.
To edit an image, upload it and type your instructions (e.g., "Add a hat").
To generate an image, just type a prompt (e.g., "A photo of an astronaut on a horse").
Find the model on Hugging Face.