Spaces:
Runtime error
Runtime error
extra msgs
Browse files- app.py +38 -32
- public/index.html +7 -5
app.py
CHANGED
|
@@ -55,7 +55,7 @@ def predict(input_image, prompt, guidance_scale=8.0, strength=0.5, seed=2159232)
|
|
| 55 |
strength=strength,
|
| 56 |
num_inference_steps=num_inference_steps,
|
| 57 |
guidance_scale=guidance_scale,
|
| 58 |
-
lcm_origin_steps=
|
| 59 |
output_type="pil",
|
| 60 |
)
|
| 61 |
nsfw_content_detected = (
|
|
@@ -106,9 +106,12 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 106 |
"queue": asyncio.Queue(),
|
| 107 |
"params": params,
|
| 108 |
}
|
|
|
|
|
|
|
|
|
|
| 109 |
await handle_websocket_data(websocket, uid)
|
| 110 |
except WebSocketDisconnect as e:
|
| 111 |
-
logging.error(f"Error: {e}")
|
| 112 |
traceback.print_exc()
|
| 113 |
finally:
|
| 114 |
print(f"User disconnected: {uid}")
|
|
@@ -131,36 +134,39 @@ async def get_queue_size():
|
|
| 131 |
@app.get("/stream/{user_id}")
|
| 132 |
async def stream(user_id: uuid.UUID):
|
| 133 |
uid = str(user_id)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
generate(), media_type="multipart/x-mixed-replace;boundary=frame"
|
| 163 |
-
)
|
| 164 |
|
| 165 |
|
| 166 |
async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
|
|
@@ -182,7 +188,7 @@ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
|
|
| 182 |
continue
|
| 183 |
await queue.put(pil_image)
|
| 184 |
if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
|
| 185 |
-
await websocket.send_json(
|
| 186 |
{
|
| 187 |
"status": "timeout",
|
| 188 |
"message": "Your session has ended",
|
|
|
|
| 55 |
strength=strength,
|
| 56 |
num_inference_steps=num_inference_steps,
|
| 57 |
guidance_scale=guidance_scale,
|
| 58 |
+
lcm_origin_steps=30,
|
| 59 |
output_type="pil",
|
| 60 |
)
|
| 61 |
nsfw_content_detected = (
|
|
|
|
| 106 |
"queue": asyncio.Queue(),
|
| 107 |
"params": params,
|
| 108 |
}
|
| 109 |
+
await websocket.send_json(
|
| 110 |
+
{"status": "start", "message": "Start Streaming", "userId": uid}
|
| 111 |
+
)
|
| 112 |
await handle_websocket_data(websocket, uid)
|
| 113 |
except WebSocketDisconnect as e:
|
| 114 |
+
logging.error(f"WebSocket Error: {e}, {uid}")
|
| 115 |
traceback.print_exc()
|
| 116 |
finally:
|
| 117 |
print(f"User disconnected: {uid}")
|
|
|
|
| 134 |
@app.get("/stream/{user_id}")
|
| 135 |
async def stream(user_id: uuid.UUID):
|
| 136 |
uid = str(user_id)
|
| 137 |
+
try:
|
| 138 |
+
user_queue = user_queue_map[uid]
|
| 139 |
+
queue = user_queue["queue"]
|
| 140 |
+
params = user_queue["params"]
|
| 141 |
+
seed = params.seed
|
| 142 |
+
prompt = params.prompt
|
| 143 |
+
strength = params.strength
|
| 144 |
+
guidance_scale = params.guidance_scale
|
| 145 |
+
|
| 146 |
+
async def generate():
|
| 147 |
+
while True:
|
| 148 |
+
input_image = await queue.get()
|
| 149 |
+
if input_image is None:
|
| 150 |
+
continue
|
| 151 |
|
| 152 |
+
image = predict(input_image, prompt, guidance_scale, strength, seed)
|
| 153 |
+
if image is None:
|
| 154 |
+
continue
|
| 155 |
+
frame_data = io.BytesIO()
|
| 156 |
+
image.save(frame_data, format="JPEG")
|
| 157 |
+
frame_data = frame_data.getvalue()
|
| 158 |
+
if frame_data is not None and len(frame_data) > 0:
|
| 159 |
+
yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
|
| 160 |
+
|
| 161 |
+
await asyncio.sleep(1.0 / 120.0)
|
| 162 |
+
|
| 163 |
+
return StreamingResponse(
|
| 164 |
+
generate(), media_type="multipart/x-mixed-replace;boundary=frame"
|
| 165 |
+
)
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logging.error(f"Streaming Error: {e}, {user_queue_map}")
|
| 168 |
+
traceback.print_exc()
|
| 169 |
+
return HTTPException(status_code=404, detail="User not found")
|
|
|
|
|
|
|
| 170 |
|
| 171 |
|
| 172 |
async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
|
|
|
|
| 188 |
continue
|
| 189 |
await queue.put(pil_image)
|
| 190 |
if TIMEOUT > 0 and time.time() - last_time > TIMEOUT:
|
| 191 |
+
await websocket.send_json(
|
| 192 |
{
|
| 193 |
"status": "timeout",
|
| 194 |
"message": "Your session has ended",
|
public/index.html
CHANGED
|
@@ -47,9 +47,10 @@
|
|
| 47 |
switch (data.status) {
|
| 48 |
case "success":
|
| 49 |
socket.send(JSON.stringify(params));
|
|
|
|
|
|
|
| 50 |
const userId = data.userId;
|
| 51 |
-
|
| 52 |
-
initVideoStream();
|
| 53 |
break;
|
| 54 |
case "timeout":
|
| 55 |
stop();
|
|
@@ -72,7 +73,8 @@
|
|
| 72 |
websocket.send(blob);
|
| 73 |
}
|
| 74 |
|
| 75 |
-
function initVideoStream() {
|
|
|
|
| 76 |
const constraints = {
|
| 77 |
audio: false,
|
| 78 |
video: { width: 512, height: 512 },
|
|
@@ -118,7 +120,7 @@
|
|
| 118 |
}
|
| 119 |
setTimeout(() => {
|
| 120 |
errorEl.hidden = true;
|
| 121 |
-
},
|
| 122 |
}
|
| 123 |
|
| 124 |
|
|
@@ -154,7 +156,7 @@
|
|
| 154 |
.catch((err) => {
|
| 155 |
console.log(err);
|
| 156 |
})
|
| 157 |
-
,
|
| 158 |
</script>
|
| 159 |
</head>
|
| 160 |
|
|
|
|
| 47 |
switch (data.status) {
|
| 48 |
case "success":
|
| 49 |
socket.send(JSON.stringify(params));
|
| 50 |
+
break;
|
| 51 |
+
case "start":
|
| 52 |
const userId = data.userId;
|
| 53 |
+
initVideoStream(userId);
|
|
|
|
| 54 |
break;
|
| 55 |
case "timeout":
|
| 56 |
stop();
|
|
|
|
| 73 |
websocket.send(blob);
|
| 74 |
}
|
| 75 |
|
| 76 |
+
function initVideoStream(userId) {
|
| 77 |
+
liveImage.src = `/stream/${userId}`;
|
| 78 |
const constraints = {
|
| 79 |
audio: false,
|
| 80 |
video: { width: 512, height: 512 },
|
|
|
|
| 120 |
}
|
| 121 |
setTimeout(() => {
|
| 122 |
errorEl.hidden = true;
|
| 123 |
+
}, 2000);
|
| 124 |
}
|
| 125 |
|
| 126 |
|
|
|
|
| 156 |
.catch((err) => {
|
| 157 |
console.log(err);
|
| 158 |
})
|
| 159 |
+
, 5000);
|
| 160 |
</script>
|
| 161 |
</head>
|
| 162 |
|