zoya23 commited on
Commit
07277d7
·
verified ·
1 Parent(s): 6b158f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -57
app.py CHANGED
@@ -1,70 +1,47 @@
1
  import streamlit as st
2
- import cv2
3
- import numpy as np
4
- from keras.models import load_model
5
  from streamlit_drawable_canvas import st_canvas
 
 
 
 
6
 
7
- # Sidebar options
8
- drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw",))
9
- stroke_width = st.sidebar.slider("Stroke width:", 10, 25, 20)
10
- realtime_update = st.sidebar.checkbox("Update in realtime", True)
11
 
12
- # Load your trained MNIST model
13
  @st.cache_resource
14
  def load_mnist_model():
15
- return load_model("mnist.keras")
16
 
17
  model = load_mnist_model()
18
 
19
- # Create the canvas
20
  canvas_result = st_canvas(
21
- fill_color="rgba(0,0,0,0)", # Transparent fill
22
- stroke_width=stroke_width,
23
- stroke_color="#000000", # Black drawing
24
- background_color="#FFFFFF", # White background
25
- update_streamlit=realtime_update,
26
- height=280,
27
- width=280,
28
- drawing_mode=drawing_mode,
29
- key="canvas",
30
  )
31
 
32
- def preprocess(image):
33
- # Convert RGBA to grayscale
34
- gray = cv2.cvtColor(image.astype("uint8"), cv2.COLOR_RGBA2GRAY)
35
-
36
- # Invert so the digit is white on black (like MNIST)
37
- gray = 255 - gray
38
-
39
- # Apply binary threshold
40
- _, thresh = cv2.threshold(gray, 30, 255, cv2.THRESH_BINARY)
41
-
42
- # Find bounding box and crop
43
- if np.sum(thresh) == 0:
44
- return None
45
- coords = cv2.findNonZero(thresh)
46
- x, y, w, h = cv2.boundingRect(coords)
47
- cropped = thresh[y:y+h, x:x+w]
48
-
49
- # Resize to 20x20
50
- resized = cv2.resize(cropped, (20, 20), interpolation=cv2.INTER_AREA)
51
-
52
- # Pad to 28x28
53
- padded = np.pad(resized, ((4,4),(4,4)), mode='constant', constant_values=0)
54
-
55
- # Normalize and reshape for model
56
- norm = padded / 255.0
57
- return norm.reshape(1, 28, 28, 1)
58
-
59
- # Handle prediction
60
  if canvas_result.image_data is not None:
61
- st.image(canvas_result.image_data, caption="Your Drawing")
62
-
63
- processed = preprocess(canvas_result.image_data)
64
-
65
- if processed is not None:
66
- st.image(processed.reshape(28, 28), caption="Processed Image (28x28)")
67
- prediction = model.predict(processed)
68
- st.subheader(f"Prediction: {np.argmax(prediction)}")
69
- else:
70
- st.warning("Please draw a digit first.")
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
 
 
2
  from streamlit_drawable_canvas import st_canvas
3
+ import numpy as np
4
+ from tensorflow.keras.models import load_model
5
+ from PIL import Image
6
+ import cv2
7
 
8
+ st.set_page_config(page_title="MNIST Digit Recognizer", layout="centered")
9
+ st.title("🖌️ Draw a digit (0-9)")
 
 
10
 
11
+ # Load pre-trained model (you can upload your own model to the space)
12
  @st.cache_resource
13
  def load_mnist_model():
14
+ return load_model("digit_recog.keras") # You must upload this file to your Space
15
 
16
  model = load_mnist_model()
17
 
18
+ # Create canvas component
19
  canvas_result = st_canvas(
20
+ fill_color="#000000", # Black background
21
+ stroke_width=10,
22
+ stroke_color="#FFFFFF", # White digit
23
+ background_color="#000000",
24
+ width=200,
25
+ height=200,
26
+ drawing_mode="freedraw",
27
+ key="canvas"
 
28
  )
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  if canvas_result.image_data is not None:
31
+ # Preprocess the image for prediction
32
+ img = canvas_result.image_data[:, :, 0] # Get only one channel
33
+ img = cv2.resize(img, (28, 28)) # Resize to 28x28
34
+ img = img.astype("float32") / 255.0
35
+ img = np.expand_dims(img, axis=0)
36
+ img = np.expand_dims(img, axis=-1)
37
+
38
+ st.subheader("🧠 Model Prediction")
39
+ pred = model.predict(img)[0]
40
+ predicted_class = np.argmax(pred)
41
+
42
+ st.write(f"**Predicted Digit:** `{predicted_class}`")
43
+ st.bar_chart(pred)
44
+ else:
45
+ st.info("Draw a digit above to see the prediction.")
46
+
47
+ st.caption("Made with Streamlit ✨")