Prime810's picture
Update Training/Code/train.py
a425514 verified
raw
history blame
2.15 kB
import os
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
# Define paths
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))
train_dir = os.path.join(base_dir, 'Data/train')
val_dir = os.path.join(base_dir, 'Data/test')
# Image generators with augmentation
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=30,
zoom_range=0.2,
horizontal_flip=True,
shear_range=0.2,
width_shift_range=0.2,
height_shift_range=0.2
)
val_datagen = ImageDataGenerator(rescale=1./255)
# Use a larger image size for better accuracy
img_size = 128
train_generator = train_datagen.flow_from_directory(
train_dir, target_size=(img_size, img_size), batch_size=32, color_mode='rgb', class_mode='categorical')
validation_generator = val_datagen.flow_from_directory(
val_dir, target_size=(img_size, img_size), batch_size=32, color_mode='rgb', class_mode='categorical')
# Load base model
base_model = MobileNetV2(include_top=False, input_shape=(img_size, img_size, 3), weights='imagenet')
base_model.trainable = False # Freeze base layers
# Add custom classification head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
x = Dropout(0.5)(x)
predictions = Dense(7, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
# Callbacks
callbacks = [
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True)
]
# Train the model
model.fit(train_generator, validation_data=validation_generator, epochs=30, callbacks=callbacks)
# Save the final model
model.save("emotion_model.keras")