Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| def load_and_preprocess_mnist(): | |
| (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() | |
| x_train = x_train.astype('float32') / 255.0 | |
| x_test = x_test.astype('float32') / 255.0 | |
| x_train = x_train.reshape((-1, 28, 28, 1)) | |
| x_test = x_test.reshape((-1, 28, 28, 1)) | |
| y_train = keras.utils.to_categorical(y_train, 10) | |
| y_test = keras.utils.to_categorical(y_test, 10) | |
| return (x_train, y_train), (x_test, y_test) | |
| def create_mnist_model(): | |
| model = keras.Sequential([ | |
| keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)), | |
| keras.layers.MaxPooling2D(pool_size=(2, 2)), | |
| keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'), | |
| keras.layers.MaxPooling2D(pool_size=(2, 2)), | |
| keras.layers.Flatten(), | |
| keras.layers.Dropout(0.5), | |
| keras.layers.Dense(64, activation='relu'), | |
| keras.layers.Dense(10, activation='softmax') | |
| ]) | |
| model.compile(optimizer='adam', | |
| loss='categorical_crossentropy', | |
| metrics=['accuracy']) | |
| return model | |
| def train_model(model, x_train, y_train, epochs, batch_size): | |
| history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1) | |
| return history | |
| def plot_training_history(history): | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) | |
| ax1.plot(history.history['accuracy'], label='Training Accuracy') | |
| ax1.plot(history.history['val_accuracy'], label='Validation Accuracy') | |
| ax1.set_title('Model Accuracy') | |
| ax1.set_xlabel('Epoch') | |
| ax1.set_ylabel('Accuracy') | |
| ax1.legend() | |
| ax2.plot(history.history['loss'], label='Training Loss') | |
| ax2.plot(history.history['val_loss'], label='Validation Loss') | |
| ax2.set_title('Model Loss') | |
| ax2.set_xlabel('Epoch') | |
| ax2.set_ylabel('Loss') | |
| ax2.legend() | |
| return fig | |
| def main(): | |
| st.title("MNIST Digit Classification with Keras and Streamlit") | |
| # Load and preprocess data | |
| (x_train, y_train), (x_test, y_test) = load_and_preprocess_mnist() | |
| # Create model | |
| if 'model' not in st.session_state: | |
| st.session_state.model = create_mnist_model() | |
| # Sidebar for training parameters | |
| st.sidebar.header("Training Parameters") | |
| epochs = st.sidebar.slider("Number of Epochs", min_value=1, max_value=50, value=10) | |
| batch_size = st.sidebar.selectbox("Batch Size", options=[32, 64, 128, 256], index=2) | |
| # Train model button | |
| if st.sidebar.button("Train Model"): | |
| with st.spinner("Training in progress..."): | |
| history = train_model(st.session_state.model, x_train, y_train, epochs, batch_size) | |
| st.success("Training completed!") | |
| # Plot training history | |
| st.subheader("Training History") | |
| fig = plot_training_history(history) | |
| st.pyplot(fig) | |
| # Evaluate model | |
| test_loss, test_acc = st.session_state.model.evaluate(x_test, y_test) | |
| st.write(f"Test accuracy: {test_acc:.4f}") | |
| # Set a flag to indicate the model has been trained | |
| st.session_state.model_trained = True | |
| # Test on random image | |
| st.subheader("Test on Random Image") | |
| if st.button("Select Random Image"): | |
| if not hasattr(st.session_state, 'model_trained'): | |
| st.warning("Please train the model first before testing.") | |
| else: | |
| # Select a random image from the test set | |
| idx = np.random.randint(0, x_test.shape[0]) | |
| image = x_test[idx] | |
| true_label = np.argmax(y_test[idx]) | |
| # Make prediction | |
| prediction = st.session_state.model.predict(image[np.newaxis, ...])[0] | |
| predicted_label = np.argmax(prediction) | |
| # Display image and prediction | |
| fig, ax = plt.subplots() | |
| ax.imshow(image.reshape(28, 28), cmap='gray') | |
| ax.axis('off') | |
| st.pyplot(fig) | |
| st.write(f"True Label: {true_label}") | |
| st.write(f"Predicted Label: {predicted_label}") | |
| st.write(f"Confidence: {prediction[predicted_label]:.4f}") | |
| if __name__ == "__main__": | |
| main() |