Husnain
commited on
💎 [Feature] Enable gpt-3.5 in chat_api
Browse files- apis/chat_api.py +31 -25
apis/chat_api.py
CHANGED
|
@@ -12,21 +12,24 @@ from fastapi.responses import HTMLResponse
|
|
| 12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 13 |
from pydantic import BaseModel, Field
|
| 14 |
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
from messagers.message_composer import MessageComposer
|
| 17 |
from mocks.stream_chat_mocker import stream_chat_mock
|
| 18 |
-
from networks.
|
| 19 |
-
from
|
| 20 |
-
from constants.models import AVAILABLE_MODELS_DICTS
|
| 21 |
|
| 22 |
|
| 23 |
class ChatAPIApp:
|
| 24 |
def __init__(self):
|
| 25 |
self.app = FastAPI(
|
| 26 |
docs_url="/",
|
| 27 |
-
title="
|
| 28 |
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
|
| 29 |
-
version="
|
| 30 |
)
|
| 31 |
self.setup_routes()
|
| 32 |
|
|
@@ -86,19 +89,22 @@ class ChatAPIApp:
|
|
| 86 |
def chat_completions(
|
| 87 |
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
| 88 |
):
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
| 102 |
if item.stream:
|
| 103 |
event_source_response = EventSourceResponse(
|
| 104 |
streamer.chat_return_generator(stream_response),
|
|
@@ -152,17 +158,17 @@ class ArgParser(argparse.ArgumentParser):
|
|
| 152 |
|
| 153 |
self.add_argument(
|
| 154 |
"-s",
|
| 155 |
-
"--
|
| 156 |
type=str,
|
| 157 |
-
default="
|
| 158 |
-
help="
|
| 159 |
)
|
| 160 |
self.add_argument(
|
| 161 |
"-p",
|
| 162 |
"--port",
|
| 163 |
type=int,
|
| 164 |
-
default=
|
| 165 |
-
help="
|
| 166 |
)
|
| 167 |
|
| 168 |
self.add_argument(
|
|
@@ -181,9 +187,9 @@ app = ChatAPIApp().app
|
|
| 181 |
if __name__ == "__main__":
|
| 182 |
args = ArgParser().args
|
| 183 |
if args.dev:
|
| 184 |
-
uvicorn.run("__main__:app", host=args.
|
| 185 |
else:
|
| 186 |
-
uvicorn.run("__main__:app", host=args.
|
| 187 |
|
| 188 |
# python -m apis.chat_api # [Docker] on product mode
|
| 189 |
# python -m apis.chat_api -d # [Dev] on develop mode
|
|
|
|
| 12 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
| 13 |
from pydantic import BaseModel, Field
|
| 14 |
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
|
| 15 |
+
from tclogger import logger
|
| 16 |
+
|
| 17 |
+
from constants.models import AVAILABLE_MODELS_DICTS
|
| 18 |
+
from constants.envs import CONFIG
|
| 19 |
|
| 20 |
from messagers.message_composer import MessageComposer
|
| 21 |
from mocks.stream_chat_mocker import stream_chat_mock
|
| 22 |
+
from networks.huggingface_streamer import HuggingfaceStreamer
|
| 23 |
+
from networks.openai_streamer import OpenaiStreamer
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class ChatAPIApp:
|
| 27 |
def __init__(self):
|
| 28 |
self.app = FastAPI(
|
| 29 |
docs_url="/",
|
| 30 |
+
title=CONFIG["app_name"],
|
| 31 |
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
|
| 32 |
+
version=CONFIG["version"],
|
| 33 |
)
|
| 34 |
self.setup_routes()
|
| 35 |
|
|
|
|
| 89 |
def chat_completions(
|
| 90 |
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
|
| 91 |
):
|
| 92 |
+
if item.model == "gpt-3.5":
|
| 93 |
+
streamer = OpenaiStreamer()
|
| 94 |
+
stream_response = streamer.chat_response(messages=item.messages)
|
| 95 |
+
else:
|
| 96 |
+
streamer = HuggingfaceStreamer(model=item.model)
|
| 97 |
+
composer = MessageComposer(model=item.model)
|
| 98 |
+
composer.merge(messages=item.messages)
|
| 99 |
+
stream_response = streamer.chat_response(
|
| 100 |
+
prompt=composer.merged_str,
|
| 101 |
+
temperature=item.temperature,
|
| 102 |
+
top_p=item.top_p,
|
| 103 |
+
max_new_tokens=item.max_tokens,
|
| 104 |
+
api_key=api_key,
|
| 105 |
+
use_cache=item.use_cache,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
if item.stream:
|
| 109 |
event_source_response = EventSourceResponse(
|
| 110 |
streamer.chat_return_generator(stream_response),
|
|
|
|
| 158 |
|
| 159 |
self.add_argument(
|
| 160 |
"-s",
|
| 161 |
+
"--host",
|
| 162 |
type=str,
|
| 163 |
+
default=CONFIG["host"],
|
| 164 |
+
help=f"Host for {CONFIG['app_name']}",
|
| 165 |
)
|
| 166 |
self.add_argument(
|
| 167 |
"-p",
|
| 168 |
"--port",
|
| 169 |
type=int,
|
| 170 |
+
default=CONFIG["port"],
|
| 171 |
+
help=f"Port for {CONFIG['app_name']}",
|
| 172 |
)
|
| 173 |
|
| 174 |
self.add_argument(
|
|
|
|
| 187 |
if __name__ == "__main__":
|
| 188 |
args = ArgParser().args
|
| 189 |
if args.dev:
|
| 190 |
+
uvicorn.run("__main__:app", host=args.host, port=args.port, reload=True)
|
| 191 |
else:
|
| 192 |
+
uvicorn.run("__main__:app", host=args.host, port=args.port, reload=False)
|
| 193 |
|
| 194 |
# python -m apis.chat_api # [Docker] on product mode
|
| 195 |
# python -m apis.chat_api -d # [Dev] on develop mode
|