Spaces:
Sleeping
Sleeping
File size: 3,095 Bytes
f2036f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import streamlit as st
import tensorflow as tf
from tensorflow.keras import layers, models, applications
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import numpy as np
# Set dataset paths
train_dir = 'data/train'
validation_dir = 'data/validation'
# Streamlit app
st.title("Transfer Learning with VGG16 for Image Classification")
# Input parameters
batch_size = st.slider("Batch Size", 16, 128, 32, 16)
epochs = st.slider("Epochs", 5, 50, 10, 5)
# Data augmentation and preprocessing
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
validation_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(150, 150),
batch_size=batch_size,
class_mode='binary'
)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(150, 150),
batch_size=batch_size,
class_mode='binary'
)
# Load the pre-trained VGG16 model
base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
# Freeze the convolutional base
base_model.trainable = False
# Add custom layers on top
model = models.Sequential([
base_model,
layers.Flatten(),
layers.Dense(256, activation='relu'),
layers.Dropout(0.5),
layers.Dense(1, activation='sigmoid') # Change the output layer based on the number of classes
])
model.summary()
# Compile the model
model.compile(optimizer='adam',
loss='binary_crossentropy', # Change loss function based on the number of classes
metrics=['accuracy'])
# Train the model
if st.button("Train Model"):
with st.spinner("Training the model..."):
history = model.fit(
train_generator,
steps_per_epoch=train_generator.samples // train_generator.batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=validation_generator.samples // validation_generator.batch_size
)
st.success("Model training completed!")
# Display training curves
st.subheader("Training and Validation Accuracy")
fig, ax = plt.subplots()
ax.plot(history.history['accuracy'], label='Training Accuracy')
ax.plot(history.history['val_accuracy'], label='Validation Accuracy')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.legend()
st.pyplot(fig)
st.subheader("Training and Validation Loss")
fig, ax = plt.subplots()
ax.plot(history.history['loss'], label='Training Loss')
ax.plot(history.history['val_loss'], label='Validation Loss')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.legend()
st.pyplot(fig)
# Evaluate the model
if st.button("Evaluate Model"):
test_loss, test_acc = model.evaluate(validation_generator, verbose=2)
st.write(f"Validation accuracy: {test_acc}")
|