# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import os import sys import shutil import json from . import code_generator from . import search_space_generator from . import specific_code_generator __all__ = ['generate_search_space', 'expand_annotations'] slash = '/' if sys.platform == "win32": slash = '\\' def generate_search_space(code_dir): """Generate search space from Python source code. Return a serializable search space object. code_dir: directory path of source files (str) """ code_dir = str(code_dir) search_space = {} if code_dir.endswith(slash): code_dir = code_dir[:-1] for subdir, _, files in os.walk(code_dir): # generate module name from path if subdir == code_dir: package = '' else: assert subdir.startswith(code_dir + slash), subdir prefix_len = len(code_dir) + 1 package = subdir[prefix_len:].replace(slash, '.') + '.' for file_name in files: if file_name.endswith('.py'): path = os.path.join(subdir, file_name) module = package + file_name[:-3] search_space.update(_generate_file_search_space(path, module)) return search_space def _generate_file_search_space(path, module): with open(path) as src: try: search_space, code = search_space_generator.generate(module, src.read()) except Exception as exc: # pylint: disable=broad-except if exc.args: raise RuntimeError(path + ' ' + '\n'.join(exc.args)) else: raise RuntimeError('Failed to generate search space for %s: %r' % (path, exc)) with open(path, 'w') as dst: dst.write(code) return search_space def expand_annotations(src_dir, dst_dir, exp_id='', trial_id='', nas_mode=None): """Expand annotations in user code. Return dst_dir if annotation detected; return src_dir if not. src_dir: directory path of user code (str) dst_dir: directory to place generated files (str) nas_mode: the mode of NAS given that NAS interface is used """ src_dir, dst_dir = str(src_dir), str(dst_dir) if src_dir[-1] == slash: src_dir = src_dir[:-1] if dst_dir[-1] == slash: dst_dir = dst_dir[:-1] annotated = False for src_subdir, dirs, files in os.walk(src_dir): assert src_subdir.startswith(src_dir) dst_subdir = src_subdir.replace(src_dir, dst_dir, 1) os.makedirs(dst_subdir, exist_ok=True) # generate module name from path if src_subdir == src_dir: package = '' else: assert src_subdir.startswith(src_dir + slash), src_subdir prefix_len = len(src_dir) + 1 package = src_subdir[prefix_len:].replace(slash, '.') + '.' for file_name in files: src_path = os.path.join(src_subdir, file_name) dst_path = os.path.join(dst_subdir, file_name) if file_name.endswith('.py'): if trial_id == '': annotated |= _expand_file_annotations(src_path, dst_path, nas_mode) else: module = package + file_name[:-3] annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module) else: shutil.copyfile(src_path, dst_path) for dir_name in dirs: os.makedirs(os.path.join(dst_subdir, dir_name), exist_ok=True) return dst_dir if annotated else src_dir def _expand_file_annotations(src_path, dst_path, nas_mode): with open(src_path) as src, open(dst_path, 'w') as dst: try: annotated_code = code_generator.parse(src.read(), nas_mode) if annotated_code is None: shutil.copyfile(src_path, dst_path) return False dst.write(annotated_code) return True except Exception as exc: # pylint: disable=broad-except if exc.args: raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args)) else: raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc)) def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module): with open(src_path) as src, open(dst_path, 'w') as dst: try: with open(os.path.expanduser('~/nni-experiments/%s/trials/%s/parameter.cfg'%(exp_id, trial_id))) as fd: para_cfg = json.load(fd) annotated_code = specific_code_generator.parse(src.read(), para_cfg["parameters"], module) if annotated_code is None: shutil.copyfile(src_path, dst_path) return False dst.write(annotated_code) return True except Exception as exc: # pylint: disable=broad-except if exc.args: raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args)) else: raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))