Freddolin's picture
Update agent.py
6179a99 verified
raw
history blame
13.1 kB
"""LangGraph Agent"""
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from langchain_community.vectorstores import Chroma # Ny import för Chroma
from langchain_core.documents import Document # Ny import för att skapa dokument
import shutil # För att hantera kataloger
load_dotenv()
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers.
Args:
a: first int
b: second int
"""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: first int
b: second int
"""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract two numbers.
Args:
a: first int
b: second int
"""
return a - b
@tool
def divide(a: int, b: int) -> int:
"""Divide two numbers.
Args:
a: first int
b: second int
"""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus of two numbers.
Args:
a: first int
b: second int
"""
return a % b
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for a query and return maximum 2 results.
Args:
query: The search query."""
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return {"wiki_results": formatted_search_docs}
@tool
def web_search(query: str) -> str:
"""Search Tavily for a query and return maximum 3 results.
Args:
query: The search query."""
search_docs = TavilySearchResults(max_results=3).invoke(query=query)
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return {"web_results": formatted_search_docs}
@tool
def arvix_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 result.
Args:
query: The search query."""
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs
])
return {"arvix_results": formatted_search_docs}
# load the system prompt from the file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
# System message
sys_msg = SystemMessage(content=system_prompt)
# --- Start ChromaDB Setup ---
# Define the directory for ChromaDB persistence
CHROMA_DB_DIR = "./chroma_db"
# Build embeddings (this remains the same)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
# Initialize ChromaDB
# If the directory exists, load the existing vector store.
# Otherwise, create a new one and add some dummy documents.
if os.path.exists(CHROMA_DB_DIR) and os.listdir(CHROMA_DB_DIR):
print(f"Loading existing ChromaDB from {CHROMA_DB_DIR}")
vector_store = Chroma(
persist_directory=CHROMA_DB_DIR,
embedding_function=embeddings
)
else:
print(f"Creating new ChromaDB at {CHROMA_DB_DIR} and adding dummy documents.")
# Ensure the directory is clean before creating new
if os.path.exists(CHROMA_DB_DIR):
shutil.rmtree(CHROMA_DB_DIR)
os.makedirs(CHROMA_DB_DIR)
# Example dummy documents to populate the vector store
# In a real application, you would load your actual documents here
documents = [
Document(page_content="What is the capital of France?", metadata={"source": "internal", "answer": "Paris"}),
Document(page_content="Who wrote Hamlet?", metadata={"source": "internal", "answer": "William Shakespeare"}),
Document(page_content="What is the highest mountain in the world?", metadata={"source": "internal", "answer": "Mount Everest"}),
Document(page_content="When was the internet invented?", metadata={"source": "internal", "answer": "The internet, as we know it, evolved from ARPANET in the late 1960s and early 1970s. The TCP/IP protocol, which forms the basis of the internet, was standardized in 1978."}),
Document(page_content="What is the square root of 64?", metadata={"source": "internal", "answer": "8"}),
Document(page_content="Who is the current president of the United States?", metadata={"source": "internal", "answer": "Joe Biden"}),
Document(page_content="What is the chemical symbol for water?", metadata={"source": "internal", "answer": "H2O"}),
Document(page_content="What is the largest ocean on Earth?", metadata={"source": "internal", "answer": "Pacific Ocean"}),
Document(page_content="What is the speed of light?", metadata={"source": "internal", "answer": "Approximately 299,792,458 meters per second in a vacuum."}),
Document(page_content="What is the capital of Sweden?", metadata={"source": "internal", "answer": "Stockholm"}),
]
vector_store = Chroma.from_documents(
documents=documents,
embedding=embeddings,
persist_directory=CHROMA_DB_DIR
)
vector_store.persist() # Save the new vector store to disk
print("ChromaDB initialized and persisted with dummy documents.")
# Create retriever tool using the Chroma vector store
retriever_tool = create_retriever_tool( # Changed variable name to avoid conflict with function name
retriever=vector_store.as_retriever(),
name="Question_Search", # Changed name to be more descriptive and valid for tool use
description="A tool to retrieve similar questions from a vector store and their answers.",
)
# Add the new retriever tool to your list of tools
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
web_search,
arvix_search,
retriever_tool, # Add the new retriever tool here
]
# Build graph function
def build_graph(provider: str = "google"):
"""Build the graph"""
# Load environment variables from .env file
if provider == "google":
# Google Gemini
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
elif provider == "groq":
# Groq https://console.groq.com/docs/models
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
elif provider == "huggingface":
# TODO: Add huggingface endpoint
llm = ChatHuggingFace(
llm=HuggingFaceEndpoint(
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
temperature=0,
),
)
else:
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
# Bind tools to LLM
llm_with_tools = llm.bind_tools(tools)
# Node
def assistant(state: MessagesState):
"""Assistant node"""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
from langchain_core.messages import AIMessage
def retriever(state: MessagesState):
query = state["messages"][-1].content
# Use the retriever tool to get similar documents
similar_docs = retriever_tool.invoke(query) # Call the tool directly
# The tool returns a list of Documents, so we need to process it
# Assuming the tool returns a list of documents, we take the first one
if similar_docs:
# The tool output is a string representation of the documents.
# We need to parse it or adjust the tool to return structured data.
# For simplicity, let's assume the tool returns a list of Document objects
# or a string that can be directly used.
# Given the original `retriever` node, it expected `similar_question[0].page_content`.
# If `retriever_tool.invoke(query)` returns a list of Document objects,
# then `similar_docs[0].page_content` is correct.
# If it returns a string, we need to adapt.
# For now, let's assume it returns a list of Documents or a string that contains the answer.
# If retriever_tool returns a string directly (as per your tool definition):
# content = similar_docs # This would be the string output from the tool
# If retriever_tool returns a list of Document objects from its internal retriever:
# Let's assume the `retriever_tool` internally uses `vector_store.as_retriever().invoke(query)`
# which returns a list of `Document` objects.
# The `create_retriever_tool` wraps this, so `retriever_tool.invoke` will return a string
# that is the `page_content` of the retrieved documents.
# The original `retriever` node was using `vector_store.similarity_search` directly.
# Now `retriever_tool` is a LangChain tool.
# When `retriever_tool.invoke(query)` is called, it will return the formatted string
# from the `create_retriever_tool` definition.
# So, `similar_docs` will be a string.
# We need to parse the `similar_docs` string to extract the answer.
# The `Question_Search` tool description is "A tool to retrieve similar questions from a vector store and their answers."
# The `create_retriever_tool` automatically formats the output of the retriever.
# Let's assume the output string from `retriever_tool.invoke(query)` will look something like:
# "content='What is the capital of Sweden?' metadata={'source': 'internal', 'answer': 'Stockholm'}"
# We need to extract the 'answer' part.
# A more robust way would be to make the retriever node *call* the tool,
# and then the LLM decides if it wants to use the tool.
# However, your current graph structure has a dedicated "retriever" node
# that directly fetches and returns an AIMessage.
# Let's refine the retriever node to parse the output of the tool more robustly.
# The `create_retriever_tool` returns a string where documents are joined.
# We need to extract the content that would be the "answer".
# The dummy documents have `metadata={"source": "internal", "answer": "..."}`.
# The `create_retriever_tool` will return `doc.page_content` by default.
# So, `similar_docs` will contain the question itself.
# We need to ensure the retriever provides the *answer* not just the question.
# Let's adjust the `retriever` node to directly access the `vector_store`
# for `similarity_search` and then extract the answer from metadata,
# similar to your original implementation. This bypasses the tool wrapper
# for this specific node, ensuring we get the full Document object.
similar_doc = vector_store.similarity_search(query, k=1)[0]
# Check if an 'answer' is directly available in metadata
if "answer" in similar_doc.metadata:
answer = similar_doc.metadata["answer"]
elif "Final answer :" in similar_doc.page_content:
answer = similar_doc.page_content.split("Final answer :")[-1].strip()
else:
answer = similar_doc.page_content.strip() # Fallback to page_content if no explicit answer
return {"messages": [AIMessage(content=answer)]}
else:
# If no similar documents found, return an empty AIMessage or a message indicating no answer
return {"messages": [AIMessage(content="No similar questions found in the knowledge base.")]}
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
# Retriever ist Start und Endpunkt
builder.set_entry_point("retriever")
builder.set_finish_point("retriever")
# Compile graph
return builder.compile()