File size: 4,674 Bytes
a295d73
41c03cf
 
a295d73
 
 
ba9faee
a295d73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6110fb8
a295d73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
import cv2
import numpy as np
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
from PIL import Image
import os
import tempfile
import supervision as sv

# Title and description
st.title("DRS Review System - Ball Detection")
st.write("Upload an image or video to detect balls using a YOLOv5 model for Decision Review System (DRS).")

# Model loading
@st.cache_resource
def load_model():
    # Replace 'your-username/your-repo' with your Hugging Face repository and model file
    model_path = hf_hub_download(repo_id="your-username/your-repo", filename="best.pt")
    model = YOLO(model_path)
    return model

model = load_model()

# Confidence threshold slider
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.7, 0.05)

# File uploader for image or video
uploaded_file = st.file_uploader("Upload an image or video", type=["jpg", "jpeg", "png", "mp4"])

if uploaded_file is not None:
    # Create a temporary file to save the uploaded content
    tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.' + uploaded_file.name.split('.')[-1])
    tfile.write(uploaded_file.read())
    tfile.close()
    file_path = tfile.name

    # Check if the uploaded file is an image
    if uploaded_file.type in ["image/jpeg", "image/png"]:
        st.subheader("Image Detection Results")
        image = cv2.imread(file_path)
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Run inference
        results = model(image, conf=confidence_threshold)
        detections = sv.Detections.from_ultralytics(results[0])
        
        # Annotate image
        box_annotator = sv.BoxAnnotator()
        annotated_image = box_annotator.annotate(scene=image_rgb, detections=detections)
        
        # Display result
        st.image(annotated_image, caption="Detected Balls", use_column_width=True)
        
        # Display detection details
        for score, label, box in zip(detections.confidence, detections.class_id, detections.xyxy):
            st.write(f"Detected ball with confidence {score:.2f} at coordinates {box.tolist()}")

    # Check if the uploaded file is a video
    elif uploaded_file.type == "video/mp4":
        st.subheader("Video Detection Results")
        output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
        
        # Process video
        cap = cv2.VideoCapture(file_path)
        if not cap.isOpened():
            st.error("Error: Could not open video file.")
        else:
            # Get video properties
            width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            fps = int(cap.get(cv2.CAP_PROP_FPS))
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
            
            # Progress bar
            total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            progress = st.progress(0)
            frame_count = 0
            
            # Process frames
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break
                
                # Run inference on frame
                results = model(frame, conf=confidence_threshold)
                detections = sv.Detections.from_ultralytics(results[0])
                
                # Annotate frame
                box_annotator = sv.BoxAnnotator()
                annotated_frame = box_annotator.annotate(scene=cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), detections=detections)
                annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
                
                # Write to output video
                out.write(annotated_frame_bgr)
                
                # Update progress
                frame_count += 1
                progress.progress(frame_count / total_frames)
            
            cap.release()
            out.release()
            
            # Display video
            st.video(output_path)
            
            # Provide download link for processed video
            with open(output_path, "rb") as file:
                st.download_button(
                    label="Download Processed Video",
                    data=file,
                    file_name="processed_drs_video.mp4",
                    mime="video/mp4"
                )
        
        # Clean up temporary files
        os.remove(file_path)
        if os.path.exists(output_path):
            os.remove(output_path)

else:
    st.info("Please upload an image or video to start the DRS review.")