Spaces:
Runtime error
Runtime error
Commit
·
25f8353
1
Parent(s):
9467a8a
revert 2
Browse files
app.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
import torch
|
| 2 |
from torchvision import transforms as T
|
| 3 |
-
from transformers import AutoTokenizer
|
| 4 |
import gradio as gr
|
| 5 |
|
| 6 |
class App:
|
|
@@ -15,7 +14,6 @@ class App:
|
|
| 15 |
T.ToTensor(),
|
| 16 |
T.Normalize(0.5, 0.5)
|
| 17 |
])
|
| 18 |
-
self._tokenizer_cache = {}
|
| 19 |
|
| 20 |
def _get_model(self, name):
|
| 21 |
if name in self._model_cache:
|
|
@@ -24,33 +22,21 @@ class App:
|
|
| 24 |
self._model_cache[name] = model
|
| 25 |
return model
|
| 26 |
|
| 27 |
-
def _get_tokenizer(self, name):
|
| 28 |
-
if name in self._tokenizer_cache:
|
| 29 |
-
return self._tokenizer_cache[name]
|
| 30 |
-
tokenizer = AutoTokenizer.from_pretrained(name)
|
| 31 |
-
self._tokenizer_cache[name] = tokenizer
|
| 32 |
-
return tokenizer
|
| 33 |
-
|
| 34 |
@torch.inference_mode()
|
| 35 |
def __call__(self, model_name, image):
|
| 36 |
if image is None:
|
| 37 |
return '', []
|
| 38 |
model = self._get_model(model_name)
|
| 39 |
-
tokenizer = self._get_tokenizer(model_name)
|
| 40 |
-
|
| 41 |
image = self._preprocess(image.convert('RGB')).unsqueeze(0)
|
| 42 |
# Greedy decoding
|
| 43 |
pred = model(image).softmax(-1)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
label = tokenizer.decode(pred.argmax(-1)[0].tolist(), skip_special_tokens=True)
|
| 47 |
-
raw_label, raw_confidence = tokenizer.decode(pred.argmax(-1)[0].tolist(), raw=True)
|
| 48 |
-
|
| 49 |
# Format confidence values
|
| 50 |
-
max_len = 25 if model_name == 'crnn' else len(label) + 1
|
| 51 |
-
conf = list(map('{:0.1f}'.format,
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
def main():
|
| 56 |
app = App()
|
|
@@ -63,12 +49,14 @@ def main():
|
|
| 63 |
read_upload = gr.Button('Read Text')
|
| 64 |
|
| 65 |
output = gr.Textbox(max_lines=1, label='Model output')
|
|
|
|
| 66 |
raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
|
| 67 |
|
| 68 |
read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
|
| 69 |
-
|
| 70 |
demo.queue(max_size=20)
|
| 71 |
demo.launch()
|
| 72 |
|
|
|
|
| 73 |
if __name__ == '__main__':
|
| 74 |
-
main()
|
|
|
|
| 1 |
import torch
|
| 2 |
from torchvision import transforms as T
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
|
| 5 |
class App:
|
|
|
|
| 14 |
T.ToTensor(),
|
| 15 |
T.Normalize(0.5, 0.5)
|
| 16 |
])
|
|
|
|
| 17 |
|
| 18 |
def _get_model(self, name):
|
| 19 |
if name in self._model_cache:
|
|
|
|
| 22 |
self._model_cache[name] = model
|
| 23 |
return model
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
@torch.inference_mode()
|
| 26 |
def __call__(self, model_name, image):
|
| 27 |
if image is None:
|
| 28 |
return '', []
|
| 29 |
model = self._get_model(model_name)
|
|
|
|
|
|
|
| 30 |
image = self._preprocess(image.convert('RGB')).unsqueeze(0)
|
| 31 |
# Greedy decoding
|
| 32 |
pred = model(image).softmax(-1)
|
| 33 |
+
label, _ = model.tokenizer.decode(pred)
|
| 34 |
+
raw_label, raw_confidence = model.tokenizer.decode(pred, raw=True)
|
|
|
|
|
|
|
|
|
|
| 35 |
# Format confidence values
|
| 36 |
+
max_len = 25 if model_name == 'crnn' else len(label[0]) + 1
|
| 37 |
+
conf = list(map('{:0.1f}'.format, raw_confidence[0][:max_len].tolist()))
|
| 38 |
+
return label[0], [raw_label[0][:max_len], conf]
|
| 39 |
+
|
| 40 |
|
| 41 |
def main():
|
| 42 |
app = App()
|
|
|
|
| 49 |
read_upload = gr.Button('Read Text')
|
| 50 |
|
| 51 |
output = gr.Textbox(max_lines=1, label='Model output')
|
| 52 |
+
#adv_output = gr.Checkbox(label='Show detailed output')
|
| 53 |
raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
|
| 54 |
|
| 55 |
read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
|
| 56 |
+
#adv_output.change(lambda x: gr.update(visible=x), inputs=adv_output, outputs=raw_output)
|
| 57 |
demo.queue(max_size=20)
|
| 58 |
demo.launch()
|
| 59 |
|
| 60 |
+
|
| 61 |
if __name__ == '__main__':
|
| 62 |
+
main()
|