File size: 10,944 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import argparse
import ctypes
import json
import logging
import os
import re
import shlex
import sys
import threading
import time
from subprocess import Popen

import pkg_resources
from pyhdfs import HdfsClient

from .constants import (LOG_DIR, MULTI_PHASE, NNI_EXP_ID, NNI_PLATFORM,
                        NNI_SYS_DIR, NNI_TRIAL_JOB_ID)
from .hdfsClientUtility import (copyDirectoryToHdfs, copyHdfsDirectoryToLocal,
                                copyHdfsFileToLocal)
from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log
from .rest_utils import rest_get, rest_post
from .url_utils import gen_parameter_meta_url, gen_send_version_url

logger = logging.getLogger('trial_keeper')
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*')

_hdfs_client = None


def get_hdfs_client(args):
    global _hdfs_client

    if _hdfs_client is not None:
        return _hdfs_client
    # backward compatibility
    hdfs_host = None

    if args.hdfs_host:
        hdfs_host = args.hdfs_host
    elif args.pai_hdfs_host:
        hdfs_host = args.pai_hdfs_host
    else:
        return None

    if hdfs_host is not None and args.nni_hdfs_exp_dir is not None:
        try:
            if args.webhdfs_path:
                _hdfs_client = HdfsClient(hosts='{0}:80'.format(hdfs_host), user_name=args.pai_user_name,
                                          webhdfs_path=args.webhdfs_path, timeout=5)
            else:
                # backward compatibility
                _hdfs_client = HdfsClient(hosts='{0}:{1}'.format(hdfs_host, '50070'), user_name=args.pai_user_name,
                                          timeout=5)
        except Exception as e:
            nni_log(LogType.Error, 'Create HDFS client error: ' + str(e))
            raise e
    return _hdfs_client


def main_loop(args):
    '''main loop logic for trial keeper'''

    if not os.path.exists(LOG_DIR):
        os.makedirs(LOG_DIR)

    trial_keeper_syslogger = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial_keeper',
                                          StdOutputType.Stdout, args.log_collection)
    # redirect trial keeper's stdout and stderr to syslog
    trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout,
                                          args.log_collection)
    sys.stdout = sys.stderr = trial_keeper_syslogger
    hdfs_output_dir = None

    if args.hdfs_output_dir:
        hdfs_output_dir = args.hdfs_output_dir
    elif args.pai_hdfs_output_dir:
        hdfs_output_dir = args.pai_hdfs_output_dir

    hdfs_client = get_hdfs_client(args)

    if hdfs_client is not None:
        copyHdfsDirectoryToLocal(args.nni_hdfs_exp_dir, os.getcwd(), hdfs_client)

    if args.job_id_file:
        with open(args.job_id_file, 'w') as job_file:
            job_file.write("%d" % os.getpid())

    # Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior
    log_pipe_stdout = trial_syslogger_stdout.get_pipelog_reader()
    process = Popen(args.trial_command, shell=True, stdout=log_pipe_stdout, stderr=log_pipe_stdout)
    nni_log(LogType.Info, 'Trial keeper spawns a subprocess (pid {0}) to run command: {1}'.format(process.pid,
                                                                                                  shlex.split(
                                                                                                      args.trial_command)))

    while True:
        retCode = process.poll()
        # child worker process exits and all stdout data is read
        if retCode is not None and log_pipe_stdout.set_process_exit() and log_pipe_stdout.is_read_completed == True:
            # In Windows, the retCode -1 is 4294967295. It's larger than c_long, and raise OverflowError.
            # So covert it to int32.
            retCode = ctypes.c_long(retCode).value
            nni_log(LogType.Info, 'subprocess terminated. Exit code is {}. Quit'.format(retCode))
            if hdfs_output_dir is not None:
                # Copy local directory to hdfs for OpenPAI
                nni_local_output_dir = os.environ['NNI_OUTPUT_DIR']
                try:
                    if copyDirectoryToHdfs(nni_local_output_dir, hdfs_output_dir, hdfs_client):
                        nni_log(LogType.Info,
                                'copy directory from {0} to {1} success!'.format(nni_local_output_dir, hdfs_output_dir))
                    else:
                        nni_log(LogType.Info,
                                'copy directory from {0} to {1} failed!'.format(nni_local_output_dir, hdfs_output_dir))
                except Exception as e:
                    nni_log(LogType.Error, 'HDFS copy directory got exception: ' + str(e))
                    raise e

            # Exit as the retCode of subprocess(trial)
            exit(retCode)
            break

        time.sleep(2)


def trial_keeper_help_info(*args):
    print('please run --help to see guidance')


