|
import streamlit as st |
|
import cv2 |
|
import numpy as np |
|
from ultralytics import YOLO |
|
from huggingface_hub import hf_hub_download |
|
from PIL import Image |
|
import os |
|
import tempfile |
|
import supervision as sv |
|
|
|
|
|
st.title("DRS Review System - Ball Detection") |
|
st.write("Upload an image or video to detect balls using a YOLOv5 model for Decision Review System (DRS).") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
|
|
model_path = hf_hub_download(repo_id="your-username/your-repo", filename="best.pt") |
|
model = YOLO(model_path) |
|
return model |
|
|
|
model = load_model() |
|
|
|
|
|
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.7, 0.05) |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload an image or video", type=["jpg", "jpeg", "png", "mp4"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.' + uploaded_file.name.split('.')[-1]) |
|
tfile.write(uploaded_file.read()) |
|
tfile.close() |
|
file_path = tfile.name |
|
|
|
|
|
if uploaded_file.type in ["image/jpeg", "image/png"]: |
|
st.subheader("Image Detection Results") |
|
image = cv2.imread(file_path) |
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
results = model(image, conf=confidence_threshold) |
|
detections = sv.Detections.from_ultralytics(results[0]) |
|
|
|
|
|
box_annotator = sv.BoxAnnotator() |
|
annotated_image = box_annotator.annotate(scene=image_rgb, detections=detections) |
|
|
|
|
|
st.image(annotated_image, caption="Detected Balls", use_column_width=True) |
|
|
|
|
|
for score, label, box in zip(detections.confidence, detections.class_id, detections.xyxy): |
|
st.write(f"Detected ball with confidence {score:.2f} at coordinates {box.tolist()}") |
|
|
|
|
|
elif uploaded_file.type == "video/mp4": |
|
st.subheader("Video Detection Results") |
|
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name |
|
|
|
|
|
cap = cv2.VideoCapture(file_path) |
|
if not cap.isOpened(): |
|
st.error("Error: Could not open video file.") |
|
else: |
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
progress = st.progress(0) |
|
frame_count = 0 |
|
|
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
results = model(frame, conf=confidence_threshold) |
|
detections = sv.Detections.from_ultralytics(results[0]) |
|
|
|
|
|
box_annotator = sv.BoxAnnotator() |
|
annotated_frame = box_annotator.annotate(scene=cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), detections=detections) |
|
annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
out.write(annotated_frame_bgr) |
|
|
|
|
|
frame_count += 1 |
|
progress.progress(frame_count / total_frames) |
|
|
|
cap.release() |
|
out.release() |
|
|
|
|
|
st.video(output_path) |
|
|
|
|
|
with open(output_path, "rb") as file: |
|
st.download_button( |
|
label="Download Processed Video", |
|
data=file, |
|
file_name="processed_drs_video.mp4", |
|
mime="video/mp4" |
|
) |
|
|
|
|
|
os.remove(file_path) |
|
if os.path.exists(output_path): |
|
os.remove(output_path) |
|
|
|
else: |
|
st.info("Please upload an image or video to start the DRS review.") |