Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,826 Bytes
cc6558b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import spaces
import time
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from src_inference.pipeline import FluxPipeline
from src_inference.lora_helper import set_single_lora
import random
base_path = "black-forest-labs/FLUX.1-dev"
# Download OmniConsistency LoRA using hf_hub_download
omni_consistency_path = hf_hub_download(repo_id="showlab/OmniConsistency",
filename="OmniConsistency.safetensors",
local_dir="./Model")
# Initialize the pipeline with the model
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16).to("cuda")
# Set LoRA weights
set_single_lora(pipe.transformer, omni_consistency_path, lora_weights=[1], cond_size=512)
# Function to clear cache
def clear_cache(transformer):
for name, attn_processor in transformer.attn_processors.items():
attn_processor.bank_kv.clear()
# Function to download all LoRAs in advance
def download_all_loras():
lora_names = [
"3D_Chibi", "American_Cartoon", "Chinese_Ink",
"Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
"Jojo", "LEGO", "Line", "Macaron",
"Oil_Painting", "Origami", "Paper_Cutting",
"Picasso", "Pixel", "Poly", "Pop_Art",
"Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
]
for lora_name in lora_names:
hf_hub_download(repo_id="showlab/OmniConsistency",
filename=f"LoRAs/{lora_name}_rank128_bf16.safetensors",
local_dir="./LoRAs")
# Download all LoRAs in advance before the interface is launched
download_all_loras()
# Main function to generate the image
@spaces.GPU()
def generate_image(lora_name, prompt, uploaded_image, width, height, guidance_scale, num_inference_steps, seed):
# Download specific LoRA based on selection (use local directory as LoRAs are already downloaded)
lora_path = f"./LoRAs/LoRAs/{lora_name}_rank128_bf16.safetensors"
# Load the specific LoRA weights
pipe.unload_lora_weights()
pipe.load_lora_weights("./LoRAs/LoRAs", weight_name=f"{lora_name}_rank128_bf16.safetensors")
# Prepare input image
spatial_image = [uploaded_image.convert("RGB")]
subject_images = []
start_time = time.time()
# Generate the image
image = pipe(
prompt,
height=(int(height) // 8) * 8,
width=(int(width) // 8) * 8,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(seed),
spatial_images=spatial_image,
subject_images=subject_images,
cond_size=512,
).images[0]
end_time = time.time()
elapsed_time = end_time - start_time
print(f"code running time: {elapsed_time} s")
# Clear cache after generation
clear_cache(pipe.transformer)
return image
# Example data
examples = [
["3D_Chibi", "3D Chibi style", Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
["Origami", "Origami style", Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
["American_Cartoon", "American Cartoon style", Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
["Origami", "Origami style", Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
["Paper_Cutting", "Paper Cutting style", Image.open("./test_imgs/04.png"), 696, 1024, 3.5, 24, 42]
]
# Gradio interface setup
def create_gradio_interface():
lora_names = [
"3D_Chibi", "American_Cartoon", "Chinese_Ink",
"Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
"Jojo", "LEGO", "Line", "Macaron",
"Oil_Painting", "Origami", "Paper_Cutting",
"Picasso", "Pixel", "Poly", "Pop_Art",
"Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
]
with gr.Blocks() as demo:
gr.Markdown("# OmniConsistency LoRA Image Generation")
gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency.")
with gr.Row():
with gr.Column(scale=1):
lora_dropdown = gr.Dropdown(lora_names, label="Select LoRA")
prompt_box = gr.Textbox(label="Prompt", placeholder="Enter a prompt...")
image_input = gr.Image(type="pil", label="Upload Image")
with gr.Column(scale=1):
width_box = gr.Textbox(label="Width", value="1024")
height_box = gr.Textbox(label="Height", value="1024")
guidance_slider = gr.Slider(minimum=0.1, maximum=20, value=3.5, step=0.1, label="Guidance Scale")
steps_slider = gr.Slider(minimum=1, maximum=50, value=25, step=1, label="Inference Steps")
seed_slider = gr.Slider(minimum=1, maximum=10000000000, value=42, step=1, label="Seed")
generate_button = gr.Button("Generate")
output_image = gr.Image(type="pil", label="Generated Image")
# Add examples for Generation
gr.Examples(
examples=examples,
inputs=[lora_dropdown, prompt_box, image_input, height_box, width_box, guidance_slider, steps_slider, seed_slider],
outputs=output_image,
fn=generate_image,
cache_examples=False,
label="Examples"
)
generate_button.click(
fn=generate_image,
inputs=[
lora_dropdown, prompt_box, image_input,
width_box, height_box, guidance_slider,
steps_slider, seed_slider
],
outputs=output_image
)
return demo
# Launch the Gradio interface
interface = create_gradio_interface()
interface.launch()
|