sundea commited on
Commit
d4ef0b3
·
1 Parent(s): 9b11d56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -87
app.py CHANGED
@@ -33,64 +33,6 @@ def build_vocab(file_path, tokenizer, max_size, min_freq):
33
  return vocab_dic
34
 
35
 
36
- # parser = argparse.ArgumentParser(description='Chinese Text Classification')
37
- # parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')
38
- # args = parser.parse_args()
39
- # model_name = 'TextCNN'
40
- # dataset = 'THUCNews' # 数据集
41
- # embedding = 'embedding_SougouNews.npz'
42
- # x = import_module('models.' + model_name)
43
- #
44
- # config = x.Config(dataset, embedding)
45
- # device = 'cuda:0'
46
- # model = models.TextCNN.Model(config)
47
- #
48
- # # vocab, train_data, dev_data, test_data = build_dataset(config, args.word)
49
- # model.load_state_dict(torch.load('THUCNews/saved_dict/TextCNN.ckpt'))
50
- # model.to(device)
51
- # model.eval()
52
- #
53
- # tokenizer = lambda x: [y for y in x] # char-level
54
- # if os.path.exists(config.vocab_path):
55
- # vocab = pkl.load(open(config.vocab_path, 'rb'))
56
- # else:
57
- # vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1)
58
- # pkl.dump(vocab, open(config.vocab_path, 'wb'))
59
- # print(f"Vocab size: {len(vocab)}")
60
- #
61
- # # content='时评:“国学小天才”录取缘何少佳话'
62
- # content = input('输入语句:')
63
- #
64
- # words_line = []
65
- # token = tokenizer(content)
66
- # seq_len = len(token)
67
- # pad_size = 32
68
- # contents = []
69
- #
70
- # if pad_size:
71
- # if len(token) < pad_size:
72
- # token.extend([PAD] * (pad_size - len(token)))
73
- # else:
74
- # token = token[:pad_size]
75
- # seq_len = pad_size
76
- # # word to id
77
- # for word in token:
78
- # words_line.append(vocab.get(word, vocab.get(UNK)))
79
- #
80
- # contents.append((words_line, seq_len))
81
- # print(words_line)
82
- # # input = torch.LongTensor(words_line).unsqueeze(1).to(device) # convert words_line to LongTensor and add batch dimension
83
- # x = torch.LongTensor([_[0] for _ in contents]).to(device)
84
- #
85
- # # pad前的长度(超过pad_size的设为pad_size)
86
- # seq_len = torch.LongTensor([_[1] for _ in contents]).to(device)
87
- # input = (x, seq_len)
88
- # # print(input)
89
- # with torch.no_grad():
90
- # output = model(input)
91
- # predic = torch.max(output.data, 1)[1].cpu().numpy()
92
- # print(predic)
93
- # print('类别为:{}'.format(classes[predic[0]]))
94
 
95
 
96
 
@@ -156,44 +98,97 @@ def greet(text):
156
  # print(predic)
157
  # print('类别为:{}'.format(classes[predic[0]]))
158
  return classes[predic[0]]
159
- #
160
  css = """
161
  body {
162
  background-color: #f6f6f6;
163
- font-family: Arial, sans-serif;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  }
165
- .btn-primary {
 
166
  background-color: #1abc9c;
167
  border-color: #1abc9c;
168
  color: #ffffff;
169
  }
170
- """
171
- demo = gr.Interface(fn=greet, inputs="text", outputs="text",title="text-classification app",css=css)
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  demo.launch()
174
- # with torch.no_grad():
175
- # output=model(input)
176
- # print(output)
177
-
178
- #
179
- # start_time = time.time()
180
- # test_iter = build_iterator(test_data, config)
181
- # with torch.no_grad():
182
- # predict_all = np.array([], dtype=int)
183
- # labels_all = np.array([], dtype=int)
184
- # for texts, labels in test_iter:
185
- # # texts=texts.to(device)
186
- # print(texts)
187
- # outputs = model(texts)
188
- # loss = F.cross_entropy(outputs, labels)
189
- # labels = labels.data.cpu().numpy()
190
- # predic = torch.max(outputs.data, 1)[1].cpu().numpy()
191
- # labels_all = np.append(labels_all, labels)
192
- # predict_all = np.append(predict_all, predic)
193
- # break
194
- # print(labels_all)
195
- # print(predict_all)
196
- #
197
- #
198
 
199
 
 
33
  return vocab_dic
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
 
 
98
  # print(predic)
99
  # print('类别为:{}'.format(classes[predic[0]]))
100
  return classes[predic[0]]
101
+
102
  css = """
103
  body {
104
  background-color: #f6f6f6;
105
+ font-family:Arial, sans-serif;
106
+ }
107
+
108
+ .gradio-interface {
109
+ padding-top: 2rem;
110
+ }
111
+
112
+ .gradio-interface-header-logo {
113
+ display: flex;
114
+ align-items: center;
115
+ }
116
+
117
+ .gradio-interface-header-logo img {
118
+ height: 3rem;
119
+ margin-right: 1rem;
120
+ }
121
+
122
+ .gradio-interface-header-title {
123
+ font-size: 2rem;
124
+ font-weight: bold;
125
+ margin: 0;
126
+ }
127
+
128
+ .gradio-interface-inputs label {
129
+ font-weight: bold;
130
+ }
131
+
132
+ .gradio-interface-inputs gr-input input[type="text"], .gradio-interface-inputs gr-output textarea {
133
+ border: 1px solid #ccc;
134
+ border-radius: 0.25rem;
135
+ padding: 0.5rem;
136
+ font-size: 1rem;
137
+ width: 100%;
138
+ margin-bottom: 1rem;
139
+ resize: none;
140
+ height: 6rem;
141
+ }
142
+
143
+ .gradio-interface-outputs gr-output div {
144
+ border: 1px solid #ccc;
145
+ border-radius: 0.25rem;
146
+ padding: 0.5rem;
147
+ font-size: 1rem;
148
+ width: 100%;
149
+ margin-bottom: 1rem;
150
+ min-height: 6rem;
151
+ }
152
+
153
+ .gradio-interface-footer {
154
+ margin-top: 2rem;
155
  }
156
+
157
+ .gradio-interface-footer .btn-primary {
158
  background-color: #1abc9c;
159
  border-color: #1abc9c;
160
  color: #ffffff;
161
  }
 
 
162
 
163
+ .gradio-interface-header-icon {
164
+ font-size: 2rem;
165
+ margin-right: 1rem;
166
+ }
167
+
168
+ .gradio-interface-footer-icon {
169
+ font-size: 2rem;
170
+ margin-left: 1rem;
171
+ }
172
+
173
+ .gradio-interface-header-icon.emoji-icon {
174
+ display: none;
175
+ }
176
+
177
+ .gradio-interface-header-icon.fa-icon {
178
+ display: inline-block;
179
+ font-family: 'Font Awesome 5 Free';
180
+ font-weight: 900;
181
+ }
182
+
183
+ .gradio-interface-header-icon.fa-icon:before {
184
+ content: '\f007';
185
+ }
186
+
187
+ """
188
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text", title="text-classification app",
189
+ icon="&#x1F60E;", css=css)
190
  demo.launch()
191
+
192
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194