import os import gradio as gr import torch from diffusers import DiffusionPipeline, AutoencoderTiny import random import spaces from PIL import Image from fastapi import FastAPI from fastapi.responses import FileResponse import transformers transformers.utils.move_cache() # Initialize the base model dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" base_model = "black-forest-labs/FLUX.1-dev" taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device) MAX_SEED = 2**32-1 # Create a FastAPI app app = FastAPI() @spaces.GPU() def generate_image(prompt, width, height, lora_path, trigger_word, hash_value, steps): # Load LoRA weights pipe.load_lora_weights(lora_path) # Combine prompt with trigger word full_prompt = f"{trigger_word} {prompt}" # Set up generation parameters seed = random.randint(0, MAX_SEED) generator = torch.Generator(device="cuda").manual_seed(seed) # Generate image image = pipe( prompt=full_prompt, num_inference_steps=steps, guidance_scale=3.5, width=width, height=height, generator=generator, ).images[0] # Create ./tmp/ directory if it doesn't exist os.makedirs('./tmp', exist_ok=True) # Save the image with the provided hash image_path = f'./tmp/{hash_value}.png' image.save(image_path) return image, image_path def run_lora(prompt, width, height, lora_path, trigger_word, hash_value, steps): image, image_path = generate_image(prompt, width, height, lora_path, trigger_word, hash_value, steps) return image, image_path # Set up the Gradio interface with gr.Blocks() as gradio_app: gr.Markdown("# LoRA Image Generator") with gr.Row(): prompt = gr.Textbox(label="Prompt", lines=3, placeholder="Enter your prompt here") with gr.Row(): width = gr.Slider(label="Width", minimum=128, maximum=1024, step=64, value=512) height = gr.Slider(label="Height", minimum=128, maximum=1024, step=64, value=512) steps = gr.Slider(label="Inference Steps", minimum=1, maximum=100, step=1, value=28) with gr.Row(): lora_path = gr.Textbox(label="LoRA Path", value="SebastianBodza/Flux_Aquarell_Watercolor_v2") trigger_word = gr.Textbox(label="Trigger Word", value="AQUACOLTOK") hash_value = gr.Textbox(label="Hash", placeholder="Enter a unique hash for the image") generate_button = gr.Button("Generate Image") output_image = gr.Image(label="Generated Image") output_path = gr.Textbox(label="Saved Image Path") generate_button.click( fn=run_lora, inputs=[prompt, width, height, lora_path, trigger_word, hash_value, steps], outputs=[output_image, output_path] ) # FastAPI endpoint for downloading images @app.get("/download/{hash_value}") async def download_image(hash_value: str): image_path = f"./tmp/{hash_value}.png" if os.path.exists(image_path): return FileResponse(image_path, media_type="image/png", filename=f"{hash_value}.png") return {"error": "Image not found"} # Launch both Gradio and FastAPI if __name__ == "__main__": import uvicorn from threading import Thread def run_gradio(): gradio_app.queue() gradio_app.launch() # Start Gradio in a separate thread Thread(target=run_gradio, daemon=True).start() # Run FastAPI with uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)