nazemi
End of training
93b0f77
raw
history blame
3.95 kB
import json
import asyncio
import websockets
from transformers import pipeline
classifier1 = pipeline("text-classification", model="ja_gi_47")
#classifier2 = pipeline("text-classification", model="ja_gi_71")
extractor1 = pipeline("text2text-generation", model="name_extract")
extractor2 = pipeline("text2text-generation", model="postcode_extract")
extractor3 = pipeline("text2text-generation", model="time_extract")
extractor4 = pipeline("text2text-generation", model="task_extract")
extractor5 = pipeline("text2text-generation", model="occ_extract")
async def time_extractor(websocket, path):
try:
while True:
data = await websocket.recv()
payload = json.loads(data)
intent = payload["prompt"]
label=extractor3(intent)[0]["generated_text"]
#label= time_check(dt)
if label=="":
label="No time intent detected"
await websocket.send(json.dumps(label))
except websockets.ConnectionClosed:
print("Connection closed")
async def task_extractor(websocket, path):
try:
while True:
data = await websocket.recv()
payload = json.loads(data)
intent = payload["prompt"]
label=extractor4(intent)[0]["generated_text"]
if label=="":
label="No request detected"
await websocket.send(json.dumps(label))
except websockets.ConnectionClosed:
print("Connection closed")
async def postcode_extractor(websocket, path):
try:
while True:
data = await websocket.recv()
payload = json.loads(data)
intent = payload["prompt"]
label=extractor2(intent)[0]["generated_text"]
if label.count(" ")>2:
label=""
if label=="":
label="No postcode detected"
await websocket.send(json.dumps(label))
except websockets.ConnectionClosed:
print("Connection closed")
async def name_extractor(websocket, path):
try:
while True:
data = await websocket.recv()
payload = json.loads(data)
intent = payload["prompt"]
label=extractor1(intent)[0]["generated_text"]
await websocket.send(json.dumps(label))
except websockets.ConnectionClosed:
print("Connection closed")
async def occ_extractor(websocket, path):
try:
while True:
data = await websocket.recv()
payload = json.loads(data)
intent = payload["prompt"]
label=extractor5(intent)[0]["generated_text"]
if label=="":
label="No occupation detected"
await websocket.send(json.dumps(label))
except websockets.ConnectionClosed:
print("Connection closed")
async def classify_intent1(websocket, path):
try:
while True:
data = await websocket.recv()
payload = json.loads(data)
intent = payload["prompt"]
label=classifier1(intent)[0]["label"]
await websocket.send(json.dumps(label))
except websockets.ConnectionClosed:
print("Connection closed")
async def start_server():
server0 = await websockets.serve(name_extractor, "0.0.0.0",8764)
print("Server name started")
server1 = await websockets.serve(classify_intent1, "0.0.0.0",8765)
print("Server 47 started")
server2 = await websockets.serve(occ_extractor, "0.0.0.0", 8766)
print("Server 71 started")
server3 = await websockets.serve(postcode_extractor, "0.0.0.0", 8763)
print("Server postcode started")
server4 = await websockets.serve(time_extractor, "0.0.0.0", 8762)
print("Server time started")
server5 = await websockets.serve(task_extractor, "0.0.0.0", 8761)
print("Server task started")
await asyncio.Future()
asyncio.run(start_server())
# await server2.wait_closed()