File size: 5,323 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import threading
import time
from abc import ABC, abstractmethod
from queue import Empty, Queue
from .log_utils import LogType, nni_log
from .commands import CommandType
INTERVAL_SECONDS = 0.5
class BaseChannel(ABC):
def __init__(self, args):
self.is_keep_parsed = args.node_count > 1
self.args = args
self.node_id = self.args.node_id
@abstractmethod
def _inner_send(self, message):
pass
@abstractmethod
def _inner_receive(self):
return []
@abstractmethod
def _inner_open(self):
pass
@abstractmethod
def _inner_close(self):
pass
def open(self):
# initialize receive, send threads.
self.is_running = True
self.receive_queue = Queue()
self.receive_thread = threading.Thread(target=self._receive_loop)
self.receive_thread.start()
self.send_queue = Queue()
self.send_thread = threading.Thread(target=self._send_loop)
self.send_thread.start()
self._inner_open()
client_info = {
"isReady": True,
"runnerId": self.args.runner_id,
"expId": self.args.exp_id,
}
nni_log(LogType.Info, 'Channel: send ready information %s' % client_info)
self.send(CommandType.Initialized, client_info)
def close(self):
self.is_running = False
try:
self._inner_close()
except Exception as err:
# ignore any error on closing
print("error on closing channel: %s" % err)
def send(self, command, data):
"""Send command to Training Service.
command: CommandType object.
data: string payload.
the message is sent synchronized.
"""
data["node"] = self.node_id
data = json.dumps(data)
data = data.encode('utf8')
message = b'%b%014d%b' % (command.value, len(data), data)
self.send_queue.put(message)
def sent(self):
return self.send_queue.qsize() == 0
def received(self):
return self.receive_queue.qsize() > 0
def receive(self):
"""Receive a command from Training Service.
Returns a tuple of command (CommandType) and payload (str)
"""
command = None
data = None
try:
command_content = self.receive_queue.get(False)
if command_content is not None:
if (len(command_content) < 16):
# invalid header
nni_log(LogType.Error, 'incorrect command is found, command must be greater than 16 bytes!')
return None, None
header = command_content[:16]
command = CommandType(header[:2])
length = int(header[2:])
if (len(command_content)-16 != length):
nni_log(LogType.Error, 'incorrect command length, length {}, actual data length is {}, header {}.'
.format(length, len(command_content)-16, header))
return None, None
data = command_content[16:16+length]
data = json.loads(data.decode('utf8'))
if self.node_id is None:
nni_log(LogType.Info, 'Received command, header: [%s], data: [%s]' % (header, data))
else:
nni_log(LogType.Info, 'Received command(%s), header: [%s], data: [%s]' % (self.node_id, header, data))
except Empty:
# do nothing, if no command received.
pass
except Exception as identifier:
nni_log(LogType.Error, 'meet unhandled exception in base_channel: %s' % identifier)
return command, data
def _fetch_message(self, buffer, has_new_line=False):
messages = []
while(len(buffer)) >= 16:
header = buffer[:16]
length = int(header[2:])
message_length = length+16
total_length = message_length
if has_new_line:
total_length += 1
# break, if buffer is too short.
if len(buffer) < total_length:
break
data = buffer[16:message_length]
if has_new_line and 10 != buffer[total_length-1]:
nni_log(LogType.Error, 'end of message should be \\n, but got {}'.format(self.in_cache[total_length-1]))
buffer = buffer[total_length:]
messages.append(header + data)
return messages, buffer
def _receive_loop(self):
while (self.is_running):
messages = self._inner_receive()
if messages is not None:
for message in messages:
self.receive_queue.put(message)
time.sleep(INTERVAL_SECONDS)
def _send_loop(self):
while (self.is_running):
message = None
try:
# no sleep, since it's a block call with INTERVAL_SECONDS second timeout
message = self.send_queue.get(True, INTERVAL_SECONDS)
except Empty:
# do nothing, if no command received.
pass
if message is not None:
self._inner_send(message)
|