flux-lora-lab / app.py
multimodalart's picture
Update app.py
512c7c4 verified
raw
history blame
28 kB
import os
import gradio as gr
import json
import logging
import torch
from PIL import Image
import spaces
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
from live_preview_helpers import calculate_shift, retrieve_timesteps, flux_pipe_call_that_returns_an_iterable_of_images
from diffusers.utils import load_image
from huggingface_hub import hf_hub_download, HfFileSystem, ModelCard, snapshot_download
import copy
import random
import time
import requests
import pandas as pd
#Load prompts for randomization
df = pd.read_csv('prompts.csv', header=None)
prompt_values = df.values.flatten()
# Load LoRAs from JSON file
with open('loras.json', 'r') as f:
loras = json.load(f)
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
base_model,
vae=good_vae,
transformer=pipe.transformer,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
text_encoder_2=pipe.text_encoder_2,
tokenizer_2=pipe.tokenizer_2,
torch_dtype=dtype
)
MAX_SEED = 2**32 - 1
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
class calculateDuration:
def __init__(self, activity_name=""):
self.activity_name = activity_name
def __enter__(self):
self.start_time = time.time()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.end_time = time.time()
self.elapsed_time = self.end_time - self.start_time
if self.activity_name:
print(f"Elapsed time for {self.activity_name}: {self.elapsed_time:.6f} seconds")
else:
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
def download_file(url, directory=None):
if directory is None:
directory = os.getcwd() # Use current working directory if not specified
# Get the filename from the URL
filename = url.split('/')[-1]
# Full path for the downloaded file
filepath = os.path.join(directory, filename)
# Download the file
response = requests.get(url)
response.raise_for_status() # Raise an exception for bad status codes
# Write the content to the file
with open(filepath, 'wb') as file:
file.write(response.content)
return filepath
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
selected_index = evt.index
selected_indices = selected_indices or []
if selected_index in selected_indices:
selected_indices.remove(selected_index)
else:
if len(selected_indices) < 2:
selected_indices.append(selected_index)
else:
gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update()
selected_info_1 = "Select a LoRA 1"
selected_info_2 = "Select a LoRA 2"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_image_1 = None
lora_image_2 = None
if len(selected_indices) >= 1:
lora1 = loras_state[selected_indices[0]]
selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
lora_image_1 = lora1['image']
if len(selected_indices) >= 2:
lora2 = loras_state[selected_indices[1]]
selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
lora_image_2 = lora2['image']
if selected_indices:
last_selected_lora = loras_state[selected_indices[-1]]
new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
else:
new_placeholder = "Type a prompt after selecting a LoRA"
return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2
def remove_lora_1(selected_indices, loras_state):
if len(selected_indices) >= 1:
selected_indices.pop(0)
selected_info_1 = "Select a LoRA 1"
selected_info_2 = "Select a LoRA 2"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_image_1 = None
lora_image_2 = None
if len(selected_indices) >= 1:
lora1 = loras_state[selected_indices[0]]
selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
lora_image_1 = lora1['image']
if len(selected_indices) >= 2:
lora2 = loras_state[selected_indices[1]]
selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
lora_image_2 = lora2['image']
return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
def remove_lora_2(selected_indices, loras_state):
if len(selected_indices) >= 2:
selected_indices.pop(1)
selected_info_1 = "Select LoRA 1"
selected_info_2 = "Select LoRA 2"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_image_1 = None
lora_image_2 = None
if len(selected_indices) >= 1:
lora1 = loras_state[selected_indices[0]]
selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
lora_image_1 = lora1['image']
if len(selected_indices) >= 2:
lora2 = loras_state[selected_indices[1]]
selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
lora_image_2 = lora2['image']
return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2
def randomize_loras(selected_indices, loras_state):
if len(loras_state) < 2:
raise gr.Error("Not enough LoRAs to randomize.")
selected_indices = random.sample(range(len(loras_state)), 2)
lora1 = loras_state[selected_indices[0]]
lora2 = loras_state[selected_indices[1]]
selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_image_1 = lora1['image']
lora_image_2 = lora2['image']
random_prompt = random.choice(prompt_values)
return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt
def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
if custom_lora:
try:
title, repo, path, trigger_word, image = check_custom_model(custom_lora)
print(f"Loaded custom LoRA: {repo}")
existing_item_index = next((index for (index, item) in enumerate(current_loras) if item['repo'] == repo), None)
if existing_item_index is None:
if repo.endswith(".safetensors") and repo.startswith("http"):
repo = download_file(repo)
new_item = {
"image": image if image else "/home/user/app/custom.png",
"title": title,
"repo": repo,
"weights": path,
"trigger_word": trigger_word
}
print(f"New LoRA: {new_item}")
existing_item_index = len(current_loras)
current_loras.append(new_item)
# Update gallery
gallery_items = [(item["image"], item["title"]) for item in current_loras]
# Update selected_indices if there's room
if len(selected_indices) < 2:
selected_indices.append(existing_item_index)
else:
gr.Warning("You can select up to 2 LoRAs, remove one to select a new one.")
# Update selected_info and images
selected_info_1 = "Select a LoRA 1"
selected_info_2 = "Select a LoRA 2"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_image_1 = None
lora_image_2 = None
if len(selected_indices) >= 1:
lora1 = current_loras[selected_indices[0]]
selected_info_1 = f"### LoRA 1 Selected: {lora1['title']} ✨"
lora_image_1 = lora1['image'] if lora1['image'] else None
if len(selected_indices) >= 2:
lora2 = current_loras[selected_indices[1]]
selected_info_2 = f"### LoRA 2 Selected: {lora2['title']} ✨"
lora_image_2 = lora2['image'] if lora2['image'] else None
print("Finished adding custom LoRA")
return (
current_loras,
gr.update(value=gallery_items),
selected_info_1,
selected_info_2,
selected_indices,
lora_scale_1,
lora_scale_2,
lora_image_1,
lora_image_2
)
except Exception as e:
print(e)
gr.Warning(str(e))
return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
else:
return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
def remove_custom_lora(selected_indices, current_loras, gallery):
if current_loras:
custom_lora_repo = current_loras[-1]['repo']
# Remove from loras list
current_loras = current_loras[:-1]
# Remove from selected_indices if selected
custom_lora_index = len(current_loras)
if custom_lora_index in selected_indices:
selected_indices.remove(custom_lora_index)
# Update gallery
gallery_items = [(item["image"], item["title"]) for item in current_loras]
# Update selected_info and images
selected_info_1 = "Select a LoRA 1"
selected_info_2 = "Select a LoRA 2"
lora_scale_1 = 1.15
lora_scale_2 = 1.15
lora_image_1 = None
lora_image_2 = None
if len(selected_indices) >= 1:
lora1 = current_loras[selected_indices[0]]
selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}]({lora1['repo']}) ✨"
lora_image_1 = lora1['image']
if len(selected_indices) >= 2:
lora2 = current_loras[selected_indices[1]]
selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}]({lora2['repo']}) ✨"
lora_image_2 = lora2['image']
return (
current_loras,
gr.update(value=gallery_items),
selected_info_1,
selected_info_2,
selected_indices,
lora_scale_1,
lora_scale_2,
lora_image_1,
lora_image_2
)
def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress):
print("Generating image...")
pipe.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
with calculateDuration("Generating image"):
# Generate image
for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt_mash,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": 1.0},
output_type="pil",
good_vae=good_vae,
):
yield img
def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps, cfg_scale, width, height, seed):
pipe_i2i.to("cuda")
generator = torch.Generator(device="cuda").manual_seed(seed)
image_input = load_image(image_input_path)
final_image = pipe_i2i(
prompt=prompt_mash,
image=image_input,
strength=image_strength,
num_inference_steps=steps,
guidance_scale=cfg_scale,
width=width,
height=height,
generator=generator,
joint_attention_kwargs={"scale": 1.0},
output_type="pil",
).images[0]
return final_image
@spaces.GPU(duration=75)
def run_lora(prompt, image_input, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
if not selected_indices:
raise gr.Error("You must select at least one LoRA before proceeding.")
selected_loras = [loras_state[idx] for idx in selected_indices]
# Build the prompt with trigger words
prepends = []
appends = []
for lora in selected_loras:
trigger_word = lora.get('trigger_word', '')
if trigger_word:
if lora.get("trigger_position") == "prepend":
prepends.append(trigger_word)
else:
appends.append(trigger_word)
prompt_mash = " ".join(prepends + [prompt] + appends)
print("Prompt Mash: ", prompt_mash)
# Unload previous LoRA weights
with calculateDuration("Unloading LoRA"):
pipe.unload_lora_weights()
pipe_i2i.unload_lora_weights()
print(pipe.get_active_adapters())
# Load LoRA weights with respective scales
lora_names = []
lora_weights = []
with calculateDuration("Loading LoRA weights"):
for idx, lora in enumerate(selected_loras):
lora_name = f"lora_{idx}"
lora_names.append(lora_name)
print(f"Lora Name: {lora_name}")
lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
lora_path = lora['repo']
weight_name = lora.get("weights")
print(f"Lora Path: {lora_path}")
pipe_to_use = pipe_i2i if image_input is not None else pipe
pipe_to_use.load_lora_weights(
lora_path,
weight_name=weight_name if weight_name else None,
low_cpu_mem_usage=True,
adapter_name=lora_name
)
# if image_input is not None: pipe_i2i = pipe_to_use
# else: pipe = pipe_to_use
print("Loaded LoRAs:", lora_names)
print("Adapter weights:", lora_weights)
if image_input is not None:
pipe_i2i.set_adapters(lora_names, adapter_weights=lora_weights)
else:
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
print(pipe.get_active_adapters())
# Set random seed for reproducibility
with calculateDuration("Randomizing seed"):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Generate image
if image_input is not None:
final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
yield final_image, seed, gr.update(visible=False)
else:
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
# Consume the generator to get the final image
final_image = None
step_counter = 0
for image in image_generator:
step_counter += 1
final_image = image
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
yield image, seed, gr.update(value=progress_bar, visible=True)
yield final_image, seed, gr.update(value=progress_bar, visible=False)
run_lora.zerogpu = True
def get_huggingface_safetensors(link):
split_link = link.split("/")
if len(split_link) != 2:
raise Exception("Invalid Hugging Face repository link format.")
print(f"Repository attempted: {split_link}")
# Load model card
model_card = ModelCard.load(link)
base_model = model_card.data.get("base_model")
print(f"Base model: {base_model}")
# Validate model type
acceptable_models = {"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-schnell"}
models_to_check = base_model if isinstance(base_model, list) else [base_model]
if not any(model in acceptable_models for model in models_to_check):
raise Exception("Not a FLUX LoRA!")
# Extract image and trigger word
print("Before trying to get image")
image_path = model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)
print(f"Image path {image_path}")
trigger_word = model_card.data.get("instance_prompt", "")
print(f"Image path {trigger_word}")
image_url = f"https://huggingface.co/{link}/resolve/main/{image_path}" if image_path else None
print(f"Image URL {image_url}")
# Initialize Hugging Face file system
fs = HfFileSystem()
try:
list_of_files = fs.ls(link, detail=False)
# Initialize variables for safetensors selection
safetensors_name = None
highest_trained_file = None
highest_steps = -1
last_safetensors_file = None
step_pattern = re.compile(r"_0{3,}\d+") # Detects step count `_000...`
for file in list_of_files:
filename = file.split("/")[-1]
# Select safetensors file
if filename.endswith(".safetensors"):
last_safetensors_file = filename # Track last encountered file
match = step_pattern.search(filename)
if not match:
# Found a full model without step numbers, return immediately
safetensors_name = filename
break
else:
# Extract step count and track highest
steps = int(match.group().lstrip("_"))
if steps > highest_steps:
highest_trained_file = filename
highest_steps = steps
# Select an image file if not found in model card
if not image_url and filename.lower().endswith((".jpg", ".jpeg", ".png", ".webp")):
image_url = f"https://huggingface.co/{link}/resolve/main/{filename}"
# If no full model found, fall back to the most trained safetensors file
if not safetensors_name:
safetensors_name = highest_trained_file if highest_trained_file else last_safetensors_file
# If still no safetensors file found, raise an exception
if not safetensors_name:
raise Exception("No valid *.safetensors file found in the repository.")
except Exception as e:
print(e)
raise Exception("You didn't include a valid Hugging Face repository with a *.safetensors LoRA")
return split_link[1], link, safetensors_name, trigger_word, image_url
def check_custom_model(link):
if link.endswith(".safetensors"):
# Treat as direct link to the LoRA weights
title = os.path.basename(link)
repo = link
path = None # No specific weight name
trigger_word = ""
image_url = None
return title, repo, path, trigger_word, image_url
elif link.startswith("https://"):
if "huggingface.co" in link:
link_split = link.split("huggingface.co/")
return get_huggingface_safetensors(link_split[1])
else:
raise Exception("Unsupported URL")
else:
# Assume it's a Hugging Face model path
return get_huggingface_safetensors(link)
def update_history(new_image, history):
"""Updates the history gallery with the new image."""
if history is None:
history = []
history.insert(0, new_image)
return history
css = '''
#gen_btn{height: 100%}
#title{text-align: center}
#title h1{font-size: 3em; display:inline-flex; align-items:center}
#title img{width: 100px; margin-right: 0.25em}
#gallery .grid-wrap{height: 5vh}
#lora_list{background: var(--block-background-fill);padding: 0 1em .3em; font-size: 90%}
.custom_lora_card{margin-bottom: 1em}
.card_internal{display: flex;height: 100px;margin-top: .5em}
.card_internal img{margin-right: 1em}
.styler{--form-gap-width: 0px !important}
#progress{height:30px}
#progress .generating{display:none}
.progress-container {width: 100%;height: 30px;background-color: #f0f0f0;border-radius: 15px;overflow: hidden;margin-bottom: 20px}
.progress-bar {height: 100%;background-color: #4f46e5;width: calc(var(--current) / var(--total) * 100%);transition: width 0.5s ease-in-out}
#component-8, .button_total{height: 100%; align-self: stretch;}
#loaded_loras [data-testid="block-info"]{font-size:80%}
#custom_lora_structure{background: var(--block-background-fill)}
#custom_lora_btn{margin-top: auto;margin-bottom: 11px}
#random_btn{font-size: 300%}
#component-11{align-self: stretch;}
'''
with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
title = gr.HTML(
"""<h1><img src="https://i.imgur.com/wMh2Oek.png" alt="LoRA"> LoRA Lab [beta]</h1><br><span style="
margin-top: -25px !important;
display: block;
margin-left: 37px;
">Mix and match any FLUX[dev] LoRAs</span>""",
elem_id="title",
)
loras_state = gr.State(loras)
selected_indices = gr.State([])
with gr.Row():
with gr.Column(scale=3):
prompt = gr.Textbox(label="Prompt", lines=1, placeholder="Type a prompt after selecting a LoRA")
with gr.Column(scale=1):
generate_button = gr.Button("Generate", variant="primary", elem_classes=["button_total"])
with gr.Row(elem_id="loaded_loras"):
with gr.Column(scale=1, min_width=25):
randomize_button = gr.Button("🎲", variant="secondary", scale=1, elem_id="random_btn")
with gr.Column(scale=8):
with gr.Row():
with gr.Column(scale=0, min_width=50):
lora_image_1 = gr.Image(label="LoRA 1 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
with gr.Column(scale=3, min_width=100):
selected_info_1 = gr.Markdown("Select a LoRA 1")
with gr.Column(scale=5, min_width=50):
lora_scale_1 = gr.Slider(label="LoRA 1 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
with gr.Row():
remove_button_1 = gr.Button("Remove", size="sm")
with gr.Column(scale=8):
with gr.Row():
with gr.Column(scale=0, min_width=50):
lora_image_2 = gr.Image(label="LoRA 2 Image", interactive=False, min_width=50, width=50, show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False, height=50)
with gr.Column(scale=3, min_width=100):
selected_info_2 = gr.Markdown("Select a LoRA 2")
with gr.Column(scale=5, min_width=50):
lora_scale_2 = gr.Slider(label="LoRA 2 Scale", minimum=0, maximum=3, step=0.01, value=1.15)
with gr.Row():
remove_button_2 = gr.Button("Remove", size="sm")
with gr.Row():
with gr.Column():
with gr.Group():
with gr.Row(elem_id="custom_lora_structure"):
custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path or *.safetensors public URL", placeholder="multimodalart/vintage-ads-flux", scale=3, min_width=150)
add_custom_lora_button = gr.Button("Add Custom LoRA", elem_id="custom_lora_btn", scale=2, min_width=150)
remove_custom_lora_button = gr.Button("Remove Custom LoRA", visible=False)
gr.Markdown("[Check the list of FLUX LoRAs](https://huggingface.co/models?other=base_model:adapter:black-forest-labs/FLUX.1-dev)", elem_id="lora_list")
gallery = gr.Gallery(
[(item["image"], item["title"]) for item in loras],
label="Or pick from the LoRA Explorer gallery",
allow_preview=False,
columns=5,
elem_id="gallery",
show_share_button=False,
interactive=False
)
with gr.Column():
progress_bar = gr.Markdown(elem_id="progress", visible=False)
result = gr.Image(label="Generated Image", interactive=False, show_share_button=False)
with gr.Accordion("History", open=False):
history_gallery = gr.Gallery(label="History", columns=6, object_fit="contain", interactive=False)
with gr.Row():
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
input_image = gr.Image(label="Input image", type="filepath", show_share_button=False)
image_strength = gr.Slider(label="Denoise Strength", info="Lower means more image influence", minimum=0.1, maximum=1.0, step=0.01, value=0.75)
with gr.Column():
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", minimum=1, maximum=20, step=0.5, value=3.5)
steps = gr.Slider(label="Steps", minimum=1, maximum=50, step=1, value=28)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1536, step=64, value=1024)
height = gr.Slider(label="Height", minimum=256, maximum=1536, step=64, value=1024)
with gr.Row():
randomize_seed = gr.Checkbox(True, label="Randomize seed")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
gallery.select(
update_selection,
inputs=[selected_indices, loras_state, width, height],
outputs=[prompt, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, width, height, lora_image_1, lora_image_2])
remove_button_1.click(
remove_lora_1,
inputs=[selected_indices, loras_state],
outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
)
remove_button_2.click(
remove_lora_2,
inputs=[selected_indices, loras_state],
outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
)
randomize_button.click(
randomize_loras,
inputs=[selected_indices, loras_state],
outputs=[selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, prompt]
)
add_custom_lora_button.click(
add_custom_lora,
inputs=[custom_lora, selected_indices, loras_state, gallery],
outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
)
remove_custom_lora_button.click(
remove_custom_lora,
inputs=[selected_indices, loras_state, gallery],
outputs=[loras_state, gallery, selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2]
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=run_lora,
inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
outputs=[result, seed, progress_bar]
).then(
fn=lambda x, history: update_history(x, history),
inputs=[result, history_gallery],
outputs=history_gallery,
)
app.queue()
app.launch()