Spaces:
Runtime error
Runtime error
File size: 3,453 Bytes
2caf84c 0e0ee20 4989e93 607d766 4989e93 d06075d 0e0ee20 c724573 463aefd c724573 f3e96f9 c59400c d06075d f645c51 d06075d f645c51 4989e93 f645c51 4989e93 5b82e60 fd8e800 4989e93 d06075d 2caf84c d06075d 0b93385 4989e93 d06075d 4989e93 0e0ee20 4989e93 0e0ee20 4989e93 f645c51 d06075d f645c51 4989e93 d06075d 4989e93 0e0ee20 d06075d 0e0ee20 d06075d 4989e93 d06075d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline, AutoencoderTiny
import random
import spaces
from PIL import Image
from fastapi import FastAPI
from fastapi.responses import FileResponse
# Initialize the base model
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "black-forest-labs/FLUX.1-dev"
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
MAX_SEED = 2**32-1
# Create a FastAPI app
app = FastAPI()
@spaces.GPU()
def generate_image(prompt, width, height, lora_path, trigger_word, hash_value):
# Load LoRA weights
pipe.load_lora_weights(lora_path)
# Combine prompt with trigger word
full_prompt = f"{trigger_word} {prompt}"
# Set up generation parameters
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device="cuda").manual_seed(seed)
# Generate image
image = pipe(
prompt=full_prompt,
num_inference_steps=28,
guidance_scale=3.5,
width=width,
height=height,
generator=generator,
).images[0]
# Create ./tmp/ directory if it doesn't exist
os.makedirs('./tmp', exist_ok=True)
# Save the image with the provided hash
image_path = f'./tmp/{hash_value}.png'
image.save(image_path)
return image, image_path
def run_lora(prompt, width, height, lora_path, trigger_word, hash_value):
image, image_path = generate_image(prompt, width, height, lora_path, trigger_word, hash_value)
return image, image_path
# Set up the Gradio interface
with gr.Blocks() as gradio_app:
gr.Markdown("# LoRA Image Generator")
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here")
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512)
height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512)
with gr.Row():
lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2")
trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK")
hash_value = gr.Textbox(label="Hash", placeholder="Enter a unique hash for the image")
generate_button = gr.Button("Generate Image")
output_image = gr.Image(label="Generated Image")
output_path = gr.Textbox(label="Saved Image Path")
generate_button.click(
fn=run_lora,
inputs=[prompt, width, height, lora_path, trigger_word, hash_value],
outputs=[output_image, output_path]
)
# FastAPI endpoint for downloading images
@app.get("/download/{hash_value}")
async def download_image(hash_value: str):
image_path = f"./tmp/{hash_value}.png"
if os.path.exists(image_path):
return FileResponse(image_path, media_type="image/png", filename=f"{hash_value}.png")
return {"error": "Image not found"}
# Launch both Gradio and FastAPI
if __name__ == "__main__":
import uvicorn
from threading import Thread
def run_gradio():
gradio_app.queue()
gradio_app.launch()
# Start Gradio in a separate thread
Thread(target=run_gradio, daemon=True).start()
# Run FastAPI with uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |