import gradio as gr import torch import re model = None tokenizer = None def init(): from transformers import MT5ForConditionalGeneration, T5TokenizerFast import os global model, tokenizer hf_token = os.environ.get("HF_TOKEN") model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token) model.eval() tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base") tokenizer.add_special_tokens({'additional_special_tokens': ['']}) def correct(text): text = re.sub(r'\u200d', '', text) inputs = tokenizer( text, return_tensors='pt', padding='do_not_pad', max_length=1024 ) with torch.inference_mode(): outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=1024, num_beams=1, do_sample=False, ) prediction = outputs[0] special_token_id_to_keep = tokenizer.convert_tokens_to_ids('') all_special_ids = set(tokenizer.all_special_ids) pred_tokens = prediction.cpu() tokens_list = pred_tokens.tolist() filtered_tokens = [ token for token in tokens_list if token == special_token_id_to_keep or token not in all_special_ids ] prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip() return re.sub(r'\s?', '\u200d', prediction_decoded) init() demo = gr.Interface(fn=correct, inputs="text", outputs="text") demo.launch()