File size: 5,826 Bytes
cc6558b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import spaces
import time
import torch
import gradio as gr
from PIL import Image
from huggingface_hub import hf_hub_download
from src_inference.pipeline import FluxPipeline
from src_inference.lora_helper import set_single_lora
import random

base_path = "black-forest-labs/FLUX.1-dev"
    
# Download OmniConsistency LoRA using hf_hub_download
omni_consistency_path = hf_hub_download(repo_id="showlab/OmniConsistency", 
                                        filename="OmniConsistency.safetensors", 
                                        local_dir="./Model")

# Initialize the pipeline with the model
pipe = FluxPipeline.from_pretrained(base_path, torch_dtype=torch.bfloat16).to("cuda")

# Set LoRA weights
set_single_lora(pipe.transformer, omni_consistency_path, lora_weights=[1], cond_size=512)

# Function to clear cache
def clear_cache(transformer):
    for name, attn_processor in transformer.attn_processors.items():
        attn_processor.bank_kv.clear()

# Function to download all LoRAs in advance
def download_all_loras():
    lora_names = [
        "3D_Chibi", "American_Cartoon", "Chinese_Ink", 
        "Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
        "Jojo", "LEGO", "Line", "Macaron",
        "Oil_Painting", "Origami", "Paper_Cutting", 
        "Picasso", "Pixel", "Poly", "Pop_Art", 
        "Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
    ]
    for lora_name in lora_names:
        hf_hub_download(repo_id="showlab/OmniConsistency", 
                        filename=f"LoRAs/{lora_name}_rank128_bf16.safetensors", 
                        local_dir="./LoRAs")

# Download all LoRAs in advance before the interface is launched
download_all_loras()

# Main function to generate the image
@spaces.GPU()
def generate_image(lora_name, prompt, uploaded_image, width, height, guidance_scale, num_inference_steps, seed):
    # Download specific LoRA based on selection (use local directory as LoRAs are already downloaded)
    lora_path = f"./LoRAs/LoRAs/{lora_name}_rank128_bf16.safetensors"

    # Load the specific LoRA weights
    pipe.unload_lora_weights()
    pipe.load_lora_weights("./LoRAs/LoRAs", weight_name=f"{lora_name}_rank128_bf16.safetensors")

    # Prepare input image
    spatial_image = [uploaded_image.convert("RGB")]
    subject_images = []

    start_time = time.time()

    # Generate the image
    image = pipe(
        prompt,
        height=(int(height) // 8) * 8,
        width=(int(width) // 8) * 8,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(seed),
        spatial_images=spatial_image,
        subject_images=subject_images,
        cond_size=512,
    ).images[0]

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"code running time: {elapsed_time} s")

    # Clear cache after generation
    clear_cache(pipe.transformer)

    return image

# Example data
examples = [
    ["3D_Chibi", "3D Chibi style",                  Image.open("./test_imgs/00.png"), 680, 1024, 3.5, 24, 42],
    ["Origami", "Origami style",                    Image.open("./test_imgs/01.png"), 560, 1024, 3.5, 24, 42],
    ["American_Cartoon", "American Cartoon style",  Image.open("./test_imgs/02.png"), 568, 1024, 3.5, 24, 42],
    ["Origami", "Origami style",                    Image.open("./test_imgs/03.png"), 768, 672, 3.5, 24, 42],
    ["Paper_Cutting", "Paper Cutting style",        Image.open("./test_imgs/04.png"), 696, 1024, 3.5, 24, 42]
]

# Gradio interface setup
def create_gradio_interface():
    lora_names = [
        "3D_Chibi", "American_Cartoon", "Chinese_Ink", 
        "Clay_Toy", "Fabric", "Ghibli", "Irasutoya",
        "Jojo", "LEGO", "Line", "Macaron",
        "Oil_Painting", "Origami", "Paper_Cutting", 
        "Picasso", "Pixel", "Poly", "Pop_Art", 
        "Rick_Morty", "Snoopy", "Van_Gogh", "Vector"
    ]

    with gr.Blocks() as demo:
        gr.Markdown("# OmniConsistency LoRA Image Generation")
        gr.Markdown("Select a LoRA, enter a prompt, and upload an image to generate a new image with OmniConsistency.")
        with gr.Row():
            with gr.Column(scale=1):
                lora_dropdown = gr.Dropdown(lora_names, label="Select LoRA")
                prompt_box = gr.Textbox(label="Prompt", placeholder="Enter a prompt...")
                image_input = gr.Image(type="pil", label="Upload Image")
            with gr.Column(scale=1):
                width_box = gr.Textbox(label="Width", value="1024")
                height_box = gr.Textbox(label="Height", value="1024")
                guidance_slider = gr.Slider(minimum=0.1, maximum=20, value=3.5, step=0.1, label="Guidance Scale")
                steps_slider = gr.Slider(minimum=1, maximum=50, value=25, step=1, label="Inference Steps")
                seed_slider = gr.Slider(minimum=1, maximum=10000000000, value=42, step=1, label="Seed")
                generate_button = gr.Button("Generate")
                output_image = gr.Image(type="pil", label="Generated Image")
        # Add examples for Generation
        gr.Examples(
            examples=examples,
            inputs=[lora_dropdown, prompt_box, image_input, height_box, width_box, guidance_slider, steps_slider, seed_slider],
            outputs=output_image,
            fn=generate_image,
            cache_examples=False,
            label="Examples"
        )

        generate_button.click(
            fn=generate_image,
            inputs=[
                lora_dropdown, prompt_box, image_input,
                width_box, height_box, guidance_slider,
                steps_slider, seed_slider
            ],
            outputs=output_image
        )

    return demo


# Launch the Gradio interface
interface = create_gradio_interface()
interface.launch()