EtienneB commited on
Commit
1e05108
·
1 Parent(s): 7faf23e
Files changed (6) hide show
  1. .gitignore +2 -0
  2. agent.py +41 -76
  3. app.py +2 -32
  4. old-tools.py +71 -0
  5. requirements.txt +12 -14
  6. tools.py +0 -68
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  .env
2
  .venv
 
 
 
1
  .env
2
  .venv
3
+ /__pycache__
4
+ /chroma_db
agent.py CHANGED
@@ -1,7 +1,8 @@
 
1
  import os
 
2
 
3
  from dotenv import load_dotenv
4
- from langchain_community.vectorstores import Chroma
5
  from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
6
  from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
7
  HuggingFaceEndpoint)
@@ -12,13 +13,11 @@ from tools import (absolute, add, analyze_csv_file, analyze_excel_file,
12
  arvix_search, audio_transcription, compound_interest,
13
  convert_temperature, divide, exponential,
14
  extract_text_from_image, factorial, floor_divide,
15
- get_current_time_in_timezone,
16
- get_max_bird_species_count_from_video,
17
- greatest_common_divisor, is_prime, least_common_multiple,
18
- logarithm, modulus, multiply, percentage_calculator, power,
19
- python_code_parser, reverse_sentence,
20
- roman_calculator_converter, square_root, subtract,
21
- web_content_extract, web_search, wiki_search)
22
 
23
  # Load Constants
24
  load_dotenv()
@@ -34,8 +33,7 @@ tools = [
34
  is_prime, least_common_multiple, percentage_calculator,
35
  wiki_search, analyze_excel_file, arvix_search,
36
  audio_transcription, python_code_parser, analyze_csv_file,
37
- extract_text_from_image, reverse_sentence, web_content_extract,
38
- get_max_bird_species_count_from_video
39
  ]
40
 
41
  # Load system prompt
@@ -47,54 +45,13 @@ If you are asked for a number, don't use a comma to write your number, nor use u
47
  If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
48
  If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
49
  Format your output as: Answers (answers): [{"task_id": ..., "submitted_answer": ...}]
 
