Spaces:
Runtime error
Runtime error
import io | |
import os | |
import time | |
import json | |
import numpy as np | |
from fastapi import FastAPI, HTTPException, Body | |
from fastapi.responses import StreamingResponse, Response, HTMLResponse | |
from fastapi.middleware import Middleware | |
from fastapi.middleware.gzip import GZipMiddleware | |
from pydantic import BaseModel | |
from onnxruntime import InferenceSession | |
from huggingface_hub import snapshot_download | |
from scipy.io.wavfile import write as write_wav | |
from diffusers import OnnxStableDiffusionPipeline | |
from PIL import Image | |
class ImageRequest(BaseModel): | |
prompt: str | |
num_inference_steps: int = 50 | |
guidance_scale: float = 7.5 | |
format: str = "png" # or "jpeg" | |
model_repo = "runwayml/stable-diffusion-v1-5" # Or any other ONNX compatible Stable Diffusion model | |
model_name = "model_q4.onnx" # if specific model file needed, otherwise directory is enough | |
voice_file_pattern = "*.bin" # not used, keep for inspiration, remove if not needed | |
local_dir = "sd_onnx_models_snapshot" # different folder for sd models | |
snapshot_download( | |
repo_id=model_repo, | |
revision="onnx", | |
local_dir=local_dir, | |
local_dir_use_symlinks=False, | |
allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed) | |
) | |
pipeline = OnnxStableDiffusionPipeline.from_pretrained( | |
local_dir, # Use the local path from snapshot_download | |
provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU | |
) | |
app = FastAPI( | |
title="FastAPI Image Generation with ONNX", | |
middleware=[Middleware(GZipMiddleware, compresslevel=9)] # maybe compression is not needed for images? check later | |
) | |
async def generate_image_streaming(request: ImageRequest = Body(...)): | |
prompt = request.prompt | |
num_inference_steps = request.num_inference_steps | |
guidance_scale = request.guidance_scale | |
format = request.format.lower() | |
def image_generator(): | |
try: | |
start_time = time.time() | |
image = pipeline( | |
prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale | |
).images[0] | |
print(f"Image generation inference time: {time.time() - start_time:.3f}s") | |
img_byte_arr = io.BytesIO() | |
image_format = format.upper() if format in ["png", "jpeg"] else "PNG" # Default to PNG if format is invalid | |
image.save(img_byte_arr, format=image_format) | |
img_byte_arr = img_byte_arr.getvalue() | |
yield img_byte_arr | |
except Exception as e: | |
print(f"Error processing image generation: {e}") | |
# yield error response? or just error out | |
media_type = f"image/{format}" if format in ["png", "jpeg"] else "image/png" | |
return StreamingResponse( | |
image_generator(), | |
media_type=media_type, | |
headers={"Cache-Control": "no-cache"}, | |
) | |
async def generate_image_full(request: ImageRequest = Body(...)): | |
prompt = request.prompt | |
num_inference_steps = request.num_inference_steps | |
guidance_scale = request.guidance_scale | |
format = request.format.lower() | |
start_time = time.time() | |
image = pipeline( | |
prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale | |
).images[0] | |
print(f"Full Image generation inference time: {time.time()-start_time:.3f}s") | |
img_byte_arr = io.BytesIO() | |
image_format = format.upper() if format in ["png", "jpeg"] else "PNG" | |
image.save(img_byte_arr, format=image_format) | |
img_byte_arr.seek(0) | |
media_type = f"image/{format}" if format in ["png", "jpeg"] else "image/png" | |
return Response(content=img_byte_arr.read(), media_type=media_type) | |
def index(): | |
return """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>FastAPI Image Generation Demo</title> | |
<style> | |
body { font-family: Arial, sans-serif; } | |
.container { width: 80%; margin: auto; padding-top: 20px; } | |
h1 { text-align: center; } | |
.form-group { margin-bottom: 15px; } | |
label { display: block; margin-bottom: 5px; font-weight: bold; } | |
input[type="text"], input[type="number"], textarea, select { width: 100%; padding: 8px; box-sizing: border-box; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; } | |
textarea { height: 100px; } | |
button { padding: 10px 15px; border: none; color: white; background-color: #007bff; border-radius: 4px; cursor: pointer; } | |
button:hover { background-color: #0056b3; } | |
img { display: block; margin-top: 20px; max-width: 500px; } /* Adjust max-width as needed */ | |
</style> | |
</head> | |
<body> | |
<div class="container"> | |
<h1>FastAPI Image Generation Demo</h1> | |
<div class="form-group"> | |
<label for="prompt">Text Prompt:</label> | |
<textarea id="prompt" rows="4" placeholder="Enter text prompt here"></textarea> | |
</div> | |
<div class="form-group"> | |
<label for="num_inference_steps">Number of Inference Steps:</label> | |
<input type="number" id="num_inference_steps" value="50"> | |
</div> | |
<div class="form-group"> | |
<label for="guidance_scale">Guidance Scale:</label> | |
<input type="number" step="0.5" id="guidance_scale" value="7.5"> | |
</div> | |
<div class="form-group"> | |
<label for="format">Format:</label> | |
<select id="format"> | |
<option value="png" selected>PNG</option> | |
<option value="jpeg">JPEG</option> | |
</select> | |
</div> | |
<div class="form-group"> | |
<button onclick="generateStreamingImage()">Generate Streaming Image</button> | |
<button onclick="generateFullImage()">Generate Full Image</button> | |
</div> | |
<div id="image-container"> | |
<img id="image" src="#" alt="Generated Image" style="display:none;"> | |
</div> | |
</div> | |
<script> | |
function generateStreamingImage() { | |
const prompt = document.getElementById('prompt').value; | |
const num_inference_steps = document.getElementById('num_inference_steps').value; | |
const guidance_scale = document.getElementById('guidance_scale').value; | |
const format = document.getElementById('format').value; | |
const imageElement = document.getElementById('image'); | |
const imageContainer = document.getElementById('image-container'); | |
fetch('/generate-image/streaming', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json' | |
}, | |
body: JSON.stringify({ | |
prompt: prompt, | |
num_inference_steps: parseInt(num_inference_steps), | |
guidance_scale: parseFloat(guidance_scale), | |
format: format | |
}) | |
}) | |
.then(response => response.blob()) | |
.then(blob => { | |
const imageUrl = URL.createObjectURL(blob); | |
imageElement.src = imageUrl; | |
imageElement.style.display = 'block'; // Show the image | |
imageContainer.style.display = 'block'; // Show the container if hidden | |
}); | |
} | |
function generateFullImage() { | |
const prompt = document.getElementById('prompt').value; | |
const num_inference_steps = document.getElementById('num_inference_steps').value; | |
const guidance_scale = document.getElementById('guidance_scale').value; | |
const format = document.getElementById('format').value; | |
const imageElement = document.getElementById('image'); | |
const imageContainer = document.getElementById('image-container'); | |
fetch('/generate-image/full', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json' | |
}, | |
body: JSON.stringify({ | |
prompt: prompt, | |
num_inference_steps: parseInt(num_inference_steps), | |
guidance_scale: parseFloat(guidance_scale), | |
format: format | |
}) | |
}) | |
.then(response => response.blob()) | |
.then(blob => { | |
const imageUrl = URL.createObjectURL(blob); | |
imageElement.src = imageUrl; | |
imageElement.style.display = 'block'; // Show the image | |
imageContainer.style.display = 'block'; // Show the container if hidden | |
}); | |
} | |
</script> | |
</body> | |
</html> | |
""" | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |