Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,12 +8,23 @@ TITLE = "Handwritten Digit Recognition Demo"
|
|
| 8 |
DESCRIPTION = "This demo employs a basic CNN architecture inspired by [MIT 6.S191’s Lab2 Part1](https://github.com/aamini/introtodeeplearning/blob/master/lab2/Part1_MNIST.ipynb). "\
|
| 9 |
"It achieves about 98% accuracy on the MNIST test dataset but may perform poorly, particularly with digits 8 and 9, likely due to suboptimal image preprocessing."
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
model = tf.keras.saving.load_model("tf_model_mnist")
|
| 12 |
|
| 13 |
|
| 14 |
-
def preprocess(image):
|
| 15 |
""" Normalize Gradio image to MNIST format """
|
| 16 |
-
image = image.resize((28, 28),
|
| 17 |
img_array = np.asarray(image, dtype=np.float32)
|
| 18 |
for i in range(img_array.shape[0]):
|
| 19 |
for j in range(img_array.shape[1]):
|
|
@@ -30,14 +41,19 @@ def preprocess(image):
|
|
| 30 |
return image_array, new_image
|
| 31 |
|
| 32 |
|
| 33 |
-
def predict(img):
|
| 34 |
img = img["composite"]
|
| 35 |
-
input_arr, new_image = preprocess(img)
|
| 36 |
print("input:", input_arr.shape)
|
| 37 |
predictions = model.predict(input_arr)
|
| 38 |
return {str(i): predictions[0][i] for i in range(10)}, new_image
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
input_image = gr.Sketchpad(
|
| 42 |
layers=False,
|
| 43 |
type="pil",
|
|
@@ -47,7 +63,7 @@ demo = gr.Interface(
|
|
| 47 |
predict,
|
| 48 |
title=TITLE,
|
| 49 |
description=DESCRIPTION,
|
| 50 |
-
inputs=input_image,
|
| 51 |
outputs=['label', 'image']
|
| 52 |
)
|
| 53 |
|
|
|
|
| 8 |
DESCRIPTION = "This demo employs a basic CNN architecture inspired by [MIT 6.S191’s Lab2 Part1](https://github.com/aamini/introtodeeplearning/blob/master/lab2/Part1_MNIST.ipynb). "\
|
| 9 |
"It achieves about 98% accuracy on the MNIST test dataset but may perform poorly, particularly with digits 8 and 9, likely due to suboptimal image preprocessing."
|
| 10 |
|
| 11 |
+
|
| 12 |
+
PIL_INTERPOLATION_METHODS = {
|
| 13 |
+
"nearest": Image.Resampling.NEAREST,
|
| 14 |
+
"bilinear": Image.Resampling.BILINEAR,
|
| 15 |
+
"bicubic": Image.Resampling.BICUBIC,
|
| 16 |
+
"hamming": Image.Resampling.HAMMING,
|
| 17 |
+
"box": Image.Resampling.BOX,
|
| 18 |
+
"lanczos": Image.Resampling.LANCZOS,
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
model = tf.keras.saving.load_model("tf_model_mnist")
|
| 23 |
|
| 24 |
|
| 25 |
+
def preprocess(image, resample_method):
|
| 26 |
""" Normalize Gradio image to MNIST format """
|
| 27 |
+
image = image.resize((28, 28), PIL_INTERPOLATION_METHODS[resample_method])
|
| 28 |
img_array = np.asarray(image, dtype=np.float32)
|
| 29 |
for i in range(img_array.shape[0]):
|
| 30 |
for j in range(img_array.shape[1]):
|
|
|
|
| 41 |
return image_array, new_image
|
| 42 |
|
| 43 |
|
| 44 |
+
def predict(img, resample_method):
|
| 45 |
img = img["composite"]
|
| 46 |
+
input_arr, new_image = preprocess(img, resample_method)
|
| 47 |
print("input:", input_arr.shape)
|
| 48 |
predictions = model.predict(input_arr)
|
| 49 |
return {str(i): predictions[0][i] for i in range(10)}, new_image
|
| 50 |
|
| 51 |
|
| 52 |
+
resample_method = gr.Dropdown(
|
| 53 |
+
choices=list(PIL_INTERPOLATION_METHODS.keys()),
|
| 54 |
+
value='bilinear',
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
input_image = gr.Sketchpad(
|
| 58 |
layers=False,
|
| 59 |
type="pil",
|
|
|
|
| 63 |
predict,
|
| 64 |
title=TITLE,
|
| 65 |
description=DESCRIPTION,
|
| 66 |
+
inputs=[input_image, resample_method],
|
| 67 |
outputs=['label', 'image']
|
| 68 |
)
|
| 69 |
|