Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from fastapi import Request | |
import markdown2 | |
from pipelines.utils.safety_checker import SafetyChecker | |
from PIL import Image | |
import logging | |
from config import config, Args | |
from connection_manager import ConnectionManager, ServerFullException | |
import uuid | |
from uuid import UUID | |
import time | |
from typing import Any, Protocol, runtime_checkable | |
from util import pil_to_frame, bytes_to_pil, is_firefox, get_pipeline_class, ParamsModel | |
from device import device, torch_dtype | |
import asyncio | |
import os | |
import time | |
import torch | |
from pydantic import BaseModel, create_model | |
class BasePipeline(Protocol): | |
class Info: | |
def schema(cls) -> dict[str, Any]: | |
... | |
page_content: str | None | |
input_mode: str | |
class InputParams(ParamsModel): | |
def schema(cls) -> dict[str, Any]: | |
... | |
def dict(self) -> dict[str, Any]: | |
... | |
def predict(self, params: ParamsModel) -> Image.Image | None: | |
... | |
THROTTLE = 1.0 / 120 | |
class App: | |
def __init__(self, config: Args, pipeline_instance: BasePipeline): | |
self.args = config | |
self.pipeline = pipeline_instance | |
self.app = FastAPI() | |
self.conn_manager = ConnectionManager() | |
self.safety_checker: SafetyChecker | None = None | |
if self.args.safety_checker: | |
self.safety_checker = SafetyChecker(device=device.type) | |
self.init_app() | |
def init_app(self) -> None: | |
self.app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
async def websocket_endpoint(user_id: UUID, websocket: WebSocket) -> None: | |
try: | |
await self.conn_manager.connect( | |
user_id, websocket, self.args.max_queue_size | |
) | |
await handle_websocket_data(user_id) | |
except ServerFullException as e: | |
logging.error(f"Server Full: {e}") | |
finally: | |
await self.conn_manager.disconnect(user_id) | |
logging.info(f"User disconnected: {user_id}") | |
async def handle_websocket_data(user_id: UUID) -> None: | |
if not self.conn_manager.check_user(user_id): | |
raise HTTPException(status_code=404, detail="User not found") | |
last_time = time.time() | |
try: | |
while True: | |
if ( | |
self.args.timeout > 0 | |
and time.time() - last_time > self.args.timeout | |
): | |
await self.conn_manager.send_json( | |
user_id, | |
{ | |
"status": "timeout", | |
"message": "Your session has ended", | |
}, | |
) | |
await self.conn_manager.disconnect(user_id) | |
return | |
data = await self.conn_manager.receive_json(user_id) | |
if data is None: | |
continue | |
if data["status"] == "next_frame": | |
info = self.pipeline.Info() | |
params_data = await self.conn_manager.receive_json(user_id) | |
if params_data is None: | |
continue | |
params = self.pipeline.InputParams.model_validate(params_data) | |
if info.input_mode == "image": | |
image_data = await self.conn_manager.receive_bytes(user_id) | |
if image_data is None or len(image_data) == 0: | |
await self.conn_manager.send_json( | |
user_id, {"status": "send_frame"} | |
) | |
continue | |
# Create a new Pydantic model with the image field | |
params_dict = params.model_dump() | |
params_dict["image"] = bytes_to_pil(image_data) | |
params = self.pipeline.InputParams.model_validate(params_dict) | |
await self.conn_manager.update_data(user_id, params) | |
await self.conn_manager.send_json(user_id, {"status": "wait"}) | |
except Exception as e: | |
logging.error(f"Websocket Error: {e}, {user_id} ") | |
await self.conn_manager.disconnect(user_id) | |
async def get_queue_size() -> JSONResponse: | |
queue_size = self.conn_manager.get_user_count() | |
return JSONResponse({"queue_size": queue_size}) | |
async def stream(user_id: UUID, request: Request) -> StreamingResponse: | |
try: | |
async def generate() -> bytes: | |
last_params: ParamsModel | None = None | |
while True: | |
last_time = time.time() | |
await self.conn_manager.send_json( | |
user_id, {"status": "send_frame"} | |
) | |
params = await self.conn_manager.get_latest_data(user_id) | |
if (params is None or | |
(last_params is not None and | |
params.model_dump() == last_params.model_dump())): | |
await asyncio.sleep(THROTTLE) | |
continue | |
last_params = params | |
image = self.pipeline.predict(params) | |
if self.args.safety_checker and self.safety_checker is not None and image is not None: | |
image, has_nsfw_concept = self.safety_checker(image) | |
if has_nsfw_concept: | |
image = None | |
if image is None: | |
continue | |
frame = pil_to_frame(image) | |
yield frame | |
# https://bugs.chromium.org/p/chromium/issues/detail?id=1250396 | |
if not is_firefox(request.headers["user-agent"]): | |
yield frame | |
if self.args.debug: | |
print(f"Time taken: {time.time() - last_time}") | |
return StreamingResponse( | |
generate(), | |
media_type="multipart/x-mixed-replace;boundary=frame", | |
headers={"Cache-Control": "no-cache"}, | |
) | |
except Exception as e: | |
logging.error(f"Streaming Error: {e}, {user_id} ") | |
raise HTTPException(status_code=404, detail="User not found") | |
# route to setup frontend | |
async def settings() -> JSONResponse: | |
info_schema = self.pipeline.Info.schema() | |
info = self.pipeline.Info() | |
page_content = "" | |
if hasattr(info, 'page_content') and info.page_content: | |
page_content = markdown2.markdown(info.page_content) | |
input_params = self.pipeline.InputParams.schema() | |
return JSONResponse( | |
{ | |
"info": info_schema, | |
"input_params": input_params, | |
"max_queue_size": self.args.max_queue_size, | |
"page_content": page_content, | |
} | |
) | |
if not os.path.exists("public"): | |
os.makedirs("public") | |
self.app.mount( | |
"/", StaticFiles(directory="frontend/public", html=True), name="public" | |
) | |
# def create_app(config): | |
# print(f"Device: {device}") | |
# print(f"torch_dtype: {torch_dtype}") | |
# # Create pipeline once | |
# pipeline_class = get_pipeline_class(config.pipeline) | |
# pipeline_instance = pipeline_class(config, device, torch_dtype) | |
# # Pass the existing pipeline instance to App | |
# app = App(config, pipeline_instance).app | |
# return app | |
# Create app instance at module level | |
print(f"Device: {device}") | |
print(f"torch_dtype: {torch_dtype}") | |
pipeline_class = get_pipeline_class(config.pipeline) | |
pipeline_instance = pipeline_class(config, device, torch_dtype) | |
app = App(config, pipeline_instance).app # This creates the FastAPI app instance | |
if __name__ == "__main__": | |
import uvicorn | |
# app = create_app(config) # Create the app once | |
uvicorn.run( | |
app, | |
host=config.host, | |
port=config.port, | |
reload=config.reload, | |
ssl_certfile=config.ssl_certfile, | |
ssl_keyfile=config.ssl_keyfile, | |
) | |