Spaces:
Runtime error
Runtime error
File size: 2,187 Bytes
2caf84c 0e0ee20 4989e93 607d766 4989e93 0e0ee20 c724573 463aefd c724573 f3e96f9 c59400c f645c51 4989e93 f645c51 4989e93 5b82e60 fd8e800 4989e93 2caf84c f645c51 0b93385 4989e93 0e0ee20 4989e93 0e0ee20 4989e93 f645c51 4989e93 0e0ee20 f645c51 4989e93 0e0ee20 4989e93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
MAX_SEED = 2**32-1
@spaces.GPU()
def generate_image(prompt, width, height, lora_path, trigger_word):
# Load LoRA weights
pipe.load_lora_weights(lora_path)
# Combine prompt with trigger word
full_prompt = f"{trigger_word} {prompt}"
# Set up generation parameters
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# Generate image
image = pipe(
prompt=full_prompt,
num_inference_steps=28,
guidance_scale=3.5,
width=width,
height=height,
generator=generator,
).images[0]
return image
def run_lora(prompt, width, height, lora_path, trigger_word):
return generate_image(prompt, width, height, lora_path, trigger_word)
# Set up the Gradio interface
with gr.Blocks() as app:
gr.Markdown("# LoRA Image Generator")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
with gr.Row():
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
generate_button.click(
fn=run_lora,
inputs=[prompt, width, height, lora_path, trigger_word],
outputs=[output_image]
)
if __name__ == "__main__":
app.queue()
app.launch() |