|
|
|
import tensorflow as tf
|
|
from tensorflow.keras import layers, models, callbacks
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import datetime
|
|
import os
|
|
import zipfile
|
|
from google.colab import files
|
|
from sklearn.metrics import classification_report, confusion_matrix
|
|
import seaborn as sns
|
|
from sklearn.utils import class_weight
|
|
|
|
print("TensorFlow version:", tf.__version__)
|
|
|
|
|
|
uploaded = files.upload()
|
|
zip_filename = list(uploaded.keys())[0]
|
|
|
|
|
|
extract_path = 'dataset'
|
|
with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
|
|
zip_ref.extractall(extract_path)
|
|
|
|
|
|
print("\nExtracted files:")
|
|
!ls {extract_path}
|
|
print("\nTrain folder contents:")
|
|
!ls {extract_path}/train
|
|
|
|
|
|
IMG_SIZE = (150, 150)
|
|
BATCH_SIZE = 32
|
|
|
|
|
|
train_datagen = ImageDataGenerator(
|
|
rescale=1./255,
|
|
rotation_range=40,
|
|
width_shift_range=0.3,
|
|
height_shift_range=0.3,
|
|
shear_range=0.3,
|
|
zoom_range=0.3,
|
|
horizontal_flip=True,
|
|
vertical_flip=True,
|
|
brightness_range=[0.8, 1.2],
|
|
validation_split=0.2,
|
|
fill_mode='nearest'
|
|
)
|
|
|
|
train_generator = train_datagen.flow_from_directory(
|
|
os.path.join(extract_path, 'train'),
|
|
target_size=IMG_SIZE,
|
|
batch_size=BATCH_SIZE,
|
|
class_mode='binary',
|
|
subset='training',
|
|
shuffle=True
|
|
)
|
|
|
|
validation_generator = train_datagen.flow_from_directory(
|
|
os.path.join(extract_path, 'train'),
|
|
target_size=IMG_SIZE,
|
|
batch_size=BATCH_SIZE,
|
|
class_mode='binary',
|
|
subset='validation',
|
|
shuffle=True
|
|
)
|
|
|
|
|
|
class_weights = class_weight.compute_class_weight(
|
|
'balanced',
|
|
classes=np.unique(train_generator.classes),
|
|
y=train_generator.classes
|
|
)
|
|
class_weights = dict(enumerate(class_weights))
|
|
|
|
class_names = list(train_generator.class_indices.keys())
|
|
print("\nDetected classes:", class_names)
|
|
print("Training samples:", train_generator.samples)
|
|
print("Validation samples:", validation_generator.samples)
|
|
print("Class weights:", class_weights)
|
|
|
|
|
|
def build_enhanced_model(input_shape):
|
|
model = models.Sequential([
|
|
|
|
layers.Conv2D(64, (3,3), activation='relu', padding='same', input_shape=input_shape),
|
|
layers.BatchNormalization(),
|
|
layers.Conv2D(64, (3,3), activation='relu', padding='same'),
|
|
layers.BatchNormalization(),
|
|
layers.MaxPooling2D((2,2)),
|
|
layers.Dropout(0.3),
|
|
|
|
|
|
layers.Conv2D(128, (3,3), activation='relu', padding='same'),
|
|
layers.BatchNormalization(),
|
|
layers.Conv2D(128, (3,3), activation='relu', padding='same'),
|
|
layers.BatchNormalization(),
|
|
layers.MaxPooling2D((2,2)),
|
|
layers.Dropout(0.3),
|
|
|
|
|
|
layers.Conv2D(256, (3,3), activation='relu', padding='same'),
|
|
layers.BatchNormalization(),
|
|
layers.Conv2D(256, (3,3), activation='relu', padding='same'),
|
|
layers.BatchNormalization(),
|
|
layers.MaxPooling2D((2,2)),
|
|
layers.Dropout(0.4),
|
|
|
|
|
|
layers.Flatten(),
|
|
layers.Dense(512, activation='relu'),
|
|
layers.BatchNormalization(),
|
|
layers.Dropout(0.5),
|
|
layers.Dense(1, activation='sigmoid')
|
|
])
|
|
|
|
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
|
|
|
|
model.compile(
|
|
optimizer=optimizer,
|
|
loss='binary_crossentropy',
|
|
metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
|
|
)
|
|
return model
|
|
|
|
model = build_enhanced_model(input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
|
|
model.summary()
|
|
|
|
|
|
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
|
|
|
|
callbacks = [
|
|
callbacks.TensorBoard(log_dir=log_dir),
|
|
callbacks.ReduceLROnPlateau(
|
|
monitor='val_loss',
|
|
factor=0.5,
|
|
patience=3,
|
|
min_lr=1e-7,
|
|
verbose=1
|
|
),
|
|
callbacks.ModelCheckpoint(
|
|
'best_model.keras',
|
|
monitor='val_auc',
|
|
mode='max',
|
|
save_best_only=True,
|
|
save_weights_only=False,
|
|
verbose=1
|
|
)
|
|
]
|
|
|
|
print("\nStarting training for full 30 epochs...")
|
|
history = model.fit(
|
|
train_generator,
|
|
steps_per_epoch=train_generator.samples // BATCH_SIZE,
|
|
epochs=30,
|
|
validation_data=validation_generator,
|
|
validation_steps=validation_generator.samples // BATCH_SIZE,
|
|
callbacks=callbacks,
|
|
class_weight=class_weights,
|
|
verbose=1
|
|
)
|
|
|
|
|
|
print("\nTraining complete. Saving final model...")
|
|
|
|
model.save('final_model.keras')
|
|
|
|
|
|
history_df = pd.DataFrame(history.history)
|
|
history_df.to_csv('training_history.csv', index=False)
|
|
|
|
|
|
plt.figure(figsize=(12, 5))
|
|
plt.subplot(1, 2, 1)
|
|
plt.plot(history.history['accuracy'], label='Train Accuracy')
|
|
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
|
|
plt.title('Model Accuracy')
|
|
plt.ylabel('Accuracy')
|
|
plt.xlabel('Epoch')
|
|
plt.legend()
|
|
|
|
plt.subplot(1, 2, 2)
|
|
plt.plot(history.history['loss'], label='Train Loss')
|
|
plt.plot(history.history['val_loss'], label='Validation Loss')
|
|
plt.title('Model Loss')
|
|
plt.ylabel('Loss')
|
|
plt.xlabel('Epoch')
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
val_preds = model.predict(validation_generator)
|
|
val_preds = (val_preds > 0.5).astype(int)
|
|
|
|
|
|
cm = confusion_matrix(validation_generator.classes, val_preds)
|
|
plt.figure(figsize=(6, 6))
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
|
xticklabels=class_names, yticklabels=class_names)
|
|
plt.title('Confusion Matrix')
|
|
plt.ylabel('True Label')
|
|
plt.xlabel('Predicted Label')
|
|
plt.show()
|
|
|
|
|
|
print("\nClassification Report:")
|
|
print(classification_report(validation_generator.classes, val_preds,
|
|
target_names=class_names))
|
|
|
|
|
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
|
tflite_model = converter.convert()
|
|
with open('cat_dog.tflite', 'wb') as f:
|
|
f.write(tflite_model)
|
|
|
|
print("\nAll models saved successfully:")
|
|
print("- final_model.keras (model after all epochs)")
|
|
print("- best_model.keras (best validation AUC model)")
|
|
print("- cat_dog.tflite (TFLite version)") |