24labsimages / app.py
erikbeltran's picture
Update app.py
3ea2ab0 verified
raw
history blame
2.07 kB
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
MAX_SEED = 2**32-1
# Hidden variables (you would set these based on your specific LoRA)
LORA_PATH = "SebastianBodza/Flux_Aquarell_Watercolor_v2"
TRIGGER_WORD = "AQUACOLTOK"
# Load LoRA weights (do this once at startup)
pipe.load_lora_weights(LORA_PATH)
@spaces.GPU(duration=70)
def generate_image(prompt, width, height):
# Combine prompt with trigger word
full_prompt = f"{TRIGGER_WORD} {prompt}"
# Set up generation parameters
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# Generate image
image = pipe(
prompt=full_prompt,
num_inference_steps=28,
guidance_scale=3.5,
width=width,
height=height,
generator=generator,
).images[0]
return image
def run_lora(prompt, width, height):
return generate_image(prompt, width, height)
# Set up the Gradio interface
with gr.Blocks() as app:
gr.Markdown("# LoRA Image Generator")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
generate_button.click(
fn=run_lora,
inputs=[prompt, width, height],
outputs=[output_image]
)
if __name__ == "__main__":
app.queue()
app.launch()