Spaces:
Runtime error
Runtime error
Commit
·
60802a9
1
Parent(s):
0a497da
code changed
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import torch
|
2 |
from torchvision import transforms as T
|
|
|
3 |
import gradio as gr
|
4 |
|
5 |
class App:
|
@@ -14,6 +15,7 @@ class App:
|
|
14 |
T.ToTensor(),
|
15 |
T.Normalize(0.5, 0.5)
|
16 |
])
|
|
|
17 |
|
18 |
def _get_model(self, name):
|
19 |
if name in self._model_cache:
|
@@ -22,21 +24,33 @@ class App:
|
|
22 |
self._model_cache[name] = model
|
23 |
return model
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
@torch.inference_mode()
|
26 |
def __call__(self, model_name, image):
|
27 |
if image is None:
|
28 |
return '', []
|
29 |
model = self._get_model(model_name)
|
|
|
|
|
30 |
image = self._preprocess(image.convert('RGB')).unsqueeze(0)
|
31 |
# Greedy decoding
|
32 |
pred = model(image).softmax(-1)
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
# Format confidence values
|
36 |
-
max_len = 25 if model_name == 'crnn' else len(label
|
37 |
-
conf = list(map('{:0.1f}'.format,
|
38 |
-
|
39 |
-
|
40 |
|
41 |
def main():
|
42 |
app = App()
|
@@ -49,14 +63,12 @@ def main():
|
|
49 |
read_upload = gr.Button('Read Text')
|
50 |
|
51 |
output = gr.Textbox(max_lines=1, label='Model output')
|
52 |
-
#adv_output = gr.Checkbox(label='Show detailed output')
|
53 |
raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
|
54 |
|
55 |
read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
|
56 |
-
|
57 |
demo.queue(max_size=20)
|
58 |
demo.launch()
|
59 |
|
60 |
-
|
61 |
if __name__ == '__main__':
|
62 |
-
main()
|
|
|
1 |
import torch
|
2 |
from torchvision import transforms as T
|
3 |
+
from transformers import AutoTokenizer
|
4 |
import gradio as gr
|
5 |
|
6 |
class App:
|
|
|
15 |
T.ToTensor(),
|
16 |
T.Normalize(0.5, 0.5)
|
17 |
])
|
18 |
+
self._tokenizer_cache = {}
|
19 |
|
20 |
def _get_model(self, name):
|
21 |
if name in self._model_cache:
|
|
|
24 |
self._model_cache[name] = model
|
25 |
return model
|
26 |
|
27 |
+
def _get_tokenizer(self, name):
|
28 |
+
if name in self._tokenizer_cache:
|
29 |
+
return self._tokenizer_cache[name]
|
30 |
+
tokenizer = AutoTokenizer.from_pretrained(name)
|
31 |
+
self._tokenizer_cache[name] = tokenizer
|
32 |
+
return tokenizer
|
33 |
+
|
34 |
@torch.inference_mode()
|
35 |
def __call__(self, model_name, image):
|
36 |
if image is None:
|
37 |
return '', []
|
38 |
model = self._get_model(model_name)
|
39 |
+
tokenizer = self._get_tokenizer(model_name)
|
40 |
+
|
41 |
image = self._preprocess(image.convert('RGB')).unsqueeze(0)
|
42 |
# Greedy decoding
|
43 |
pred = model(image).softmax(-1)
|
44 |
+
|
45 |
+
# Tokenize input data
|
46 |
+
label = tokenizer.decode(pred.argmax(-1)[0].tolist(), skip_special_tokens=True)
|
47 |
+
raw_label, raw_confidence = tokenizer.decode(pred.argmax(-1)[0].tolist(), raw=True)
|
48 |
+
|
49 |
# Format confidence values
|
50 |
+
max_len = 25 if model_name == 'crnn' else len(label) + 1
|
51 |
+
conf = list(map('{:0.1f}'.format, pred[0, :, :max_len].tolist()))
|
52 |
+
|
53 |
+
return label, [raw_label[:max_len], conf]
|
54 |
|
55 |
def main():
|
56 |
app = App()
|
|
|
63 |
read_upload = gr.Button('Read Text')
|
64 |
|
65 |
output = gr.Textbox(max_lines=1, label='Model output')
|
|
|
66 |
raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
|
67 |
|
68 |
read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
|
69 |
+
|
70 |
demo.queue(max_size=20)
|
71 |
demo.launch()
|
72 |
|
|
|
73 |
if __name__ == '__main__':
|
74 |
+
main()
|