|
import os |
|
import numpy as np |
|
from tensorflow.keras.models import Model |
|
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input |
|
from tensorflow.keras.optimizers import Adam |
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
from tensorflow.keras.applications import MobileNetV2 |
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint |
|
|
|
|
|
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../')) |
|
train_dir = os.path.join(base_dir, 'Data/train') |
|
val_dir = os.path.join(base_dir, 'Data/test') |
|
|
|
|
|
train_datagen = ImageDataGenerator( |
|
rescale=1./255, |
|
rotation_range=30, |
|
zoom_range=0.2, |
|
horizontal_flip=True, |
|
shear_range=0.2, |
|
width_shift_range=0.2, |
|
height_shift_range=0.2 |
|
) |
|
val_datagen = ImageDataGenerator(rescale=1./255) |
|
|
|
|
|
img_size = 128 |
|
|
|
train_generator = train_datagen.flow_from_directory( |
|
train_dir, target_size=(img_size, img_size), batch_size=32, color_mode='rgb', class_mode='categorical') |
|
|
|
validation_generator = val_datagen.flow_from_directory( |
|
val_dir, target_size=(img_size, img_size), batch_size=32, color_mode='rgb', class_mode='categorical') |
|
|
|
|
|
base_model = MobileNetV2(include_top=False, input_shape=(img_size, img_size, 3), weights='imagenet') |
|
base_model.trainable = False |
|
|
|
|
|
x = base_model.output |
|
x = GlobalAveragePooling2D()(x) |
|
x = Dense(256, activation='relu')(x) |
|
x = Dropout(0.5)(x) |
|
predictions = Dense(7, activation='softmax')(x) |
|
|
|
model = Model(inputs=base_model.input, outputs=predictions) |
|
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy']) |
|
|
|
|
|
callbacks = [ |
|
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True), |
|
ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True) |
|
] |
|
|
|
|
|
model.fit(train_generator, validation_data=validation_generator, epochs=30, callbacks=callbacks) |
|
|
|
|
|
model.save("emotion_model.keras") |
|
|