Spaces:
Runtime error
Runtime error
import os | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Set protobuf implementation to avoid C++ extension issues | |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
# Load keys from environment | |
hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN") | |
serper_api_key = os.getenv("SERPER_API_KEY") | |
# ---- Imports ---- | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
from langchain_community.vectorstores import Chroma | |
from langchain_core.documents import Document | |
from langchain_core.messages import SystemMessage, HumanMessage | |
from langchain_core.tools import tool | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain.embeddings import HuggingFaceEmbeddings | |
import json | |
# ---- Tools ---- | |
def multiply(a: int, b: int) -> int: | |
"""Multiply two numbers together.""" | |
return a * b | |
def add(a: int, b: int) -> int: | |
"""Add two numbers together.""" | |
return a + b | |
def subtract(a: int, b: int) -> int: | |
"""Subtract the second number from the first.""" | |
return a - b | |
def divide(a: int, b: int) -> float: | |
"""Divide the first number by the second. Returns float or error if dividing by zero.""" | |
if b == 0: | |
raise ValueError("Cannot divide by zero.") | |
return a / b | |
def modulus(a: int, b: int) -> int: | |
"""Returns the remainder after division of the first number by the second.""" | |
return a % b | |
def wiki_search(query: str) -> str: | |
"""Search Wikipedia for information. Useful for factual questions about people, places, events, etc.""" | |
try: | |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
formatted = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
return {"wiki_results": formatted} | |
except Exception as e: | |
return f"Wikipedia search failed: {str(e)}" | |
def web_search(query: str) -> str: | |
"""Search the web for current information. Useful when you need recent or non-Wikipedia information.""" | |
try: | |
search = TavilySearchResults(max_results=3) | |
search_docs = search.invoke(query) | |
formatted = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc["url"]}"/>\n{doc["content"]}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
return {"web_results": formatted} | |
except Exception as e: | |
return f"Web search failed: {str(e)}" | |
def arxiv_search(query: str) -> str: | |
"""Search academic papers on ArXiv. Useful for technical or scientific questions.""" | |
try: | |
search_docs = ArxivLoader(query=query, load_max_docs=2).load() | |
formatted = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content[:1000]}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
return {"arxiv_results": formatted} | |
except Exception as e: | |
return f"ArXiv search failed: {str(e)}" | |
# ---- Embedding & Vector Store Setup ---- | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
# Load QA pairs | |
json_QA = [] | |
try: | |
with open('metadata.jsonl', 'r') as jsonl_file: | |
for line in jsonl_file: | |
json_QA.append(json.loads(line)) | |
except Exception as e: | |
print(f"Error loading metadata.jsonl: {e}") | |
json_QA = [] | |
documents = [ | |
Document( | |
page_content=f"Question: {sample['Question']}\n\nAnswer: {sample['Final answer']}", | |
metadata={"source": sample["task_id"], "question": sample["Question"], "answer": sample["Final answer"]} | |
) | |
for sample in json_QA | |
] | |
try: | |
vector_store = Chroma.from_documents( | |
documents=documents, | |
embedding=embeddings, | |
persist_directory="./chroma_db", | |
collection_name="qa_collection" | |
) | |
vector_store.persist() | |
print(f"Documents inserted: {len(documents)}") | |
except Exception as e: | |
print(f"Error creating vector store: {e}") | |
raise | |
def similar_question_search(query: str) -> str: | |
"""Search for similar questions that have been answered before. Always check here first before using other tools.""" | |
try: | |
matched_docs = vector_store.similarity_search(query, k=3) | |
formatted = "\n\n---\n\n".join( | |
[ | |
f'<Question: {doc.metadata["question"]}>\n<Answer: {doc.metadata["answer"]}>\n</Document>' | |
for doc in matched_docs | |
] | |
) | |
return {"similar_questions": formatted} | |
except Exception as e: | |
return f"Similar question search failed: {str(e)}" | |
# ---- System Prompt ---- | |
system_prompt = """ | |
You are an expert question-answering assistant. Follow these steps for each question: | |
1. FIRST check for similar questions using the similar_question_search tool | |
2. If a similar question exists with a clear answer, use that answer | |
3. If not, determine which tools might help answer the question | |
4. Use the tools systematically to gather information | |
5. Combine information from multiple sources if needed | |
6. Format your final answer precisely as: | |
FINAL ANSWER: [your answer here] | |
Rules for answers: | |
- Numbers: plain digits only (no commas, units, or symbols) | |
- Strings: minimal words, no articles, full names | |
- Lists: comma-separated with no extra formatting | |
- Be concise but accurate | |
""" | |
sys_msg = SystemMessage(content=system_prompt) | |
# ---- Tool List ---- | |
tools = [ | |
similar_question_search, # Check this first | |
multiply, add, subtract, divide, modulus, # Math tools | |
wiki_search, web_search, arxiv_search # Information tools | |
] | |
# ---- Graph Definition ---- | |
def build_graph(): | |
try: | |
# Using a powerful HuggingFace model | |
llm = ChatHuggingFace( | |
llm=HuggingFaceEndpoint( | |
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", | |
temperature=0, | |
max_new_tokens=512, | |
huggingfacehub_api_token=hf_token | |
) | |
) | |
llm_with_tools = llm.bind_tools(tools) | |
def assistant(state: MessagesState): | |
return {"messages": [llm_with_tools.invoke(state["messages"])]} | |
def retriever(state: MessagesState): | |
try: | |
# First try to find similar questions | |
similar = vector_store.similarity_search(state["messages"][-1].content, k=2) | |
if similar: | |
example_msg = HumanMessage( | |
content=f"Here are similar questions and their answers:\n\n" + | |
"\n\n".join([f"Q: {doc.metadata['question']}\nA: {doc.metadata['answer']}" | |
for doc in similar]) | |
) | |
return {"messages": [sys_msg] + state["messages"] + [example_msg]} | |
return {"messages": [sys_msg] + state["messages"]} | |
except Exception as e: | |
print(f"Retriever error: {e}") | |
return {"messages": [sys_msg] + state["messages"]} | |
builder = StateGraph(MessagesState) | |
builder.add_node("retriever", retriever) | |
builder.add_node("assistant", assistant) | |
builder.add_node("tools", ToolNode(tools)) | |
builder.add_edge(START, "retriever") | |
builder.add_edge("retriever", "assistant") | |
builder.add_conditional_edges("assistant", tools_condition) | |
builder.add_edge("tools", "assistant") | |
return builder.compile() | |
except Exception as e: | |
print(f"Error building graph: {e}") | |
raise |