nullHawk commited on
Commit
1b97263
·
verified ·
1 Parent(s): f9031d7

added main.py

Browse files
Files changed (1) hide show
  1. main.py +109 -0
main.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import av
4
+ from ultralytics import YOLO
5
+ from PIL import Image
6
+ from datetime import timedelta
7
+
8
+ # Paths
9
+ VIDEOS_DIR = '.'
10
+ video_path = os.path.join(VIDEOS_DIR, 'sample_video.mp4')
11
+ output_json_path = 'output.json'
12
+ model_path = os.path.join('.', 'runs', 'detect', 'train', 'weights', 'best.pt')
13
+
14
+ # Load YOLOv8 model
15
+ model = YOLO(model_path) # Load a custom model
16
+
17
+ threshold = 0.5
18
+
19
+ def format_timestamp(seconds):
20
+ # Convert seconds to timedelta and format as HH:MM:SS
21
+ td = timedelta(seconds=seconds)
22
+ return str(td)
23
+
24
+ def extract_frames(video_path):
25
+ container = av.open(video_path)
26
+ frames = []
27
+ for frame in container.decode(video=0):
28
+ # Convert timestamp to float seconds
29
+ timestamp = float(frame.pts * frame.time_base)
30
+ img = frame.to_image()
31
+ frames.append((img, timestamp))
32
+ return frames
33
+
34
+ def detect_logos(frames):
35
+ pepsi_pts = []
36
+ cocacola_pts = []
37
+
38
+ for img, timestamp in frames:
39
+ results = model(img) # Run inference
40
+
41
+ for result in results:
42
+ boxes = result.boxes # Boxes object for bounding box outputs
43
+
44
+ for box in boxes:
45
+ # Extract the bounding box and confidence
46
+ x1, y1, x2, y2 = box.xyxy[0].tolist() # Convert to list
47
+ score = box.conf[0].item() # Convert to float
48
+ class_id = int(box.cls[0].item()) # Convert to int
49
+
50
+ if score > threshold:
51
+ class_name = result.names[class_id].upper()
52
+ width = x2 - x1
53
+ height = y2 - y1
54
+ center_x = (x1 + x2) / 2
55
+ center_y = (y1 + y2) / 2
56
+ frame_center_x = img.width / 2
57
+ frame_center_y = img.height / 2
58
+ distance_from_center = ((center_x - frame_center_x) ** 2 + (center_y - frame_center_y) ** 2) ** 0.5
59
+
60
+ formatted_timestamp = format_timestamp(timestamp)
61
+ entry = {
62
+ "timestamp": formatted_timestamp,
63
+ "size": {"width": width, "height": height},
64
+ "distance_from_center": distance_from_center
65
+ }
66
+
67
+ if class_name == 'PEPSI':
68
+ pepsi_pts.append(entry)
69
+ elif class_name == 'COCA-COLA':
70
+ cocacola_pts.append(entry)
71
+
72
+ return pepsi_pts, cocacola_pts
73
+
74
+ def generate_output_json(pepsi_pts, cocacola_pts, output_path='output.json'):
75
+ # Convert all values to strings for JSON serialization
76
+ def to_serializable(obj):
77
+ if isinstance(obj, (list, dict)):
78
+ return obj
79
+ elif hasattr(obj, 'tolist'):
80
+ return obj.tolist() # Convert numpy arrays or tensors
81
+ elif hasattr(obj, 'item'):
82
+ return obj.item() # Convert single element tensors
83
+ else:
84
+ return str(obj) # Convert other non-serializable objects to string
85
+
86
+ output = {
87
+ "Pepsi_pts": [entry["timestamp"] for entry in pepsi_pts],
88
+ "CocaCola_pts": [entry["timestamp"] for entry in cocacola_pts],
89
+ "Pepsi_details": [ {k: to_serializable(v) for k, v in entry.items()} for entry in pepsi_pts ],
90
+ "CocaCola_details": [ {k: to_serializable(v) for k, v in entry.items()} for entry in cocacola_pts ]
91
+ }
92
+ with open(output_path, 'w') as f:
93
+ json.dump(output, f, indent=4)
94
+
95
+
96
+
97
+
98
+ def main(video_path):
99
+ frames = extract_frames(video_path)
100
+ pepsi_pts, cocacola_pts = detect_logos(frames)
101
+ generate_output_json(pepsi_pts, cocacola_pts)
102
+
103
+ if __name__ == "__main__":
104
+ import sys
105
+ if len(sys.argv) < 2:
106
+ print("Usage: python main.py <video_path>")
107
+ sys.exit(1)
108
+ video_path = sys.argv[1]
109
+ main(video_path)