|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision import models, transforms |
|
from config import DEVICE, FRAME_RATE |
|
from tqdm import tqdm |
|
|
|
|
|
from torchvision.models import GoogLeNet_Weights |
|
weights = GoogLeNet_Weights.DEFAULT |
|
googlenet = models.googlenet(weights=weights).to(DEVICE).eval() |
|
|
|
feature_extractor = torch.nn.Sequential( |
|
googlenet.conv1, |
|
googlenet.maxpool1, |
|
googlenet.conv2, |
|
googlenet.conv3, |
|
googlenet.maxpool2, |
|
googlenet.inception3a, |
|
googlenet.inception3b, |
|
googlenet.maxpool3, |
|
googlenet.inception4a, |
|
googlenet.inception4b, |
|
googlenet.inception4c, |
|
googlenet.inception4d, |
|
googlenet.inception4e, |
|
googlenet.maxpool4, |
|
googlenet.inception5a, |
|
googlenet.inception5b, |
|
googlenet.avgpool, |
|
torch.nn.Flatten() |
|
) |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
]) |
|
|
|
def extract_frames(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
frames = [] |
|
indices = [] |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
total_frames = 100 |
|
|
|
for idx in tqdm(range(0, total_frames, FRAME_RATE)): |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
frames.append(Image.fromarray(frame)) |
|
indices.append(idx) |
|
|
|
cap.release() |
|
return frames, indices |
|
|
|
def extract_features(frames): |
|
features = [transform(frame) for frame in frames] |
|
features = torch.stack(features).to(DEVICE) |
|
features = feature_extractor(features) |
|
return features |
|
|