svjack's picture
Create app.py
ef88f9e
raw
history blame
9.87 kB
import os
import sys
import gradio as gr
import numpy as np
import shutil
import copy
import json
import gc
import random
from PIL import Image
'''
models
images
custom.css
sd_cfg.json
'''
if not os.path.exists("sd-ggml-cpp-dp"):
os.system("git clone https://huggingface.co/svjack/sd-ggml-cpp-dp")
else:
shutil.rmtree("sd-ggml-cpp-dp")
os.system("git clone https://huggingface.co/svjack/sd-ggml-cpp-dp")
assert os.path.exists("sd-ggml-cpp-dp")
os.chdir("sd-ggml-cpp-dp")
assert os.path.exists("stable-diffusion.cpp")
os.system("cmake stable-diffusion.cpp")
os.system("cmake --build . --config Release")
assert os.path.exists("bin")
def process(model_path ,prompt, num_samples, image_resolution, sample_steps, seed,):
from PIL import Image
from uuid import uuid1
output_path = "output_image_dir"
if not os.path.exists(output_path):
os.mkdir(output_path)
else:
shutil.rmtree(output_path)
os.mkdir(output_path)
assert os.path.exists(output_path)
run_format = './bin/sd -m {} --sampling-method "dpm++2mv2" -o "{}/{}.png" -p "{}" --steps {} -H {} -W {} -s {}'
images = []
for i in range(num_samples):
uid = str(uuid1())
run_cmd = run_format.format(model_path, output_path,
uid, prompt, sample_steps, image_resolution,
image_resolution, seed + i)
print("run cmd: {}".format(run_cmd))
os.system(run_cmd)
assert os.path.exists(os.path.join(output_path, "{}.png".format(uid)))
image = Image.open(os.path.join(output_path, "{}.png".format(uid)))
images.append(np.asarray(image))
results = images
return results
model_list = list(map(lambda x: os.path.join("models", x), os.listdir("models")))
assert model_list
sdxl_loras_raw = []
with open("sd_cfg.json", "r") as file:
data = json.load(file)
sdxl_loras_raw = [
{
"image": item["image"],
"title": item["title"],
"repo": item["repo"],
"trigger_word": item["trigger_word"],
"model_path": item["model_path"]
#"weights": item["weights"],
#"is_compatible": item["is_compatible"],
#"is_pivotal": item.get("is_pivotal", False),
#"text_embedding_weights": item.get("text_embedding_weights", None),
#"likes": item.get("likes", 0),
#"downloads": item.get("downloads", 0),
#"is_nc": item.get("is_nc", False)
}
for item in data
]
sdxl_loras_raw = list(filter(lambda d: d["model_path"] in model_list, sdxl_loras_raw))
assert sdxl_loras_raw
def update_selection(selected_state: gr.SelectData, sdxl_loras):
lora_repo = sdxl_loras[selected_state.index]["repo"]
instance_prompt = sdxl_loras[selected_state.index]["trigger_word"]
new_placeholder = "Type a prompt. This applies for all prompts, no need for a trigger word" if instance_prompt == "" else "Type a prompt to use your selected LoRA"
#weight_name = sdxl_loras[selected_state.index]["weights"]
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo}) ✨ "
is_compatible = True
is_pivotal = True
use_with_diffusers = f'''
## Using [`{lora_repo}`](https://huggingface.co/{lora_repo})
## Use it with diffusers:
'''
use_with_uis = f'''
## Use it with Comfy UI, Invoke AI, SD.Next, AUTO1111:
### Download the `*.safetensors` weights of [here](https://huggingface.co/{lora_repo})
- [ComfyUI guide](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
- [Invoke AI guide](https://invoke-ai.github.io/InvokeAI/features/CONCEPTS/?h=lora#using-loras)
- [SD.Next guide](https://github.com/vladmandic/automatic)
- [AUTOMATIC1111 guide](https://stable-diffusion-art.com/lora/)
'''
return (
updated_text,
instance_prompt,
gr.update(placeholder=new_placeholder),
selected_state,
use_with_diffusers,
use_with_uis,
)
def check_selected(selected_state):
if not selected_state:
raise gr.Error("You must select a Model")
def shuffle_gallery(sdxl_loras):
random.shuffle(sdxl_loras)
return [(item["image"], item["title"]) for item in sdxl_loras], sdxl_loras
def swap_gallery(order, sdxl_loras):
if(order == "random"):
return shuffle_gallery(sdxl_loras)
else:
#sorted_gallery = sorted(sdxl_loras, key=lambda x: x.get(order, 0), reverse=True)
sorted_gallery = sorted(sdxl_loras, key=lambda x: x["title"], reverse=False)
return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
'''
def run_lora(prompt, negative, lora_scale, selected_state, sdxl_loras,
progress=gr.Progress(track_tqdm=True)):
'''
def run_lora(prompt, selected_state, sdxl_loras,
image_resolution, sample_steps, seed,
progress=gr.Progress(track_tqdm=True)):
#global last_lora, last_merged, last_fused, pipe
'''
if negative == "":
negative = None
'''
if not selected_state:
raise gr.Error("You must select a Model")
repo_name = sdxl_loras[selected_state.index]["repo"]
model_path = sdxl_loras[selected_state.index]["model_path"]
#weight_name = sdxl_loras[selected_state.index]["weights"]
'''
image = pipe(
prompt=prompt,
negative_prompt=negative,
width=1024,
height=1024,
num_inference_steps=20,
guidance_scale=7.5,
).images[0]
last_lora = repo_name
gc.collect()
'''
num_samples = 1
#### image_resolution : 512
#### sample_steps : 8
#### seed : 20
image = process(model_path ,prompt, num_samples, image_resolution, sample_steps, seed,)[0]
image = Image.fromarray(image.astype(np.uint8))
#return image, gr.update(visible=True)
return image
with gr.Blocks(css="custom.css") as demo:
#with gr.Blocks() as demo:
gr_sdxl_loras = gr.State(value=sdxl_loras_raw)
title = gr.HTML(
"""<h1><img src="https://i.imgur.com/vT48NAO.png" alt="SD"> StableDiffusion GGML Explorer</h1>""",
elem_id="title",
)
selected_state = gr.State()
with gr.Row(elem_id="main_app"):
with gr.Box(elem_id="gallery_box"):
order_gallery = gr.Radio(choices=["random", "alphabetical"],
value="random", label="Order by", elem_id="order_radio")
gallery = gr.Gallery(
#value=[(item["image"], item["title"]) for item in sdxl_loras_raw],
label="SD Model Gallery",
allow_preview=True,
#rows = 1,
columns=2,
#scale = 3,
min_width = 256,
#object_fit = "scale-down",
elem_id="gallery",
show_share_button=False,
height=512
)
with gr.Column():
prompt_title = gr.Markdown(
value="### Click on a Model in the gallery to select it",
visible=True,
elem_id="selected_model",
)
with gr.Row():
prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1,
placeholder="Type a prompt after selecting a Model", elem_id="prompt")
button = gr.Button("Run", elem_id="run_button")
'''
with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
'''
result = gr.Image(
interactive=False, label="Generated Image", elem_id="result-image"
)
with gr.Accordion("Advanced options", open=False):
#negative = gr.Textbox(label="Negative Prompt")
#weight = gr.Slider(0, 10, value=0.8, step=0.1, label="LoRA weight")
#negative = ""
image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=256)
sample_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
order_gallery.change(
fn=swap_gallery,
inputs=[order_gallery, gr_sdxl_loras],
outputs=[gallery, gr_sdxl_loras],
queue=False
)
gallery.select(
fn=update_selection,
inputs=[gr_sdxl_loras],
#outputs=[prompt_title, prompt, prompt, selected_state, use_diffusers, use_uis],
outputs=[prompt_title, prompt, prompt, selected_state,],
queue=False,
show_progress=False
)
prompt.submit(
fn=check_selected,
inputs=[selected_state],
queue=False,
show_progress=False
).success(
fn=run_lora,
#inputs=[prompt, negative, weight, selected_state, gr_sdxl_loras],
inputs=[prompt, selected_state, gr_sdxl_loras, image_resolution, sample_steps, seed],
#outputs=[result, share_group],
#outputs=[result,],
outputs = result
)
button.click(
fn=check_selected,
inputs=[selected_state],
queue=False,
show_progress=False
).success(
fn=run_lora,
#inputs=[prompt, negative, weight, selected_state, gr_sdxl_loras],
inputs=[prompt, selected_state, gr_sdxl_loras, image_resolution, sample_steps, seed],
#outputs=[result, share_group],
#outputs=[result,],
outputs = result
)
#share_button.click(None, [], [], _js=share_js)
demo.load(fn=shuffle_gallery, inputs=[gr_sdxl_loras], outputs=[gallery, gr_sdxl_loras], queue=False)
demo.queue(max_size=20)
demo.launch()