SimpleNN / app.py
ricardo-lsantos's picture
Fixed minor bugs. Fixed the Ui parameters to be more easy to train.
b7f5a9c
raw
history blame
2.77 kB
import streamlit as st
from nn import NeuralNetwork
import json
from utils import sigmoid, sigmoid_prime
INPUTS = [[0,0],[0,1],[1,0],[1,1]]
OUTPUTS = [[0],[1],[1],[0]]
def resetSession():
st.session_state.nn = None
st.session_state.train_count = 0
## Controller Function
def runNN():
nn = st.session_state.nn
df = {
"input": [],
"expected": [],
"predicted": [],
"rounded": [],
"correct": []
}
for i in range(4):
result = nn.predict(INPUTS[i][0],INPUTS[i][1], activation=sigmoid)
df["input"].append(f"{INPUTS[i][0]} xor {INPUTS[i][1]}")
df["expected"].append(OUTPUTS[i][0])
df["predicted"].append(result)
df["rounded"].append(round(result))
df["correct"].append('correct' if round(result)==OUTPUTS[i][0] else 'incorrect')
st.dataframe(df)
# st.write(f"for input `{INPUTS[i][0]} xor {INPUTS[i][1]}` expected `{OUTPUTS[i][0]}` predicted `{result}` which rounds to `{round(result)}` and is `{ 'correct' if round(result)==OUTPUTS[i][0] else 'incorrect' }`")
def sidebar():
# Neural network controls
st.sidebar.header('Neural Network Controls')
st.sidebar.text('Number of epochs')
epochs = st.sidebar.slider('Epochs', 1, 10000, 500)
st.sidebar.text('Learning rate')
alphas = st.sidebar.slider('Alphas', 1, 100, 20)
col1, col2 = st.sidebar.columns(2)
if col1.button('New Model'):
btnNewModel()
if col2.button('Reset Model'):
resetSession()
if "nn" in st.session_state and st.session_state.nn is not None:
if st.sidebar.button('Train Model'):
btnTrainModel(epochs, alphas)
if st.sidebar.button('Run Neural Network'):
btnRunModel()
st.sidebar.download_button(label="Save Model", data=json.dumps(st.session_state.nn.getModelJson()), file_name="model.json", mime="application/json")
def btnNewModel():
resetSession()
st.session_state.nn = NeuralNetwork()
st.sidebar.text("New model created")
def btnTrainModel(epochs, alphas):
st.session_state.nn.train(inputs=INPUTS, outputs=OUTPUTS, epochs=epochs, alpha=alphas)
st.session_state.train_count += 1
st.sidebar.text(f"Model trained {st.session_state.train_count} times")
def btnRunModel():
runNN()
def btnResetModel():
resetSession()
st.sidebar.text("Model reset")
def app():
# initSession()
st.title('Simple Neural Network App')
st.write('This is the Neural Network image we are trying to implement!')
st.image('nn.png', width=500)
sidebar()
st.markdown('''
### References
* https://www.codingame.com/playgrounds/59631/neural-network-xor-example-from-scratch-no-libs
''')
if __name__ == '__main__':
app()