24labsimages / app.py
erikbeltran's picture
Update app.py
d06075d verified
raw
history blame
3.45 kB
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces
from PIL import Image
from fastapi import FastAPI
from fastapi.responses import FileResponse
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
MAX_SEED = 2**32-1
# Create a FastAPI app
app = FastAPI()
@spaces.GPU()
def generate_image(prompt, width, height, lora_path, trigger_word, hash_value):
# Load LoRA weights
pipe.load_lora_weights(lora_path)
# Combine prompt with trigger word
full_prompt = f"{trigger_word} {prompt}"
# Set up generation parameters
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# Generate image
image = pipe(
prompt=full_prompt,
num_inference_steps=28,
guidance_scale=3.5,
width=width,
height=height,
generator=generator,
).images[0]
# Create ./tmp/ directory if it doesn't exist
os.makedirs('./tmp', exist_ok=True)
# Save the image with the provided hash
image_path = f'./tmp/{hash_value}.png'
image.save(image_path)
return image, image_path
def run_lora(prompt, width, height, lora_path, trigger_word, hash_value):
image, image_path = generate_image(prompt, width, height, lora_path, trigger_word, hash_value)
return image, image_path
# Set up the Gradio interface
with gr.Blocks() as gradio_app:
gr.Markdown("# LoRA Image Generator")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
with gr.Row():
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
hash_value = gr.Textbox(label="Hash", placeholder="Enter a unique hash for the image")
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
output_path = gr.Textbox(label="Saved Image Path")
generate_button.click(
fn=run_lora,
inputs=[prompt, width, height, lora_path, trigger_word, hash_value],
outputs=[output_image, output_path]
)
# FastAPI endpoint for downloading images
@app.get("/download/{hash_value}")
async def download_image(hash_value: str):
image_path = f"./tmp/{hash_value}.png"
if os.path.exists(image_path):
return FileResponse(image_path, media_type="image/png", filename=f"{hash_value}.png")
return {"error": "Image not found"}
# Launch both Gradio and FastAPI
if __name__ == "__main__":
import uvicorn
from threading import Thread
def run_gradio():
gradio_app.queue()
gradio_app.launch()
# Start Gradio in a separate thread
Thread(target=run_gradio, daemon=True).start()
# Run FastAPI with uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)