Nadil Karunarathna commited on
Commit
914f0b6
·
1 Parent(s): cc0fa13
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -14,12 +14,13 @@ def init():
14
  hf_token = os.environ.get("HF_TOKEN")
15
 
16
  model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
 
 
17
  tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
18
  tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
19
-
20
 
21
  def correct(text):
22
- model.eval()
23
 
24
  text = re.sub(r'\u200d', '<ZWJ>', text)
25
  inputs = tokenizer(
@@ -48,10 +49,6 @@ def correct(text):
48
  token for token in tokens_list
49
  if token == special_token_id_to_keep or token not in all_special_ids
50
  ]
51
- # filtered_tokens = [
52
- # token for token in prediction
53
- # if token == special_token_id_to_keep or token not in all_special_ids
54
- # ]
55
 
56
  prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
57
 
 
14
  hf_token = os.environ.get("HF_TOKEN")
15
 
16
  model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
17
+ model.eval()
18
+
19
  tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
20
  tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
21
+
22
 
23
  def correct(text):
 
24
 
25
  text = re.sub(r'\u200d', '<ZWJ>', text)
26
  inputs = tokenizer(
 
49
  token for token in tokens_list
50
  if token == special_token_id_to_keep or token not in all_special_ids
51
  ]
 
 
 
 
52
 
53
  prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
54