murtazadahmardeh commited on
Commit
60802a9
·
1 Parent(s): 0a497da

code changed

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  from torchvision import transforms as T
 
3
  import gradio as gr
4
 
5
  class App:
@@ -14,6 +15,7 @@ class App:
14
  T.ToTensor(),
15
  T.Normalize(0.5, 0.5)
16
  ])
 
17
 
18
  def _get_model(self, name):
19
  if name in self._model_cache:
@@ -22,21 +24,33 @@ class App:
22
  self._model_cache[name] = model
23
  return model
24
 
 
 
 
 
 
 
 
25
  @torch.inference_mode()
26
  def __call__(self, model_name, image):
27
  if image is None:
28
  return '', []
29
  model = self._get_model(model_name)
 
 
30
  image = self._preprocess(image.convert('RGB')).unsqueeze(0)
31
  # Greedy decoding
32
  pred = model(image).softmax(-1)
33
- label, _ = model.tokenizer.decode(pred)
34
- raw_label, raw_confidence = model.tokenizer.decode(pred, raw=True)
 
 
 
35
  # Format confidence values
36
- max_len = 25 if model_name == 'crnn' else len(label[0]) + 1
37
- conf = list(map('{:0.1f}'.format, raw_confidence[0][:max_len].tolist()))
38
- return label[0], [raw_label[0][:max_len], conf]
39
-
40
 
41
  def main():
42
  app = App()
@@ -49,14 +63,12 @@ def main():
49
  read_upload = gr.Button('Read Text')
50
 
51
  output = gr.Textbox(max_lines=1, label='Model output')
52
- #adv_output = gr.Checkbox(label='Show detailed output')
53
  raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
54
 
55
  read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
56
- #adv_output.change(lambda x: gr.update(visible=x), inputs=adv_output, outputs=raw_output)
57
  demo.queue(max_size=20)
58
  demo.launch()
59
 
60
-
61
  if __name__ == '__main__':
62
- main()
 
1
  import torch
2
  from torchvision import transforms as T
3
+ from transformers import AutoTokenizer
4
  import gradio as gr
5
 
6
  class App:
 
15
  T.ToTensor(),
16
  T.Normalize(0.5, 0.5)
17
  ])
18
+ self._tokenizer_cache = {}
19
 
20
  def _get_model(self, name):
21
  if name in self._model_cache:
 
24
  self._model_cache[name] = model
25
  return model
26
 
27
+ def _get_tokenizer(self, name):
28
+ if name in self._tokenizer_cache:
29
+ return self._tokenizer_cache[name]
30
+ tokenizer = AutoTokenizer.from_pretrained(name)
31
+ self._tokenizer_cache[name] = tokenizer
32
+ return tokenizer
33
+
34
  @torch.inference_mode()
35
  def __call__(self, model_name, image):
36
  if image is None:
37
  return '', []
38
  model = self._get_model(model_name)
39
+ tokenizer = self._get_tokenizer(model_name)
40
+
41
  image = self._preprocess(image.convert('RGB')).unsqueeze(0)
42
  # Greedy decoding
43
  pred = model(image).softmax(-1)
44
+
45
+ # Tokenize input data
46
+ label = tokenizer.decode(pred.argmax(-1)[0].tolist(), skip_special_tokens=True)
47
+ raw_label, raw_confidence = tokenizer.decode(pred.argmax(-1)[0].tolist(), raw=True)
48
+
49
  # Format confidence values
50
+ max_len = 25 if model_name == 'crnn' else len(label) + 1
51
+ conf = list(map('{:0.1f}'.format, pred[0, :, :max_len].tolist()))
52
+
53
+ return label, [raw_label[:max_len], conf]
54
 
55
  def main():
56
  app = App()
 
63
  read_upload = gr.Button('Read Text')
64
 
65
  output = gr.Textbox(max_lines=1, label='Model output')
 
66
  raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
67
 
68
  read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
69
+
70
  demo.queue(max_size=20)
71
  demo.launch()
72
 
 
73
  if __name__ == '__main__':
74
+ main()