File size: 1,595 Bytes
f0c636b 9ba0dd3 f0c636b 9ba0dd3 faa4aa2 699251d 9ba0dd3 979b71c 9ba0dd3 699251d 979b71c b9e0b01 9ba0dd3 914f0b6 9ba0dd3 2d278af 9ba0dd3 8068f7e 436d1f7 8068f7e 9ba0dd3 8068f7e 90b5d4b 8068f7e 90b5d4b 8068f7e 9ba0dd3 436d1f7 fdd932a a46e61a fdd932a 9ba0dd3 436d1f7 991fe21 436d1f7 8068f7e 699251d 19ba9de 252e169 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
import torch
import re
model = None
tokenizer = None
def init():
from transformers import MT5ForConditionalGeneration, T5TokenizerFast
import os
global model, tokenizer
hf_token = os.environ.get("HF_TOKEN")
model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
def correct(text):
model.eval()
text = re.sub(r'\u200d', '<ZWJ>', text)
inputs = tokenizer(
text,
return_tensors='pt',
padding='do_not_pad',
max_length=1024
)
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=1024,
num_beams=1,
do_sample=False,
)
prediction = outputs[0]
special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
all_special_ids = set(tokenizer.all_special_ids)
pred_tokens = prediction.cpu()
tokens_list = pred_tokens.tolist()
filtered_tokens = [
token for token in tokens_list
if token == special_token_id_to_keep or token not in all_special_ids
]
prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()
return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)
init()
demo = gr.Interface(fn=correct, inputs="text", outputs="text")
demo.launch()
|