Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import cv2 | |
import traceback | |
from collections import Counter | |
from sklearn.model_selection import train_test_split | |
from tensorflow.keras.utils import Sequence | |
from tensorflow.keras.models import Sequential, load_model | |
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, Flatten, Dense, Dropout, BatchNormalization | |
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger | |
import tensorflow as tf | |
# === CONFIG === | |
DATA_DIR = "D:\\K_REPO\\ComV\\train" | |
N_FRAMES = 30 | |
IMG_SIZE = (96, 96) | |
EPOCHS = 10 | |
BATCH_SIZE = 14 | |
CHECKPOINT_DIR = r"D:\K_REPO\ComV\AI_made\trainnig_output\checkpoint" | |
RESUME_TRAINING = 1 | |
MIN_REQUIRED_FRAMES = 10 | |
OUTPUT_PATH = r"D:\K_REPO\ComV\AI_made\trainnig_output\final_model_2.h5" | |
# Optimize OpenCV | |
cv2.setUseOptimized(True) | |
cv2.setNumThreads(8) | |
# === VIDEO DATA GENERATOR === | |
class VideoDataGenerator(Sequence): | |
def __init__(self, video_paths, labels, batch_size, n_frames, img_size): | |
self.video_paths, self.labels = self._filter_invalid_videos(video_paths, labels) | |
self.batch_size = batch_size | |
self.n_frames = n_frames | |
self.img_size = img_size | |
self.indices = np.arange(len(self.video_paths)) | |
print(f"[INFO] Final dataset size: {len(self.video_paths)} videos") | |
def _filter_invalid_videos(self, paths, labels): | |
valid_paths = [] | |
valid_labels = [] | |
for path, label in zip(paths, labels): | |
cap = cv2.VideoCapture(path) | |
if not cap.isOpened(): | |
print(f"[WARNING] Could not open video: {path}") | |
continue | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
cap.release() | |
if total_frames < MIN_REQUIRED_FRAMES: | |
print(f"[WARNING] Skipping {path} - only {total_frames} frames (needs at least {MIN_REQUIRED_FRAMES})") | |
continue | |
valid_paths.append(path) | |
valid_labels.append(label) | |
return valid_paths, valid_labels | |
def __len__(self): | |
return int(np.ceil(len(self.video_paths) / self.batch_size)) | |
def __getitem__(self, index): | |
batch_indices = self.indices[index*self.batch_size:(index+1)*self.batch_size] | |
X, y = [], [] | |
for i in batch_indices: | |
path = self.video_paths[i] | |
label = self.labels[i] | |
try: | |
frames = self._load_video_frames(path) | |
X.append(frames) | |
y.append(label) | |
except Exception as e: | |
print(f"[WARNING] Error processing {path} - {str(e)}") | |
X.append(np.zeros((self.n_frames, *self.img_size, 3))) | |
y.append(label) | |
return np.array(X), np.array(y) | |
def _load_video_frames(self, path): | |
cap = cv2.VideoCapture(path) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
if total_frames < self.n_frames: | |
frame_indices = np.linspace(0, total_frames - 1, min(total_frames, self.n_frames), dtype=np.int32) | |
else: | |
frame_indices = np.linspace(0, total_frames - 1, self.n_frames, dtype=np.int32) | |
frames = [] | |
for idx in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
ret, frame = cap.read() | |
if not ret: | |
frame = np.zeros((*self.img_size, 3), dtype=np.uint8) | |
else: | |
frame = cv2.resize(frame, self.img_size) | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames.append(frame) | |
cap.release() | |
while len(frames) < self.n_frames: | |
frames.append(frames[-1] if frames else np.zeros((*self.img_size, 3), dtype=np.uint8)) | |
return np.array(frames) / 255.0 | |
def on_epoch_end(self): | |
np.random.shuffle(self.indices) | |
def create_model(): | |
model = Sequential([ | |
Input(shape=(N_FRAMES, *IMG_SIZE, 3)), | |
Conv3D(32, kernel_size=(3, 3, 3), activation='relu', padding='same'), | |
MaxPooling3D(pool_size=(1, 2, 2)), | |
BatchNormalization(), | |
Conv3D(64, kernel_size=(3, 3, 3), activation='relu', padding='same'), | |
MaxPooling3D(pool_size=(1, 2, 2)), | |
BatchNormalization(), | |
Conv3D(128, kernel_size=(3, 3, 3), activation='relu', padding='same'), | |
MaxPooling3D(pool_size=(2, 2, 2)), | |
BatchNormalization(), | |
Flatten(), | |
Dense(256, activation='relu'), | |
Dropout(0.5), | |
Dense(1, activation='sigmoid') | |
]) | |
model.compile(optimizer='adam', | |
loss='binary_crossentropy', | |
metrics=['accuracy']) | |
return model | |
def load_data(): | |
video_paths, labels = [], [] | |
for label_name in ["Fighting", "Normal"]: | |
label_dir = os.path.join(DATA_DIR, label_name) | |
if not os.path.isdir(label_dir): | |
raise FileNotFoundError(f"Directory not found: {label_dir}") | |
label = 1 if label_name.lower() == "fighting" else 0 | |
for file in os.listdir(label_dir): | |
if file.lower().endswith((".mp4", ".mpeg", ".avi", ".mov")): | |
full_path = os.path.join(label_dir, file) | |
video_paths.append(full_path) | |
labels.append(label) | |
if not video_paths: | |
raise ValueError(f"No videos found in {DATA_DIR}") | |
print(f"[INFO] Total videos: {len(video_paths)} (Fighting: {labels.count(1)}, Normal: {labels.count(0)})") | |
if len(set(labels)) > 1: | |
return train_test_split(video_paths, labels, test_size=0.2, stratify=labels, random_state=42) | |
else: | |
print("[WARNING] Only one class found. Splitting without stratification.") | |
return train_test_split(video_paths, labels, test_size=0.2, random_state=42) | |
def get_latest_checkpoint(): | |
if not os.path.exists(CHECKPOINT_DIR): | |
os.makedirs(CHECKPOINT_DIR) | |
return None | |
checkpoints = [f for f in os.listdir(CHECKPOINT_DIR) | |
if f.startswith('ckpt_') and f.endswith('.h5')] | |
if not checkpoints: | |
return None | |
checkpoints.sort(key=lambda x: int(x.split('_')[1].split('.')[0])) | |
return os.path.join(CHECKPOINT_DIR, checkpoints[-1]) | |
def main(): | |
# Load and split data | |
try: | |
train_paths, val_paths, train_labels, val_labels = load_data() | |
except Exception as e: | |
print(f"[ERROR] Failed to load data: {str(e)}") | |
return | |
# Create data generators | |
try: | |
train_gen = VideoDataGenerator(train_paths, train_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE) | |
val_gen = VideoDataGenerator(val_paths, val_labels, BATCH_SIZE, N_FRAMES, IMG_SIZE) | |
except Exception as e: | |
print(f"[ERROR] Failed to create data generators: {str(e)}") | |
return | |
# Callbacks | |
callbacks = [ | |
ModelCheckpoint( | |
os.path.join(CHECKPOINT_DIR, 'ckpt_{epoch}.h5'), | |
save_best_only=False, | |
save_weights_only=False | |
), | |
CSVLogger('training_log.csv', append=True), | |
EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) | |
] | |
# Handle resume training | |
initial_epoch = 0 | |
try: | |
if RESUME_TRAINING: | |
ckpt = get_latest_checkpoint() | |
if ckpt: | |
print(f"[INFO] Resuming training from checkpoint: {ckpt}") | |
model = load_model(ckpt) | |
initial_epoch = int(ckpt.split('_')[1].split('.')[0]) | |
else: | |
print("[INFO] No checkpoint found, starting new training") | |
model = create_model() | |
else: | |
model = create_model() | |
except Exception as e: | |
print(f"[ERROR] Failed to initialize model: {str(e)}") | |
return | |
# Display model summary | |
model.summary() | |
# Train model | |
try: | |
print("[INFO] Starting training...") | |
history = model.fit( | |
train_gen, | |
validation_data=val_gen, | |
epochs=EPOCHS, | |
initial_epoch=initial_epoch, | |
callbacks=callbacks, | |
verbose=1 | |
) | |
except Exception as e: | |
print(f"[ERROR] Training failed: {str(e)}") | |
traceback.print_exc() | |
finally: | |
model.save(OUTPUT_PATH) | |
print("[INFO] Training completed. Model saved to final_model_2.h5") | |
if __name__ == "__main__": | |
print("[INFO] Starting script...") | |
main() | |
print("[INFO] Script execution completed.") | |