Spaces:
Build error
Build error
Daniel Marques
commited on
Commit
·
abff149
1
Parent(s):
e04f8da
fix: add websocket in handlerToken
Browse files
main.py
CHANGED
|
@@ -20,8 +20,6 @@ from langchain.memory import ConversationBufferMemory
|
|
| 20 |
from langchain.callbacks.base import BaseCallbackHandler
|
| 21 |
from langchain.schema import LLMResult
|
| 22 |
|
| 23 |
-
from varstate import State
|
| 24 |
-
|
| 25 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
| 26 |
from load_models import load_model
|
| 27 |
|
|
@@ -58,9 +56,9 @@ DB = Chroma(
|
|
| 58 |
RETRIEVER = DB.as_retriever()
|
| 59 |
|
| 60 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
| 61 |
-
def __init__(self
|
| 62 |
self.end = False
|
| 63 |
-
self.
|
| 64 |
|
| 65 |
def on_llm_start(
|
| 66 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
@@ -77,14 +75,8 @@ class MyCustomSyncHandler(BaseCallbackHandler):
|
|
| 77 |
|
| 78 |
print(token)
|
| 79 |
|
| 80 |
-
|
| 81 |
# Create State
|
| 82 |
-
|
| 83 |
-
tokenMessageLLM = State()
|
| 84 |
-
|
| 85 |
-
get, update = tokenMessageLLM.create('')
|
| 86 |
-
|
| 87 |
-
handlerToken = MyCustomSyncHandler(update)
|
| 88 |
|
| 89 |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
|
| 90 |
|
|
@@ -253,8 +245,8 @@ async def create_upload_file(file: UploadFile):
|
|
| 253 |
|
| 254 |
return {"filename": file.filename}
|
| 255 |
|
| 256 |
-
@api_app.websocket("/ws")
|
| 257 |
-
async def websocket_endpoint(websocket: WebSocket):
|
| 258 |
global QA
|
| 259 |
|
| 260 |
await websocket.accept()
|
|
@@ -265,16 +257,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 265 |
while True:
|
| 266 |
prompt = await websocket.receive_text()
|
| 267 |
|
| 268 |
-
|
| 269 |
|
| 270 |
-
if (oldReceiveText != prompt)
|
| 271 |
oldReceiveText = prompt
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
print(statusProcess);
|
| 275 |
-
|
| 276 |
-
tokenState = get()
|
| 277 |
-
await websocket.send_text(f"token: {tokenState}")
|
| 278 |
|
| 279 |
except WebSocketDisconnect:
|
| 280 |
print('disconnect')
|
|
|
|
| 20 |
from langchain.callbacks.base import BaseCallbackHandler
|
| 21 |
from langchain.schema import LLMResult
|
| 22 |
|
|
|
|
|
|
|
| 23 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
| 24 |
from load_models import load_model
|
| 25 |
|
|
|
|
| 56 |
RETRIEVER = DB.as_retriever()
|
| 57 |
|
| 58 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
| 59 |
+
def __init__(self):
|
| 60 |
self.end = False
|
| 61 |
+
self.callback = None
|
| 62 |
|
| 63 |
def on_llm_start(
|
| 64 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
|
|
| 75 |
|
| 76 |
print(token)
|
| 77 |
|
|
|
|
| 78 |
# Create State
|
| 79 |
+
handlerToken = MyCustomSyncHandler()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
|
| 82 |
|
|
|
|
| 245 |
|
| 246 |
return {"filename": file.filename}
|
| 247 |
|
| 248 |
+
@api_app.websocket("/ws/{client_id}")
|
| 249 |
+
async def websocket_endpoint(websocket: WebSocket, client_id: int):
|
| 250 |
global QA
|
| 251 |
|
| 252 |
await websocket.accept()
|
|
|
|
| 257 |
while True:
|
| 258 |
prompt = await websocket.receive_text()
|
| 259 |
|
| 260 |
+
handlerToken.callback = websocket.send_text;
|
| 261 |
|
| 262 |
+
if (oldReceiveText != prompt):
|
| 263 |
oldReceiveText = prompt
|
| 264 |
+
asyncio.run(QA(prompt))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
except WebSocketDisconnect:
|
| 267 |
print('disconnect')
|