TensorFlowClass / pages /13_TransferLearning.py
eaglelandsonce's picture
Create 13_TransferLearning.py
f2036f2 verified
raw
history blame
3.1 kB
import streamlit as st
import tensorflow as tf
from tensorflow.keras import layers, models, applications
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
# Set dataset paths
train_dir = 'data/train'
validation_dir = 'data/validation'
# Streamlit app
st.title("Transfer Learning with VGG16 for Image Classification")
# Input parameters
batch_size = st.slider("Batch Size", 16, 128, 32, 16)
epochs = st.slider("Epochs", 5, 50, 10, 5)
# Data augmentation and preprocessing
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
validation_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(150, 150),
batch_size=batch_size,
class_mode='binary'
)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(150, 150),
batch_size=batch_size,
class_mode='binary'
)
# Load the pre-trained VGG16 model
base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
# Freeze the convolutional base
base_model.trainable = False
# Add custom layers on top
model = models.Sequential([
base_model,
layers.Flatten(),
layers.Dense(256, activation='relu'),
layers.Dropout(0.5),
layers.Dense(1, activation='sigmoid') # Change the output layer based on the number of classes
])
model.summary()
# Compile the model
model.compile(optimizer='adam',
loss='binary_crossentropy', # Change loss function based on the number of classes
metrics=['accuracy'])
# Train the model
if st.button("Train Model"):
with st.spinner("Training the model..."):
history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // train_generator.batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=validation_generator.samples // validation_generator.batch_size
)
st.success("Model training completed!")
# Display training curves
st.subheader("Training and Validation Accuracy")
fig, ax = plt.subplots()
ax.plot(history.history['accuracy'], label='Training Accuracy')
ax.plot(history.history['val_accuracy'], label='Validation Accuracy')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.legend()
st.pyplot(fig)
st.subheader("Training and Validation Loss")
fig, ax = plt.subplots()
ax.plot(history.history['loss'], label='Training Loss')
ax.plot(history.history['val_loss'], label='Validation Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()
st.pyplot(fig)
# Evaluate the model
if st.button("Evaluate Model"):
test_loss, test_acc = model.evaluate(validation_generator, verbose=2)
st.write(f"Validation accuracy: {test_acc}")