50
  """
51
 
52
-
53
  # System message
54
  sys_msg = SystemMessage(content=system_prompt)
55
 
56
 
57
- def get_vector_store(persist_directory="chroma_db"):
58
- """
59
- Initializes and returns a Chroma vector store.
60
- If the database exists, it loads it. If not, it creates it,
61
- adds some initial documents, and persists them.
62
- """
63
- embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
64
-
65
- if os.path.exists(persist_directory) and os.listdir(persist_directory):
66
- print("Loading existing vector store...")
67
- vector_store = Chroma(
68
- persist_directory=persist_directory,
69
- embedding_function=embedding_function
70
- )
71
- else:
72
- print("Creating new vector store...")
73
- os.makedirs(persist_directory, exist_ok=True)
74
- # Example documents to add
75
- initial_documents = [
76
- "The Principle of Double Effect is an ethical theory that distinguishes between the intended and foreseen consequences of an action.",
77
- "St. Thomas Aquinas is often associated with the development of the Principle of Double Effect.",
78
- "LangGraph is a library for building stateful, multi-actor applications with LLMs.",
79
- "Chroma is a vector database used for storing and retrieving embeddings."
80
- ]
81
- vector_store = Chroma.from_texts(
82
- texts=initial_documents,
83
- embedding=embedding_function,
84
- persist_directory=persist_directory
85
- )
86
- # No need to call persist() when using from_texts with a persist_directory
87
-
88
- return vector_store
89
-
90
- # --- Initialize Vector Store and Retriever ---
91
- vector_store = get_vector_store()
92
- retriever_component = vector_store.as_retriever(
93
- search_type="mmr", # Use Maximum Marginal Relevance for diverse results
94
- search_kwargs={'k': 2, 'lambda_mult': 0.5} # Retrieve 2 documents
95
- )
96
-
97
-
98
  def build_graph():
99
  """Build the graph"""
100
  # First create the HuggingFaceEndpoint
@@ -127,32 +84,12 @@ def build_graph():
127
  formatted = f'Answers (answers): [{{"task_id": "{task_id}", "submitted_answer": "{answer_text}"}}]'
128
  return {"messages": [formatted]}
129
 
130
- def retriever_node(state: MessagesState):
131
- """
132
- Retrieves relevant documents from the vector store based on the latest human message.
133
- """
134
- last_human_message = state["messages"][-1].content
135
- retrieved_docs = retriever_component.invoke(last_human_message)
136
-
137
- if retrieved_docs:
138
- retrieved_context = "\n\n".join([doc.page_content for doc in retrieved_docs])
139
- # Create a ToolMessage to hold the retrieved context
140
- context_message = ToolMessage(
141
- content=f"Retrieved context from vector store:\n\n{retrieved_context}",
142
- tool_call_id="retriever" # A descriptive ID
143
- )
144
- return {"messages": [context_message]}
145
-
146
- return {"messages": []}
147
-
148
  # --- Graph Definition ---
149
  builder = StateGraph(MessagesState)
150
- # builder.add_node("retriever", retriever_node)
151
  builder.add_node("assistant", assistant)
152
  builder.add_node("tools", ToolNode(tools))
153
 
154
  builder.add_edge(START, "assistant")
155
- # builder.add_edge("retriever", "assistant")
156
  builder.add_conditional_edges("assistant", tools_condition)
157
  builder.add_edge("tools", "assistant")
158
 
@@ -160,6 +97,30 @@ def build_graph():
160
  return builder.compile()
161
 
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  # test
164
  if __name__ == "__main__":
165
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
@@ -168,8 +129,8 @@ if __name__ == "__main__":
168
  # Run the graph
169
  messages = [HumanMessage(content=question)]
170
  # The initial state for the graph
171
- initial_state = {"messages": messages}
172
-
173
  # Invoke the graph stream to see the steps
174
  for s in graph.stream(initial_state, stream_mode="values"):
175
  message = s["messages"][-1]
@@ -178,5 +139,9 @@ if __name__ == "__main__":
178
  print(message.content)
179
  print("-----------------------")
180
  else:
181
- message.pretty_print()
182
-
 
 
 
 
 
1
+ import json
2
  import os
3
+ import re
4
 
5
  from dotenv import load_dotenv
 
6
  from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage
7
  from langchain_huggingface import (ChatHuggingFace, HuggingFaceEmbeddings,
8
  HuggingFaceEndpoint)
 
13
  arvix_search, audio_transcription, compound_interest,
14
  convert_temperature, divide, exponential,
15
  extract_text_from_image, factorial, floor_divide,
16
+ get_current_time_in_timezone, greatest_common_divisor,
17
+ is_prime, least_common_multiple, logarithm, modulus,
18
+ multiply, percentage_calculator, power, python_code_parser,
19
+ reverse_sentence, roman_calculator_converter, square_root,
20
+ subtract, web_content_extract, web_search, wiki_search)
 
 
21
 
22
  # Load Constants
23
  load_dotenv()
 
33
  is_prime, least_common_multiple, percentage_calculator,
34
  wiki_search, analyze_excel_file, arvix_search,
35
  audio_transcription, python_code_parser, analyze_csv_file,
36
+ extract_text_from_image, reverse_sentence, web_content_extract
 
37
  ]
38
 
39
  # Load system prompt
 
45
  If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
46
  If you are asked for a comma separated list, apply the above rules depending on whether the element to be put in the list is a number or a string.
47
  Format your output as: Answers (answers): [{"task_id": ..., "submitted_answer": ...}]
48
+ Do not repeat the format or include any nested JSON. Output only one flat list as: Answers (answers): [{...}]
49
  """
50
 
 
51
  # System message
52
  sys_msg = SystemMessage(content=system_prompt)
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def build_graph():
56
  """Build the graph"""
57
  # First create the HuggingFaceEndpoint
 
84
  formatted = f'Answers (answers): [{{"task_id": "{task_id}", "submitted_answer": "{answer_text}"}}]'
85
  return {"messages": [formatted]}
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  # --- Graph Definition ---
88
  builder = StateGraph(MessagesState)
 
89
  builder.add_node("assistant", assistant)
90
  builder.add_node("tools", ToolNode(tools))
