import os import gradio as gr import torch from diffusers import DiffusionPipeline, AutoencoderTiny import random import spaces from PIL import Image from fastapi import FastAPI from fastapi.responses import FileResponse # Initialize the base model dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" base_model = "black-forest-labs/FLUX.1-dev" taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device) MAX_SEED = 2**32-1 # Create a FastAPI app app = FastAPI() @spaces.GPU() def generate_image(prompt, width, height, lora_path, trigger_word, hash_value): # Load LoRA weights pipe.load_lora_weights(lora_path) # Combine prompt with trigger word full_prompt = f"{trigger_word} {prompt}" # Set up generation parameters seed = random.randint(0, MAX_SEED) generator = torch.Generator(device="cuda").manual_seed(seed) # Generate image image = pipe( prompt=full_prompt, num_inference_steps=28, guidance_scale=3.5, width=width, height=height, generator=generator, ).images[0] # Create ./tmp/ directory if it doesn't exist os.makedirs('./tmp', exist_ok=True) # Save the image with the provided hash image_path = f'./tmp/{hash_value}.png' image.save(image_path) return image, image_path def run_lora(prompt, width, height, lora_path, trigger_word, hash_value): image, image_path = generate_image(prompt, width, height, lora_path, trigger_word, hash_value) return image, image_path # Set up the Gradio interface with gr.Blocks() as gradio_app: gr.Markdown("# LoRA Image Generator") with gr.Row(): prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here") with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=512) height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=512) with gr.Row(): lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2") trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK") hash_value = gr.Textbox(label="Hash", placeholder="Enter a unique hash for the image") generate_button = gr.Button("Generate Image") output_image = gr.Image(label="Generated Image") output_path = gr.Textbox(label="Saved Image Path") generate_button.click( fn=run_lora, inputs=[prompt, width, height, lora_path, trigger_word, hash_value], outputs=[output_image, output_path] ) # FastAPI endpoint for downloading images @app.get("/download/{hash_value}") async def download_image(hash_value: str): image_path = f"./tmp/{hash_value}.png" if os.path.exists(image_path): return FileResponse(image_path, media_type="image/png", filename=f"{hash_value}.png") return {"error": "Image not found"} # Launch both Gradio and FastAPI if __name__ == "__main__": import uvicorn from threading import Thread def run_gradio(): gradio_app.queue() gradio_app.launch() # Start Gradio in a separate thread Thread(target=run_gradio, daemon=True).start() # Run FastAPI with uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)