File size: 3,599 Bytes
89fb082
 
 
 
 
d34f45c
 
e0b823f
a51dc81
d34f45c
89fb082
f817463
db4a829
89fb082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d34f45c
3dffb6e
 
 
 
d25500f
2f7cbf4
d3d4e02
 
 
 
 
 
 
 
 
89fb082
 
 
0a2e4df
d01708c
89fb082
8664f05
 
89fb082
 
 
 
 
 
 
 
 
 
 
 
 
f2e2bcd
89fb082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

import base64
import streamlit as st
from PIL import Image
import numpy as np
from keras.models import model_from_json
import subprocess
import os
import tensorflow as tf


st.markdown('<h1 style="color:white;">Image Classification App</h1>', unsafe_allow_html=True)
st.markdown('<h2 style="color:white;">for classifying **zebras** and **horses**</h2>', unsafe_allow_html=True)

st.cache(allow_output_mutation=True)
def get_base64_of_bin_file(bin_file):
    with open(bin_file, 'rb') as f:
        data = f.read()
    return base64.b64encode(data).decode()

def set_png_as_page_bg(png_file):
    bin_str = get_base64_of_bin_file(png_file) 
    page_bg_img = '''
    <style>
    .stApp {
    background-image: url("data:image/png;base64,%s");
    background-size: cover;
    background-repeat: no-repeat;
    background-attachment: scroll; # doesn't work
    }
    </style>
    ''' % bin_str
    
    st.markdown(page_bg_img, unsafe_allow_html=True)
    return

set_png_as_page_bg('background.webp')
        

# def load_model():
#     # load json and create model
#     json_file = open('model.json', 'r')
#     loaded_model_json = json_file.read()
#     json_file.close()
#     CNN_class_index = model_from_json(loaded_model_json)
#     # load weights into new model
#     model = CNN_class_index.load_weights("model.h5")

#     #model= tf.keras.load_model('model.h5')
#     #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
#     return model, CNN_class_index

def load_model():
    if not os.path.isfile('model.h5'):
        subprocess.run(['curl --output model.h5 "https://github.com/KaburaJ/Binary-Image-classification/blob/main/ZebraHorse/CNN%20Application/model.h5"'], shell=True)
    
        tf.keras.models.load_model('model.h5', compile=False)
        return model
# def load_model():
#     # Load the model architecture
#     with open('model.json', 'r') as f:
#         model = model_from_json(f.read())

#     # Load the model weights
#     model.load_weights('model.h5')
#     #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
#     return model


def image_transformation(image):
    #image = Image._resize_dispatcher(image, new_shape=(256, 256))
    #image= np.resize((256,256))
    image = np.array(image)
    np.save('images.npy', image)
    image = np.load('images.npy', allow_pickle=True)

    return image


def image_prediction(image, model):
    image = image_transformation(image=image)
    outputs = model.predict(image)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return predicted_idx

def main():
    
    image_file  = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png'])

    if image_file:
       
        left_column, right_column = st.columns(2)
        left_column.image(image_file, caption="Uploaded image", use_column_width=True)
        image = Image.open(image_file)
        image = image_transformation(image=image)
        

        pred_button = st.button("Predict")
        
        model = load_model()
        # label = ['Zebra', 'Horse']
        # label = np.array(label).reshape(1, -1)
        # ohe= OneHotEncoder()
        # labels = ohe.fit_transform(label).toarray()

        if pred_button:
            image_prediction(image, model)
            outputs = model.predict(image)
            _, y_hat = outputs.max(1)
            predicted_idx = str(y_hat.item())
            right_column.title("Prediction")
            right_column.write(predicted_idx)
        

if __name__ == '__main__':
    main()