Salimshakeel commited on
Commit
6a6ee7b
·
1 Parent(s): 7db9110
config.py CHANGED
@@ -1,7 +1,8 @@
1
  # config.py
2
  import torch
3
- UPLOAD_DIR = "/code/static/uploads"
4
- OUTPUT_DIR = "/code/static/outputs"
 
5
  FRAME_RATE = 15
6
  SCORE_THRESHOLD = 0.4
7
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
1
  # config.py
2
  import torch
3
+ import os
4
+ UPLOAD_DIR = os.path.join(os.getcwd(), "static/uploads")
5
+ OUTPUT_DIR = os.path.join(os.getcwd(), "static/outputs")
6
  FRAME_RATE = 15
7
  SCORE_THRESHOLD = 0.4
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
routes/summarize.py CHANGED
@@ -1,8 +1,7 @@
1
  from fastapi import APIRouter, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
  from utils.file_utils import save_uploaded_file
4
- from services.extractor import extract_features
5
- from services.model_loader import load_model
6
  from services.summarizer import get_scores, get_selected_indices, save_summary_video
7
  from config import UPLOAD_DIR, OUTPUT_DIR
8
 
@@ -13,15 +12,27 @@ def summarize_video(video: UploadFile = File(...)):
13
  if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
14
  return JSONResponse(content={"error": "Unsupported file format"}, status_code=400)
15
 
 
16
  video_path = save_uploaded_file(video, UPLOAD_DIR)
17
- features, picks = extract_features(video_path)
18
- model = load_model("Model/epoch-199.pkl")
19
- scores = get_scores(model, features)
 
 
 
 
 
 
 
 
20
  selected = get_selected_indices(scores, picks)
21
  output_path = f"{OUTPUT_DIR}/summary_{video.filename}"
 
 
22
  save_summary_video(video_path, selected, output_path)
23
  summary_url = f"/static/outputs/summary_{video.filename}"
24
 
 
25
  return JSONResponse(content={
26
  "message": "Summarization complete",
27
  "summary_video_url": summary_url
 
1
  from fastapi import APIRouter, UploadFile, File
2
  from fastapi.responses import JSONResponse
3
  from utils.file_utils import save_uploaded_file
4
+ from services.extractor import extract_frames, extract_features
 
5
  from services.summarizer import get_scores, get_selected_indices, save_summary_video
6
  from config import UPLOAD_DIR, OUTPUT_DIR
7
 
 
12
  if not video.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
13
  return JSONResponse(content={"error": "Unsupported file format"}, status_code=400)
14
 
15
+ print("\n-----------> Uploading Video ....")
16
  video_path = save_uploaded_file(video, UPLOAD_DIR)
17
+
18
+ print("\n-----------> Extracting Frames ....")
19
+ frames, picks = extract_frames(video_path)
20
+
21
+ print("\n-----------> Extracting Features ....")
22
+ features = extract_features(frames)
23
+
24
+ print("\n-----------> Getting Scores ....")
25
+ scores = get_scores(features)
26
+
27
+ print("\n-----------> Selecting Indices ....")
28
  selected = get_selected_indices(scores, picks)
29
  output_path = f"{OUTPUT_DIR}/summary_{video.filename}"
30
+
31
+ print("\n-----------> Saving Video ....")
32
  save_summary_video(video_path, selected, output_path)
33
  summary_url = f"/static/outputs/summary_{video.filename}"
34
 
35
+ print("\n-----------> Returning Response ....")
36
  return JSONResponse(content={
37
  "message": "Summarization complete",
38
  "summary_video_url": summary_url
services/extractor.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  from PIL import Image
5
  from torchvision import models, transforms
6
  from config import DEVICE, FRAME_RATE
 
7
 
8
  # Load GoogLeNet once
9
  from torchvision.models import GoogLeNet_Weights
@@ -40,23 +41,26 @@ transform = transforms.Compose([
40
  )
41
  ])
42
 
43
- def extract_features(video_path):
44
  cap = cv2.VideoCapture(video_path)
45
- fps = cap.get(cv2.CAP_PROP_FPS)
46
- picks, frames = [], []
47
- count = 0
 
48
 
49
- while cap.isOpened():
 
50
  ret, frame = cap.read()
51
  if not ret:
52
  break
53
- if int(count % round(fps // FRAME_RATE)) == 0:
54
- image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
55
- input_tensor = transform(image).unsqueeze(0).to(DEVICE)
56
- with torch.no_grad():
57
- feature = feature_extractor(input_tensor).squeeze(0).cpu().numpy()
58
- frames.append(feature)
59
- picks.append(count)
60
- count += 1
61
  cap.release()
62
- return np.stack(frames), picks
 
 
 
 
 
 
 
4
  from PIL import Image
5
  from torchvision import models, transforms
6
  from config import DEVICE, FRAME_RATE
7
+ from tqdm import tqdm
8
 
9
  # Load GoogLeNet once
10
  from torchvision.models import GoogLeNet_Weights
 
41
  )
42
  ])
43
 
44
+ def extract_frames(video_path):
45
  cap = cv2.VideoCapture(video_path)
46
+ frames = []
47
+ indices = []
48
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
49
+ total_frames = 100 # TEMP
50
 
51
+ for idx in tqdm(range(0, total_frames, FRAME_RATE)):
52
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
53
  ret, frame = cap.read()
54
  if not ret:
55
  break
56
+ frames.append(Image.fromarray(frame))
57
+ indices.append(idx)
58
+
 
 
 
 
 
59
  cap.release()
60
+ return frames, indices
61
+
62
+ def extract_features(frames):
63
+ features = [transform(frame) for frame in frames]
64
+ features = torch.stack(features).to(DEVICE)
65
+ features = feature_extractor(features)
66
+ return features
services/summarizer.py CHANGED
@@ -1,16 +1,19 @@
1
  import cv2
2
  import torch
3
  from config import SCORE_THRESHOLD
 
4
 
5
- def get_scores(model, features):
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
- model = model.to(device)
 
 
 
 
8
  with torch.no_grad():
9
- features_tensor = torch.tensor(features, dtype=torch.float32).to(device)
10
- scores, _ = model(features_tensor)
11
  return scores.squeeze().cpu().numpy()
12
 
13
-
14
  def get_selected_indices(scores, picks, threshold=SCORE_THRESHOLD):
15
  return [picks[i] for i, score in enumerate(scores) if score >= threshold]
16
 
 
1
  import cv2
2
  import torch
3
  from config import SCORE_THRESHOLD
4
+ from services.model_loader import load_model
5
 
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+ model = load_model("Model/epoch-199.pkl")
8
+ model = model.to(device)
9
+ model = model.eval()
10
+
11
+ def get_scores(features):
12
+ # features.shape: (N, 1024)
13
  with torch.no_grad():
14
+ scores, _ = model(features)
 
15
  return scores.squeeze().cpu().numpy()
16
 
 
17
  def get_selected_indices(scores, picks, threshold=SCORE_THRESHOLD):
18
  return [picks[i] for i, score in enumerate(scores) if score >= threshold]
19