def check_version(args):
    try:
        trial_keeper_version = pkg_resources.get_distribution('nni').version
    except pkg_resources.ResolutionError as err:
        # package nni does not exist, try nni-tool package
        nni_log(LogType.Error, 'Package nni does not exist!')
        os._exit(1)
    if not args.nni_manager_version:
        # skip version check
        nni_log(LogType.Warning, 'Skipping version check!')
    else:
        try:
            trial_keeper_version = regular.search(trial_keeper_version).group('version')
            nni_log(LogType.Info, 'trial_keeper_version is {0}'.format(trial_keeper_version))
            nni_manager_version = regular.search(args.nni_manager_version).group('version')
            nni_log(LogType.Info, 'nni_manager_version is {0}'.format(nni_manager_version))
            log_entry = {}
            if trial_keeper_version != nni_manager_version:
                nni_log(LogType.Error, 'Version does not match!')
                error_message = 'NNIManager version is {0}, TrialKeeper version is {1}, NNI version does not match!'.format(
                    nni_manager_version, trial_keeper_version)
                log_entry['tag'] = 'VCFail'
                log_entry['msg'] = error_message
                rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10,
                          False)
                os._exit(1)
            else:
                nni_log(LogType.Info, 'Version match!')
                log_entry['tag'] = 'VCSuccess'
                rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10,
                          False)
        except AttributeError as err:
            nni_log(LogType.Error, err)


def is_multi_phase():
    return MULTI_PHASE and (MULTI_PHASE in ['True', 'true'])


def download_parameter(meta_list, args):
    """
    Download parameter file to local working directory.
    meta_list format is defined in paiJobRestServer.ts
    example meta_list:
    [
        {"experimentId":"yWFJarYa","trialId":"UpPkl","filePath":"/chec/nni/experiments/yWFJarYa/trials/UpPkl/parameter_1.cfg"},
        {"experimentId":"yWFJarYa","trialId":"aIUMA","filePath":"/chec/nni/experiments/yWFJarYa/trials/aIUMA/parameter_1.cfg"}
    ]
    """
    nni_log(LogType.Debug, str(meta_list))
    nni_log(LogType.Debug,
            'NNI_SYS_DIR: {}, trial Id: {}, experiment ID: {}'.format(NNI_SYS_DIR, NNI_TRIAL_JOB_ID, NNI_EXP_ID))
    nni_log(LogType.Debug, 'NNI_SYS_DIR files: {}'.format(os.listdir(NNI_SYS_DIR)))
    for meta in meta_list:
        if meta['experimentId'] == NNI_EXP_ID and meta['trialId'] == NNI_TRIAL_JOB_ID:
            param_fp = os.path.join(NNI_SYS_DIR, os.path.basename(meta['filePath']))
            if not os.path.exists(param_fp):
                hdfs_client = get_hdfs_client(args)
                copyHdfsFileToLocal(meta['filePath'], param_fp, hdfs_client, override=False)


def fetch_parameter_file(args):
    class FetchThread(threading.Thread):
        def __init__(self, args):
            super(FetchThread, self).__init__()
            self.args = args

        def run(self):
            uri = gen_parameter_meta_url(self.args.nnimanager_ip, self.args.nnimanager_port)
            nni_log(LogType.Info, uri)

            while True:
                res = rest_get(uri, 10)
                nni_log(LogType.Debug, 'status code: {}'.format(res.status_code))
                if res.status_code == 200:
                    meta_list = res.json()
                    download_parameter(meta_list, self.args)
                else:
                    nni_log(LogType.Warning, 'rest response: {}'.format(str(res)))
                time.sleep(5)

    fetch_file_thread = FetchThread(args)
    fetch_file_thread.start()


if __name__ == '__main__':
    '''NNI Trial Keeper main function'''
    PARSER = argparse.ArgumentParser()
    PARSER.set_defaults(func=trial_keeper_help_info)
    PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
    PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP')
    PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port')
    PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of pai_hdfs')  # backward compatibility
    PARSER.add_argument('--hdfs_output_dir', type=str, help='the output dir of hdfs')
    PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of pai_hdfs')  # backward compatibility
    PARSER.add_argument('--hdfs_host', type=str, help='the host of hdfs')
    PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
    PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs')
    PARSER.add_argument('--webhdfs_path', type=str, help='the webhdfs path used in webhdfs URL')
    PARSER.add_argument('--nni_manager_version', type=str, help='the nni version transmitted from nniManager')
    PARSER.add_argument('--log_collection', type=str, help='set the way to collect log in trialkeeper')
    PARSER.add_argument('--job_id_file', type=str, help='set job id file for operating and monitoring job.')
    args, unknown = PARSER.parse_known_args()
    if args.trial_command is None:
        exit(1)
    check_version(args)
    try:
        if NNI_PLATFORM == 'paiYarn' and is_multi_phase():
            fetch_parameter_file(args)
        main_loop(args)
    except SystemExit as se:
        nni_log(LogType.Info, 'NNI trial keeper exit with code {}'.format(se.code))
        os._exit(se.code)
    except Exception as e:
        nni_log(LogType.Error, 'Exit trial keeper with code 1 because Exception: {} is catched'.format(str(e)))
        os._exit(1)