Spaces:
Running
Running
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import PlainTextResponse, JSONResponse | |
| from loguru import logger | |
| class OpenAIError(Exception): | |
| pass | |
| class APIError(OpenAIError): | |
| message: str | |
| code: str = None | |
| param: str = None | |
| type: str = None | |
| def __init__(self, message: str, code: int = 500, param: str = None, internal_message: str = ''): | |
| super().__init__(message) | |
| self.message = message | |
| self.code = code | |
| self.param = param | |
| self.type = self.__class__.__name__, | |
| self.internal_message = internal_message | |
| def __repr__(self): | |
| return "%s(message=%r, code=%d, param=%s)" % ( | |
| self.__class__.__name__, | |
| self.message, | |
| self.code, | |
| self.param, | |
| ) | |
| class InternalServerError(APIError): | |
| pass | |
| class ServiceUnavailableError(APIError): | |
| def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''): | |
| super().__init__(message, code, internal_message) | |
| class APIStatusError(APIError): | |
| status_code: int = 400 | |
| def __init__(self, message: str, param: str = None, internal_message: str = ''): | |
| super().__init__(message, self.status_code, param, internal_message) | |
| class BadRequestError(APIStatusError): | |
| status_code: int = 400 | |
| class AuthenticationError(APIStatusError): | |
| status_code: int = 401 | |
| class PermissionDeniedError(APIStatusError): | |
| status_code: int = 403 | |
| class NotFoundError(APIStatusError): | |
| status_code: int = 404 | |
| class ConflictError(APIStatusError): | |
| status_code: int = 409 | |
| class UnprocessableEntityError(APIStatusError): | |
| status_code: int = 422 | |
| class RateLimitError(APIStatusError): | |
| status_code: int = 429 | |
| class OpenAIStub(FastAPI): | |
| def __init__(self, **kwargs) -> None: | |
| super().__init__(**kwargs) | |
| self.models = {} | |
| self.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"] | |
| ) | |
| def openai_exception_handler(request: Request, exc: Exception) -> JSONResponse: | |
| # Generic server errors | |
| #logger.opt(exception=exc).error("Logging exception traceback") | |
| return JSONResponse(status_code=500, content={ | |
| 'message': 'InternalServerError', | |
| 'code': 500, | |
| }) | |
| def openai_apierror_handler(request: Request, exc: APIError) -> JSONResponse: | |
| # Server error | |
| logger.opt(exception=exc).error("Logging exception traceback") | |
| if exc.internal_message: | |
| logger.info(exc.internal_message) | |
| return JSONResponse(status_code = exc.code, content={ | |
| 'message': exc.message, | |
| 'code': exc.code, | |
| 'type': exc.__class__.__name__, | |
| 'param': exc.param, | |
| }) | |
| def openai_statuserror_handler(request: Request, exc: APIStatusError) -> JSONResponse: | |
| # Client side error | |
| logger.info(repr(exc)) | |
| if exc.internal_message: | |
| logger.info(exc.internal_message) | |
| return JSONResponse(status_code = exc.code, content={ | |
| 'message': exc.message, | |
| 'code': exc.code, | |
| 'type': exc.__class__.__name__, | |
| 'param': exc.param, | |
| }) | |
| async def log_requests(request: Request, call_next): | |
| logger.debug(f"Request path: {request.url.path}") | |
| logger.debug(f"Request method: {request.method}") | |
| logger.debug(f"Request headers: {request.headers}") | |
| logger.debug(f"Request query params: {request.query_params}") | |
| logger.debug(f"Request body: {await request.body()}") | |
| response = await call_next(request) | |
| logger.debug(f"Response status code: {response.status_code}") | |
| logger.debug(f"Response headers: {response.headers}") | |
| return response | |
| async def handle_billing_usage(): | |
| return { 'total_usage': 0 } | |
| async def root(): | |
| return PlainTextResponse(content="", status_code=200 if self.models else 503) | |
| async def health(): | |
| return {"status": "ok" if self.models else "unk" } | |
| async def get_model_list(): | |
| return self.model_list() | |
| async def get_model_info(model_id: str): | |
| return self.model_info(model_id) | |
| def register_model(self, name: str, model: str = None) -> None: | |
| self.models[name] = model if model else name | |
| def deregister_model(self, name: str) -> None: | |
| if name in self.models: | |
| del self.models[name] | |
| def model_info(self, model: str) -> dict: | |
| result = { | |
| "id": model, | |
| "object": "model", | |
| "created": 0, | |
| "owned_by": "user" | |
| } | |
| return result | |
| def model_list(self) -> dict: | |
| if not self.models: | |
| return {} | |
| result = { | |
| "object": "list", | |
| "data": [ self.model_info(model) for model in list(set(self.models.keys() | self.models.values())) if model ] | |
| } | |
| return result | |