91
 
92
  builder.add_edge(START, "assistant")
 
93
  builder.add_conditional_edges("assistant", tools_condition)
94
  builder.add_edge("tools", "assistant")
95
 
 
97
  return builder.compile()
98
 
99
 
100
+ def is_valid_agent_output(output):
101
+ """
102
+ Checks if the output matches the required format:
103
+ Answers (answers): [{"task_id": ..., "submitted_answer": ...}]
104
+ """
105
+ # Basic regex to check the format
106
+ pattern = r'^Answers \(answers\): \[(\{.*\})\]$'
107
+ match = re.match(pattern, output.strip())
108
+ if not match:
109
+ return False
110
+
111
+ # Try to parse the JSON part
112
+ try:
113
+ answers_list = json.loads(f'[{match.group(1)}]')
114
+ # Check required keys
115
+ for ans in answers_list:
116
+ if not isinstance(ans, dict):
117
+ return False
118
+ if "task_id" not in ans or "submitted_answer" not in ans:
119
+ return False
120
+ return True
121
+ except Exception:
122
+ return False
123
+
124
  # test
125
  if __name__ == "__main__":
126
  question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
 
129
  # Run the graph
130
  messages = [HumanMessage(content=question)]
131
  # The initial state for the graph
132
+ initial_state = {"messages": messages, "task_id": "test123"}
133
+
134
  # Invoke the graph stream to see the steps
135
  for s in graph.stream(initial_state, stream_mode="values"):
136
  message = s["messages"][-1]
 
139
  print(message.content)
140
  print("-----------------------")
141
  else:
142
+ output = str(message)
143
+ print("Agent Output:", output)
144
+ if is_valid_agent_output(output):
145
+ print("✅ Output is in the correct format!")
146
+ else:
147
+ print("❌ Output is NOT in the correct format!")
app.py CHANGED
@@ -54,38 +54,8 @@ class BasicAgent:
54
  # The answer is expected to be in the 'content' of the last message.
55
  answer = response_messages['messages'][-1].content
56
  print(f"Agent full response: {answer}")
57
-
58
- final_answer = ""
59
- if not messages:
60
- # print(f"No messages found in the result state for task {task_id}.")
61
- return "AGENT ERROR: No messages returned by the agent."
62
-
63
- for msg in reversed(messages):
64
- if hasattr(msg, "content") and msg.content:
65
- content = msg.content
66
- if isinstance(content, str):
67
- if "FINAL ANSWER:" in content:
68
- final_answer = content.split("FINAL ANSWER:", 1)[1].strip()
69
- break
70
- elif isinstance(msg, AIMessage):
71
- # If it's an AIMessage and no "FINAL ANSWER:" has been found yet,
72
- # tentatively set it. This will be overridden if a "FINAL ANSWER:" is found later.
73
- if not final_answer:
74
- final_answer = content
75
-
76
- # If after checking all messages, final_answer is still from a non-"FINAL ANSWER:" AIMessage, that's our best guess.
77
- # If final_answer is empty, it means no AIMessage with content or "FINAL ANSWER:" was found.
78
- if not final_answer: # This means no "FINAL ANSWER:" and no AIMessage content was suitable
79
- final_answer = "AGENT ERROR: Could not extract a final answer from the agent's messages."
80
- # print(f"Could not extract final answer for task {task_id}. Messages: {messages}")
81
-
82
- # print(f"FinalAgent returning answer for task_id '{task_id}': {final_answer[:100]}...")
83
- print(f"FinalAgent returning answer: {final_answer[:100]}...")
84
- return final_answer
85
-
86
- # return answer
87
-
88
-
89
 
90
  def run_and_submit_all( profile: gr.OAuthProfile | None):
91
  """
 
54
  # The answer is expected to be in the 'content' of the last message.
55
  answer = response_messages['messages'][-1].content
56
  print(f"Agent full response: {answer}")
57
+ return answer
58
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def run_and_submit_all( profile: gr.OAuthProfile | None):
61
  """
old-tools.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
 
2
  @tool
3
  def web_search(query: str) -> str:
@@ -23,3 +29,68 @@ def web_search(query: str) -> str:
23
  return results
