Nadil Karunarathna commited on
Commit
6baca04
·
1 Parent(s): 03ce3fe
Files changed (1) hide show
  1. app.py +43 -48
app.py CHANGED
@@ -1,65 +1,60 @@
1
  import gradio as gr
2
- # import torch
3
- # import re
4
 
5
- # model = None
6
- # tokenizer = None
7
 
8
- # def init():
9
- # from transformers import MT5ForConditionalGeneration, T5TokenizerFast
10
- # import os
11
-
12
- # global model, tokenizer
13
 
14
- # hf_token = os.environ.get("HF_TOKEN")
15
 
16
- # model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token).to('cpu')
17
- # torch.set_num_threads(2)
18
- # tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
19
- # tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
20
 
 
 
 
 
21
 
22
- # def correct(text):
23
 
24
- # model.eval()
25
 
26
- # text = re.sub(r'\u200d', '<ZWJ>', text)
27
- # inputs = tokenizer(
28
- # text,
29
- # return_tensors='pt',
30
- # padding='do_not_pad',
31
- # max_length=1024
32
- # )
33
 
34
- # with torch.inference_mode():
35
- # outputs = model.generate(
36
- # input_ids=inputs["input_ids"],
37
- # attention_mask=inputs["attention_mask"],
38
- # max_length=1024,
39
- # num_beams=1,
40
- # do_sample=False,
41
- # )
42
- # prediction = outputs[0]
43
 
44
- # special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
45
- # all_special_ids = set(tokenizer.all_special_ids)
46
- # pred_tokens = prediction.cpu()
 
 
 
 
 
 
47
 
48
- # tokens_list = pred_tokens.tolist()
49
- # filtered_tokens = [
50
- # token for token in tokens_list
51
- # if token == special_token_id_to_keep or token not in all_special_ids
52
- # ]
53
 
54
- # prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
 
 
 
 
55
 
56
- # return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
57
 
58
- def test(text):
59
- import os
60
- return f"vCPUs: {os.cpu_count()}"
61
 
62
- # init()
63
- # demo = gr.Interface(fn=correct, inputs="text", outputs="text")
64
- demo = gr.Interface(fn=test, inputs="text", outputs="text")
65
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import re
4
 
5
+ model = None
6
+ tokenizer = None
7
 
8
+ def init():
9
+ from transformers import MT5ForConditionalGeneration, T5TokenizerFast
10
+ import os
 
 
11
 
12
+ global model, tokenizer
13
 
14
+ hf_token = os.environ.get("HF_TOKEN")
 
 
 
15
 
16
+ model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
17
+ torch.set_num_threads(16)
18
+ tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
19
+ tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
20
 
 
21
 
22
+ def correct(text):
23
 
24
+ model.eval()
 
 
 
 
 
 
25
 
26
+ text = re.sub(r'\u200d', '<ZWJ>', text)
27
+ inputs = tokenizer(
28
+ text,
29
+ return_tensors='pt',
30
+ padding='do_not_pad',
31
+ max_length=1024
32
+ )
 
 
33
 
34
+ with torch.inference_mode():
35
+ outputs = model.generate(
36
+ input_ids=inputs["input_ids"],
37
+ attention_mask=inputs["attention_mask"],
38
+ max_length=1024,
39
+ num_beams=1,
40
+ do_sample=False,
41
+ )
42
+ prediction = outputs[0]
43
 
44
+ special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
45
+ all_special_ids = set(tokenizer.all_special_ids)
46
+ pred_tokens = prediction.cpu()
 
 
47
 
48
+ tokens_list = pred_tokens.tolist()
49
+ filtered_tokens = [
50
+ token for token in tokens_list
51
+ if token == special_token_id_to_keep or token not in all_special_ids
52
+ ]
53
 
54
+ prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
55
 
56
+ return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
 
 
57
 
58
+ init()
59
+ demo = gr.Interface(fn=correct, inputs="text", outputs="text")
 
60
  demo.launch()