DRS_AI / app.py
AjaykumarPilla's picture
Update app.py
a295d73 verified
raw
history blame
4.67 kB
import streamlit as st
import cv2
import numpy as np
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
from PIL import Image
import os
import tempfile
import supervision as sv
# Title and description
st.title("DRS Review System - Ball Detection")
st.write("Upload an image or video to detect balls using a YOLOv5 model for Decision Review System (DRS).")
# Model loading
@st.cache_resource
def load_model():
# Replace 'your-username/your-repo' with your Hugging Face repository and model file
model_path = hf_hub_download(repo_id="your-username/your-repo", filename="best.pt")
model = YOLO(model_path)
return model
model = load_model()
# Confidence threshold slider
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.7, 0.05)
# File uploader for image or video
uploaded_file = st.file_uploader("Upload an image or video", type=["jpg", "jpeg", "png", "mp4"])
if uploaded_file is not None:
# Create a temporary file to save the uploaded content
tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.' + uploaded_file.name.split('.')[-1])
tfile.write(uploaded_file.read())
tfile.close()
file_path = tfile.name
# Check if the uploaded file is an image
if uploaded_file.type in ["image/jpeg", "image/png"]:
st.subheader("Image Detection Results")
image = cv2.imread(file_path)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Run inference
results = model(image, conf=confidence_threshold)
detections = sv.Detections.from_ultralytics(results[0])
# Annotate image
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image_rgb, detections=detections)
# Display result
st.image(annotated_image, caption="Detected Balls", use_column_width=True)
# Display detection details
for score, label, box in zip(detections.confidence, detections.class_id, detections.xyxy):
st.write(f"Detected ball with confidence {score:.2f} at coordinates {box.tolist()}")
# Check if the uploaded file is a video
elif uploaded_file.type == "video/mp4":
st.subheader("Video Detection Results")
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
# Process video
cap = cv2.VideoCapture(file_path)
if not cap.isOpened():
st.error("Error: Could not open video file.")
else:
# Get video properties
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# Progress bar
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
progress = st.progress(0)
frame_count = 0
# Process frames
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Run inference on frame
results = model(frame, conf=confidence_threshold)
detections = sv.Detections.from_ultralytics(results[0])
# Annotate frame
box_annotator = sv.BoxAnnotator()
annotated_frame = box_annotator.annotate(scene=cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), detections=detections)
annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
# Write to output video
out.write(annotated_frame_bgr)
# Update progress
frame_count += 1
progress.progress(frame_count / total_frames)
cap.release()
out.release()
# Display video
st.video(output_path)
# Provide download link for processed video
with open(output_path, "rb") as file:
st.download_button(
label="Download Processed Video",
data=file,
file_name="processed_drs_video.mp4",
mime="video/mp4"
)
# Clean up temporary files
os.remove(file_path)
if os.path.exists(output_path):
os.remove(output_path)
else:
st.info("Please upload an image or video to start the DRS review.")