Spaces:
Runtime error
Runtime error
Commit
·
d8bfc92
1
Parent(s):
f1f5e4d
reverted again
Browse files
app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
from torchvision import transforms as T
|
| 3 |
import gradio as gr
|
| 4 |
-
import base64
|
| 5 |
|
| 6 |
class App:
|
| 7 |
|
|
@@ -24,26 +23,18 @@ class App:
|
|
| 24 |
return model
|
| 25 |
|
| 26 |
@torch.inference_mode()
|
| 27 |
-
def __call__(self, model_name,
|
| 28 |
-
if
|
| 29 |
return '', []
|
| 30 |
-
|
| 31 |
-
# Decode base64 image blob and convert to PIL Image
|
| 32 |
-
image_data = base64.b64decode(image_blob)
|
| 33 |
-
image = Image.open(BytesIO(image_data)).convert('RGB')
|
| 34 |
-
|
| 35 |
model = self._get_model(model_name)
|
| 36 |
-
image = self._preprocess(image).unsqueeze(0)
|
| 37 |
-
|
| 38 |
# Greedy decoding
|
| 39 |
pred = model(image).softmax(-1)
|
| 40 |
label, _ = model.tokenizer.decode(pred)
|
| 41 |
raw_label, raw_confidence = model.tokenizer.decode(pred, raw=True)
|
| 42 |
-
|
| 43 |
# Format confidence values
|
| 44 |
max_len = 25 if model_name == 'crnn' else len(label[0]) + 1
|
| 45 |
conf = list(map('{:0.1f}'.format, raw_confidence[0][:max_len].tolist()))
|
| 46 |
-
|
| 47 |
return label[0], [raw_label[0][:max_len], conf]
|
| 48 |
|
| 49 |
|
|
|
|
| 1 |
import torch
|
| 2 |
from torchvision import transforms as T
|
| 3 |
import gradio as gr
|
|
|
|
| 4 |
|
| 5 |
class App:
|
| 6 |
|
|
|
|
| 23 |
return model
|
| 24 |
|
| 25 |
@torch.inference_mode()
|
| 26 |
+
def __call__(self, model_name, image):
|
| 27 |
+
if image is None:
|
| 28 |
return '', []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
model = self._get_model(model_name)
|
| 30 |
+
image = self._preprocess(image.convert('RGB')).unsqueeze(0)
|
|
|
|
| 31 |
# Greedy decoding
|
| 32 |
pred = model(image).softmax(-1)
|
| 33 |
label, _ = model.tokenizer.decode(pred)
|
| 34 |
raw_label, raw_confidence = model.tokenizer.decode(pred, raw=True)
|
|
|
|
| 35 |
# Format confidence values
|
| 36 |
max_len = 25 if model_name == 'crnn' else len(label[0]) + 1
|
| 37 |
conf = list(map('{:0.1f}'.format, raw_confidence[0][:max_len].tolist()))
|
|
|
|
| 38 |
return label[0], [raw_label[0][:max_len], conf]
|
| 39 |
|
| 40 |
|