Spaces:
Runtime error
Runtime error
File size: 2,575 Bytes
b0a80fd 4d7777a 133bc8b 4d7777a 133bc8b b0a80fd 133bc8b b0a80fd 133bc8b b0a80fd 133bc8b c1332c7 133bc8b c1332c7 133bc8b c1332c7 133bc8b 25f8353 c1332c7 133bc8b 25f8353 c1332c7 25f8353 b0a80fd 133bc8b b0a80fd d70af8b 0a497da 133bc8b 25f8353 133bc8b 25f8353 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
from torchvision import transforms as T
import gradio as gr
class App:
title = 'Scene Text Recognition with<br/>Permuted Autoregressive Sequence Models'
models = ['parseq', 'parseq_tiny', 'abinet', 'crnn', 'trba', 'vitstr']
def __init__(self):
self._model_cache = {}
self._preprocess = T.Compose([
T.Resize((32, 128), T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(0.5, 0.5)
])
def _get_model(self, name):
if name in self._model_cache:
return self._model_cache[name]
model = torch.hub.load('baudm/parseq', name, pretrained=True).eval()
self._model_cache[name] = model
return model
@torch.inference_mode()
def __call__(self, model_name, image_blob):
if image_blob is None:
return '', []
# Decode base64 image blob and convert to PIL Image
image_data = base64.b64decode(image_blob)
image = Image.open(BytesIO(image_data)).convert('RGB')
model = self._get_model(model_name)
image = self._preprocess(image).unsqueeze(0)
# Greedy decoding
pred = model(image).softmax(-1)
label, _ = model.tokenizer.decode(pred)
raw_label, raw_confidence = model.tokenizer.decode(pred, raw=True)
# Format confidence values
max_len = 25 if model_name == 'crnn' else len(label[0]) + 1
conf = list(map('{:0.1f}'.format, raw_confidence[0][:max_len].tolist()))
return label[0], [raw_label[0][:max_len], conf]
def main():
app = App()
with gr.Blocks(analytics_enabled=False, title=app.title.replace('<br/>', ' ')) as demo:
model_name = gr.Radio(app.models, value=app.models[0], label='The STR model to use')
with gr.Tabs():
with gr.TabItem('Image Upload'):
image_upload = gr.Image(type='pil', label='Image')
read_upload = gr.Button('Read Text')
output = gr.Textbox(max_lines=1, label='Model output')
#adv_output = gr.Checkbox(label='Show detailed output')
raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
#adv_output.change(lambda x: gr.update(visible=x), inputs=adv_output, outputs=raw_output)
demo.queue(max_size=20)
demo.launch()
if __name__ == '__main__':
main() |