Spaces:
Runtime error
Runtime error
File size: 2,680 Bytes
b0a80fd 60802a9 4d7777a 133bc8b 4d7777a 133bc8b b0a80fd 133bc8b b0a80fd 60802a9 b0a80fd 133bc8b b0a80fd 60802a9 133bc8b 60802a9 133bc8b 60802a9 133bc8b 60802a9 b0a80fd 133bc8b b0a80fd 133bc8b d7658d1 133bc8b b0a80fd 133bc8b 60802a9 0a497da 133bc8b 60802a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import torch
from torchvision import transforms as T
from transformers import AutoTokenizer
import gradio as gr
class App:
title = 'Scene Text Recognition with<br/>Permuted Autoregressive Sequence Models'
models = ['parseq', 'parseq_tiny', 'abinet', 'crnn', 'trba', 'vitstr']
def __init__(self):
self._model_cache = {}
self._preprocess = T.Compose([
T.Resize((32, 128), T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(0.5, 0.5)
])
self._tokenizer_cache = {}
def _get_model(self, name):
if name in self._model_cache:
return self._model_cache[name]
model = torch.hub.load('baudm/parseq', name, pretrained=True).eval()
self._model_cache[name] = model
return model
def _get_tokenizer(self, name):
if name in self._tokenizer_cache:
return self._tokenizer_cache[name]
tokenizer = AutoTokenizer.from_pretrained(name)
self._tokenizer_cache[name] = tokenizer
return tokenizer
@torch.inference_mode()
def __call__(self, model_name, image):
if image is None:
return '', []
model = self._get_model(model_name)
tokenizer = self._get_tokenizer(model_name)
image = self._preprocess(image.convert('RGB')).unsqueeze(0)
# Greedy decoding
pred = model(image).softmax(-1)
# Tokenize input data
label = tokenizer.decode(pred.argmax(-1)[0].tolist(), skip_special_tokens=True)
raw_label, raw_confidence = tokenizer.decode(pred.argmax(-1)[0].tolist(), raw=True)
# Format confidence values
max_len = 25 if model_name == 'crnn' else len(label) + 1
conf = list(map('{:0.1f}'.format, pred[0, :, :max_len].tolist()))
return label, [raw_label[:max_len], conf]
def main():
app = App()
with gr.Blocks(analytics_enabled=False, title=app.title.replace('<br/>', ' ')) as demo:
model_name = gr.Radio(app.models, value=app.models[0], label='The STR model to use')
with gr.Tabs():
with gr.TabItem('Image Upload'):
image_upload = gr.Image(type='pil', label='Image')
read_upload = gr.Button('Read Text')
output = gr.Textbox(max_lines=1, label='Model output')
raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
demo.queue(max_size=20)
demo.launch()
if __name__ == '__main__':
main()
|