24
  except Exception as e:
25
  return f"Error performing web search: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+
3
+ import cv2
4
+ import torch
5
+ from pytube import YouTube
6
+
7
 
8
  @tool
9
  def web_search(query: str) -> str:
 
29
  return results
30
  except Exception as e:
31
  return f"Error performing web search: {str(e)}"
32
+
33
+
34
+ @tool
35
+ def get_max_bird_species_count_from_video(url: str) -> Dict:
36
+ """
37
+ Downloads a YouTube video and returns the maximum number of unique bird species
38
+ visible in any frame, along with the timestamp.
39
+
40
+ Parameters:
41
+ url (str): YouTube video URL
42
+
43
+ Returns:
44
+ dict: {
45
+ "max_species_count": int,
46
+ "timestamp": str,
47
+ "species_list": List[str],
48
+ }
49
+ """
50
+ # 1. Download YouTube video
51
+ yt = YouTube(url)
52
+ stream = yt.streams.filter(file_extension='mp4').get_highest_resolution()
53
+ temp_video_path = os.path.join(tempfile.gettempdir(), "video.mp4")
54
+ stream.download(filename=temp_video_path)
55
+
56
+ # 2. Load object detection model for bird species
57
+ # Load a fine-tuned YOLOv5 model or similar pretrained on bird species
58
+ model = torch.hub.load('ultralytics/yolov5', 'custom', path='best_birds.pt') # path to your trained model
59
+
60
+ # 3. Process video frames
61
+ cap = cv2.VideoCapture(temp_video_path)
62
+ fps = cap.get(cv2.CAP_PROP_FPS)
63
+ frame_interval = int(fps * 1) # 1 frame per second
64
+
65
+ max_species_count = 0
66
+ max_species_frame_time = 0
67
+ species_at_max = []
68
+
69
+ frame_idx = 0
70
+ while cap.isOpened():
71
+ ret, frame = cap.read()
72
+ if not ret:
73
+ break
74
+ if frame_idx % frame_interval == 0:
75
+ # Run detection
76
+ results = model(frame)
77
+ detected_species = set()
78
+ for *box, conf, cls in results.xyxy[0]:
79
+ species_name = model.names[int(cls)]
80
+ detected_species.add(species_name)
81
+
82
+ if len(detected_species) > max_species_count:
83
+ max_species_count = len(detected_species)
84
+ max_species_frame_time = int(cap.get(cv2.CAP_PROP_POS_MSEC)) // 1000
85
+ species_at_max = list(detected_species)
86
+
87
+ frame_idx += 1
88
+
89
+ cap.release()
90
+ os.remove(temp_video_path)
91
+
92
+ return {
93
+ "max_species_count": max_species_count,
94
+ "timestamp": f"{max_species_frame_time}s",
95
+ "species_list": species_at_max
96
+ }
requirements.txt CHANGED
@@ -10,9 +10,7 @@ langchain-core
10
  langchain-community
11
  langgraph
12
  langchain-huggingface
13
- langchain-chroma
14
- chromadb # Explicitly add the Chroma database
15
- sentence-transformers
16
  langfuse
17
  langchain-google-genai
18
  langchain-tavily
@@ -40,18 +38,18 @@ typing-extensions
40
  #tenacity
41
  # loguru
42
 
43
- torch
44
- torchvision
45
- opencv-python
46
- pytube
47
 
48
  # YOLOv5 and dependencies
49
- numpy
50
- matplotlib
51
- scipy
52
- seaborn
53
- tqdm
54
- pyyaml
55
- pillow
56
 
57
  # git+https://github.com/ultralytics/yolov5.git
 
10
  langchain-community
11
  langgraph
12
  langchain-huggingface
13
+ # sentence-transformers
 
 
14
  langfuse
15
  langchain-google-genai
16
  langchain-tavily
 
38
  #tenacity
39
  # loguru
40
 
41
+ # torch
42
+ # torchvision
43
+ # opencv-python
44
+ # pytube
45
 
46
  # YOLOv5 and dependencies
47
+ # numpy
48
+ # matplotlib
49
+ # scipy
50
+ # seaborn
51
+ # tqdm
52
+ # pyyaml
53
+ # pillow
54
 
