Nadil Karunarathna commited on
Commit
9ba0dd3
·
1 Parent(s): faa4aa2
Files changed (2) hide show
  1. app.py +37 -7
  2. requirements.txt +3 -1
app.py CHANGED
@@ -1,16 +1,46 @@
1
  import gradio as gr
 
 
2
 
3
- x = ''
 
 
4
 
5
  def init():
6
- global x
7
- x = 'Karu'
8
- print("Model or environment initialized.")
9
 
10
- def correct(name):
11
- return "Hello " + name + x + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  init()
14
 
15
- demo = gr.Interface(fn=correct, inputs="text", outputs="text")
16
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import re
4
 
5
+ model = None
6
+ tokenizer = None
7
+ device = "cpu"
8
 
9
  def init():
10
+ from transformers import MT5ForConditionalGeneration, T5TokenizerFast
11
+ global model, tokenizer
 
12
 
13
+ model_path = "lm-spell/mt5-base-ft-ssc"
14
+ model = MT5ForConditionalGeneration.from_pretrained(model_path).to(device)
15
+ tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
16
+ tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
17
+
18
+
19
+ def correct(text):
20
+ model.eval()
21
+
22
+ text = re.sub(r'\u200d', '<ZWJ>', text)
23
+ inputs = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=128)
24
+ inputs = {k: v.to(device) for k, v in inputs.items()}
25
+
26
+ with torch.no_grad():
27
+ outputs = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
28
+ prediction = outputs[0]
29
+
30
+ special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
31
+ all_special_ids = torch.tensor(tokenizer.all_special_ids, dtype=torch.int64).to(device)
32
+ special_token_tensor = torch.tensor([special_token_id_to_keep], dtype=torch.int64).to(device)
33
+
34
+ pred_tokens = prediction.to(device)
35
+ tokens_tensor = pred_tokens.clone().detach().to(dtype=torch.int64)
36
+ mask = (tokens_tensor == special_token_tensor) | (~torch.isin(tokens_tensor, all_special_ids))
37
+ filtered_tokens = tokens_tensor[mask].tolist()
38
+
39
+ prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
40
+
41
+ return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
42
 
43
  init()
44
 
45
+ demo = gr.Interface(fn=correct, inputs="text", outputs="text", share=True)
46
  demo.launch()
requirements.txt CHANGED
@@ -1 +1,3 @@
1
- gradio
 
 
 
1
+ gradio
2
+ torch==2.5.1
3
+ transformers==4.51.3