Salimshakeel commited on
Commit
1579b70
·
1 Parent(s): 7cd255e
routes/summarize.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import APIRouter, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
- from services.extractor import extract_frames, extract_features
4
  from services.summarizer import get_scores, get_selected_indices, save_summary_video
5
  from uuid import uuid4
6
  import time
@@ -25,11 +25,8 @@ def summarize_video(video: UploadFile = File(...)):
25
  with open(filepath, "wb") as f:
26
  f.write(video.file.read())
27
 
28
- print("\n-----------> Extracting Frames ....")
29
- frames, picks = extract_frames(filepath)
30
-
31
  print("\n-----------> Extracting Features ....")
32
- features = extract_features(frames)
33
 
34
  print("\n-----------> Getting Scores ....")
35
  scores = get_scores(features)
 
1
  from fastapi import APIRouter, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
+ from services.extractor import extract_features
4
  from services.summarizer import get_scores, get_selected_indices, save_summary_video
5
  from uuid import uuid4
6
  import time
 
25
  with open(filepath, "wb") as f:
26
  f.write(video.file.read())
27
 
 
 
 
28
  print("\n-----------> Extracting Features ....")
29
+ features, picks = extract_features(filepath)
30
 
31
  print("\n-----------> Getting Scores ....")
32
  scores = get_scores(features)
services/extractor.py CHANGED
@@ -5,6 +5,7 @@ from PIL import Image
5
  from torchvision import models, transforms
6
  from config import DEVICE, FRAME_RATE
7
  from tqdm import tqdm
 
8
 
9
  # Load GoogLeNet once
10
  from torchvision.models import GoogLeNet_Weights
@@ -31,6 +32,7 @@ feature_extractor = torch.nn.Sequential(
31
  googlenet.avgpool,
32
  torch.nn.Flatten()
33
  )
 
34
 
35
  transform = transforms.Compose([
36
  transforms.Resize((224, 224)),
@@ -41,33 +43,32 @@ transform = transforms.Compose([
41
  )
42
  ])
43
 
44
- def extract_frames(video_path):
45
  cap = cv2.VideoCapture(video_path)
46
  frames = []
47
  indices = []
48
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
49
  # total_frames = 300 # TEMP
50
  print(f"Total frames in video: {total_frames}")
51
- print(f"Extracting frames at every {FRAME_RATE} frames...")
52
 
53
- for idx in tqdm(range(0, total_frames, FRAME_RATE)):
54
  cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
55
  ret, frame = cap.read()
56
  if not ret:
57
  break
58
- frames.append(Image.fromarray(frame))
59
- indices.append(idx)
60
 
61
- print(f"Indices of extracted frames: {indices}")
62
- print(f"Total frames extracted: {len(frames)}")
 
 
 
 
63
 
64
  cap.release()
65
- return frames, indices
66
 
67
- def extract_features(frames):
68
- features = [transform(frame) for frame in frames]
69
- features = torch.stack(features).to(DEVICE)
70
- print("Features before GoogleNet extraction:", features.shape)
71
- features = feature_extractor(features)
72
- print("Features after GoogleNet extraction:", features.shape)
73
- return features
 
5
  from torchvision import models, transforms
6
  from config import DEVICE, FRAME_RATE
7
  from tqdm import tqdm
8
+ from services.model_loader import batch_inference
9
 
10
  # Load GoogLeNet once
11
  from torchvision.models import GoogLeNet_Weights
 
32
  googlenet.avgpool,
33
  torch.nn.Flatten()
34
  )
35
+ feature_extractor = feature_extractor.eval()
36
 
37
  transform = transforms.Compose([
38
  transforms.Resize((224, 224)),
 
43
  )
44
  ])
45
 
46
+ def extract_features(video_path):
47
  cap = cv2.VideoCapture(video_path)
48
  frames = []
49
  indices = []
50
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
51
  # total_frames = 300 # TEMP
52
  print(f"Total frames in video: {total_frames}")
 
53
 
54
+ for idx in tqdm(range(total_frames)):
55
  cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
56
  ret, frame = cap.read()
57
  if not ret:
58
  break
 
 
59
 
60
+ # process frame
61
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
62
+ frame = transform(frame)
63
+
64
+ frames.append(frame)
65
+ indices.append(idx)
66
 
67
  cap.release()
 
68
 
69
+ frames = torch.stack(frames).to(DEVICE)
70
+ print("Features before GoogleNet extraction:", frames.shape)
71
+ frames = batch_inference(model=feature_extractor, input=frames, batch_size=32)
72
+ print("Features after GoogleNet extraction:", frames.shape)
73
+
74
+ return frames, indices
 
services/model_loader.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
5
  from layers.summarizer import PGL_SUM
6
  from config import DEVICE
 
7
 
8
  def load_model(weights_path):
9
  model = PGL_SUM(
@@ -17,3 +18,13 @@ def load_model(weights_path):
17
  model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
18
  model.eval()
19
  return model
 
 
 
 
 
 
 
 
 
 
 
4
  sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
5
  from layers.summarizer import PGL_SUM
6
  from config import DEVICE
7
+ from tqdm import tqdm
8
 
9
  def load_model(weights_path):
10
  model = PGL_SUM(
 
18
  model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
19
  model.eval()
20
  return model
21
+
22
+ def batch_inference(model, input, batch_size=128):
23
+ model.eval()
24
+ output = []
25
+ with torch.no_grad():
26
+ for i in tqdm(range(0, input.size(0), batch_size)):
27
+ batch = input[i:i + batch_size].to(DEVICE)
28
+ out = model(batch)
29
+ output.append(out.cpu())
30
+ return torch.cat(output, dim=0)
services/summarizer.py CHANGED
@@ -60,7 +60,7 @@ def save_summary_video(video_path, selected_indices, output_path, fps=15):
60
  out.release()
61
 
62
  print("Fixing the video with ffmpeg")
63
- # fix_video_with_ffmpeg(output_path)
64
 
65
  def fix_video_with_ffmpeg(path):
66
  temp_path = path + ".fixed.mp4"
 
60
  out.release()
61
 
62
  print("Fixing the video with ffmpeg")
63
+ fix_video_with_ffmpeg(output_path)
64
 
65
  def fix_video_with_ffmpeg(path):
66
  temp_path = path + ".fixed.mp4"