File size: 2,187 Bytes
2caf84c
0e0ee20
 
4989e93
607d766
4989e93
0e0ee20
 
c724573
 
463aefd
c724573
 
f3e96f9
c59400c
f645c51
 
 
 
 
4989e93
f645c51
4989e93
 
 
5b82e60
fd8e800
4989e93
 
 
 
 
 
 
 
 
 
 
2caf84c
f645c51
 
0b93385
4989e93
 
 
 
0e0ee20
4989e93
 
0e0ee20
4989e93
 
 
f645c51
 
 
 
4989e93
 
 
 
 
0e0ee20
f645c51
4989e93
0e0ee20
 
4989e93
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces

# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
MAX_SEED = 2**32-1

@spaces.GPU()
def generate_image(prompt, width, height, lora_path, trigger_word):
    # Load LoRA weights
    pipe.load_lora_weights(lora_path)
    
    # Combine prompt with trigger word
    full_prompt = f"{trigger_word} {prompt}"
    
    # Set up generation parameters
    seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    # Generate image
    image = pipe(
        prompt=full_prompt,
        num_inference_steps=28,
        guidance_scale=3.5,
        width=width,
        height=height,
        generator=generator,
    ).images[0]
    
    return image

def run_lora(prompt, width, height, lora_path, trigger_word):
    return generate_image(prompt, width, height, lora_path, trigger_word)

# Set up the Gradio interface
with gr.Blocks() as app:
    gr.Markdown("# LoRA Image Generator")
    
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
    
    with gr.Row():
        width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
        height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
    
    with gr.Row():
        lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
        trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
    
    generate_button = gr.Button("Generate Image")
    
    output_image = gr.Image(label="Generated Image")
    
    generate_button.click(
        fn=run_lora,
        inputs=[prompt, width, height, lora_path, trigger_word],
        outputs=[output_image]
    )

if __name__ == "__main__":
    app.queue()
    app.launch()