Spaces:
Paused
Paused
| import uvicorn | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel, Field | |
| from sse_starlette.sse import EventSourceResponse | |
| from utils.logger import logger | |
| from networks.message_streamer import MessageStreamer | |
| from messagers.message_composer import MessageComposer | |
| class ChatAPIApp: | |
| def __init__(self): | |
| self.app = FastAPI( | |
| docs_url="/", | |
| title="HuggingFace LLM API", | |
| swagger_ui_parameters={"defaultModelsExpandDepth": -1}, | |
| version="1.0", | |
| ) | |
| self.setup_routes() | |
| def get_available_models(self): | |
| self.available_models = [ | |
| { | |
| "id": "mixtral-8x7b", | |
| "description": "Mixtral-8x7B: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| }, | |
| ] | |
| return self.available_models | |
| class ChatCompletionsPostItem(BaseModel): | |
| model: str = Field( | |
| default="mixtral-8x7b", | |
| description="(str) `mixtral-8x7b`", | |
| ) | |
| messages: list = Field( | |
| default=[{"role": "user", "content": "Hello, who are you?"}], | |
| description="(list) Messages", | |
| ) | |
| temperature: float = Field( | |
| default=0.01, | |
| description="(float) Temperature", | |
| ) | |
| max_tokens: int = Field( | |
| default=32000, | |
| description="(int) Max tokens", | |
| ) | |
| stream: bool = Field( | |
| default=True, | |
| description="(bool) Stream", | |
| ) | |
| def chat_completions(self, item: ChatCompletionsPostItem): | |
| streamer = MessageStreamer(model=item.model) | |
| composer = MessageComposer(model=item.model) | |
| composer.merge(messages=item.messages) | |
| return EventSourceResponse( | |
| streamer.chat( | |
| prompt=composer.merged_str, | |
| temperature=item.temperature, | |
| max_new_tokens=item.max_tokens, | |
| stream=item.stream, | |
| yield_output=True, | |
| ), | |
| media_type="text/event-stream", | |
| ) | |
| def setup_routes(self): | |
| for prefix in ["", "/v1"]: | |
| self.app.get( | |
| prefix + "/models", | |
| summary="Get available models", | |
| )(self.get_available_models) | |
| self.app.post( | |
| prefix + "/chat/completions", | |
| summary="Chat completions in conversation session", | |
| )(self.chat_completions) | |
| app = ChatAPIApp().app | |
| if __name__ == "__main__": | |
| uvicorn.run("__main__:app", host="0.0.0.0", port=23333, reload=True) | |