File size: 2,067 Bytes
2caf84c
0e0ee20
 
4989e93
607d766
4989e93
0e0ee20
 
c724573
 
463aefd
c724573
 
 
0e0ee20
f3e96f9
c59400c
4989e93
3ea2ab0
 
c724573
4989e93
 
0e0ee20
0b93385
4989e93
 
 
 
 
 
5b82e60
fd8e800
4989e93
 
 
 
 
 
 
 
 
 
 
2caf84c
4989e93
 
0b93385
4989e93
 
 
 
0e0ee20
4989e93
 
0e0ee20
4989e93
 
 
 
 
 
 
 
0e0ee20
4989e93
 
0e0ee20
 
4989e93
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces

# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"

taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)

MAX_SEED = 2**32-1

# Hidden variables (you would set these based on your specific LoRA)
LORA_PATH = "SebastianBodza/Flux_Aquarell_Watercolor_v2"
TRIGGER_WORD = "AQUACOLTOK"

# Load LoRA weights (do this once at startup)
pipe.load_lora_weights(LORA_PATH)

@spaces.GPU(duration=70)
def generate_image(prompt, width, height):
    # Combine prompt with trigger word
    full_prompt = f"{TRIGGER_WORD} {prompt}"
    
    # Set up generation parameters
    seed = random.randint(0, MAX_SEED)
    generator = torch.Generator(device="cuda").manual_seed(seed)
    
    # Generate image
    image = pipe(
        prompt=full_prompt,
        num_inference_steps=28,
        guidance_scale=3.5,
        width=width,
        height=height,
        generator=generator,
    ).images[0]
    
    return image

def run_lora(prompt, width, height):
    return generate_image(prompt, width, height)

# Set up the Gradio interface
with gr.Blocks() as app:
    gr.Markdown("# LoRA Image Generator")
    
    with gr.Row():
        prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
    
    with gr.Row():
        width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
        height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
    
    generate_button = gr.Button("Generate Image")
    
    output_image = gr.Image(label="Generated Image")
    
    generate_button.click(
        fn=run_lora,
        inputs=[prompt, width, height],
        outputs=[output_image]
    )

if __name__ == "__main__":
    app.queue()
    app.launch()