cybergamer0123 commited on
Commit
0ed9e98
·
verified ·
1 Parent(s): f825c75

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -0
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ import json
5
+
6
+ import numpy as np
7
+ from fastapi import FastAPI, HTTPException, Body
8
+ from fastapi.responses import StreamingResponse, Response, HTMLResponse
9
+ from fastapi.middleware import Middleware
10
+ from fastapi.middleware.gzip import GZipMiddleware
11
+ from pydantic import BaseModel
12
+
13
+ from onnxruntime import InferenceSession
14
+ from huggingface_hub import snapshot_download
15
+ from scipy.io.wavfile import write as write_wav
16
+
17
+ from diffusers import OnnxStableDiffusionPipeline
18
+ from PIL import Image
19
+
20
+ class ImageRequest(BaseModel):
21
+ prompt: str
22
+ num_inference_steps: int = 50
23
+ guidance_scale: float = 7.5
24
+ format: str = "png" # or "jpeg"
25
+
26
+
27
+ model_repo = "runwayml/stable-diffusion-v1-5" # Or any other ONNX compatible Stable Diffusion model
28
+ model_name = "model_q4.onnx" # if specific model file needed, otherwise directory is enough
29
+ voice_file_pattern = "*.bin" # not used, keep for inspiration, remove if not needed
30
+ local_dir = "sd_onnx_models_snapshot" # different folder for sd models
31
+ snapshot_download(
32
+ repo_id=model_repo,
33
+ revision="onnx",
34
+ local_dir=local_dir,
35
+ local_dir_use_symlinks=False,
36
+ allow_patterns=["*.onnx", "*.json", "vae/*.onnx"] # Specify necessary file patterns (adjust as needed)
37
+ )
38
+
39
+
40
+ pipeline = OnnxStableDiffusionPipeline.from_pretrained(
41
+ local_dir, # Use the local path from snapshot_download
42
+ provider="CPUExecutionProvider", # Or "CUDAExecutionProvider" if you have GPU
43
+ )
44
+
45
+
46
+ app = FastAPI(
47
+ title="FastAPI Image Generation with ONNX",
48
+ middleware=[Middleware(GZipMiddleware, compresslevel=9)] # maybe compression is not needed for images? check later
49
+ )
50
+
51
+
52
+ @app.post("/generate-image/streaming", summary="Streaming Image Generation")
53
+ async def generate_image_streaming(request: ImageRequest = Body(...)):
54
+ prompt = request.prompt
55
+ num_inference_steps = request.num_inference_steps
56
+ guidance_scale = request.guidance_scale
57
+ format = request.format.lower()
58
+
59
+ def image_generator():
60
+
61
+ try:
62
+ start_time = time.time()
63
+ image = pipeline(
64
+ prompt,
65
+ num_inference_steps=num_inference_steps,
66
+ guidance_scale=guidance_scale
67
+ ).images[0]
68
+ print(f"Image generation inference time: {time.time() - start_time:.3f}s")
69
+
70
+ img_byte_arr = io.BytesIO()
71
+ image_format = format.upper() if format in ["png", "jpeg"] else "PNG" # Default to PNG if format is invalid
72
+ image.save(img_byte_arr, format=image_format)
73
+ img_byte_arr = img_byte_arr.getvalue()
74
+ yield img_byte_arr
75
+
76
+ except Exception as e:
77
+ print(f"Error processing image generation: {e}")
78
+ # yield error response? or just error out
79
+
80
+ media_type = f"image/{format}" if format in ["png", "jpeg"] else "image/png"
81
+ return StreamingResponse(
82
+ image_generator(),
83
+ media_type=media_type,
84
+ headers={"Cache-Control": "no-cache"},
85
+ )
86
+
87
+
88
+ @app.post("/generate-image/full", summary="Full Image Generation")
89
+ async def generate_image_full(request: ImageRequest = Body(...)):
90
+ prompt = request.prompt
91
+ num_inference_steps = request.num_inference_steps
92
+ guidance_scale = request.guidance_scale
93
+ format = request.format.lower()
94
+
95
+ start_time = time.time()
96
+ image = pipeline(
97
+ prompt,
98
+ num_inference_steps=num_inference_steps,
99
+ guidance_scale=guidance_scale
100
+ ).images[0]
101
+ print(f"Full Image generation inference time: {time.time()-start_time:.3f}s")
102
+
103
+
104
+ img_byte_arr = io.BytesIO()
105
+ image_format = format.upper() if format in ["png", "jpeg"] else "PNG"
106
+ image.save(img_byte_arr, format=image_format)
107
+ img_byte_arr.seek(0)
108
+
109
+
110
+ media_type = f"image/{format}" if format in ["png", "jpeg"] else "image/png"
111
+ return Response(content=img_byte_arr.read(), media_type=media_type)
112
+
113
+
114
+ @app.get("/", response_class=HTMLResponse)
115
+ def index():
116
+ return """
117
+ <!DOCTYPE html>
118
+ <html>
119
+ <head>
120
+ <title>FastAPI Image Generation Demo</title>
121
+ <style>
122
+ body { font-family: Arial, sans-serif; }
123
+ .container { width: 80%; margin: auto; padding-top: 20px; }
124
+ h1 { text-align: center; }
125
+ .form-group { margin-bottom: 15px; }
126
+ label { display: block; margin-bottom: 5px; font-weight: bold; }
127
+ input[type="text"], input[type="number"], textarea, select { width: 100%; padding: 8px; box-sizing: border-box; margin-bottom: 10px; border: 1px solid #ccc; border-radius: 4px; }
128
+ textarea { height: 100px; }
129
+ button { padding: 10px 15px; border: none; color: white; background-color: #007bff; border-radius: 4px; cursor: pointer; }
130
+ button:hover { background-color: #0056b3; }
131
+ img { display: block; margin-top: 20px; max-width: 500px; } /* Adjust max-width as needed */
132
+ </style>
133
+ </head>
134
+ <body>
135
+ <div class="container">
136
+ <h1>FastAPI Image Generation Demo</h1>
137
+ <div class="form-group">
138
+ <label for="prompt">Text Prompt:</label>
139
+ <textarea id="prompt" rows="4" placeholder="Enter text prompt here"></textarea>
140
+ </div>
141
+ <div class="form-group">
142
+ <label for="num_inference_steps">Number of Inference Steps:</label>
143
+ <input type="number" id="num_inference_steps" value="50">
144
+ </div>
145
+ <div class="form-group">
146
+ <label for="guidance_scale">Guidance Scale:</label>
147
+ <input type="number" step="0.5" id="guidance_scale" value="7.5">
148
+ </div>
149
+ <div class="form-group">
150
+ <label for="format">Format:</label>
151
+ <select id="format">
152
+ <option value="png" selected>PNG</option>
153
+ <option value="jpeg">JPEG</option>
154
+ </select>
155
+ </div>
156
+ <div class="form-group">
157
+ <button onclick="generateStreamingImage()">Generate Streaming Image</button>
158
+ <button onclick="generateFullImage()">Generate Full Image</button>
159
+ </div>
160
+ <div id="image-container">
161
+ <img id="image" src="#" alt="Generated Image" style="display:none;">
162
+ </div>
163
+ </div>
164
+ <script>
165
+ function generateStreamingImage() {
166
+ const prompt = document.getElementById('prompt').value;
167
+ const num_inference_steps = document.getElementById('num_inference_steps').value;
168
+ const guidance_scale = document.getElementById('guidance_scale').value;
169
+ const format = document.getElementById('format').value;
170
+ const imageElement = document.getElementById('image');
171
+ const imageContainer = document.getElementById('image-container');
172
+
173
+ fetch('/generate-image/streaming', {
174
+ method: 'POST',
175
+ headers: {
176
+ 'Content-Type': 'application/json'
177
+ },
178
+ body: JSON.stringify({
179
+ prompt: prompt,
180
+ num_inference_steps: parseInt(num_inference_steps),
181
+ guidance_scale: parseFloat(guidance_scale),
182
+ format: format
183
+ })
184
+ })
185
+ .then(response => response.blob())
186
+ .then(blob => {
187
+ const imageUrl = URL.createObjectURL(blob);
188
+ imageElement.src = imageUrl;
189
+ imageElement.style.display = 'block'; // Show the image
190
+ imageContainer.style.display = 'block'; // Show the container if hidden
191
+ });
192
+ }
193
+
194
+ function generateFullImage() {
195
+ const prompt = document.getElementById('prompt').value;
196
+ const num_inference_steps = document.getElementById('num_inference_steps').value;
197
+ const guidance_scale = document.getElementById('guidance_scale').value;
198
+ const format = document.getElementById('format').value;
199
+ const imageElement = document.getElementById('image');
200
+ const imageContainer = document.getElementById('image-container');
201
+
202
+
203
+ fetch('/generate-image/full', {
204
+ method: 'POST',
205
+ headers: {
206
+ 'Content-Type': 'application/json'
207
+ },
208
+ body: JSON.stringify({
209
+ prompt: prompt,
210
+ num_inference_steps: parseInt(num_inference_steps),
211
+ guidance_scale: parseFloat(guidance_scale),
212
+ format: format
213
+ })
214
+ })
215
+ .then(response => response.blob())
216
+ .then(blob => {
217
+ const imageUrl = URL.createObjectURL(blob);
218
+ imageElement.src = imageUrl;
219
+ imageElement.style.display = 'block'; // Show the image
220
+ imageContainer.style.display = 'block'; // Show the container if hidden
221
+ });
222
+ }
223
+ </script>
224
+ </body>
225
+ </html>
226
+ """
227
+
228
+ if __name__ == "__main__":
229
+ import uvicorn
230
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)