Macropodus commited on
Commit
391e575
·
verified ·
1 Parent(s): 292c061

using marco-correct

Browse files
Files changed (1) hide show
  1. app.py +43 -246
app.py CHANGED
@@ -1,271 +1,68 @@
 
1
  # -*- coding: utf-8 -*-
2
-
3
- import operator
4
- import copy
5
- import re
6
-
7
- from transformers import BertTokenizer, BertForMaskedLM
 
 
 
 
 
 
 
 
 
 
 
8
  import gradio as gr
9
- import opencc
10
- import torch
11
-
12
-
13
- pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v2"
14
- tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path)
15
- model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
16
- vocab = tokenizer.vocab
17
-
18
 
19
- # from modelscope import AutoTokenizer, AutoModelForMaskedLM
20
- # pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
21
- # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
22
- # model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
23
- # vocab = tokenizer.vocab
24
- converter_t2s = opencc.OpenCC("t2s.json")
25
- context = converter_t2s.convert("汉字") # 漢字
26
- PUN_EN2ZH_DICT = {",": ",", ";": ";", "!": "!", "?": "?", ":": ":", "(": "(", ")": ")", "_": "—"}
27
- PUN_BERT_DICT = {"“":'"', "”":'"', "‘":'"', "’":'"', "—": "_", "——": "__"}
28
 
29
 
30
- def func_macro_correct(text):
31
- with torch.no_grad():
32
- outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
 
 
 
 
 
33
 
34
- def flag_total_chinese(text):
35
- """
36
- judge is total chinese or not, 判断是不是全是中文
37
- Args:
38
- text: str, eg. "macadam, 碎石路"
39
- Returns:
40
- bool, True or False
41
- """
42
- for word in text:
43
- if not "\u4e00" <= word <= "\u9fa5":
44
- return False
45
- return True
46
 
47
- def get_errors_from_diff_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
48
- """Get errors between corrected text and origin text
49
- code from: https://github.com/shibing624/pycorrector
50
- """
51
- new_corrected_text = ""
52
- errors = []
53
- i, j = 0, 0
54
- unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '']
55
- while i < len(origin_text) and j < len(corrected_text):
56
- if origin_text[i] in unk_tokens or origin_text[i] not in know_tokens:
57
- new_corrected_text += origin_text[i]
58
- i += 1
59
- elif corrected_text[j] in unk_tokens:
60
- new_corrected_text += corrected_text[j]
61
- j += 1
62
- # Deal with Chinese characters
63
- elif flag_total_chinese(origin_text[i]) and flag_total_chinese(corrected_text[j]):
64
- # If the two characters are the same, then the two pointers move forward together
65
- if origin_text[i] == corrected_text[j]:
66
- new_corrected_text += corrected_text[j]
67
- i += 1
68
- j += 1
69
- else:
70
- # Check for insertion errors
71
- if j + 1 < len(corrected_text) and origin_text[i] == corrected_text[j + 1]:
72
- errors.append(('', corrected_text[j], j))
73
- new_corrected_text += corrected_text[j]
74
- j += 1
75
- # Check for deletion errors
76
- elif i + 1 < len(origin_text) and origin_text[i + 1] == corrected_text[j]:
77
- errors.append((origin_text[i], '', i))
78
- i += 1
79
- # Check for replacement errors
80
- else:
81
- errors.append((origin_text[i], corrected_text[j], i))
82
- new_corrected_text += corrected_text[j]
83
- i += 1
84
- j += 1
85
- else:
86
- new_corrected_text += origin_text[i]
87
- if origin_text[i] == corrected_text[j]:
88
- j += 1
89
- i += 1
90
- errors = sorted(errors, key=operator.itemgetter(2))
91
- return new_corrected_text, errors
92
-
93
- def get_errors_from_same_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
94
- """Get new corrected text and errors between corrected text and origin text
95
- code from: https://github.com/shibing624/pycorrector
96
- """
97
- errors = []
98
- unk_tokens = unk_tokens or [' ', '“', '”', '‘', '’', '琊', '\n', '…', '擤', '\t', '玕', '', ',']
99
-
100
- for i, ori_char in enumerate(origin_text):
101
- if i >= len(corrected_text):
102
- continue
103
- if ori_char in unk_tokens or ori_char not in know_tokens:
104
- # deal with unk word
105
- corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
106
- continue
107
- if ori_char != corrected_text[i]:
108
- if not flag_total_chinese(ori_char):
109
- # pass not chinese char
110
- corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
111
- continue
112
- if not flag_total_chinese(corrected_text[i]):
113
- corrected_text = corrected_text[:i] + corrected_text[i + 1:]
114
- continue
115
- errors.append([ori_char, corrected_text[i], i])
116
- errors = sorted(errors, key=operator.itemgetter(2))
117
- return corrected_text, errors
118
-
119
- _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
120
- corrected_text = _text[:len(text)]
121
- print("#" * 128)
122
  print(text)
