import cv2 import torch import numpy as np from PIL import Image from torchvision import models, transforms from config import DEVICE, FRAME_RATE from tqdm import tqdm # Load GoogLeNet once from torchvision.models import GoogLeNet_Weights weights = GoogLeNet_Weights.DEFAULT googlenet = models.googlenet(weights=weights).to(DEVICE).eval() feature_extractor = torch.nn.Sequential( googlenet.conv1, googlenet.maxpool1, googlenet.conv2, googlenet.conv3, googlenet.maxpool2, googlenet.inception3a, googlenet.inception3b, googlenet.maxpool3, googlenet.inception4a, googlenet.inception4b, googlenet.inception4c, googlenet.inception4d, googlenet.inception4e, googlenet.maxpool4, googlenet.inception5a, googlenet.inception5b, googlenet.avgpool, torch.nn.Flatten() ) transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def extract_frames(video_path): cap = cv2.VideoCapture(video_path) frames = [] indices = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) total_frames = 100 # TEMP for idx in tqdm(range(0, total_frames, FRAME_RATE)): cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if not ret: break frames.append(Image.fromarray(frame)) indices.append(idx) cap.release() return frames, indices def extract_features(frames): features = [transform(frame) for frame in frames] features = torch.stack(features).to(DEVICE) features = feature_extractor(features) return features