55
  # git+https://github.com/ultralytics/yolov5.git
tools.py CHANGED
@@ -2,13 +2,10 @@ import base64
2
  import datetime
3
  import math
4
  import os
5
- import tempfile
6
  from typing import Dict, Union
7
 
8
- import cv2
9
  import pandas
10
  import pytz
11
- import torch
12
  from bs4 import BeautifulSoup
13
  from langchain_community.document_loaders import (
14
  ArxivLoader, AssemblyAIAudioTranscriptLoader, WikipediaLoader)
@@ -19,7 +16,6 @@ from langchain_core.messages import HumanMessage
19
  from langchain_core.tools import tool
20
  from langchain_google_genai import ChatGoogleGenerativeAI
21
  from langchain_tavily import TavilySearch
22
- from pytube import YouTube
23
 
24
 
25
  @tool
@@ -742,70 +738,6 @@ def reverse_sentence(text: str) -> str:
742
  """
743
  return text[::-1]
744
 
745
- @tool
746
- def get_max_bird_species_count_from_video(url: str) -> Dict:
747
- """
748
- Downloads a YouTube video and returns the maximum number of unique bird species
749
- visible in any frame, along with the timestamp.
750
-
751
- Parameters:
752
- url (str): YouTube video URL
753
-
754
- Returns:
755
- dict: {
756
- "max_species_count": int,
757
- "timestamp": str,
758
- "species_list": List[str],
759
- }
760
- """
761
- # 1. Download YouTube video
762
- yt = YouTube(url)
763
- stream = yt.streams.filter(file_extension='mp4').get_highest_resolution()
764
- temp_video_path = os.path.join(tempfile.gettempdir(), "video.mp4")
765
- stream.download(filename=temp_video_path)
766
-
767
- # 2. Load object detection model for bird species
768
- # Load a fine-tuned YOLOv5 model or similar pretrained on bird species
769
- model = torch.hub.load('ultralytics/yolov5', 'custom', path='best_birds.pt') # path to your trained model
770
-
771
- # 3. Process video frames
772
- cap = cv2.VideoCapture(temp_video_path)
773
- fps = cap.get(cv2.CAP_PROP_FPS)
774
- frame_interval = int(fps * 1) # 1 frame per second
775
-
776
- max_species_count = 0
777
- max_species_frame_time = 0
778
- species_at_max = []
779
-
780
- frame_idx = 0
781
- while cap.isOpened():
782
- ret, frame = cap.read()
783
- if not ret:
784
- break
785
- if frame_idx % frame_interval == 0:
786
- # Run detection
787
- results = model(frame)
788
- detected_species = set()
789
- for *box, conf, cls in results.xyxy[0]:
790
- species_name = model.names[int(cls)]
791
- detected_species.add(species_name)
792
-
793
- if len(detected_species) > max_species_count:
794
- max_species_count = len(detected_species)
795
- max_species_frame_time = int(cap.get(cv2.CAP_PROP_POS_MSEC)) // 1000
796
- species_at_max = list(detected_species)
797
-
798
- frame_idx += 1
799
-
800
- cap.release()
801
- os.remove(temp_video_path)
802
-
803
- return {
804
- "max_species_count": max_species_count,
805
- "timestamp": f"{max_species_frame_time}s",
806
- "species_list": species_at_max
807
- }
808
-
809
  @tool
810
  def web_search(query: str) -> str:
811
  """
 
2
  import datetime
3
  import math
4
  import os
 
5
  from typing import Dict, Union
6
 
 
7
  import pandas
8
  import pytz
 
9
  from bs4 import BeautifulSoup
10
  from langchain_community.document_loaders import (
11
  ArxivLoader, AssemblyAIAudioTranscriptLoader, WikipediaLoader)
 
16
  from langchain_core.tools import tool
17
  from langchain_google_genai import ChatGoogleGenerativeAI
18
  from langchain_tavily import TavilySearch
 
19
 
20
 
21
  @tool
 
738
  """
739
  return text[::-1]
740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
  @tool
742
  def web_search(query: str) -> str:
743
  """