Michele De Stefano
Now using Tavily for web searches. It's a lot more powerful than DuckDuckGo
6770007
import datetime as dt | |
import dotenv | |
import re | |
from typing import Any, Literal | |
from langchain_community.tools import DuckDuckGoSearchResults | |
from langchain_core.messages import SystemMessage, AnyMessage | |
from langchain_core.runnables import Runnable | |
from langchain_core.tools import BaseTool | |
from langchain_ollama import ChatOllama | |
from langchain_tavily import TavilySearch, TavilyExtract | |
from langgraph.constants import START, END | |
from langgraph.graph import MessagesState, StateGraph | |
from langgraph.graph.graph import CompiledGraph | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langgraph.prebuilt import ToolNode | |
from pydantic import BaseModel | |
from tools import ( | |
get_excel_table_content, | |
get_youtube_video_transcript, | |
reverse_string, | |
transcribe_audio_file, | |
web_page_info_retriever, | |
youtube_video_to_frame_captions, sum_list, execute_python_script, | |
) | |
dotenv.load_dotenv() | |
class AgentFactory: | |
""" | |
A factory for the agent. It is assumed that an Ollama server is running | |
on the machine where the factory is used. | |
""" | |
__system_prompt: str = ( | |
"You have to answer to some test questions.\n" | |
"Sometimes auxiliary files may be attached to the question.\n" | |
"Each question is a JSON string with the following fields:\n" | |
"1. task_id: unique hash identifier of the question.\n" | |
"2. question: the text of the question.\n" | |
"3. Level: ignore this field.\n" | |
"4. file_name: the name of the file needed to answer the question. " | |
"This is empty if the question does not refer to any file. " | |
"IMPORTANT: The text of the question may mention a file name that is " | |
"different from what is reported into the \"file_name\" JSON field. " | |
"YOU HAVE TO IGNORE THE FILE NAME MENTIONED INTO \"question\" AND " | |
"YOU MUST USE THE FILE NAME PROVIDED INTO THE \"file_name\" FIELD.\n" | |
"\n" | |
"Achieve the solution by dividing your reasoning in steps, and\n" | |
"provide an explanation for each step.\n" | |
"\n" | |
"The format of your final answer must be\n" | |
"\n" | |
"<ANSWER>your_final_answer</Answer>, where your_final_answer is a\n" | |
"number OR as few words as possible OR a comma separated list of\n" | |
"numbers and/or strings. If you are asked for\n" | |
"a number, don't use comma to write your number neither use units\n" | |
"such as $ or percent sign unless specified otherwise. If you are\n" | |
"asked for a string, don't use articles, neither abbreviations (e.g.\n" | |
"for cities), and write the digits in plain text unless specified\n" | |
"otherwise. If you are asked for a comma separated list, apply the\n" | |
"above rules depending of whether the element to be put in the list\n" | |
"is a number or a string.\n" | |
"ALWAYS PRESENT THE FINAL ANSWER BETWEEN THE <ANSWER> AND </ANSWER>\n" | |
"TAGS.\n" | |
"\n" | |
"When, for achieving the solution, you have to perform a sum, DON'T\n" | |
"try to do that yourself. Exploit the tool that is able to sum a list\n" | |
" of numbers. If you have to sum the results of previous sums, use\n" | |
"again the same tool, by calling it again.\n" | |
"You are advised to cycle between reasoning and tool calling also\n" | |
"multiple times. Provide an answer only when you are sure you don't\n" | |
"have to call any tool again.\n" | |
"\n" | |
f"If you need it, the date today is {dt.date.today()}." | |
) | |
__llm: Runnable | |
__tools: list[BaseTool] | |
def __init__( | |
self, | |
model: str = "qwen2.5-coder:32b", | |
# model: str = "mistral-small3.1", | |
# model: str = "phi4-mini", | |
temperature: float = 0.0, | |
num_ctx: int = 8192 | |
) -> None: | |
""" | |
Constructor. | |
Args: | |
model: The name of the Ollama model to use. | |
temperature: Temperature parameter. | |
num_ctx: Size of the context window used to generate the | |
next token. | |
""" | |
# search_tool = DuckDuckGoSearchResults( | |
# description=( | |
# "A wrapper around Duck Duck Go Search. Useful for when you " | |
# "need to answer questions about information you can find on " | |
# "the web. Input should be a search query. It is advisable to " | |
# "use this tool to retrieve web page URLs and use another tool " | |
# "to analyze the pages. If the web source is suggested by the " | |
# "user query, prefer retrieving information from that source. " | |
# "For example, the query may suggest to search on Wikipedia or " | |
# "Medium. In those cases, prepend the query with " | |
# "'site: <name of the source>'. For example: " | |
# "'site: wikipedia.org'" | |
# ), | |
# output_format="list" | |
# ) | |
search_tool = TavilySearch( | |
topic="general", | |
max_results=5, | |
include_answer="advanced", | |
) | |
# search_tool.with_retry() | |
extract_tool = TavilyExtract( | |
extract_depth="advanced", | |
include_images=False, | |
) | |
self.__tools = [ | |
execute_python_script, | |
get_excel_table_content, | |
get_youtube_video_transcript, | |
reverse_string, | |
search_tool, | |
extract_tool, | |
sum_list, | |
transcribe_audio_file, | |
# web_page_info_retriever, | |
youtube_video_to_frame_captions | |
] | |
self.__llm = ChatOllama( | |
model=model, | |
temperature=temperature, | |
num_ctx=num_ctx | |
).bind_tools(tools=self.__tools) | |
# llm_endpoint = HuggingFaceEndpoint( | |
# repo_id="Qwen/Qwen2.5-72B-Instruct", | |
# task="text-generation", | |
# max_new_tokens=num_ctx, | |
# do_sample=False, | |
# repetition_penalty=1.03, | |
# temperature=temperature, | |
# ) | |
# | |
# self.__llm = ( | |
# ChatHuggingFace(llm=llm_endpoint) | |
# .bind_tools(tools=self.__tools) | |
# ) | |
def __run_llm(self, state: MessagesState) -> dict[str, Any]: | |
answer = self.__llm.invoke(state["messages"]) | |
# Remove thinking pattern if present | |
pattern = r'\n*<think>.*?</think>\n*' | |
answer.content = re.sub( | |
pattern, "", answer.content, flags=re.DOTALL | |
) | |
return {"messages": [answer]} | |
def __extract_last_message( | |
state: list[AnyMessage] | dict[str, Any] | BaseModel, | |
messages_key: str | |
) -> str: | |
if isinstance(state, list): | |
last_message = state[-1] | |
elif isinstance(state, dict) and (messages := state.get(messages_key, [])): | |
last_message = messages[-1] | |
elif messages := getattr(state, messages_key, []): | |
last_message = messages[-1] | |
else: | |
raise ValueError(f"No messages found in input state to tool_edge: {state}") | |
return last_message | |
def __route_from_llm( | |
self, | |
state: list[AnyMessage] | dict[str, Any] | BaseModel, | |
messages_key: str = "messages", | |
) -> Literal["tools", "extract_final_answer"]: | |
ai_message = self.__extract_last_message(state, messages_key) | |
if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0: | |
return "tools" | |
return "extract_final_answer" | |
def __extract_final_answer(state: MessagesState) -> dict[str, Any]: | |
last_message = state["messages"][-1].content | |
pattern = r"<ANSWER>(?P<answer>.*?)</ANSWER>" | |
m = re.search(pattern, last_message, flags=re.DOTALL) | |
answer = m.group("answer").strip() if m else "" | |
return {"messages": [answer]} | |
def system_prompt(self) -> SystemMessage: | |
""" | |
Returns: | |
The system prompt to use with the agent. | |
""" | |
return SystemMessage(content=self.__system_prompt) | |
def get(self) -> CompiledGraph: | |
""" | |
Factory method. | |
Returns: | |
The instance of the agent. | |
""" | |
graph_builder = StateGraph(MessagesState) | |
graph_builder.add_node("LLM", self.__run_llm) | |
graph_builder.add_node("tools", ToolNode(tools=self.__tools)) | |
graph_builder.add_node( | |
"extract_final_answer", | |
self.__extract_final_answer | |
) | |
graph_builder.add_edge(start_key=START, end_key="LLM") | |
graph_builder.add_conditional_edges( | |
source="LLM", | |
path=self.__route_from_llm, | |
path_map={ | |
"tools": "tools", | |
"extract_final_answer": "extract_final_answer" | |
} | |
) | |
graph_builder.add_edge(start_key="tools", end_key="LLM") | |
graph_builder.add_edge(start_key="extract_final_answer", end_key=END) | |
return graph_builder.compile() | |