Spaces:
Runtime error
Runtime error
import os | |
import random | |
import numpy as np | |
# disable tensorflow warnings | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
import tensorflow as tf | |
from tensorflow import keras | |
from keras.datasets import mnist | |
# Set the random seed for reproducibility, remember these lines :) | |
SEED = 42 | |
random.seed(SEED) | |
np.random.seed(SEED) | |
tf.random.set_seed(SEED) | |
# Load the dataset from keras.datasets (so noone would need to download it manually from any sources) | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
# Preprocess the dataset | |
x_train = x_train.astype('float32') / 255.0 | |
x_test = x_test.astype('float32') / 255.0 | |
# Define the model architecture | |
model = keras.Sequential([ | |
keras.layers.Flatten(input_shape=(28, 28)), | |
keras.layers.Dense(128, activation='relu'), | |
keras.layers.Dense(10, activation='softmax') | |
]) | |
# Compile and train the model | |
# target in one-hot categorical_crossentropy -> [0,0,1,0,0,0,0,0,0] | |
# target can be as integer sparse_categorical_crossentropy -> 3 | |
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) | |
# 4-epoch is overfitting, 3-rd is okay | |
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=4, shuffle=True, batch_size=32) | |
# Evaluate the model | |
print('\n') | |
_, test_accuracy = model.evaluate(x_test, y_test) | |
print('Test accuracy:', test_accuracy) | |
# Save the model | |
model.save('artifacts/models/mnist_model.h5') | |