cisemh commited on
Commit
a12fda9
·
verified ·
1 Parent(s): d1c466d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -77
app.py CHANGED
@@ -1,83 +1,25 @@
1
- import numpy as np
 
 
2
  import tensorflow as tf
3
- from tensorflow import keras
 
4
  import gradio as gr
5
 
6
- # Load and preprocess the MNIST dataset
7
- def load_data():
8
- """Load and preprocess the MNIST dataset."""
9
- (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
10
- X_train = X_train.astype("float32") / 255
11
- X_test = X_test.astype("float32") / 255
12
- X_train = X_train.reshape(-1, 28, 28, 1)
13
- X_test = X_test.reshape(-1, 28, 28, 1)
14
- return (X_train, y_train), (X_test, y_test)
15
-
16
- # Build the CNN model
17
- def build_model(input_shape, num_classes):
18
- """Build the CNN model."""
19
- inputs = keras.layers.Input(input_shape)
20
- x = keras.layers.Conv2D(28, kernel_size=(3, 3), activation='relu')(inputs)
21
- x = keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
22
- x = keras.layers.Flatten()(x)
23
- x = keras.layers.Dense(128, activation='relu')(x)
24
- outputs = keras.layers.Dense(num_classes, activation='softmax')(x)
25
- return keras.models.Model(inputs=inputs, outputs=outputs)
26
-
27
- # Preprocess input for prediction
28
- def preprocess_image(image):
29
- """Resize and normalize the input image for prediction."""
30
- image = np.array(image.convert('L')) # Convert to grayscale
31
- image = image.astype("float32") / 255 # Normalize
32
- image = image.reshape(1, 28, 28, 1) # Reshape to model's input
33
- return image
34
-
35
- # Predict digit
36
- def predict_digit(image):
37
- """Predict the digit in the uploaded image."""
38
- processed_image = preprocess_image(image)
39
- prediction = model.predict(processed_image)
40
- class_id = np.argmax(prediction)
41
- confidence = prediction[0][class_id]
42
- label = classes_names[class_id]
43
- results = {name: float(prediction[0][i]) for i, name in enumerate(classes_names)}
44
- return label, results
45
-
46
- if __name__ == "__main__":
47
- # Parameters
48
- classes_names = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
49
- input_shape = (28, 28, 1)
50
- num_classes = len(classes_names)
51
-
52
- # Load data
53
- (X_train, y_train), (X_test, y_test) = load_data()
54
-
55
- # Build and train model
56
- model = build_model(input_shape, num_classes)
57
- model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
58
- print("Training model...")
59
- model.fit(X_train, y_train, epochs=3, batch_size=64) # Quick training for demonstration
60
 
61
- # Gradio Interface
62
- title = "Welcome to Your First Sketch Recognition App!"
63
- description = (
64
- "The robot was trained to classify numbers (from 0 to 9). To test it, draw your number in the space provided."
65
- )
66
- examples = [["example_image.png"]] # You can add example images here.
 
67
 
68
- interface = gr.Interface(
69
- fn=predict_digit,
70
- inputs=gr.inputs.Image(shape=(28, 28), image_mode="L", invert_colors=True, label="Draw a Digit"),
71
- outputs=[
72
- gr.outputs.Textbox(label="Predicted Digit"),
73
- gr.outputs.Label(num_top_classes=10, label="Prediction Confidence"),
74
- ],
75
- title=title,
76
- description=description,
77
- examples=examples,
78
- live=True,
79
- )
80
 
81
- # Launch Gradio interface
82
- print("Launching Gradio interface...")
83
- interface.launch()
 
1
+ import os
2
+ os.system("pip uninstall -y gradio")
3
+ os.system("pip install gradio==3.50.2")
4
  import tensorflow as tf
5
+ from matplotlib import pyplot as plt
6
+ import numpy as np
7
  import gradio as gr
8
 
9
+ # Load the model
10
+ model = tf.keras.models.load_model('number_recognition_model_colab.keras')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ def recognize_digit(image):
13
+ if image is not None:
14
+ image = image.reshape((1,28,28,1)).astype('float32')/255
15
+ prediction = model.predict(image)
16
+ return {str(i) : float(prediction[0][i]) for i in range(10)}
17
+ else:
18
+ return ''
19
 
20
+ iface = gr.Interface(
21
+ fn = recognize_digit,
22
+ inputs=gr.Image(shape=(28,28),image_mode = 'L',invert_colors=True, source = 'canvas'),
23
+ outputs=gr.Label(top_num_classes=3))
 
 
 
 
 
 
 
 
24
 
25
+ iface.launch()