Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
from diffusers import DiffusionPipeline, AutoencoderTiny | |
import random | |
import spaces | |
from PIL import Image | |
from fastapi import FastAPI | |
from fastapi.responses import FileResponse | |
# Initialize the base model | |
dtype = torch.bfloat16 | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
base_model = "black-forest-labs/FLUX.1-dev" | |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device) | |
MAX_SEED = 2**32-1 | |
# Create a FastAPI app | |
app = FastAPI() | |
def generate_image(prompt, width, height, lora_path, trigger_word, hash_value): | |
# Load LoRA weights | |
pipe.load_lora_weights(lora_path) | |
# Combine prompt with trigger word | |
full_prompt = f"{trigger_word} {prompt}" | |
# Set up generation parameters | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
# Generate image | |
image = pipe( | |
prompt=full_prompt, | |
num_inference_steps=28, | |
guidance_scale=3.5, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
# Create ./tmp/ directory if it doesn't exist | |
os.makedirs('./tmp', exist_ok=True) | |
# Save the image with the provided hash | |
image_path = f'./tmp/{hash_value}.png' | |
image.save(image_path) | |
return image, image_path | |
def run_lora(prompt, width, height, lora_path, trigger_word, hash_value): | |
image, image_path = generate_image(prompt, width, height, lora_path, trigger_word, hash_value) | |
return image, image_path | |
# Set up the Gradio interface | |
with gr.Blocks() as gradio_app: | |
gr.Markdown("# LoRA Image Generator") | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here") | |
with gr.Row(): | |
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512) | |
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512) | |
with gr.Row(): | |
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2") | |
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK") | |
hash_value = gr.Textbox(label="Hash", placeholder="Enter a unique hash for the image") | |
generate_button = gr.Button("Generate Image") | |
output_image = gr.Image(label="Generated Image") | |
output_path = gr.Textbox(label="Saved Image Path") | |
generate_button.click( | |
fn=run_lora, | |
inputs=[prompt, width, height, lora_path, trigger_word, hash_value], | |
outputs=[output_image, output_path] | |
) | |
# FastAPI endpoint for downloading images | |
async def download_image(hash_value: str): | |
image_path = f"./tmp/{hash_value}.png" | |
if os.path.exists(image_path): | |
return FileResponse(image_path, media_type="image/png", filename=f"{hash_value}.png") | |
return {"error": "Image not found"} | |
# Launch both Gradio and FastAPI | |
if __name__ == "__main__": | |
import uvicorn | |
from threading import Thread | |
def run_gradio(): | |
gradio_app.queue() | |
gradio_app.launch() | |
# Start Gradio in a separate thread | |
Thread(target=run_gradio, daemon=True).start() | |
# Run FastAPI with uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |