|
|
|
import json |
|
import asyncio |
|
import websockets |
|
from transformers import pipeline |
|
classifier1 = pipeline("text-classification", model="ja_gi_47") |
|
|
|
extractor1 = pipeline("text2text-generation", model="name_extract") |
|
extractor2 = pipeline("text2text-generation", model="postcode_extract") |
|
extractor3 = pipeline("text2text-generation", model="time_extract") |
|
extractor4 = pipeline("text2text-generation", model="task_extract") |
|
extractor5 = pipeline("text2text-generation", model="occ_extract") |
|
|
|
|
|
async def time_extractor(websocket, path): |
|
try: |
|
while True: |
|
data = await websocket.recv() |
|
payload = json.loads(data) |
|
intent = payload["prompt"] |
|
label=extractor3(intent)[0]["generated_text"] |
|
|
|
if label=="": |
|
label="No time intent detected" |
|
await websocket.send(json.dumps(label)) |
|
except websockets.ConnectionClosed: |
|
print("Connection closed") |
|
async def task_extractor(websocket, path): |
|
try: |
|
while True: |
|
data = await websocket.recv() |
|
payload = json.loads(data) |
|
intent = payload["prompt"] |
|
label=extractor4(intent)[0]["generated_text"] |
|
if label=="": |
|
label="No request detected" |
|
await websocket.send(json.dumps(label)) |
|
except websockets.ConnectionClosed: |
|
print("Connection closed") |
|
async def postcode_extractor(websocket, path): |
|
try: |
|
while True: |
|
data = await websocket.recv() |
|
payload = json.loads(data) |
|
intent = payload["prompt"] |
|
label=extractor2(intent)[0]["generated_text"] |
|
|
|
if label.count(" ")>2: |
|
label="" |
|
if label=="": |
|
label="No postcode detected" |
|
await websocket.send(json.dumps(label)) |
|
except websockets.ConnectionClosed: |
|
print("Connection closed") |
|
async def name_extractor(websocket, path): |
|
try: |
|
while True: |
|
data = await websocket.recv() |
|
payload = json.loads(data) |
|
intent = payload["prompt"] |
|
label=extractor1(intent)[0]["generated_text"] |
|
await websocket.send(json.dumps(label)) |
|
except websockets.ConnectionClosed: |
|
print("Connection closed") |
|
async def occ_extractor(websocket, path): |
|
try: |
|
while True: |
|
data = await websocket.recv() |
|
payload = json.loads(data) |
|
intent = payload["prompt"] |
|
label=extractor5(intent)[0]["generated_text"] |
|
if label=="": |
|
label="No occupation detected" |
|
await websocket.send(json.dumps(label)) |
|
except websockets.ConnectionClosed: |
|
print("Connection closed") |
|
|
|
|
|
async def classify_intent1(websocket, path): |
|
try: |
|
while True: |
|
data = await websocket.recv() |
|
payload = json.loads(data) |
|
intent = payload["prompt"] |
|
label=classifier1(intent)[0]["label"] |
|
await websocket.send(json.dumps(label)) |
|
except websockets.ConnectionClosed: |
|
print("Connection closed") |
|
async def start_server(): |
|
server0 = await websockets.serve(name_extractor, "0.0.0.0",8764) |
|
print("Server name started") |
|
server1 = await websockets.serve(classify_intent1, "0.0.0.0",8765) |
|
print("Server 47 started") |
|
server2 = await websockets.serve(occ_extractor, "0.0.0.0", 8766) |
|
print("Server 71 started") |
|
server3 = await websockets.serve(postcode_extractor, "0.0.0.0", 8763) |
|
print("Server postcode started") |
|
server4 = await websockets.serve(time_extractor, "0.0.0.0", 8762) |
|
print("Server time started") |
|
server5 = await websockets.serve(task_extractor, "0.0.0.0", 8761) |
|
print("Server task started") |
|
|
|
await asyncio.Future() |
|
|
|
asyncio.run(start_server()) |
|
|
|
|
|
|