Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import tensorflow as tf | |
| from tensorflow.keras import layers, models, applications | |
| import tensorflow_datasets as tfds | |
| import matplotlib.pyplot as plt | |
| # Load the dataset | |
| dataset_name = "cats_vs_dogs" | |
| (ds_train, ds_val), ds_info = tfds.load(dataset_name, split=['train[:80%]', 'train[80%:]'], with_info=True, as_supervised=True) | |
| # Preprocess the dataset | |
| def preprocess_image(image, label): | |
| image = tf.image.resize(image, (150, 150)) | |
| image = image / 255.0 | |
| return image, label | |
| ds_train = ds_train.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE) | |
| ds_val = ds_val.map(preprocess_image).batch(32).prefetch(tf.data.AUTOTUNE) | |
| # Streamlit app | |
| st.title("Transfer Learning with VGG16 for Image Classification") | |
| # Input parameters | |
| batch_size = st.slider("Batch Size", 16, 128, 32, 16) | |
| epochs = st.slider("Epochs", 5, 50, 10, 5) | |
| # Load the pre-trained VGG16 model | |
| base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3)) | |
| # Freeze the convolutional base | |
| base_model.trainable = False | |
| # Add custom layers on top | |
| model = models.Sequential([ | |
| base_model, | |
| layers.Flatten(), | |
| layers.Dense(256, activation='relu'), | |
| layers.Dropout(0.5), | |
| layers.Dense(1, activation='sigmoid') # Change the output layer based on the number of classes | |
| ]) | |
| model.summary() | |
| # Compile the model | |
| model.compile(optimizer='adam', | |
| loss='binary_crossentropy', # Change loss function based on the number of classes | |
| metrics=['accuracy']) | |
| # Train the model | |
| if st.button("Train Model"): | |
| with st.spinner("Training the model..."): | |
| history = model.fit( | |
| ds_train, | |
| epochs=epochs, | |
| validation_data=ds_val | |
| ) | |
| st.success("Model training completed!") | |
| # Display training curves | |
| st.subheader("Training and Validation Accuracy") | |
| fig, ax = plt.subplots() | |
| ax.plot(history.history['accuracy'], label='Training Accuracy') | |
| ax.plot(history.history['val_accuracy'], label='Validation Accuracy') | |
| ax.set_xlabel('Epoch') | |
| ax.set_ylabel('Accuracy') | |
| ax.legend() | |
| st.pyplot(fig) | |
| st.subheader("Training and Validation Loss") | |
| fig, ax = plt.subplots() | |
| ax.plot(history.history['loss'], label='Training Loss') | |
| ax.plot(history.history['val_loss'], label='Validation Loss') | |
| ax.set_xlabel('Epoch') | |
| ax.set_ylabel('Loss') | |
| ax.legend() | |
| st.pyplot(fig) | |
| # Evaluate the model | |
| if st.button("Evaluate Model"): | |
| test_loss, test_acc = model.evaluate(ds_val, verbose=2) | |
| st.write(f"Validation accuracy: {test_acc}") | |