123
- print(corrected_text)
124
- print(len(text), len(corrected_text))
125
- if len(corrected_text) == len(text):
126
- corrected_text, details = get_errors_from_same_length(corrected_text, text, know_tokens=vocab)
127
- else:
128
- corrected_text, details = get_errors_from_diff_length(corrected_text, text, know_tokens=vocab)
129
- print(text, ' => ', corrected_text, details)
130
- # return corrected_text + ' ' + str(details)
131
- line_dict = {"source": text, "target": corrected_text, "errors": details}
132
- return line_dict
133
-
134
-
135
- def transfor_english_symbol_to_chinese(text, kv_dict=PUN_EN2ZH_DICT):
136
- """ 将英文标点符号转化为中文标点符号, 位数不能变防止pos_id变化 """
137
- for k, v in kv_dict.items(): # 英文替换
138
- text = text.replace(k, v)
139
- if text and text[-1] == ".": # 最后一个字符是英文.
140
- text = text[:-1] + "。"
141
-
142
- if text and "\"" in text: # 双引号
143
- index_list = [i.start() for i in re.finditer("\"", text)]
144
- if index_list:
145
- for idx, index in enumerate(index_list):
146
- symbol = "“" if idx % 2 == 0 else "”"
147
- text = text[:index] + symbol + text[index + 1:]
148
-
149
- if text and "'" in text: # 单引号
150
- index_list = [i.start() for i in re.finditer("'", text)]
151
- if index_list:
152
- for idx, index in enumerate(index_list):
153
- symbol = "‘" if idx % 2 == 0 else "’"
154
- text = text[:index] + symbol + text[index + 1:]
155
- return text
156
- def cut_sent_by_stay(text, return_length=True, add_semicolon=False):
157
- """ 分句但是保存原标点符号 """
158
- if add_semicolon:
159
- text_sp = re.split(r"!”|?”|。”|……”|”!|”?|”。|”……|》。|)。|;|!|?|。|…|\!|\?", text)
160
- conn_symbol = ";!?。…”;!?》)\n"
161
- else:
162
- text_sp = re.split(r"!”|?”|。”|……”|”!|”?|”。|”……|》。|)。|!|?|。|…|\!|\?", text)
163
- conn_symbol = "!?。…”!?》)\n"
164
- text_length_s = []
165
- text_cut = []
166
- len_text = len(text) - 1
167
- # signal_symbol = "—”>;?…)‘《’(·》“~,、!。:<"
168
- len_global = 0
169
- for idx, text_sp_i in enumerate(text_sp):
170
- text_cut_idx = text_sp[idx]
171
- len_global_before = copy.deepcopy(len_global)
172
- len_global += len(text_sp_i)
173
- while True:
174
- if len_global <= len_text and text[len_global] in conn_symbol:
175
- text_cut_idx += text[len_global]
176
- else:
177
- # len_global += 1
178
- if text_cut_idx:
179
- text_length_s.append([len_global_before, len_global])
180
- text_cut.append(text_cut_idx)
181
- break
182
- len_global += 1
183
- if return_length:
184
- return text_cut, text_length_s
185
- return text_cut
186
- def transfor_bert_unk_pun_to_know(text, kv_dict=PUN_BERT_DICT):
187
- """ 将英文标点符号转化为中文标点符号, 位数不能变防止pos_id变化 """
188
- for k, v in kv_dict.items(): # 英文替换
189
- text = text.replace(k, v)
190
- return text
191
- def tradition_to_simple(text):
192
- """ 繁体到简体 """
193
- return converter_t2s.convert(text)
194
- def string_q2b(ustring):
195
- """把字符串全角转半角"""
196
- return "".join([q2b(uchar) for uchar in ustring])
197
- def q2b(uchar):
198
- """全角转半角"""
199
- inside_code = ord(uchar)
200
- if inside_code == 0x3000:
201
- inside_code = 0x0020
202
- else:
203
- inside_code -= 0xfee0
204
- if inside_code < 0x0020 or inside_code > 0x7e: # 转完之后不是半角字符返回原来的字符
205
- return uchar
206
- return chr(inside_code)
207
-
208
-
209
- def func_macro_correct_long(text):
210
- """ 长句 """
211
- texts, length = cut_sent_by_stay(text, return_length=True, add_semicolon=True)
212
- text_correct = ""
213
- errors_new = []
214
- for idx, text in enumerate(texts):
215
- # 前处理
216
- text = transfor_english_symbol_to_chinese(text)
217
- text = string_q2b(text)
218
- text = tradition_to_simple(text)
219
- text = transfor_bert_unk_pun_to_know(text)
220
-
221
- text_out = func_macro_correct(text)
222
- source = text_out.get("source")
223
- target = text_out.get("target")
224
- errors = text_out.get("errors")
225
- text_correct += target
226
- for error in errors:
227
- if not error[0].strip() or not error[1].strip():
228
- continue
229
- pos = length[idx][0] + error[-1]
230
- error_1 = [error[0], error[1], pos]
231
- errors_new.append(error_1)
232
- return text_correct + '\n' + str(errors_new)
233
 
