Prime810 commited on
Commit
aa8d5ee
·
verified ·
1 Parent(s): f920e65

Update Training/Code/train.py

Browse files
Files changed (1) hide show
  1. Training/Code/train.py +77 -38
Training/Code/train.py CHANGED
@@ -1,60 +1,99 @@
1
  import os
2
  import numpy as np
 
 
 
3
  from tensorflow.keras.models import Model
4
- from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input
 
5
  from tensorflow.keras.optimizers import Adam
6
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
7
- from tensorflow.keras.applications import MobileNetV2
8
- from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
 
 
9
 
10
- # Define paths
11
- base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))
12
- train_dir = os.path.join(base_dir, 'Data/train')
13
- val_dir = os.path.join(base_dir, 'Data/test')
 
14
 
15
- # Image generators with augmentation
16
  train_datagen = ImageDataGenerator(
17
  rescale=1./255,
18
- rotation_range=30,
19
- zoom_range=0.2,
 
 
 
20
  horizontal_flip=True,
21
- shear_range=0.2,
22
- width_shift_range=0.2,
23
- height_shift_range=0.2
24
  )
25
- val_datagen = ImageDataGenerator(rescale=1./255)
26
 
27
- # Use a larger image size for better accuracy
28
- img_size = 128
29
 
30
  train_generator = train_datagen.flow_from_directory(
31
- train_dir, target_size=(img_size, img_size), batch_size=32, color_mode='rgb', class_mode='categorical')
 
 
 
 
 
32
 
33
- validation_generator = val_datagen.flow_from_directory(
34
- val_dir, target_size=(img_size, img_size), batch_size=32, color_mode='rgb', class_mode='categorical')
 
 
 
 
 
35
 
36
- # Load base model
37
- base_model = MobileNetV2(include_top=False, input_shape=(img_size, img_size, 3), weights='imagenet')
38
- base_model.trainable = False # Freeze base layers
 
39
 
40
- # Add custom classification head
 
41
  x = base_model.output
42
  x = GlobalAveragePooling2D()(x)
43
- x = Dense(256, activation='relu')(x)
44
- x = Dropout(0.5)(x)
45
- predictions = Dense(7, activation='softmax')(x)
 
 
 
 
46
 
47
- model = Model(inputs=base_model.input, outputs=predictions)
48
- model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
 
 
 
 
 
49
 
50
- # Callbacks
51
- callbacks = [
52
- EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
53
- ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True)
54
- ]
 
55
 
56
- # Train the model
57
- model.fit(train_generator, validation_data=validation_generator, epochs=30, callbacks=callbacks)
 
 
 
 
 
58
 
59
- # Save the final model
60
- model.save("emotion_model.keras")
 
 
 
 
 
 
 
1
  import os
2
  import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
+ from tensorflow.keras.applications import EfficientNetV2B1
6
  from tensorflow.keras.models import Model
7
+ from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
8
+ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
9
  from tensorflow.keras.optimizers import Adam
10
+ from sklearn.utils.class_weight import compute_class_weight
11
+
12
+ # ==================== Paths ====================
13
+ train_dir = "/content/combine_dataset/train"
14
+ val_dir = "/content/combine_dataset/test"
15
 
16
+ # ==================== Parameters ====================
17
+ img_size = (192, 192) # Recommended for EfficientNetV2B1
18
+ batch_size = 32
19
+ epochs = 30
20
+ num_classes = 7
21
 
22
+ # ==================== Data Augmentation ====================
23
  train_datagen = ImageDataGenerator(
24
  rescale=1./255,
25
+ rotation_range=10,
26
+ zoom_range=0.1,
27
+ width_shift_range=0.05,
28
+ height_shift_range=0.05,
29
+ brightness_range=[0.9, 1.1],
30
  horizontal_flip=True,
31
+ fill_mode='nearest'
 
 
32
  )
 
33
 
34
+ val_datagen = ImageDataGenerator(rescale=1./255)
 
35
 
36
  train_generator = train_datagen.flow_from_directory(
37
+ train_dir,
38
+ target_size=img_size,
39
+ batch_size=batch_size,
40
+ class_mode='categorical',
41
+ shuffle=True
42
+ )
43
 
44
+ val_generator = val_datagen.flow_from_directory(
45
+ val_dir,
46
+ target_size=img_size,
47
+ batch_size=batch_size,
48
+ class_mode='categorical',
49
+ shuffle=False
50
+ )
51
 
52
+ # ==================== Compute Class Weights ====================
53
+ labels = train_generator.classes
54
+ class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
55
+ class_weights = dict(enumerate(class_weights))
56
 
57
+ # ==================== Build Model ====================
58
+ base_model = EfficientNetV2B1(include_top=False, input_shape=(192, 192, 3), weights='imagenet')
59
  x = base_model.output
60
  x = GlobalAveragePooling2D()(x)
61
+ x = Dropout(0.4)(x)
62
+ output = Dense(num_classes, activation='softmax')(x)
63
+ model = Model(inputs=base_model.input, outputs=output)
64
+
65
+ # ==================== Compile Model ====================
66
+ optimizer = Adam(learning_rate=1e-5)
67
+ model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
68
 
69
+ # ==================== Callbacks ====================
70
+ checkpoint = ModelCheckpoint(
71
+ "/content/emotion_model.keras",
72
+ monitor='val_accuracy',
73
+ save_best_only=True,
74
+ verbose=1
75
+ )
76
 
77
+ early_stop = EarlyStopping(
78
+ monitor='val_loss',
79
+ patience=7,
80
+ restore_best_weights=True,
81
+ verbose=1
82
+ )
83
 
84
+ lr_schedule = ReduceLROnPlateau(
85
+ monitor='val_loss',
86
+ factor=0.5,
87
+ patience=3,
88
+ verbose=1,
89
+ min_lr=1e-6
90
+ )
91
 
92
+ # ==================== Train Model ====================
93
+ model.fit(
94
+ train_generator,
95
+ validation_data=val_generator,
96
+ epochs=epochs,
97
+ callbacks=[checkpoint, early_stop, lr_schedule],
98
+ class_weight=class_weights
99
+ )