Nadil Karunarathna commited on
Commit
8068f7e
·
1 Parent(s): 90b5d4b
Files changed (1) hide show
  1. app.py +25 -12
app.py CHANGED
@@ -23,29 +23,42 @@ def correct(text):
23
  model.eval()
24
 
25
  text = re.sub(r'\u200d', '<ZWJ>', text)
26
- inputs = tokenizer(text, return_tensors='pt', padding='do_not_pad', max_length=1024)
 
 
 
 
 
27
  inputs = {k: v.to(device) for k, v in inputs.items()}
28
 
29
- with torch.no_grad():
30
  outputs = model.generate(
31
  input_ids=inputs["input_ids"],
32
  attention_mask=inputs["attention_mask"],
33
  max_length=1024,
 
 
34
  )
35
- prediction = outputs[0]
36
 
37
- special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
38
- all_special_ids = torch.tensor(tokenizer.all_special_ids, dtype=torch.int64).to(device)
39
- special_token_tensor = torch.tensor([special_token_id_to_keep], dtype=torch.int64).to(device)
40
 
41
- pred_tokens = prediction.to(device)
42
- tokens_tensor = pred_tokens.clone().detach().to(dtype=torch.int64)
43
- mask = (tokens_tensor == special_token_tensor) | (~torch.isin(tokens_tensor, all_special_ids))
44
- filtered_tokens = tokens_tensor[mask].tolist()
45
 
46
- prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
47
 
48
- return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
 
 
 
 
 
 
49
 
50
  init()
51
 
 
23
  model.eval()
24
 
25
  text = re.sub(r'\u200d', '<ZWJ>', text)
26
+ inputs = tokenizer(
27
+ text,
28
+ return_tensors='pt',
29
+ padding='do_not_pad',
30
+ max_length=1024
31
+ )
32
  inputs = {k: v.to(device) for k, v in inputs.items()}
33
 
34
+ with torch.inference_mode():
35
  outputs = model.generate(
36
  input_ids=inputs["input_ids"],
37
  attention_mask=inputs["attention_mask"],
38
  max_length=1024,
39
+ num_beams=1,
40
+ do_sample=False,
41
  )
42
+ prediction = outputs[0]
43
 
44
+ # special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
45
+ # all_special_ids = torch.tensor(tokenizer.all_special_ids, dtype=torch.int64).to(device)
46
+ # special_token_tensor = torch.tensor([special_token_id_to_keep], dtype=torch.int64).to(device)
47
 
48
+ # pred_tokens = prediction.to(device)
49
+ # tokens_tensor = pred_tokens.clone().detach().to(dtype=torch.int64)
50
+ # mask = (tokens_tensor == special_token_tensor) | (~torch.isin(tokens_tensor, all_special_ids))
51
+ # filtered_tokens = tokens_tensor[mask].tolist()
52
 
53
+ # prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
54
 
55
+ # return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
56
+
57
+ prediction_decoded = tokenizer.decode(prediction, skip_special_tokens=True).replace('\n', '').strip()
58
+ prediction_decoded = re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
59
+
60
+ return prediction_decoded
61
+
62
 
63
  init()
64