Spaces:
Paused
Paused
File size: 3,957 Bytes
1bafe30 9231de3 1bafe30 920a718 1bafe30 920a718 1bafe30 920a718 642f1c7 920a718 1bafe30 920a718 1bafe30 920a718 1bafe30 9231de3 1bafe30 920a718 bffe891 c4c2a88 1bafe30 920a718 642f1c7 9231de3 1bafe30 920a718 1bafe30 c4c2a88 1bafe30 9231de3 1bafe30 10a36a8 1bafe30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import numpy as np
import spaces
import torch
import random
import os
from PIL import Image
# Import the pipeline from diffusers
from diffusers import FluxKontextPipeline
# --- Constants ---
MAX_SEED = np.iinfo(np.int32).max
# --- Global pipeline variable ---
pipe = None
def load_model():
"""Load the model on CPU first, then move to GPU when needed"""
global pipe
if pipe is None:
# Get token from environment variable
hf_token = os.getenv("HF_TOKEN")
if hf_token:
pipe = FluxKontextPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Kontext-dev",
torch_dtype=torch.bfloat16,
token=hf_token,
)
else:
raise gr.Error("HF_TOKEN environment variable not found. Please add your Hugging Face token to the Space settings.")
return pipe
# --- Core Inference Function for ChatInterface ---
@spaces.GPU(duration=120) # Set duration based on expected inference time
def chat_fn(message, chat_history, seed, randomize_seed, guidance_scale, steps, progress=gr.Progress(track_tqdm=True)):
"""
Performs image generation or editing based on user input from the chat interface.
"""
# Load and move model to GPU within the decorated function
pipe = load_model()
pipe = pipe.to("cuda")
prompt = message["text"]
files = message["files"]
if not prompt and not files:
raise gr.Error("Please provide a prompt and/or upload an image.")
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(int(seed))
input_image = None
if files:
print(f"Received image: {files[0]}")
input_image = Image.open(files[0]).convert("RGB")
image = pipe(
image=input_image,
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=generator,
).images[0]
else:
print(f"Received prompt for text-to-image: {prompt}")
image = pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=generator,
).images[0]
# Move model back to CPU to free GPU memory
pipe = pipe.to("cpu")
torch.cuda.empty_cache()
# Return the PIL Image directly - ChatInterface will handle it properly
return image
# --- UI Definition using gr.ChatInterface ---
seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_checkbox = gr.Checkbox(label="Randomize seed", value=False)
guidance_slider = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, step=0.1, value=2.5)
steps_slider = gr.Slider(label="Steps", minimum=1, maximum=30, value=28, step=1)
# --- Examples without external URLs ---
# Remove examples temporarily to avoid format issues
examples = None
demo = gr.ChatInterface(
fn=chat_fn,
title="FLUX.1 Kontext [dev]",
description="""<p style='text-align: center;'>
A simple chat UI for the <b>FLUX.1 Kontext</b> model running on ZeroGPU.
<br>
To edit an image, upload it and type your instructions (e.g., "Add a hat").
<br>
To generate an image, just type a prompt (e.g., "A photo of an astronaut on a horse").
<br>
Find the model on <a href='https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev' target='_blank'>Hugging Face</a>.
</p>""",
multimodal=True, # This is important for MultimodalTextbox to work
textbox=gr.MultimodalTextbox(
file_types=["image"],
placeholder="Type a prompt and/or upload an image...",
render=False
),
additional_inputs=[
seed_slider,
randomize_checkbox,
guidance_slider,
steps_slider
],
examples=examples,
theme="soft"
)
if __name__ == "__main__":
demo.launch() |