cisemh commited on
Commit
acf19ac
·
verified ·
1 Parent(s): 4362dde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -83
app.py CHANGED
@@ -15,88 +15,14 @@ head = (
15
  # Model yükleniyor
16
  model = tf.keras.models.load_model("number_recognition_model_colab.keras")
17
 
 
 
 
18
 
19
- img_size = 28
 
 
 
 
 
20
 
21
- labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
22
-
23
- def predict(img):
24
- try:
25
- # Enhanced image validation and conversion
26
- if img is None:
27
- raise ValueError("No image provided")
28
-
29
- # Convert to numpy array if it's a PIL Image
30
- if isinstance(img, Image.Image):
31
- img = np.array(img)
32
-
33
- # Handle base64 image strings
34
- elif isinstance(img, str):
35
- # Check if it's a base64 data URL
36
- if img.startswith('data:image'):
37
- # Split and decode base64 part
38
- img = img.split(',')[1]
39
-
40
- # Decode base64 to image
41
- try:
42
- img = Image.open(io.BytesIO(base64.b64decode(img)))
43
- img = np.array(img)
44
- except Exception as e:
45
- print(f"Base64 decoding error: {e}")
46
- raise ValueError("Invalid base64 image")
47
-
48
- # Validate numpy array
49
- if not isinstance(img, np.ndarray):
50
- raise ValueError("Input could not be converted to a valid image")
51
-
52
- # Print initial image details for debugging
53
- print(f"Initial image type: {type(img)}, shape: {img.shape}")
54
-
55
- # Handle color channels
56
- if img.ndim == 3:
57
- if img.shape[-1] == 3: # Color image
58
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
59
- elif img.shape[-1] == 4: # RGBA image
60
- img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
61
-
62
- # Ensure single channel
63
- if img.ndim == 2:
64
- img = np.expand_dims(img, axis=-1)
65
-
66
- # Resize and normalize
67
- img = cv2.resize(img, (img_size, img_size))
68
- img = img.astype('float32') / 255.0
69
- img = img.reshape(1, img_size, img_size, 1)
70
-
71
- # Print processed image details
72
- print(f"Processed image shape: {img.shape}")
73
-
74
- # Get predictions from the model
75
- preds = model.predict(img)[0]
76
-
77
- # Print predictions for debugging
78
- print("Predictions:", preds)
79
-
80
- # Return predictions as a dictionary
81
- return {label: float(pred) for label, pred in zip(labels, preds)}
82
-
83
- except Exception as e:
84
- # Comprehensive error logging
85
- print(f"Full error during prediction: {e}")
86
- return {"Error": str(e)}
87
-
88
-
89
-
90
- # Set up the Gradio interface with the input as a sketchpad and output as labels
91
- label = gr.Label(num_top_classes=3)
92
-
93
- # Gradio arayüzü
94
- interface = gr.Interface(
95
- fn=predict,
96
- inputs=gr.Sketchpad(type="pil"),
97
- outputs=label,
98
- title="Sketch Recognition App",
99
- description="Draw a number (0-9) and see the model's top predictions."
100
- )
101
-
102
- interface.launch(debug=True, share=True)
 
15
  # Model yükleniyor
16
  model = tf.keras.models.load_model("number_recognition_model_colab.keras")
17
 
18
+ def recognize_digit(image):
19
+ prediction = model.predict(np.reshape(image, (1, 28, 28))).tolist()[0]
20
+ return {str(i): prediction[i] for i in range(10)}
21
 
22
+ sketchpad = gr.Sketchpad(shape=(28, 28))
23
+ gr.Interface(fn=recognize_digit,
24
+ inputs=sketchpad,
25
+ outputs="label",
26
+ title="Handwritten Digits Classifier",
27
+ description="This app uses lenet5 for handwritten digits classification").launch()
28