fight-detection-live-demo / FeatureExtraction.py
Aquibjaved's picture
Upload 3 files
f126c7c verified
raw
history blame
1.62 kB
import cv2
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.applications.resnet import ResNet152
from tensorflow.keras.layers import AveragePooling2D, Flatten
from tensorflow.keras.models import Model
from tqdm import tqdm
import os
class FeatureExtractor:
def __init__(self, img_shape, channels):
self.seq_length = 40 # Number of frames to process
self.height = img_shape[0]
self.width = img_shape[1]
self.channels = channels
# Load ResNet152 model without the top fully connected layer
self.base_model = ResNet152(include_top=False, input_shape=(224, 224, 3), weights='imagenet')
# Freeze the base model layers
for layer in self.base_model.layers:
layer.trainable = False
# Adding an Average Pooling layer followed by Flatten
self.op = self.base_model.output
self.x_model = AveragePooling2D((7, 7), name='avg_pool')(self.op)
self.x_model = Flatten()(self.x_model)
# Create the feature extraction model
self.model = Model(self.base_model.input, self.x_model)
def extract_feature(self, frames_buffer):
x_op = np.zeros((2048, 40)) # Shape (features_dim, seq_length)
for i in range(len(frames_buffer)):
x_t = frames_buffer[i]
x_t = cv2.resize(x_t, (224, 224)) # Resize each frame to the required input size
x_t = np.expand_dims(x_t, axis=0) # Add batch dimension
x = self.model.predict(x_t)
x_op[:, i] = x
return x_op