EtienneB commited on
Commit
eed7f48
·
1 Parent(s): 0728bfa

Revert "Delete chroma.sqlite3"

Browse files

This reverts commit 0728bfa0c0de754b797326662801c3abf913d31c.

Files changed (4) hide show
  1. .gitignore +0 -2
  2. __pycache__/tools.cpython-310.pyc +0 -0
  3. agent.py +3 -3
  4. tools.py +50 -67
.gitignore CHANGED
@@ -1,4 +1,2 @@
1
  .env
2
  .venv
3
- chroma_db/chroma.sqlite3
4
- __pycache__
 
1
  .env
2
  .venv
 
 
__pycache__/tools.cpython-310.pyc DELETED
Binary file (21.1 kB)
 
agent.py CHANGED
@@ -139,12 +139,12 @@ def build_graph():
139
 
140
  # --- Graph Definition ---
141
  builder = StateGraph(MessagesState)
142
- # builder.add_node("retriever", retriever_node)
143
  builder.add_node("assistant", assistant)
144
  builder.add_node("tools", ToolNode(tools))
145
 
146
- builder.add_edge(START, "assistant")
147
- # builder.add_edge("retriever", "assistant")
148
  builder.add_conditional_edges("assistant", tools_condition)
149
  builder.add_edge("tools", "assistant")
150
 
 
139
 
140
  # --- Graph Definition ---
141
  builder = StateGraph(MessagesState)
142
+ builder.add_node("retriever", retriever_node)
143
  builder.add_node("assistant", assistant)
144
  builder.add_node("tools", ToolNode(tools))
145
 
146
+ builder.add_edge(START, "retriever")
147
+ builder.add_edge("retriever", "assistant")
148
  builder.add_conditional_edges("assistant", tools_condition)
149
  builder.add_edge("tools", "assistant")
150
 
tools.py CHANGED
@@ -766,20 +766,13 @@ def reverse_sentence(text: str) -> str:
766
  return text[::-1]
767
 
768
 
769
- @tool
770
- def get_max_bird_species_count_from_video(
771
- url: str,
772
- model_path: str = "best_birds.pt",
773
- sample_rate_seconds: int = 1
774
- ) -> dict:
775
  """
776
  Downloads a YouTube video and returns the maximum number of unique bird species
777
- visible in any frame, along with the timestamp and species list.
778
 
779
- Args:
780
- url (str): YouTube video URL.
781
- model_path (str): Path to the YOLOv5 bird species detection model.
782
- sample_rate_seconds (int): How often (in seconds) to sample frames.
783
 
784
  Returns:
785
  dict: {
@@ -788,60 +781,50 @@ def get_max_bird_species_count_from_video(
788
  "species_list": List[str],
789
  }
790
  """
791
- import traceback
792
-
 
793
  temp_video_path = os.path.join(tempfile.gettempdir(), "video.mp4")
794
- try:
795
- # Download YouTube video
796
- yt = YouTube(url)
797
- stream = yt.streams.filter(file_extension='mp4').get_highest_resolution()
798
- stream.download(filename=temp_video_path)
799
-
800
- # Load object detection model
801
- model = torch.hub.load('ultralytics/yolov5', 'custom', path=model_path)
802
-
803
- # Process video frames
804
- cap = cv2.VideoCapture(temp_video_path)
805
- fps = cap.get(cv2.CAP_PROP_FPS)
806
- if not fps or fps <= 0:
807
- return {"error": "Could not determine video FPS."}
808
- frame_interval = int(fps * sample_rate_seconds)
809
-
810
- max_species_count = 0
811
- max_species_frame_time = 0
812
- species_at_max = []
813
-
814
- frame_idx = 0
815
- while cap.isOpened():
816
- ret, frame = cap.read()
817
- if not ret:
818
- break
819
- if frame_idx % frame_interval == 0:
820
- # Run detection
821
- results = model(frame)
822
- detected_species = set()
823
- for *box, conf, cls in results.xyxy[0]:
824
- species_name = model.names[int(cls)]
825
- detected_species.add(species_name)
826
-
827
- if len(detected_species) > max_species_count:
828
- max_species_count = len(detected_species)
829
- max_species_frame_time = int(cap.get(cv2.CAP_PROP_POS_MSEC)) // 1000
830
- species_at_max = list(detected_species)
831
-
832
- frame_idx += 1
833
-
834
- cap.release()
835
- return {
836
- "max_species_count": max_species_count,
837
- "timestamp": f"{max_species_frame_time}s",
838
- "species_list": species_at_max
839
- }
840
- except Exception as e:
841
- return {"error": f"Exception occurred: {str(e)}\n{traceback.format_exc()}"}
842
- finally:
843
- if os.path.exists(temp_video_path):
844
- try:
845
- os.remove(temp_video_path)
846
- except Exception:
847
- pass
 
766
  return text[::-1]
767
 
768
 
769
+ def get_max_bird_species_count_from_video(url: str) -> Dict:
 
 
 
 
 
770
  """
771
  Downloads a YouTube video and returns the maximum number of unique bird species
772
+ visible in any frame, along with the timestamp.
773
 
774
+ Parameters:
775
+ url (str): YouTube video URL
 
 
776
 
777
  Returns:
778
  dict: {
 
781
  "species_list": List[str],
782
  }
783
  """
784
+ # 1. Download YouTube video
785
+ yt = YouTube(url)
786
+ stream = yt.streams.filter(file_extension='mp4').get_highest_resolution()
787
  temp_video_path = os.path.join(tempfile.gettempdir(), "video.mp4")
788
+ stream.download(filename=temp_video_path)
789
+
790
+ # 2. Load object detection model for bird species
791
+ # Load a fine-tuned YOLOv5 model or similar pretrained on bird species
792
+ model = torch.hub.load('ultralytics/yolov5', 'custom', path='best_birds.pt') # path to your trained model
793
+
794
+ # 3. Process video frames
795
+ cap = cv2.VideoCapture(temp_video_path)
796
+ fps = cap.get(cv2.CAP_PROP_FPS)
797
+ frame_interval = int(fps * 1) # 1 frame per second
798
+
799
+ max_species_count = 0
800
+ max_species_frame_time = 0
801
+ species_at_max = []
802
+
803
+ frame_idx = 0
804
+ while cap.isOpened():
805
+ ret, frame = cap.read()
806
+ if not ret:
807
+ break
808
+ if frame_idx % frame_interval == 0:
809
+ # Run detection
810
+ results = model(frame)
811
+ detected_species = set()
812
+ for *box, conf, cls in results.xyxy[0]:
813
+ species_name = model.names[int(cls)]
814
+ detected_species.add(species_name)
815
+
816
+ if len(detected_species) > max_species_count:
817
+ max_species_count = len(detected_species)
818
+ max_species_frame_time = int(cap.get(cv2.CAP_PROP_POS_MSEC)) // 1000
819
+ species_at_max = list(detected_species)
820
+
821
+ frame_idx += 1
822
+
823
+ cap.release()
824
+ os.remove(temp_video_path)
825
+
826
+ return {
827
+ "max_species_count": max_species_count,
828
+ "timestamp": f"{max_species_frame_time}s",
829
+ "species_list": species_at_max
830
+ }