Spaces:
Runtime error
Runtime error
| # coding: UTF-8 | |
| import time | |
| import torch | |
| import numpy as np | |
| from train_eval import train, init_network | |
| from importlib import import_module | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Chinese Text Classification') | |
| parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer') | |
| parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained') | |
| parser.add_argument('--word', default=False, type=bool, help='True for word, False for char') | |
| args = parser.parse_args() | |
| if __name__ == '__main__': | |
| dataset = 'THUCNews' # 数据集 | |
| # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random | |
| embedding = 'embedding_SougouNews.npz' | |
| if args.embedding == 'random': | |
| embedding = 'random' | |
| model_name = args.model # 'TextRCNN' # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer | |
| if model_name == 'FastText': | |
| from utils_fasttext import build_dataset, build_iterator, get_time_dif | |
| embedding = 'random' | |
| else: | |
| from utils import build_dataset, build_iterator, get_time_dif | |
| x = import_module('models.' + model_name) | |
| config = x.Config(dataset, embedding) | |
| np.random.seed(1) | |
| torch.manual_seed(1) | |
| torch.cuda.manual_seed_all(1) | |
| torch.backends.cudnn.deterministic = True # 保证每次结果一样 | |
| start_time = time.time() | |
| print("Loading data...") | |
| vocab, train_data, dev_data, test_data = build_dataset(config, args.word) | |
| train_iter = build_iterator(train_data, config) | |
| dev_iter = build_iterator(dev_data, config) | |
| test_iter = build_iterator(test_data, config) | |
| time_dif = get_time_dif(start_time) | |
| print("Time usage:", time_dif) | |
| # train | |
| config.n_vocab = len(vocab) | |
| model = x.Model(config).to(config.device) | |
| if model_name != 'Transformer': | |
| init_network(model) | |
| print(model.parameters) | |
| train(config, model, train_iter, dev_iter, test_iter) | |