Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
from diffusers import DiffusionPipeline, AutoencoderTiny | |
import random | |
import spaces | |
# Initialize the base model | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
base_model = "black-forest-labs/FLUX.1-dev" | |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device) | |
MAX_SEED = 2**32-1 | |
# Hidden variables (you would set these based on your specific LoRA) | |
LORA_PATH = "SebastianBodza/Flux_Aquarell_Watercolor_v2" | |
TRIGGER_WORD = "AQUACOLTOK" | |
# Load LoRA weights (do this once at startup) | |
pipe.load_lora_weights(LORA_PATH) | |
def generate_image(prompt, width, height): | |
# Combine prompt with trigger word | |
full_prompt = f"{TRIGGER_WORD} {prompt}" | |
# Set up generation parameters | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
# Generate image | |
image = pipe( | |
prompt=full_prompt, | |
num_inference_steps=28, | |
guidance_scale=3.5, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
return image | |
def run_lora(prompt, width, height): | |
return generate_image(prompt, width, height) | |
# Set up the Gradio interface | |
with gr.Blocks() as app: | |
gr.Markdown("# LoRA Image Generator") | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here") | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512) | |
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512) | |
generate_button = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
generate_button.click( | |
fn=run_lora, | |
inputs=[prompt, width, height], | |
outputs=[output_image] | |
) | |
if __name__ == "__main__": | |
app.queue() | |
app.launch() |