File size: 1,905 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import os
import websockets

from .base_channel import BaseChannel
from .log_utils import LogType, nni_log


class WebChannel(BaseChannel):

    def __init__(self, args):
        self.node_id = args.node_id
        self.args = args
        self.client = None
        self.in_cache = b""
        self.timeout = 10

        super(WebChannel, self).__init__(args)

        self._event_loop = None

    def _inner_open(self):
        url = "ws://{}:{}".format(self.args.nnimanager_ip, self.args.nnimanager_port)
        try:
            connect = asyncio.wait_for(websockets.connect(url), self.timeout)
            self._event_loop = asyncio.get_event_loop()
            client = self._event_loop.run_until_complete(connect)
            self.client = client
            nni_log(LogType.Info, 'WebChannel: connected with info %s' % url)
        except asyncio.TimeoutError:
            nni_log(LogType.Error, 'connect to %s timeout! Please make sure NNIManagerIP configured correctly, and accessable.' % url)
            os._exit(1)

    def _inner_close(self):
        if self.client is not None:
            self.client.close()
            self.client = None
            if self._event_loop.is_running():
                self._event_loop.stop()
            self._event_loop = None

    def _inner_send(self, message):
        loop = asyncio.new_event_loop()
        loop.run_until_complete(self.client.send(message))

    def _inner_receive(self):
        messages = []
        if self.client is not None:
            received = self._event_loop.run_until_complete(self.client.recv())
            # receive message is string, to get consistent result, encode it here.
            self.in_cache += received.encode("utf8")
            messages, self.in_cache = self._fetch_message(self.in_cache)

        return messages