Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,70 +1,47 @@
|
|
1 |
import streamlit as st
|
2 |
-
import cv2
|
3 |
-
import numpy as np
|
4 |
-
from keras.models import load_model
|
5 |
from streamlit_drawable_canvas import st_canvas
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
stroke_width = st.sidebar.slider("Stroke width:", 10, 25, 20)
|
10 |
-
realtime_update = st.sidebar.checkbox("Update in realtime", True)
|
11 |
|
12 |
-
# Load your
|
13 |
@st.cache_resource
|
14 |
def load_mnist_model():
|
15 |
-
return load_model("
|
16 |
|
17 |
model = load_mnist_model()
|
18 |
|
19 |
-
# Create
|
20 |
canvas_result = st_canvas(
|
21 |
-
fill_color="
|
22 |
-
stroke_width=
|
23 |
-
stroke_color="#
|
24 |
-
background_color="#
|
25 |
-
|
26 |
-
height=
|
27 |
-
|
28 |
-
|
29 |
-
key="canvas",
|
30 |
)
|
31 |
|
32 |
-
def preprocess(image):
|
33 |
-
# Convert RGBA to grayscale
|
34 |
-
gray = cv2.cvtColor(image.astype("uint8"), cv2.COLOR_RGBA2GRAY)
|
35 |
-
|
36 |
-
# Invert so the digit is white on black (like MNIST)
|
37 |
-
gray = 255 - gray
|
38 |
-
|
39 |
-
# Apply binary threshold
|
40 |
-
_, thresh = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY)
|
41 |
-
|
42 |
-
# Find bounding box and crop
|
43 |
-
if np.sum(thresh) == 0:
|
44 |
-
return None
|
45 |
-
coords = cv2.findNonZero(thresh)
|
46 |
-
x, y, w, h = cv2.boundingRect(coords)
|
47 |
-
cropped = thresh[y:y+h, x:x+w]
|
48 |
-
|
49 |
-
# Resize to 20x20
|
50 |
-
resized = cv2.resize(cropped, (20, 20), interpolation=cv2.INTER_AREA)
|
51 |
-
|
52 |
-
# Pad to 28x28
|
53 |
-
padded = np.pad(resized, ((4,4),(4,4)), mode='constant', constant_values=0)
|
54 |
-
|
55 |
-
# Normalize and reshape for model
|
56 |
-
norm = padded / 255.0
|
57 |
-
return norm.reshape(1, 28, 28, 1)
|
58 |
-
|
59 |
-
# Handle prediction
|
60 |
if canvas_result.image_data is not None:
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
2 |
from streamlit_drawable_canvas import st_canvas
|
3 |
+
import numpy as np
|
4 |
+
from tensorflow.keras.models import load_model
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
|
8 |
+
st.set_page_config(page_title="MNIST Digit Recognizer", layout="centered")
|
9 |
+
st.title("🖌️ Draw a digit (0-9)")
|
|
|
|
|
10 |
|
11 |
+
# Load pre-trained model (you can upload your own model to the space)
|
12 |
@st.cache_resource
|
13 |
def load_mnist_model():
|
14 |
+
return load_model("digit_recog.keras") # You must upload this file to your Space
|
15 |
|
16 |
model = load_mnist_model()
|
17 |
|
18 |
+
# Create canvas component
|
19 |
canvas_result = st_canvas(
|
20 |
+
fill_color="#000000", # Black background
|
21 |
+
stroke_width=10,
|
22 |
+
stroke_color="#FFFFFF", # White digit
|
23 |
+
background_color="#000000",
|
24 |
+
width=200,
|
25 |
+
height=200,
|
26 |
+
drawing_mode="freedraw",
|
27 |
+
key="canvas"
|
|
|
28 |
)
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
if canvas_result.image_data is not None:
|
31 |
+
# Preprocess the image for prediction
|
32 |
+
img = canvas_result.image_data[:, :, 0] # Get only one channel
|
33 |
+
img = cv2.resize(img, (28, 28)) # Resize to 28x28
|
34 |
+
img = img.astype("float32") / 255.0
|
35 |
+
img = np.expand_dims(img, axis=0)
|
36 |
+
img = np.expand_dims(img, axis=-1)
|
37 |
+
|
38 |
+
st.subheader("🧠 Model Prediction")
|
39 |
+
pred = model.predict(img)[0]
|
40 |
+
predicted_class = np.argmax(pred)
|
41 |
+
|
42 |
+
st.write(f"**Predicted Digit:** `{predicted_class}`")
|
43 |
+
st.bar_chart(pred)
|
44 |
+
else:
|
45 |
+
st.info("Draw a digit above to see the prediction.")
|
46 |
+
|
47 |
+
st.caption("Made with Streamlit ✨")
|