fight-detection-live-demo / FeatureExtraction.py
sdafd's picture
Update FeatureExtraction.py
09165b3 verified
raw
history blame
1.19 kB
import cv2
import numpy as np
from tensorflow.keras.applications.resnet import ResNet152
from tensorflow.keras.layers import AveragePooling2D, Flatten
from tensorflow.keras.models import Model
class FeatureExtractor:
def __init__(self, img_shape, channels, seq_length):
self.seq_length = seq_length
self.height = img_shape[0]
self.width = img_shape[1]
self.channels = channels
self.base_model = ResNet152(include_top=False, input_shape=(224, 224, 3), weights='imagenet')
for layer in self.base_model.layers:
layer.trainable = False
self.op = self.base_model.output
self.x_model = AveragePooling2D((7, 7), name='avg_pool')(self.op)
self.x_model = Flatten()(self.x_model)
self.model = Model(self.base_model.input, self.x_model)
def extract_feature(self, frames_buffer):
x_op = np.zeros((2048, self.seq_length))
for i in range(len(frames_buffer)):
x_t = frames_buffer[i]
x_t = cv2.resize(x_t, (224, 224))
x_t = np.expand_dims(x_t, axis=0)
x = self.model.predict(x_t)
x_op[:, i] = x.flatten()
return x_op