|
|
|
|
|
|
|
import asyncio |
|
import os |
|
import websockets |
|
|
|
from .base_channel import BaseChannel |
|
from .log_utils import LogType, nni_log |
|
|
|
|
|
class WebChannel(BaseChannel): |
|
|
|
def __init__(self, args): |
|
self.node_id = args.node_id |
|
self.args = args |
|
self.client = None |
|
self.in_cache = b"" |
|
self.timeout = 10 |
|
|
|
super(WebChannel, self).__init__(args) |
|
|
|
self._event_loop = None |
|
|
|
def _inner_open(self): |
|
url = "ws://{}:{}".format(self.args.nnimanager_ip, self.args.nnimanager_port) |
|
try: |
|
connect = asyncio.wait_for(websockets.connect(url), self.timeout) |
|
self._event_loop = asyncio.get_event_loop() |
|
client = self._event_loop.run_until_complete(connect) |
|
self.client = client |
|
nni_log(LogType.Info, 'WebChannel: connected with info %s' % url) |
|
except asyncio.TimeoutError: |
|
nni_log(LogType.Error, 'connect to %s timeout! Please make sure NNIManagerIP configured correctly, and accessable.' % url) |
|
os._exit(1) |
|
|
|
def _inner_close(self): |
|
if self.client is not None: |
|
self.client.close() |
|
self.client = None |
|
if self._event_loop.is_running(): |
|
self._event_loop.stop() |
|
self._event_loop = None |
|
|
|
def _inner_send(self, message): |
|
loop = asyncio.new_event_loop() |
|
loop.run_until_complete(self.client.send(message)) |
|
|
|
def _inner_receive(self): |
|
messages = [] |
|
if self.client is not None: |
|
received = self._event_loop.run_until_complete(self.client.recv()) |
|
|
|
self.in_cache += received.encode("utf8") |
|
messages, self.in_cache = self._fetch_message(self.in_cache) |
|
|
|
return messages |
|
|