Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tensorflow as tf | |
from tensorflow.keras import layers, models, applications | |
from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
import matplotlib.pyplot as plt | |
import numpy as np | |
# Set dataset paths | |
train_dir = 'data/train' | |
validation_dir = 'data/validation' | |
# Streamlit app | |
st.title("Transfer Learning with VGG16 for Image Classification") | |
# Input parameters | |
batch_size = st.slider("Batch Size", 16, 128, 32, 16) | |
epochs = st.slider("Epochs", 5, 50, 10, 5) | |
# Data augmentation and preprocessing | |
train_datagen = ImageDataGenerator( | |
rescale=1./255, | |
rotation_range=40, | |
width_shift_range=0.2, | |
height_shift_range=0.2, | |
shear_range=0.2, | |
zoom_range=0.2, | |
horizontal_flip=True, | |
fill_mode='nearest' | |
) | |
validation_datagen = ImageDataGenerator(rescale=1./255) | |
train_generator = train_datagen.flow_from_directory( | |
train_dir, | |
target_size=(150, 150), | |
batch_size=batch_size, | |
class_mode='binary' | |
) | |
validation_generator = validation_datagen.flow_from_directory( | |
validation_dir, | |
target_size=(150, 150), | |
batch_size=batch_size, | |
class_mode='binary' | |
) | |
# Load the pre-trained VGG16 model | |
base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3)) | |
# Freeze the convolutional base | |
base_model.trainable = False | |
# Add custom layers on top | |
model = models.Sequential([ | |
base_model, | |
layers.Flatten(), | |
layers.Dense(256, activation='relu'), | |
layers.Dropout(0.5), | |
layers.Dense(1, activation='sigmoid') # Change the output layer based on the number of classes | |
]) | |
model.summary() | |
# Compile the model | |
model.compile(optimizer='adam', | |
loss='binary_crossentropy', # Change loss function based on the number of classes | |
metrics=['accuracy']) | |
# Train the model | |
if st.button("Train Model"): | |
with st.spinner("Training the model..."): | |
history = model.fit( | |
train_generator, | |
steps_per_epoch=train_generator.samples // train_generator.batch_size, | |
epochs=epochs, | |
validation_data=validation_generator, | |
validation_steps=validation_generator.samples // validation_generator.batch_size | |
) | |
st.success("Model training completed!") | |
# Display training curves | |
st.subheader("Training and Validation Accuracy") | |
fig, ax = plt.subplots() | |
ax.plot(history.history['accuracy'], label='Training Accuracy') | |
ax.plot(history.history['val_accuracy'], label='Validation Accuracy') | |
ax.set_xlabel('Epoch') | |
ax.set_ylabel('Accuracy') | |
ax.legend() | |
st.pyplot(fig) | |
st.subheader("Training and Validation Loss") | |
fig, ax = plt.subplots() | |
ax.plot(history.history['loss'], label='Training Loss') | |
ax.plot(history.history['val_loss'], label='Validation Loss') | |
ax.set_xlabel('Epoch') | |
ax.set_ylabel('Loss') | |
ax.legend() | |
st.pyplot(fig) | |
# Evaluate the model | |
if st.button("Evaluate Model"): | |
test_loss, test_acc = model.evaluate(validation_generator, verbose=2) | |
st.write(f"Validation accuracy: {test_acc}") | |