Spaces:
Running
Running
| import os | |
| import sys | |
| p = os.path.split(os.path.dirname(os.path.abspath(__file__)))[0] | |
| sys.path.append(p) | |
| import ast | |
| import logging | |
| import argparse | |
| import numpy as np | |
| import tensorflow as tf | |
| from pprint import pformat | |
| from utils.hparams import HParams | |
| from models.runner import Runner | |
| from models import get_model | |
| from datasets import get_dataset | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--cfg_file', type=str) | |
| parser.add_argument('--config', action='append', default=[]) | |
| parser.add_argument('--mode', type=str, default=None) | |
| parser.add_argument('--seed', type=int, default=1234) | |
| parser.add_argument('--gpu', type=str, default='0') | |
| args = parser.parse_args() | |
| params = HParams(args.cfg_file) | |
| # parse config | |
| for kv in args.config: | |
| k, v = kv.split('=', maxsplit=1) | |
| assert k, "Config item can't have empty key" | |
| assert v, "Config item can't have empty value" | |
| try: | |
| v = ast.literal_eval(v) | |
| except ValueError: | |
| v = str(v) | |
| params.update({k: v}) | |
| for handler in logging.root.handlers[:]: | |
| logging.root.removeHandler(handler) | |
| logging.basicConfig(filename=params.exp_dir + f'/{args.mode}.log', | |
| filemode='w', | |
| level=logging.INFO, | |
| format='%(asctime)-15s %(message)s') | |
| logging.info(pformat(params.dict)) | |
| os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu | |
| np.random.seed(args.seed) | |
| tf.compat.v1.set_random_seed(args.seed) | |
| # data | |
| trainset = get_dataset(params, 'train') | |
| validset = get_dataset(params, 'valid') | |
| testset = get_dataset(params, 'test') | |
| logging.info((f"trainset: {trainset.size}", | |
| f"validset: {validset.size}", | |
| f"testset: {testset.size}")) | |
| # model | |
| model = get_model(params) | |
| runner = Runner(params, model) | |
| runner.set_dataset(trainset, validset, testset) | |
| # run | |
| if args.mode == 'train': | |
| runner.run() | |
| elif args.mode == 'resume': | |
| model.load() | |
| runner.run() | |
| elif args.mode == 'test': | |
| runner.evaluate() | |
| else: | |
| raise ValueError() | |