import gradio as gr
import numpy as np
import spaces
import torch
import random
import json
import os
from PIL import Image
from diffusers import FluxKontextPipeline
from diffusers.utils import load_image, peft_utils
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard
from safetensors.torch import load_file
import requests
import re
# Load the base model
MAX_SEED = np.iinfo(np.int32).max
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
try: # Temporary workaround for diffusers LoRA loading issue
from diffusers.utils.peft_utils import _derive_exclude_modules
def new_derive_exclude_modules(*args, **kwargs):
exclude_modules = _derive_exclude_modules(*args, **kwargs)
if exclude_modules is not None:
exclude_modules = [n for n in exclude_modules if "proj_out" not in n]
return exclude_modules
peft_utils._derive_exclude_modules = new_derive_exclude_modules
except:
pass
# Load LoRA configurations from JSON
with open("lora_configs.json", "r") as file:
data = json.load(file)
lora_configs = [
{
"image": item["image"],
"title": item["title"],
"repo": item["repo"],
"trigger_word": item.get("trigger_word", ""),
"trigger_position": item.get("trigger_position", "prepend"),
"weights": item.get("weights", "pytorch_lora_weights.safetensors"),
}
for item in data
]
print(f"Loaded {len(lora_configs)} LoRAs from JSON")
# Global variables for adapter management
active_lora_adapter = None
lora_cache = {}
def load_lora_weights(repo_id, weights_filename):
"""Load adapter weights from HuggingFace"""
try:
if repo_id not in lora_cache:
lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
lora_cache[repo_id] = lora_path
return lora_cache[repo_id]
except Exception as e:
print(f"Error loading adapter from {repo_id}: {e}")
return None
def on_lora_select(selected_state: gr.SelectData, lora_configs):
"""Update UI when an adapter is selected"""
if selected_state.index >= len(lora_configs):
return "### No adapter selected", gr.update(), None
lora_repo = lora_configs[selected_state.index]["repo"]
trigger_word = lora_configs[selected_state.index]["trigger_word"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
new_placeholder = f"optional description, e.g. 'a man with glasses and a beard'"
return updated_text, gr.update(placeholder=new_placeholder), selected_state.index
def fetch_lora_from_hf(link):
"""Retrieve adapter from HuggingFace link"""
split_link = link.split("/")
if len(split_link) == 2:
try:
model_card = ModelCard.load(link)
trigger_word = model_card.data.get("instance_prompt", "")
fs = HfFileSystem()
list_of_files = fs.ls(link, detail=False)
safetensors_file = None
for file in list_of_files:
if file.endswith(".safetensors") and "lora" in file.lower():
safetensors_file = file.split("/")[-1]
break
if not safetensors_file:
safetensors_file = "pytorch_lora_weights.safetensors"
return split_link[1], safetensors_file, trigger_word
except Exception as e:
raise Exception(f"Error loading adapter: {e}")
else:
raise Exception("Invalid HuggingFace repository format")
def load_user_lora(link):
"""Load a user-provided adapter"""
if not link:
return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on an adapter in the gallery to select it", None
try:
repo_name, weights_file, trigger_word = fetch_lora_from_hf(link)
card = f'''
Loaded custom adapter:
{repo_name}
{"Using: "+trigger_word+"
as trigger word" if trigger_word else "No trigger word found"}
'''
user_lora_data = {
"repo": link,
"weights": weights_file,
"trigger_word": trigger_word
}
return gr.update(visible=True), card, gr.update(visible=True), user_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None
except Exception as e:
return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on an adapter in the gallery to select it", None
def unload_user_lora():
"""Remove the user-provided adapter"""
return "", gr.update(visible=False), gr.update(visible=False), None, None
def sort_lora_gallery(lora_configs):
"""Sort the adapter gallery by likes"""
sorted_gallery = sorted(lora_configs, key=lambda x: x.get("likes", 0), reverse=True)
return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
def generate_image_wrapper(input_image, prompt, selected_index, user_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.75, width=960, height=1280, lora_configs=None, progress=gr.Progress(track_tqdm=True)):
"""Wrapper for image generation to handle state"""
return generate_image(input_image, prompt, selected_index, user_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, lora_configs, progress)
@spaces.GPU
def generate_image(input_image, prompt, selected_index, user_lora, seed=42, randomize_seed=False, steps=28, guidance_scale=2.5, lora_scale=1.0, width=960, height=1280, lora_configs=None, progress=gr.Progress(track_tqdm=True)):
"""Generate an image using the selected adapter"""
global active_lora_adapter, pipe
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Select the adapter to use
lora_to_use = None
if user_lora:
lora_to_use = user_lora
elif selected_index is not None and lora_configs and selected_index < len(lora_configs):
lora_to_use = lora_configs[selected_index]
print(f"Loaded {len(lora_configs)} adapters from JSON")
# Load the adapter if necessary
if lora_to_use and lora_to_use != active_lora_adapter:
try:
if active_lora_adapter:
pipe.unload_lora_weights()
lora_path = load_lora_weights(lora_to_use["repo"], lora_to_use["weights"])
if lora_path:
pipe.load_lora_weights(lora_path, adapter_name="selected_lora")
pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
print(f"loaded: {lora_path} with scale {lora_scale}")
active_lora_adapter = lora_to_use
except Exception as e:
print(f"Error loading adapter: {e}")
else:
print(f"using already loaded adapter: {lora_to_use}")
input_image = input_image.convert("RGB")
# Modify prompt based on trigger word
trigger_word = lora_to_use["trigger_word"]
if trigger_word == ", How2Draw":
prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
elif trigger_word == "__ ":
prompt = f" {prompt}. Accurately render the toolimpact logo and any tool impact iconography. The toolimpact logo begins with a two-line-tall drop-cap capital letter T with a dot in the center of its top bar."
else:
prompt = f" {prompt}. convert the style of this photo or image to {trigger_word}. Maintain the facial identity of any persons and the general features of the image!"
try:
image = pipe(
image=input_image,
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
generator=torch.Generator().manual_seed(seed),
width=width,
height=height,
max_area=width * height
).images[0]
return image, seed, gr.update(visible=True)
except Exception as e:
print(f"Error during generation: {e}")
return None, seed, gr.update(visible=False)
# CSS styling
css = """
#app_container {
display: flex;
gap: 20px;
}
#left_panel {
min-width: 400px;
}
#lora_info {
color: #2563eb;
font-weight: bold;
}
#edit_prompt {
flex-grow: 1;
}
#generate_button {
background: linear-gradient(45deg, #2563eb, #3b82f6);
color: white;
border: none;
padding: 8px 16px;
border-radius: 6px;
font-weight: bold;
}
.user_lora_card {
background: #f8fafc;
border: 1px solid #e2e8f0;
border-radius: 8px;
padding: 12px;
margin: 8px 0;
}
#lora_gallery{
overflow: scroll !important
}
"""
# Build the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css=css, delete_cache=(60, 60)) as demo:
gr_lora_configs = gr.State(value=lora_configs)
title = gr.HTML(
"""Flux Kontext DLC😍
""",
)
selected_state = gr.State(value=None)
user_lora = gr.State(value=None)
with gr.Row(elem_id="app_container"):
with gr.Column(scale=4, elem_id="left_panel"):
with gr.Group(elem_id="lora_selection"):
input_image = gr.Image(label="Upload a picture", type="pil", height=300)
gallery = gr.Gallery(
label="Pick an Adapter",
allow_preview=False,
columns=3,
elem_id="lora_gallery",
show_share_button=False,
height=400
)
user_lora_input = gr.Textbox(
label="Or enter a custom HuggingFace adapter",
placeholder="e.g., username/adapter-name",
visible=True
)
user_lora_card = gr.HTML(visible=False)
unload_user_lora_button = gr.Button("Remove custom adapter", visible=True)
with gr.Column(scale=5):
with gr.Row():
prompt = gr.Textbox(
label="Editing Prompt",
show_label=False,
lines=1,
max_lines=1,
placeholder="optional description, e.g. 'colorize and stylize, leave all else as is'",
elem_id="edit_prompt"
)
run_button = gr.Button("Generate", elem_id="generate_button")
result = gr.Image(label="Generated Image", interactive=False)
reuse_button = gr.Button("Reuse this image", visible=False)
with gr.Accordion("Advanced Settings", open=True):
lora_scale = gr.Slider(
label="Adapter Scale",
minimum=0,
maximum=2,
step=0.1,
value=1.5,
info="Controls the strength of the adapter effect"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=40,
value=10,
step=1
)
width = gr.Slider(
label="Width",
minimum=128,
maximum=2560,
step=1,
value=960,
)
height = gr.Slider(
label="Height",
minimum=128,
maximum=2560,
step=1,
value=1280,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=10,
step=0.1,
value=2.8,
)
prompt_title = gr.Markdown(
value="### Click on an adapter in the gallery to select it",
visible=True,
elem_id="lora_info",
)
# Event handlers
user_lora_input.input(
fn=load_user_lora,
inputs=[user_lora_input],
outputs=[user_lora_card, user_lora_card, unload_user_lora_button, user_lora, gallery, prompt_title, selected_state],
)
unload_user_lora_button.click(
fn=unload_user_lora,
outputs=[user_lora_input, unload_user_lora_button, user_lora_card, user_lora, selected_state]
)
gallery.select(
fn=on_lora_select,
inputs=[gr_lora_configs],
outputs=[prompt_title, prompt, selected_state],
show_progress=False
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=generate_image_wrapper,
inputs=[input_image, prompt, selected_state, user_lora, seed, randomize_seed, steps, guidance_scale, lora_scale, width, height, gr_lora_configs],
outputs=[result, seed, reuse_button]
)
reuse_button.click(
fn=lambda image: image,
inputs=[result],
outputs=[input_image]
)
# Initialize the gallery
demo.load(
fn=sort_lora_gallery,
inputs=[gr_lora_configs],
outputs=[gallery, gr_lora_configs]
)
demo.queue(default_concurrency_limit=None)
demo.launch()