234
 
235
  if __name__ == '__main__':
236
- text = """网购的烦脑
237
- emer 发布于 2025-7-3 18:20 阅读:73
238
-
239
- 最近网购遇到件恼火的事。我在网店看中件羽戎服,店家保正是正品,还承诺七天无里由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。
240
-
241
- 联系客服时,对方态度敷衔,先说让我自行缝补,后又说要扣除运废才给退。我在评沦区如实描述经历,结果发现好多消废者都有类似遭遇。
242
-
243
- 这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复合对商品信息。
244
- 网购的烦恼发布于2025-7-310期阅读:最近网购遇到件恼火的事。我在网店看中件羽绒服,店家保证是正品,还承诺七天无理由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。联系客服时,对方态度敷衍,先说让我自行缝补,后又说要扣除运废才给退。我在评论区如实描述经历,结果发现好多消废者都有类似遭遇。这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复核对商品信息。
245
- 网购的烦恼e发布于2025-7-3期期阅读:最近网购遇到件恼火的事。我在网店看中件羽绒服,店家保证是正品,还承诺七天无理由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。联系客服时,对方态度敷衍,先说让我自行缝补,后又说要扣除运废才给退。我在评论区如实描述经历,结果发现好多消废者都有类似遭遇。这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复核对商品信息。网购的烦恼发布于2025-7-310期阅读:最近网购遇到件恼火的事。我在网店看中件羽绒服,店家保证是正品,还承诺七天无理由退换。收到货后却发现袖口有开线,更糟的是拉链老是卡住。联系客服时,对方态度敷衍,先说让我自行缝补,后又说要扣除运废才给退。我在评论区如实描述经历,结果发现好多消废者都有类似遭遇。这次购物让我明白,不能光看店家的宣全,要多查考真实评价。现在我已经学精了,下单前总会反复核对商品信息。"""
246
- print(func_macro_correct_long(text))
247
 
