Nadil Karunarathna
commited on
Commit
·
914f0b6
1
Parent(s):
cc0fa13
wip
Browse files
app.py
CHANGED
@@ -14,12 +14,13 @@ def init():
|
|
14 |
hf_token = os.environ.get("HF_TOKEN")
|
15 |
|
16 |
model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
|
|
|
|
|
17 |
tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
|
18 |
tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
|
19 |
-
|
20 |
|
21 |
def correct(text):
|
22 |
-
model.eval()
|
23 |
|
24 |
text = re.sub(r'\u200d', '<ZWJ>', text)
|
25 |
inputs = tokenizer(
|
@@ -48,10 +49,6 @@ def correct(text):
|
|
48 |
token for token in tokens_list
|
49 |
if token == special_token_id_to_keep or token not in all_special_ids
|
50 |
]
|
51 |
-
# filtered_tokens = [
|
52 |
-
# token for token in prediction
|
53 |
-
# if token == special_token_id_to_keep or token not in all_special_ids
|
54 |
-
# ]
|
55 |
|
56 |
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
|
57 |
|
|
|
14 |
hf_token = os.environ.get("HF_TOKEN")
|
15 |
|
16 |
model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
|
17 |
+
model.eval()
|
18 |
+
|
19 |
tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
|
20 |
tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
|
21 |
+
|
22 |
|
23 |
def correct(text):
|
|
|
24 |
|
25 |
text = re.sub(r'\u200d', '<ZWJ>', text)
|
26 |
inputs = tokenizer(
|
|
|
49 |
token for token in tokens_list
|
50 |
if token == special_token_id_to_keep or token not in all_special_ids
|
51 |
]
|
|
|
|
|
|
|
|
|
52 |
|
53 |
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
|
54 |
|