zoya23 commited on
Commit
b9deab1
·
verified ·
1 Parent(s): 6d29c98

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -28
app.py CHANGED
@@ -3,25 +3,25 @@ import cv2
3
  from streamlit_drawable_canvas import st_canvas
4
  from keras.models import load_model
5
  import numpy as np
 
6
 
7
- # Sidebar configs
8
- drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw", "line", "rect", "circle", "transform"))
9
- stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 15)
10
- stroke_color = st.sidebar.color_picker("Stroke color hex: ", "#000000") # black
11
- bg_color = st.sidebar.color_picker("Background color hex: ", "#FFFFFF") # white
12
- bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
13
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
14
 
15
- # Load trained MNIST model
16
  @st.cache_resource
17
  def load_mnist_model():
18
  return load_model("mnist.keras")
19
 
20
  model = load_mnist_model()
21
 
22
- # Canvas for user drawing
23
  canvas_result = st_canvas(
24
- fill_color="rgba(255, 165, 0, 0.3)", # translucent fill
25
  stroke_width=stroke_width,
26
  stroke_color=stroke_color,
27
  background_color=bg_color,
@@ -32,31 +32,52 @@ canvas_result = st_canvas(
32
  key="canvas",
33
  )
34
 
35
- # Prediction block
36
- if canvas_result.image_data is not None:
37
- st.image(canvas_result.image_data, caption="Original Drawing")
 
 
 
38
 
39
- # Convert RGBA to grayscale
40
- img = cv2.cvtColor(canvas_result.image_data.astype("uint8"), cv2.COLOR_RGBA2GRAY)
41
 
42
- # Invert colors: black background with white digit
43
- img = 255 - img
 
 
 
 
44
 
45
- # Resize to 28x28
46
- img_resized = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
 
 
 
 
47
 
48
- # Apply binary threshold to remove noise
49
- _, img_thresh = cv2.threshold(img_resized, 100, 255, cv2.THRESH_BINARY)
 
 
 
 
 
50
 
51
  # Normalize
52
- img_normalized = img_thresh / 255.0
53
 
54
- # Reshape to match model input
55
- final_img = img_normalized.reshape(1, 28, 28, 1)
 
 
 
56
 
57
- # Show processed image
58
- st.image(img_thresh, caption="Preprocessed (28x28 Thresholded)")
59
 
60
- # Predict and display result
61
- prediction = model.predict(final_img)
62
- st.write("Prediction:", np.argmax(prediction))
 
 
 
 
3
  from streamlit_drawable_canvas import st_canvas
4
  from keras.models import load_model
5
  import numpy as np
6
+ from scipy.ndimage import interpolation
7
 
8
+ # Sidebar: Canvas controls
9
+ drawing_mode = st.sidebar.selectbox("Drawing tool:", ("freedraw",))
10
+ stroke_width = st.sidebar.slider("Stroke width: ", 10, 25, 20)
11
+ stroke_color = "#000000" # Black
12
+ bg_color = "#FFFFFF" # White
 
13
  realtime_update = st.sidebar.checkbox("Update in realtime", True)
14
 
15
+ # Load MNIST model
16
  @st.cache_resource
17
  def load_mnist_model():
18
  return load_model("mnist.keras")
19
 
20
  model = load_mnist_model()
21
 
22
+ # Streamlit drawable canvas
23
  canvas_result = st_canvas(
24
+ fill_color="rgba(0,0,0,0)", # Transparent fill
25
  stroke_width=stroke_width,
26
  stroke_color=stroke_color,
27
  background_color=bg_color,
 
32
  key="canvas",
33
  )
34
 
35
+ def preprocess(img):
36
+ # Convert to grayscale
37
+ gray = cv2.cvtColor(img.astype("uint8"), cv2.COLOR_RGBA2GRAY)
38
+
39
+ # Invert (black bg, white digit)
40
+ gray = 255 - gray
41
 
42
+ # Apply threshold
43
+ _, thresh = cv2.threshold(gray, 50, 255, cv2.THRESH_BINARY)
44
 
45
+ # Crop the digit (remove empty rows/cols)
46
+ if np.sum(thresh) == 0:
47
+ return None # blank canvas
48
+ coords = cv2.findNonZero(thresh)
49
+ x, y, w, h = cv2.boundingRect(coords)
50
+ cropped = thresh[y:y+h, x:x+w]
51
 
52
+ # Resize keeping aspect ratio
53
+ h, w = cropped.shape
54
+ if h > w:
55
+ resized = cv2.resize(cropped, (int(20 * w / h), 20))
56
+ else:
57
+ resized = cv2.resize(cropped, (20, int(20 * h / w)))
58
 
59
+ # Add padding to get 28x28
60
+ h, w = resized.shape
61
+ top = (28 - h) // 2
62
+ bottom = 28 - h - top
63
+ left = (28 - w) // 2
64
+ right = 28 - w - left
65
+ padded = cv2.copyMakeBorder(resized, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
66
 
67
  # Normalize
68
+ norm = padded / 255.0
69
 
70
+ # Final reshape
71
+ return norm.reshape(1, 28, 28, 1)
72
+
73
+ if canvas_result.image_data is not None:
74
+ st.image(canvas_result.image_data, caption="Original Drawing")
75
 
76
+ processed = preprocess(canvas_result.image_data)
 
77
 
78
+ if processed is not None:
79
+ st.image(processed.reshape(28, 28), caption="Processed Input (28x28)")
80
+ pred = model.predict(processed)
81
+ st.subheader(f"Prediction: {np.argmax(pred)}")
82
+ else:
83
+ st.warning("Please draw a digit!")