File size: 1,595 Bytes
f0c636b
9ba0dd3
 
f0c636b
9ba0dd3
 
faa4aa2
699251d
9ba0dd3
979b71c
 
9ba0dd3
699251d
979b71c
b9e0b01
 
9ba0dd3
 
914f0b6
9ba0dd3
 
 
2d278af
 
9ba0dd3
8068f7e
 
436d1f7
 
8068f7e
 
9ba0dd3
8068f7e
90b5d4b
 
 
 
8068f7e
 
90b5d4b
8068f7e
9ba0dd3
436d1f7
fdd932a
 
a46e61a
fdd932a
 
 
 
 
9ba0dd3
436d1f7
991fe21
436d1f7
8068f7e
699251d
 
19ba9de
252e169
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import torch
import re

model = None
tokenizer = None

def init():
    from transformers import MT5ForConditionalGeneration, T5TokenizerFast
    import os

    global model, tokenizer

    hf_token = os.environ.get("HF_TOKEN")

    model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
    tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
    tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})


def correct(text):

    model.eval()

    text = re.sub(r'\u200d', '<ZWJ>', text)
    inputs = tokenizer(
        text,
        return_tensors='pt',
        padding='do_not_pad',
        max_length=1024
    )

    with torch.inference_mode():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=1024,
            num_beams=1,
            do_sample=False,
        )
    prediction = outputs[0]

    special_token_id_to_keep = tokenizer.convert_tokens_to_ids('<ZWJ>')
    all_special_ids = set(tokenizer.all_special_ids)
    pred_tokens = prediction.cpu()

    tokens_list = pred_tokens.tolist()
    filtered_tokens = [
        token for token in tokens_list 
        if token == special_token_id_to_keep or token not in all_special_ids
    ]

    prediction_decoded = tokenizer.decode(filtered_tokens, skip_special_tokens=False).replace('\n', '').strip()

    return re.sub(r'<ZWJ>\s?', '\u200d', prediction_decoded)

init()

demo = gr.Interface(fn=correct, inputs="text", outputs="text")
demo.launch()