nagasurendra commited on
Commit
4527f8f
·
verified ·
1 Parent(s): 2e1fb9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -47
app.py CHANGED
@@ -3,32 +3,26 @@ import torch
3
  import gradio as gr
4
  import numpy as np
5
  from ultralytics import YOLO
 
6
 
7
  # Load YOLOv8 model
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
  model = YOLO('./data/best.pt') # Path to your model
10
  model.to(device)
11
 
12
- # Store frames with detected objects
13
  frames_with_detections = []
14
- detection_counts = []
15
 
16
- # Define the function that processes the uploaded video
17
  def process_video(video):
18
- # video is now the file path string, not a file object
19
- input_video = cv2.VideoCapture(video) # Directly pass the path to cv2.VideoCapture
20
-
21
- # Get frame width, height, and fps from input video
22
  frame_width = int(input_video.get(cv2.CAP_PROP_FRAME_WIDTH))
23
  frame_height = int(input_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
24
  fps = input_video.get(cv2.CAP_PROP_FPS)
25
 
26
- # Resize to reduce computation (optional)
27
- new_width, new_height = 640, 480 # Resize to 640x480 resolution
28
- frame_width, frame_height = new_width, new_height
29
-
30
- # Track detected objects by their bounding box coordinates
31
- detected_boxes = set()
32
 
33
  while True:
34
  # Read a frame from the video
@@ -36,56 +30,54 @@ def process_video(video):
36
  if not ret:
37
  break # End of video
38
 
39
- # Resize the frame to reduce computational load
40
  frame = cv2.resize(frame, (new_width, new_height))
41
 
42
  # Perform inference on the frame
43
  results = model(frame) # Automatically uses GPU if available
44
 
45
- # Check if any object was detected
46
- if len(results[0].boxes) > 0: # If there are detected objects
47
- # Get the bounding boxes for each detected object
48
- boxes = results[0].boxes.xyxy.cpu().numpy() # Get xyxy coordinates
49
 
50
- # Loop through each detection and only show the frame for new objects
51
- for box in boxes:
52
- x1, y1, x2, y2 = box
53
- detection_box = (x1, y1, x2, y2)
54
 
55
- # Check if this box was already processed
56
- if detection_box not in detected_boxes:
57
- # Add the box to the set to avoid repeating the detection
58
- detected_boxes.add(detection_box)
59
 
60
- # Annotate the frame with bounding boxes
61
- annotated_frame = results[0].plot() # Plot the frame with bounding boxes
62
 
63
- # Convert the annotated frame to RGB format for displaying
64
- annotated_frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
 
 
 
 
65
 
66
- # Add this frame to the list of frames with detections
67
- frames_with_detections.append(annotated_frame_rgb)
 
68
 
69
- # Yield the latest frame immediately for Gradio's real-time display
70
- yield annotated_frame_rgb
 
 
 
 
71
 
72
  # Release resources
73
  input_video.release()
74
 
75
- # Create a Gradio Blocks interface
76
  with gr.Blocks() as demo:
77
- # Define a file input for video upload
78
- video_input = gr.Video(label="Upload Video")
79
-
80
- # Define the output area to show processed frames (gallery for continuous update)
81
- gallery_output = gr.Gallery(label="Detection Album", show_label=True, columns=3) # Display images in a row (album)
82
-
83
- # Define the function to update frames in the album
84
- def update_gallery(video):
85
- return process_video(video) # Return frames one by one as they are detected
86
-
87
- # Connect the video input to the gallery update
88
- video_input.change(update_gallery, inputs=video_input, outputs=gallery_output)
89
 
90
  # Launch the interface
91
  demo.launch()
 
3
  import gradio as gr
4
  import numpy as np
5
  from ultralytics import YOLO
6
+ import matplotlib.pyplot as plt
7
 
8
  # Load YOLOv8 model
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model = YOLO('./data/best.pt') # Path to your model
11
  model.to(device)
12
 
13
+ # List to store frames with detections
14
  frames_with_detections = []
 
15
 
16
+ # Define the function to process the video
17
  def process_video(video):
18
+ # Open the video file
19
+ input_video = cv2.VideoCapture(video)
 
 
20
  frame_width = int(input_video.get(cv2.CAP_PROP_FRAME_WIDTH))
21
  frame_height = int(input_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
22
  fps = input_video.get(cv2.CAP_PROP_FPS)
23
 
24
+ # Resize frames to 640x480 (optional, to reduce computational load)
25
+ new_width, new_height = 640, 480
 
 
 
 
26
 
27
  while True:
28
  # Read a frame from the video
 
30
  if not ret:
31
  break # End of video
32
 
33
+ # Resize the frame
34
  frame = cv2.resize(frame, (new_width, new_height))
35
 
36
  # Perform inference on the frame
37
  results = model(frame) # Automatically uses GPU if available
38
 
39
+ # If there are detections
40
+ if len(results[0].boxes) > 0:
41
+ boxes = results[0].boxes.xyxy.cpu().numpy() # Get the bounding boxes
 
42
 
43
+ # Annotate the frame with bounding boxes
44
+ annotated_frame = results[0].plot()
 
 
45
 
46
+ # Convert the frame to RGB
47
+ annotated_frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
 
 
48
 
49
+ # Append the frame with detection to list
50
+ frames_with_detections.append(annotated_frame_rgb)
51
 
52
+ # Create a simple bar chart to show the count of detected objects
53
+ fig, ax = plt.subplots()
54
+ ax.bar([1], [len(boxes)], color='blue') # Bar for the current frame detection
55
+ ax.set_xlabel('Frame')
56
+ ax.set_ylabel('Number of Detections')
57
+ ax.set_title('Detection Count per Frame')
58
 
59
+ # Convert plot to an image to return it in Gradio output
60
+ plt.tight_layout()
61
+ plt.close(fig)
62
 
63
+ # Save the plot as an image in memory
64
+ buf = np.frombuffer(fig.canvas.print_to_buffer()[0], dtype=np.uint8)
65
+ img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
66
+
67
+ # Yield the detected frame and the graph at the same time
68
+ yield annotated_frame_rgb, img
69
 
70
  # Release resources
71
  input_video.release()
72
 
73
+ # Gradio interface
74
  with gr.Blocks() as demo:
75
+ with gr.Row():
76
+ video_input = gr.Video(label="Upload Video")
77
+ gallery_output = gr.Gallery(label="Detection Album").style(columns=3) # Display images in a row
78
+ graph_output = gr.Image(label="Detection Counts Graph", type="numpy") # For displaying graph
79
+
80
+ video_input.change(process_video, inputs=video_input, outputs=[gallery_output, graph_output])
 
 
 
 
 
 
81
 
82
  # Launch the interface
83
  demo.launch()