Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import json | |
import os | |
from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response | |
from .url_utils import experiment_url, import_data_url | |
from .config_utils import Config | |
from .common_utils import get_json_content, print_normal, print_error, print_warning | |
from .nnictl_utils import get_experiment_port, get_config_filename, detect_process | |
from .launcher_utils import parse_time | |
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA | |
def validate_digit(value, start, end): | |
'''validate if a digit is valid''' | |
if not str(value).isdigit() or int(value) < start or int(value) > end: | |
raise ValueError('value (%s) must be a digit from %s to %s' % (value, start, end)) | |
def validate_file(path): | |
'''validate if a file exist''' | |
if not os.path.exists(path): | |
raise FileNotFoundError('%s is not a valid file path' % path) | |
def validate_dispatcher(args): | |
'''validate if the dispatcher of the experiment supports importing data''' | |
nni_config = Config(get_config_filename(args)).get_config('experimentConfig') | |
if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'): | |
dispatcher_name = nni_config['tuner']['builtinTunerName'] | |
elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'): | |
dispatcher_name = nni_config['advisor']['builtinAdvisorName'] | |
else: # otherwise it should be a customized one | |
return | |
if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA: | |
if dispatcher_name in TUNERS_NO_NEED_TO_IMPORT_DATA: | |
print_warning("There is no need to import data for %s" % dispatcher_name) | |
exit(0) | |
else: | |
print_error("%s does not support importing addtional data" % dispatcher_name) | |
exit(1) | |
def load_search_space(path): | |
'''load search space content''' | |
content = json.dumps(get_json_content(path)) | |
if not content: | |
raise ValueError('searchSpace file should not be empty') | |
return content | |
def get_query_type(key): | |
'''get update query type''' | |
if key == 'trialConcurrency': | |
return '?update_type=TRIAL_CONCURRENCY' | |
if key == 'maxExecDuration': | |
return '?update_type=MAX_EXEC_DURATION' | |
if key == 'searchSpace': | |
return '?update_type=SEARCH_SPACE' | |
if key == 'maxTrialNum': | |
return '?update_type=MAX_TRIAL_NUM' | |
def update_experiment_profile(args, key, value): | |
'''call restful server to update experiment profile''' | |
nni_config = Config(get_config_filename(args)) | |
rest_port = nni_config.get_config('restServerPort') | |
running, _ = check_rest_server_quick(rest_port) | |
if running: | |
response = rest_get(experiment_url(rest_port), REST_TIME_OUT) | |
if response and check_response(response): | |
experiment_profile = json.loads(response.text) | |
experiment_profile['params'][key] = value | |
response = rest_put(experiment_url(rest_port)+get_query_type(key), json.dumps(experiment_profile), REST_TIME_OUT) | |
if response and check_response(response): | |
return response | |
else: | |
print_error('Restful server is not running...') | |
return None | |
def update_searchspace(args): | |
validate_file(args.filename) | |
content = load_search_space(args.filename) | |
args.port = get_experiment_port(args) | |
if args.port is not None: | |
if update_experiment_profile(args, 'searchSpace', content): | |
print_normal('Update %s success!' % 'searchSpace') | |
else: | |
print_error('Update %s failed!' % 'searchSpace') | |
def update_concurrency(args): | |
validate_digit(args.value, 1, 1000) | |
args.port = get_experiment_port(args) | |
if args.port is not None: | |
if update_experiment_profile(args, 'trialConcurrency', int(args.value)): | |
print_normal('Update %s success!' % 'concurrency') | |
else: | |
print_error('Update %s failed!' % 'concurrency') | |
def update_duration(args): | |
#parse time, change time unit to seconds | |
args.value = parse_time(args.value) | |
args.port = get_experiment_port(args) | |
if args.port is not None: | |
if update_experiment_profile(args, 'maxExecDuration', int(args.value)): | |
print_normal('Update %s success!' % 'duration') | |
else: | |
print_error('Update %s failed!' % 'duration') | |
def update_trialnum(args): | |
validate_digit(args.value, 1, 999999999) | |
if update_experiment_profile(args, 'maxTrialNum', int(args.value)): | |
print_normal('Update %s success!' % 'trialnum') | |
else: | |
print_error('Update %s failed!' % 'trialnum') | |
def import_data(args): | |
'''import additional data to the experiment''' | |
validate_file(args.filename) | |
validate_dispatcher(args) | |
content = load_search_space(args.filename) | |
nni_config = Config(get_config_filename(args)) | |
rest_port = nni_config.get_config('restServerPort') | |
rest_pid = nni_config.get_config('restServerPid') | |
if not detect_process(rest_pid): | |
print_error('Experiment is not running...') | |
return | |
running, _ = check_rest_server_quick(rest_port) | |
if not running: | |
print_error('Restful server is not running') | |
return | |
args.port = rest_port | |
if args.port is not None: | |
if import_data_to_restful_server(args, content): | |
pass | |
else: | |
print_error('Import data failed!') | |
def import_data_to_restful_server(args, content): | |
'''call restful server to import data to the experiment''' | |
nni_config = Config(get_config_filename(args)) | |
rest_port = nni_config.get_config('restServerPort') | |
running, _ = check_rest_server_quick(rest_port) | |
if running: | |
response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT) | |
if response and check_response(response): | |
return response | |
else: | |
print_error('Restful server is not running...') | |
return None | |