AGAZO_Final_Assignment / count_max_bird_species_tool.py
agazo's picture
Update count_max_bird_species_tool.py
ad46712 verified
import os
from sample_youtube_video import sample_youtube_video
from ultralytics import YOLO
from langchain_core.tools import tool
from utils import decode_base64_to_frame
@tool
def count_max_bird_species_in_video(youtube_url: str) -> int:
"""
Count the maximum number of bird species to be on camera simultaneously in a YouTube video.
Args:
- youtube_url: str: the URL of the YouTube video.
Returns:
- int: the maximum number of bird species detected in the YouTube video.
"""
# frames = sample_youtube_video(youtube_url, 5)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
youtube_video_in_base_64 = os.path.join(BASE_DIR, 'files', 'frames_youtube_birds_video.b64')
print('youtube video path: ' + youtube_video_in_base_64)
frames = read_list_from_txt(youtube_video_in_base_64)
max_species = count_max_bird_species(frames)
return max_species
def count_max_bird_species(framesStr: list[str]) -> int:
"""
Count the maximum number of bird species in a single frame considering a list of frames.
Args:
- frames (List[cv2.Mat]): A list of frames (images) to analyze.
Returns:
- int: The maximum number of bird species detected in any single frame.
"""
# List of bird class indices from COCO dataset (class IDs for birds)
bird_class_ids = [
14, 15, 16, 17, 18, 19, 20, 21 # Example bird class IDs in COCO dataset
]
# Load YOLOv8 model for object detection (pre-trained or custom-trained for animals)
model = YOLO("yolov8n.pt") # Use a pre-trained YOLOv8 model
max_bird_species_count = 0 # To track the maximum count of distinct bird species
frames = [decode_base64_to_frame(b64) for b64 in framesStr]
# Iterate through all frames in the list
for frame in frames:
# Perform inference on the current frame
results = model(frame)
bird_species_detected = set() # To track distinct bird species in this frame
# Parse the detected objects
for result in results:
# Convert the class indices from tensor to integer
detected_classes = result.boxes.cls.int() # Convert tensor to int
# Iterate over detected classes and filter only bird species
for detected_class in detected_classes:
# If the class is a bird (based on COCO bird class IDs)
if detected_class.item() in bird_class_ids:
species_name = result.names[detected_class.item()] # Get the species name
bird_species_detected.add(species_name)
# Update the maximum number of distinct bird species detected in any frame
max_bird_species_count = max(max_bird_species_count, len(bird_species_detected))
# Return the maximum number of distinct bird species detected in any frame
return max_bird_species_count
def read_list_from_txt(file_path: str) -> list[str]:
with open(file_path, 'r', encoding='utf-8') as f:
return [line.rstrip('\n') for line in f]
if __name__ == "__main__":
sampled_frames = sample_youtube_video.invoke({
"youtube_url": "https://www.youtube.com/watch?v=L1vXCYZAYYM" ,
"sample_rate": 5
}
)
distinct_species_count = count.invoke({
"framesStr": sampled_frames
}
)
print(f"Distinct species detected: {distinct_species_count}")