248
  examples = [
249
- "夫谷之雨,犹复云之亦从的起,因与疾风俱飘,参于天,集于的。",
250
  "机七学习是人工智能领遇最能体现智能的一个分知",
251
- '他们的吵翻很不错,再说他们做的咖喱鸡也好吃',
252
- "抗疫路上,除了提心吊胆也有难的得欢笑。",
253
  "我是练习时长两念半的鸽仁练习生蔡徐坤",
254
- "清晨,如纱一般地薄雾笼罩着世界。",
255
- "得府许我立庙于此,故请君移去尔。",
256
  "他法语说的很好,的语也不错",
257
  "遇到一位很棒的奴生跟我疗天",
258
- "五年级得数学,我考的很差。",
259
  "我们为这个目标努力不解",
260
- '今天兴情很好',
261
  ]
262
-
263
  gr.Interface(
264
- func_macro_correct_long,
265
  inputs='text',
266
  outputs='text',
267
- title="Chinese Spelling Correction Model Macropodus/macbert4mdcspell_v2",
268
  description="Copy or input error Chinese text. Submit and the machine will correct text.",
269
  article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
270
  examples=examples
271
- ).launch()
 
 
 
1
+ # !/usr/bin/python
2
  # -*- coding: utf-8 -*-
3
+ # @time : 2021/2/29 21:41
4
+ # @author : Mo
5
+ # @function: transformers直接加载bert类模型测试
6
+
7
+
8
+ import traceback
9
+ import time
10
+ import sys
11
+ import os
12
+ os.environ["MACRO_CORRECT_FLAG_CSC_TOKEN"] = "1"
13
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
14
+ os.environ["USE_TORCH"] = "1"
15
+
16
+ from macro_correct.pytorch_textcorrection.tcTools import cut_sent_by_stay
17
+ from macro_correct import correct_basic
18
+ from macro_correct import correct_long
19
+ from macro_correct import correct
20
  import gradio as gr
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
+ # pretrained_model_name_or_path = "shibing624/macbert4csc-base-chinese"
25
+ pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
26
+ # pretrained_model_name_or_path = "Macropodus/macbert4mdcspell_v1"
27
+ # pretrained_model_name_or_path = "Macropodus/macbert4csc_v1"
28
+ # pretrained_model_name_or_path = "Macropodus/macbert4csc_v2"
29
+ # pretrained_model_name_or_path = "Macropodus/bert4csc_v1"
30
+ # device = torch.device("cpu")
31
+ # device = torch.device("cuda")
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ def macro_correct(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  print(text)
36
+ text_csc = correct_long(text)
37
+ print(text_csc)
38
+ print("#"*128)
39
+ text_out = ""
40
+ for t in text_csc:
41
+ for k, v in t.items():
42
+ text_out += f"{k}: {v}\n"
43
+ text_out += "\n"
44
+ return text_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  if __name__ == '__main__':
48
+ print(macro_correct('少先队员因该为老人让坐'))
 
 
 
 
 
 
 
 
 
 
49
 
50
  examples = [
 
51
  "机七学习是人工智能领遇最能体现智能的一个分知",
 
 
52
  "我是练习时长两念半的鸽仁练习生蔡徐坤",
53
+ "真麻烦你了。希望你们好好的跳无",
 
54
  "他法语说的很好,的语也不错",
55
  "遇到一位很棒的奴生跟我疗天",
 
56
  "我们为这个目标努力不解",
 
57
  ]
 
58
  gr.Interface(
59
+ macro_correct,
60
  inputs='text',
61
  outputs='text',
62
+ title="Chinese Spelling Correction Model Macropodus/macbert4csc_v2",
63
  description="Copy or input error Chinese text. Submit and the machine will correct text.",
64
  article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
65
  examples=examples
66
+ ).launch()
67
+ # ).launch(server_name="0.0.0.0", server_port=8066, share=False, debug=True)
68
+