EtienneB commited on
Commit
0728bfa
·
1 Parent(s): fbf50c4

Delete chroma.sqlite3

Browse files

Update .gitignore

updates

Update .gitignore

Update .gitignore

Files changed (4) hide show
  1. .gitignore +2 -0
  2. __pycache__/tools.cpython-310.pyc +0 -0
  3. agent.py +3 -3
  4. tools.py +67 -50
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  .env
2
  .venv
 
 
 
1
  .env
2
  .venv
3
+ chroma_db/chroma.sqlite3
4
+ __pycache__
__pycache__/tools.cpython-310.pyc ADDED
Binary file (21.1 kB). View file
 
agent.py CHANGED
@@ -139,12 +139,12 @@ def build_graph():
139
 
140
  # --- Graph Definition ---
141
  builder = StateGraph(MessagesState)
142
- builder.add_node("retriever", retriever_node)
143
  builder.add_node("assistant", assistant)
144
  builder.add_node("tools", ToolNode(tools))
145
 
146
- builder.add_edge(START, "retriever")
147
- builder.add_edge("retriever", "assistant")
148
  builder.add_conditional_edges("assistant", tools_condition)
149
  builder.add_edge("tools", "assistant")
150
 
 
139
 
140
  # --- Graph Definition ---
141
  builder = StateGraph(MessagesState)
142
+ # builder.add_node("retriever", retriever_node)
143
  builder.add_node("assistant", assistant)
144
  builder.add_node("tools", ToolNode(tools))
145
 
146
+ builder.add_edge(START, "assistant")
147
+ # builder.add_edge("retriever", "assistant")
148
  builder.add_conditional_edges("assistant", tools_condition)
149
  builder.add_edge("tools", "assistant")
150
 
tools.py CHANGED
@@ -766,13 +766,20 @@ def reverse_sentence(text: str) -> str:
766
  return text[::-1]
767
 
768
 
769
- def get_max_bird_species_count_from_video(url: str) -> Dict:
 
 
 
 
 
770
  """
771
  Downloads a YouTube video and returns the maximum number of unique bird species
772
- visible in any frame, along with the timestamp.
773
 
774
- Parameters:
775
- url (str): YouTube video URL
 
 
776
 
777
  Returns:
778
  dict: {
@@ -781,50 +788,60 @@ def get_max_bird_species_count_from_video(url: str) -> Dict:
781
  "species_list": List[str],
782
  }
783
  """
784
- # 1. Download YouTube video
785
- yt = YouTube(url)
786
- stream = yt.streams.filter(file_extension='mp4').get_highest_resolution()
787
  temp_video_path = os.path.join(tempfile.gettempdir(), "video.mp4")
788
- stream.download(filename=temp_video_path)
789
-
790
- # 2. Load object detection model for bird species
791
- # Load a fine-tuned YOLOv5 model or similar pretrained on bird species
792
- model = torch.hub.load('ultralytics/yolov5', 'custom', path='best_birds.pt') # path to your trained model
793
-
794
- # 3. Process video frames
795
- cap = cv2.VideoCapture(temp_video_path)
796
- fps = cap.get(cv2.CAP_PROP_FPS)
797
- frame_interval = int(fps * 1) # 1 frame per second
798
-
799
- max_species_count = 0
800
- max_species_frame_time = 0
801
- species_at_max = []
802
-
803
- frame_idx = 0
804
- while cap.isOpened():
805
- ret, frame = cap.read()
806
- if not ret:
807
- break
808
- if frame_idx % frame_interval == 0:
809
- # Run detection
810
- results = model(frame)
811
- detected_species = set()
812
- for *box, conf, cls in results.xyxy[0]:
813
- species_name = model.names[int(cls)]
814
- detected_species.add(species_name)
815
-
816
- if len(detected_species) > max_species_count:
817
- max_species_count = len(detected_species)
818
- max_species_frame_time = int(cap.get(cv2.CAP_PROP_POS_MSEC)) // 1000
819
- species_at_max = list(detected_species)
820
-
821
- frame_idx += 1
822
-
823
- cap.release()
824
- os.remove(temp_video_path)
825
-
826
- return {
827
- "max_species_count": max_species_count,
828
- "timestamp": f"{max_species_frame_time}s",
829
- "species_list": species_at_max
830
- }
 
 
 
 
 
 
 
 
 
 
 
 
766
  return text[::-1]
767
 
768
 
769
+ @tool
770
+ def get_max_bird_species_count_from_video(
771
+ url: str,
772
+ model_path: str = "best_birds.pt",
773
+ sample_rate_seconds: int = 1
774
+ ) -> dict:
775
  """
776
  Downloads a YouTube video and returns the maximum number of unique bird species
777
+ visible in any frame, along with the timestamp and species list.
778
 
779
+ Args:
780
+ url (str): YouTube video URL.
781
+ model_path (str): Path to the YOLOv5 bird species detection model.
782
+ sample_rate_seconds (int): How often (in seconds) to sample frames.
783
 
784
  Returns:
785
  dict: {
 
788
  "species_list": List[str],
789
  }
790
  """
791
+ import traceback
792
+
 
793
  temp_video_path = os.path.join(tempfile.gettempdir(), "video.mp4")
794
+ try:
795
+ # Download YouTube video
796
+ yt = YouTube(url)
797
+ stream = yt.streams.filter(file_extension='mp4').get_highest_resolution()
798
+ stream.download(filename=temp_video_path)
799
+
800
+ # Load object detection model
801
+ model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path)
802
+
803
+ # Process video frames
804
+ cap = cv2.VideoCapture(temp_video_path)
805
+ fps = cap.get(cv2.CAP_PROP_FPS)
806
+ if not fps or fps <= 0:
807
+ return {"error": "Could not determine video FPS."}
808
+ frame_interval = int(fps * sample_rate_seconds)
809
+
810
+ max_species_count = 0
811
+ max_species_frame_time = 0
812
+ species_at_max = []
813
+
814
+ frame_idx = 0
815
+ while cap.isOpened():
816
+ ret, frame = cap.read()
817
+ if not ret:
818
+ break
819
+ if frame_idx % frame_interval == 0:
820
+ # Run detection
821
+ results = model(frame)
822
+ detected_species = set()
823
+ for *box, conf, cls in results.xyxy[0]:
824
+ species_name = model.names[int(cls)]
825
+ detected_species.add(species_name)
826
+
827
+ if len(detected_species) > max_species_count:
828
+ max_species_count = len(detected_species)
829
+ max_species_frame_time = int(cap.get(cv2.CAP_PROP_POS_MSEC)) // 1000
830
+ species_at_max = list(detected_species)
831
+
832
+ frame_idx += 1
833
+
834
+ cap.release()
835
+ return {
836
+ "max_species_count": max_species_count,
837
+ "timestamp": f"{max_species_frame_time}s",
838
+ "species_list": species_at_max
839
+ }
840
+ except Exception as e:
841
+ return {"error": f"Exception occurred: {str(e)}\n{traceback.format_exc()}"}
842
+ finally:
843
+ if os.path.exists(temp_video_path):
844
+ try:
845
+ os.remove(temp_video_path)
846
+ except Exception:
847
+ pass