Spaces:
Build error
Build error
| import json | |
| from comet.utils.utils import DD | |
| device = "cpu" | |
| save = False | |
| test_save = False | |
| toy = False | |
| do_gen = False | |
| save_strategy = "all" | |
| def get_parameters(opt, exp_type="model"): | |
| params = DD() | |
| params.net = DD() | |
| params.mle = 0 | |
| params.dataset = opt.dataset | |
| params.net = get_net_parameters(opt) | |
| params.train = get_training_parameters(opt) | |
| params.model = params.net.model | |
| params.exp = opt.exp | |
| params.data = get_data_parameters(opt, params.exp, params.dataset) | |
| params.eval = get_eval_parameters(opt, params.data.get("categories", None)) | |
| meta = DD() | |
| params.trainer = opt.trainer | |
| meta.iterations = int(opt.iterations) | |
| meta.cycle = opt.cycle | |
| params.cycle = opt.cycle | |
| params.iters = int(opt.iterations) | |
| global toy | |
| toy = opt.toy | |
| global do_gen | |
| do_gen = opt.do_gen | |
| global save | |
| save = opt.save | |
| global test_save | |
| test_save = opt.test_save | |
| global save_strategy | |
| save_strategy = opt.save_strategy | |
| print(params) | |
| return params, meta | |
| def get_eval_parameters(opt, force_categories=None): | |
| evaluate = DD() | |
| if opt.eval_sampler == "beam": | |
| evaluate.bs = opt.beam_size | |
| elif opt.eval_sampler == "greedy": | |
| evaluate.bs = 1 | |
| elif opt.eval_sampler == "topk": | |
| evaluate.k = opt.topk_size | |
| evaluate.smax = opt.gen_seqlength | |
| evaluate.sample = opt.eval_sampler | |
| evaluate.numseq = opt.num_sequences | |
| evaluate.gs = opt.generate_sequences | |
| evaluate.es = opt.evaluate_sequences | |
| if opt.dataset == "atomic": | |
| if "eval_categories" in opt and force_categories is None: | |
| evaluate.categories = opt.eval_categories | |
| else: | |
| evaluate.categories = force_categories | |
| return evaluate | |
| def get_data_parameters(opt, experiment, dataset): | |
| data = DD() | |
| if dataset == "atomic": | |
| data.categories = sorted(opt.categories) | |
| # hard-coded | |
| data.maxe1 = 17 | |
| data.maxe2 = 35 | |
| data.maxr = 1 | |
| elif dataset == "conceptnet": | |
| data.rel = opt.relation_format | |
| data.trainsize = opt.training_set_size | |
| data.devversion = opt.development_set_versions_to_use | |
| data.maxe1 = opt.max_event_1_size | |
| data.maxe2 = opt.max_event_2_size | |
| if data.rel == "language": | |
| # hard-coded | |
| data.maxr = 5 | |
| else: | |
| # hard-coded | |
| data.maxr = 1 | |
| return data | |
| def get_training_parameters(opt): | |
| train = DD() | |
| static = DD() | |
| static.exp = opt.exp | |
| static.seed = opt.random_seed | |
| # weight decay | |
| static.l2 = opt.l2 | |
| static.vl2 = True | |
| static.lrsched = opt.learning_rate_schedule # 'warmup_linear' | |
| static.lrwarm = opt.learning_rate_warmup # 0.002 | |
| # gradient clipping | |
| static.clip = opt.clip | |
| # what loss function to use | |
| static.loss = opt.loss | |
| dynamic = DD() | |
| dynamic.lr = opt.learning_rate # learning rate | |
| dynamic.bs = opt.batch_size # batch size | |
| # optimizer to use {adam, rmsprop, etc.} | |
| dynamic.optim = opt.optimizer | |
| # rmsprop | |
| # alpha is interpolation average | |
| static.update(opt[dynamic.optim]) | |
| train.static = static | |
| train.dynamic = dynamic | |
| return train | |
| def get_net_parameters(opt): | |
| net = DD() | |
| net.model = opt.model | |
| net.nL = opt.num_layers | |
| net.nH = opt.num_heads | |
| net.hSize = opt.hidden_dim | |
| net.edpt = opt.embedding_dropout | |
| net.adpt = opt.attention_dropout | |
| net.rdpt = opt.residual_dropout | |
| net.odpt = opt.output_dropout | |
| net.pt = opt.pretrain | |
| net.afn = opt.activation | |
| # how to intialize parameters | |
| # format is gauss+{}+{}.format(mean, std) | |
| # n = the default initialization pytorch | |
| net.init = opt.init | |
| return net | |
| def read_config(file_): | |
| config = DD() | |
| print(file_) | |
| for k, v in file_.items(): | |
| if v == "True" or v == "T" or v == "true": | |
| config[k] = True | |
| elif v == "False" or v == "F" or v == "false": | |
| config[k] = False | |
| elif type(v) == dict: | |
| config[k] = read_config(v) | |
| else: | |
| config[k] = v | |
| return config | |
| def load_config(name): | |
| with open(name, "r") as f: | |
| config = json.load(f) | |
| return config | |