Prime810's picture
Update Training/Code/train.py
ee4d02f verified
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetV2B1
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from sklearn.utils.class_weight import compute_class_weight
# ==================== Paths ====================
# Define paths
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))
train_dir = os.path.join(base_dir, 'combine_dataset/train')
val_dir = os.path.join(base_dir, 'combine_dataset/test')
# ==================== Parameters ====================
img_size = (192, 192) # Recommended for EfficientNetV2B1
batch_size = 32
epochs = 30
num_classes = 7
# ==================== Data Augmentation ====================
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=10,
zoom_range=0.1,
width_shift_range=0.05,
height_shift_range=0.05,
brightness_range=[0.9, 1.1],
horizontal_flip=True,
fill_mode='nearest'
)
val_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=img_size,
batch_size=batch_size,
class_mode='categorical',
shuffle=True
)
val_generator = val_datagen.flow_from_directory(
val_dir,
target_size=img_size,
batch_size=batch_size,
class_mode='categorical',
shuffle=False
)
# ==================== Compute Class Weights ====================
labels = train_generator.classes
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
class_weights = dict(enumerate(class_weights))
# ==================== Build Model ====================
base_model = EfficientNetV2B1(include_top=False, input_shape=(192, 192, 3), weights='imagenet')
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.4)(x)
output = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=output)
# ==================== Compile Model ====================
optimizer = Adam(learning_rate=1e-5)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# ==================== Callbacks ====================
checkpoint = ModelCheckpoint(
"/content/emotion_model.keras",
monitor='val_accuracy',
save_best_only=True,
verbose=1
)
early_stop = EarlyStopping(
monitor='val_loss',
patience=7,
restore_best_weights=True,
verbose=1
)
lr_schedule = ReduceLROnPlateau(
monitor='val_loss',
factor=0.5,
patience=3,
verbose=1,
min_lr=1e-6
)
# ==================== Train Model ====================
model.fit(
train_generator,
validation_data=val_generator,
epochs=epochs,
callbacks=[checkpoint, early_stop, lr_schedule],
class_weight=class_weights
)