|
import os |
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
from tensorflow.keras.applications import EfficientNetV2B1 |
|
from tensorflow.keras.models import Model |
|
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D |
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau |
|
from tensorflow.keras.optimizers import Adam |
|
from sklearn.utils.class_weight import compute_class_weight |
|
|
|
|
|
|
|
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../')) |
|
train_dir = os.path.join(base_dir, 'combine_dataset/train') |
|
val_dir = os.path.join(base_dir, 'combine_dataset/test') |
|
|
|
|
|
img_size = (192, 192) |
|
batch_size = 32 |
|
epochs = 30 |
|
num_classes = 7 |
|
|
|
|
|
train_datagen = ImageDataGenerator( |
|
rescale=1./255, |
|
rotation_range=10, |
|
zoom_range=0.1, |
|
width_shift_range=0.05, |
|
height_shift_range=0.05, |
|
brightness_range=[0.9, 1.1], |
|
horizontal_flip=True, |
|
fill_mode='nearest' |
|
) |
|
|
|
val_datagen = ImageDataGenerator(rescale=1./255) |
|
|
|
train_generator = train_datagen.flow_from_directory( |
|
train_dir, |
|
target_size=img_size, |
|
batch_size=batch_size, |
|
class_mode='categorical', |
|
shuffle=True |
|
) |
|
|
|
val_generator = val_datagen.flow_from_directory( |
|
val_dir, |
|
target_size=img_size, |
|
batch_size=batch_size, |
|
class_mode='categorical', |
|
shuffle=False |
|
) |
|
|
|
|
|
labels = train_generator.classes |
|
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels) |
|
class_weights = dict(enumerate(class_weights)) |
|
|
|
|
|
base_model = EfficientNetV2B1(include_top=False, input_shape=(192, 192, 3), weights='imagenet') |
|
x = base_model.output |
|
x = GlobalAveragePooling2D()(x) |
|
x = Dropout(0.4)(x) |
|
output = Dense(num_classes, activation='softmax')(x) |
|
model = Model(inputs=base_model.input, outputs=output) |
|
|
|
|
|
optimizer = Adam(learning_rate=1e-5) |
|
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) |
|
|
|
|
|
checkpoint = ModelCheckpoint( |
|
"/content/emotion_model.keras", |
|
monitor='val_accuracy', |
|
save_best_only=True, |
|
verbose=1 |
|
) |
|
|
|
early_stop = EarlyStopping( |
|
monitor='val_loss', |
|
patience=7, |
|
restore_best_weights=True, |
|
verbose=1 |
|
) |
|
|
|
lr_schedule = ReduceLROnPlateau( |
|
monitor='val_loss', |
|
factor=0.5, |
|
patience=3, |
|
verbose=1, |
|
min_lr=1e-6 |
|
) |
|
|
|
|
|
model.fit( |
|
train_generator, |
|
validation_data=val_generator, |
|
epochs=epochs, |
|
callbacks=[checkpoint, early_stop, lr_schedule], |
|
class_weight=class_weights |
|
) |
|
|