imagen / app.py
cybergamer0123's picture
Create app.py
0ed9e98 verified
raw
history blame
9.34 kB
import io
import os
import time
import json
import numpy as np
from fastapi import FastAPI, HTTPException, Body
from fastapi.responses import StreamingResponse, Response, HTMLResponse
from fastapi.middleware import Middleware
from fastapi.middleware.gzip import GZipMiddleware
from pydantic import BaseModel
from onnxruntime import InferenceSession
from huggingface_hub import snapshot_download
from scipy.io.wavfile import write as write_wav
from diffusers import OnnxStableDiffusionPipeline
from PIL import Image
class ImageRequest(BaseModel):
prompt: str
num_inference_steps: int = 50
guidance_scale: float = 7.5
format: str = "png" # or "jpeg"
model_repo = "runwayml/stable-diffusion-v1-5" # Or any other ONNX compatible Stable Diffusion model
model_name = "model_q4.onnx" # if specific model file needed, otherwise directory is enough
voice_file_pattern = "*.bin" # not used, keep for inspiration, remove if not needed
local_dir = "sd_onnx_models_snapshot" # different folder for sd models
snapshot_download(
repo_id=model_repo,
revision="onnx",
local_dir=local_dir,
local_dir_use_symlinks=False,
allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed)
)
pipeline = OnnxStableDiffusionPipeline.from_pretrained(
local_dir, # Use the local path from snapshot_download
provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
)
app = FastAPI(
title="FastAPI Image Generation with ONNX",
middleware=[Middleware(GZipMiddleware, compresslevel=9)] # maybe compression is not needed for images? check later
)
@app.post("/generate-image/streaming", summary="Streaming Image Generation")
async def generate_image_streaming(request: ImageRequest = Body(...)):
prompt = request.prompt
num_inference_steps = request.num_inference_steps
guidance_scale = request.guidance_scale
format = request.format.lower()
def image_generator():
try:
start_time = time.time()
image = pipeline(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale
).images[0]
print(f"Image generation inference time: {time.time() - start_time:.3f}s")
img_byte_arr = io.BytesIO()
image_format = format.upper() if format in ["png", "jpeg"] else "PNG" # Default to PNG if format is invalid
image.save(img_byte_arr, format=image_format)
img_byte_arr = img_byte_arr.getvalue()
yield img_byte_arr
except Exception as e:
print(f"Error processing image generation: {e}")
# yield error response? or just error out
media_type = f"image/{format}" if format in ["png", "jpeg"] else "image/png"
return StreamingResponse(
image_generator(),
media_type=media_type,
headers={"Cache-Control": "no-cache"},
)
@app.post("/generate-image/full", summary="Full Image Generation")
async def generate_image_full(request: ImageRequest = Body(...)):
prompt = request.prompt
num_inference_steps = request.num_inference_steps
guidance_scale = request.guidance_scale
format = request.format.lower()
start_time = time.time()
image = pipeline(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale
).images[0]
print(f"Full Image generation inference time: {time.time()-start_time:.3f}s")
img_byte_arr = io.BytesIO()
image_format = format.upper() if format in ["png", "jpeg"] else "PNG"
image.save(img_byte_arr, format=image_format)
img_byte_arr.seek(0)
media_type = f"image/{format}" if format in ["png", "jpeg"] else "image/png"
return Response(content=img_byte_arr.read(), media_type=media_type)
@app.get("/", response_class=HTMLResponse)
def index():
return """
<!DOCTYPE html>
<html>
<head>
<title>FastAPI Image Generation Demo</title>
<style>
body { font-family: Arial, sans-serif; }
.container { width: 80%; margin: auto; padding-top: 20px; }
h1 { text-align: center; }
.form-group { margin-bottom: 15px; }
label { display: block; margin-bottom: 5px; font-weight: bold; }
input[type="text"], input[type="number"], textarea, select { width: 100%; padding: 8px; box-sizing: border-box; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; }
textarea { height: 100px; }
button { padding: 10px 15px; border: none; color: white; background-color: #007bff; border-radius: 4px; cursor: pointer; }
button:hover { background-color: #0056b3; }
img { display: block; margin-top: 20px; max-width: 500px; } /* Adjust max-width as needed */
</style>
</head>
<body>
<div class="container">
<h1>FastAPI Image Generation Demo</h1>
<div class="form-group">
<label for="prompt">Text Prompt:</label>
<textarea id="prompt" rows="4" placeholder="Enter text prompt here"></textarea>
</div>
<div class="form-group">
<label for="num_inference_steps">Number of Inference Steps:</label>
<input type="number" id="num_inference_steps" value="50">
</div>
<div class="form-group">
<label for="guidance_scale">Guidance Scale:</label>
<input type="number" step="0.5" id="guidance_scale" value="7.5">
</div>
<div class="form-group">
<label for="format">Format:</label>
<select id="format">
<option value="png" selected>PNG</option>
<option value="jpeg">JPEG</option>
</select>
</div>
<div class="form-group">
<button onclick="generateStreamingImage()">Generate Streaming Image</button>
<button onclick="generateFullImage()">Generate Full Image</button>
</div>
<div id="image-container">
<img id="image" src="#" alt="Generated Image" style="display:none;">
</div>
</div>
<script>
function generateStreamingImage() {
const prompt = document.getElementById('prompt').value;
const num_inference_steps = document.getElementById('num_inference_steps').value;
const guidance_scale = document.getElementById('guidance_scale').value;
const format = document.getElementById('format').value;
const imageElement = document.getElementById('image');
const imageContainer = document.getElementById('image-container');
fetch('/generate-image/streaming', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
prompt: prompt,
num_inference_steps: parseInt(num_inference_steps),
guidance_scale: parseFloat(guidance_scale),
format: format
})
})
.then(response => response.blob())
.then(blob => {
const imageUrl = URL.createObjectURL(blob);
imageElement.src = imageUrl;
imageElement.style.display = 'block'; // Show the image
imageContainer.style.display = 'block'; // Show the container if hidden
});
}
function generateFullImage() {
const prompt = document.getElementById('prompt').value;
const num_inference_steps = document.getElementById('num_inference_steps').value;
const guidance_scale = document.getElementById('guidance_scale').value;
const format = document.getElementById('format').value;
const imageElement = document.getElementById('image');
const imageContainer = document.getElementById('image-container');
fetch('/generate-image/full', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
prompt: prompt,
num_inference_steps: parseInt(num_inference_steps),
guidance_scale: parseFloat(guidance_scale),
format: format
})
})
.then(response => response.blob())
.then(blob => {
const imageUrl = URL.createObjectURL(blob);
imageElement.src = imageUrl;
imageElement.style.display = 'block'; // Show the image
imageContainer.style.display = 'block'; // Show the container if hidden
});
}
</script>
</body>
</html>
"""
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)