File size: 2,680 Bytes
b0a80fd
 
60802a9
4d7777a
 
133bc8b
4d7777a
133bc8b
 
b0a80fd
133bc8b
 
 
 
b0a80fd
 
 
60802a9
b0a80fd
133bc8b
 
 
 
 
 
b0a80fd
60802a9
 
 
 
 
 
 
133bc8b
 
 
 
 
60802a9
 
133bc8b
 
 
60802a9
 
 
 
 
133bc8b
60802a9
 
 
 
b0a80fd
133bc8b
 
b0a80fd
133bc8b
 
 
 
d7658d1
133bc8b
b0a80fd
133bc8b
 
 
 
60802a9
0a497da
133bc8b
 
 
60802a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from torchvision import transforms as T
from transformers import AutoTokenizer
import gradio as gr

class App:

    title = 'Scene Text Recognition with<br/>Permuted Autoregressive Sequence Models'
    models = ['parseq', 'parseq_tiny', 'abinet', 'crnn', 'trba', 'vitstr']

    def __init__(self):
        self._model_cache = {}
        self._preprocess = T.Compose([
            T.Resize((32, 128), T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(0.5, 0.5)
        ])
        self._tokenizer_cache = {}

    def _get_model(self, name):
        if name in self._model_cache:
            return self._model_cache[name]
        model = torch.hub.load('baudm/parseq', name, pretrained=True).eval()
        self._model_cache[name] = model
        return model

    def _get_tokenizer(self, name):
        if name in self._tokenizer_cache:
            return self._tokenizer_cache[name]
        tokenizer = AutoTokenizer.from_pretrained(name)
        self._tokenizer_cache[name] = tokenizer
        return tokenizer

    @torch.inference_mode()
    def __call__(self, model_name, image):
        if image is None:
            return '', []
        model = self._get_model(model_name)
        tokenizer = self._get_tokenizer(model_name)

        image = self._preprocess(image.convert('RGB')).unsqueeze(0)
        # Greedy decoding
        pred = model(image).softmax(-1)
        
        # Tokenize input data
        label = tokenizer.decode(pred.argmax(-1)[0].tolist(), skip_special_tokens=True)
        raw_label, raw_confidence = tokenizer.decode(pred.argmax(-1)[0].tolist(), raw=True)
        
        # Format confidence values
        max_len = 25 if model_name == 'crnn' else len(label) + 1
        conf = list(map('{:0.1f}'.format, pred[0, :, :max_len].tolist()))
        
        return label, [raw_label[:max_len], conf]

def main():
    app = App()

    with gr.Blocks(analytics_enabled=False, title=app.title.replace('<br/>', ' ')) as demo:
        model_name = gr.Radio(app.models, value=app.models[0], label='The STR model to use')
        with gr.Tabs():
            with gr.TabItem('Image Upload'):
                image_upload = gr.Image(type='pil', label='Image')
                read_upload = gr.Button('Read Text')

        output = gr.Textbox(max_lines=1, label='Model output')
        raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')

        read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])

    demo.queue(max_size=20)
    demo.launch()

if __name__ == '__main__':
    main()