File size: 1,939 Bytes
051f92c
 
 
a49ba8d
d4b4b25
051f92c
a49ba8d
e1b26df
051f92c
a49ba8d
e1b26df
 
051f92c
 
 
 
 
d4b4b25
 
e1b26df
 
 
051f92c
d4b4b25
 
051f92c
e1b26df
d4b4b25
051f92c
 
 
 
 
 
448af73
e1b26df
051f92c
 
 
c40b85e
051f92c
 
 
c40b85e
051f92c
 
 
 
 
cd17133
051f92c
 
cd17133
051f92c
 
 
 
 
 
 
 
 
 
c40b85e
051f92c
 
 
 
 
c40b85e
 
051f92c
331746a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import numpy as np
import cv2
import gradio as gr
import tensorflow as tf
from PIL import Image

# app title
title = "Welcome on your first sketch recognition app!"

# app description
head = (
  "<center>"
  "<img src='./mnist-classes.png' width=400>"
  "<p>The robot was trained to classify numbers (0 to 9). "
  "To test it, write your number in the space provided!</p>"
  "</center>"
)

# GitHub repository link
ref = "Find the whole code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)."

# Image size
img_size = 28

# Classes
labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]

# Load model
model_path = "./sketch_recognition_numbers_model.h5"
try:
    model = tf.keras.models.load_model(model_path)
except Exception as e:
    raise FileNotFoundError(f"Model file '{model_path}' not found or failed to load. {str(e)}")

def predict(img):
    # If no image is provided, return an error message
    if img is None:
        return {"error": "No image provided."}

    # Ensure the image is a PIL Image
    if not isinstance(img, Image.Image):
        img = Image.fromarray(np.uint8(img))

    # Convert to grayscale
    img = img.convert("L")
    
    # Convert PIL Image to a NumPy array of type uint8
    img = np.array(img, dtype=np.uint8)

    # Resize to (28x28)
    img = cv2.resize(img, (img_size, img_size))

    # Reshape to match model input shape (1, 28, 28, 1)
    img = img.reshape(1, img_size, img_size, 1)

    # Model predictions
    preds = model.predict(img)[0]

    # Return probabilities for each class
    return {label: float(pred) for label, pred in zip(labels, preds)}

# Use gr.Sketchpad to ensure a PIL image is returned
interface = gr.Interface(
    fn=predict,
    inputs=gr.Sketchpad(type="pil"),
    outputs=gr.Label(num_top_classes=3),
    title=title,
    description=head,
    article=ref
)

interface.launch()