Nadil Karunarathna commited on
Commit
436d1f7
·
1 Parent(s): 8068f7e
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -25,8 +25,8 @@ def correct(text):
25
  text = re.sub(r'\u200d', '<ZWJ>', text)
26
  inputs = tokenizer(
27
  text,
28
- return_tensors='pt',
29
- padding='do_not_pad',
30
  max_length=1024
31
  )
32
  inputs = {k: v.to(device) for k, v in inputs.items()}
@@ -41,23 +41,23 @@ def correct(text):
41
  )
42
  prediction = outputs[0]
43
 
44
- # special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
45
- # all_special_ids = torch.tensor(tokenizer.all_special_ids, dtype=torch.int64).to(device)
46
- # special_token_tensor = torch.tensor([special_token_id_to_keep], dtype=torch.int64).to(device)
47
 
48
- # pred_tokens = prediction.to(device)
49
- # tokens_tensor = pred_tokens.clone().detach().to(dtype=torch.int64)
50
- # mask = (tokens_tensor == special_token_tensor) | (~torch.isin(tokens_tensor, all_special_ids))
51
- # filtered_tokens = tokens_tensor[mask].tolist()
52
 
53
- # prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
54
 
55
- # return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
56
 
57
- prediction_decoded = tokenizer.decode(prediction, skip_special_tokens=True).replace('\n', '').strip()
58
- prediction_decoded = re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
59
 
60
- return prediction_decoded
61
 
62
 
63
  init()
 
25
  text = re.sub(r'\u200d', '<ZWJ>', text)
26
  inputs = tokenizer(
27
  text,
28
+ return_tensors='pt',
29
+ padding='do_not_pad',
30
  max_length=1024
31
  )
32
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
41
  )
42
  prediction = outputs[0]
43
 
44
+ special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
45
+ all_special_ids = torch.tensor(tokenizer.all_special_ids, dtype=torch.int64).to(device)
46
+ special_token_tensor = torch.tensor([special_token_id_to_keep], dtype=torch.int64).to(device)
47
 
48
+ pred_tokens = prediction.to(device)
49
+ tokens_tensor = pred_tokens.clone().detach().to(dtype=torch.int64)
50
+ mask = (tokens_tensor == special_token_tensor) | (~torch.isin(tokens_tensor, all_special_ids))
51
+ filtered_tokens = tokens_tensor[mask].tolist()
52
 
53
+ prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
54
 
55
+ return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
56
 
57
+ # prediction_decoded = tokenizer.decode(prediction, skip_special_tokens=True).replace('\n', '').strip()
58
+ # prediction_decoded = re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
59
 
60
+ # return prediction_decoded
61
 
62
 
63
  init()