File size: 8,003 Bytes
574b6ca
82111b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d32b2f
82111b7
 
 
 
 
 
 
 
5d32b2f
82111b7
 
 
 
 
d26735b
82111b7
 
 
 
 
 
 
 
 
d26735b
82111b7
 
 
 
 
d26735b
82111b7
 
 
 
 
 
d26735b
82111b7
d26735b
82111b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d26735b
82111b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7963312
82111b7
d26735b
82111b7
 
d26735b
82111b7
 
 
 
 
 
 
d26735b
82111b7
 
 
 
 
 
e80aab9
82111b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Set protobuf implementation to avoid C++ extension issues
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

# Load keys from environment
hf_token = os.getenv("HUGGINGFACE_INFERENCE_TOKEN")
serper_api_key = os.getenv("SERPER_API_KEY")

# ---- Imports ----
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from langchain.embeddings import HuggingFaceEmbeddings
import json

# ---- Tools ----

@tool
def multiply(a: int, b: int) -> int:
    """Multiply two numbers together."""
    return a * b

@tool
def add(a: int, b: int) -> int:
    """Add two numbers together."""
    return a + b

@tool
def subtract(a: int, b: int) -> int:
    """Subtract the second number from the first."""
    return a - b

@tool
def divide(a: int, b: int) -> float:
    """Divide the first number by the second. Returns float or error if dividing by zero."""
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b

@tool
def modulus(a: int, b: int) -> int:
    """Returns the remainder after division of the first number by the second."""
    return a % b

@tool
def wiki_search(query: str) -> str:
    """Search Wikipedia for information. Useful for factual questions about people, places, events, etc."""
    try:
        search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
        formatted = "\n\n---\n\n".join(
            [
                f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
                for doc in search_docs
            ]
        )
        return {"wiki_results": formatted}
    except Exception as e:
        return f"Wikipedia search failed: {str(e)}"

@tool
def web_search(query: str) -> str:
    """Search the web for current information. Useful when you need recent or non-Wikipedia information."""
    try:
        search = TavilySearchResults(max_results=3)
        search_docs = search.invoke(query)
        formatted = "\n\n---\n\n".join(
            [
                f'<Document source="{doc["url"]}"/>\n{doc["content"]}\n</Document>'
                for doc in search_docs
            ]
        )
        return {"web_results": formatted}
    except Exception as e:
        return f"Web search failed: {str(e)}"

@tool
def arxiv_search(query: str) -> str:
    """Search academic papers on ArXiv. Useful for technical or scientific questions."""
    try:
        search_docs = ArxivLoader(query=query, load_max_docs=2).load()
        formatted = "\n\n---\n\n".join(
            [
                f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content[:1000]}\n</Document>'
                for doc in search_docs
            ]
        )
        return {"arxiv_results": formatted}
    except Exception as e:
        return f"ArXiv search failed: {str(e)}"

# ---- Embedding & Vector Store Setup ----

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

# Load QA pairs
json_QA = []
try:
    with open('metadata.jsonl', 'r') as jsonl_file:
        for line in jsonl_file:
            json_QA.append(json.loads(line))
except Exception as e:
    print(f"Error loading metadata.jsonl: {e}")
    json_QA = []

documents = [
    Document(
        page_content=f"Question: {sample['Question']}\n\nAnswer: {sample['Final answer']}",
        metadata={"source": sample["task_id"], "question": sample["Question"], "answer": sample["Final answer"]}
    )
    for sample in json_QA
]

try:
    vector_store = Chroma.from_documents(
        documents=documents,
        embedding=embeddings,
        persist_directory="./chroma_db",
        collection_name="qa_collection"
    )
    vector_store.persist()
    print(f"Documents inserted: {len(documents)}")
except Exception as e:
    print(f"Error creating vector store: {e}")
    raise

@tool
def similar_question_search(query: str) -> str:
    """Search for similar questions that have been answered before. Always check here first before using other tools."""
    try:
        matched_docs = vector_store.similarity_search(query, k=3)
        formatted = "\n\n---\n\n".join(
            [
                f'<Question: {doc.metadata["question"]}>\n<Answer: {doc.metadata["answer"]}>\n</Document>'
                for doc in matched_docs
            ]
        )
        return {"similar_questions": formatted}
    except Exception as e:
        return f"Similar question search failed: {str(e)}"

# ---- System Prompt ----

system_prompt = """
You are an expert question-answering assistant. Follow these steps for each question:

1. FIRST check for similar questions using the similar_question_search tool
2. If a similar question exists with a clear answer, use that answer
3. If not, determine which tools might help answer the question
4. Use the tools systematically to gather information
5. Combine information from multiple sources if needed
6. Format your final answer precisely as:
FINAL ANSWER: [your answer here]

Rules for answers:
- Numbers: plain digits only (no commas, units, or symbols)
- Strings: minimal words, no articles, full names
- Lists: comma-separated with no extra formatting
- Be concise but accurate
"""

sys_msg = SystemMessage(content=system_prompt)

# ---- Tool List ----

tools = [
    similar_question_search,  # Check this first
    multiply, add, subtract, divide, modulus,  # Math tools
    wiki_search, web_search, arxiv_search  # Information tools
]

# ---- Graph Definition ----

def build_graph():
    try:
        # Using a powerful HuggingFace model
        llm = ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
                temperature=0,
                max_new_tokens=512,
                huggingfacehub_api_token=hf_token
            )
        )
        
        llm_with_tools = llm.bind_tools(tools)

        def assistant(state: MessagesState):
            return {"messages": [llm_with_tools.invoke(state["messages"])]}

        def retriever(state: MessagesState):
            try:
                # First try to find similar questions
                similar = vector_store.similarity_search(state["messages"][-1].content, k=2)
                if similar:
                    example_msg = HumanMessage(
                        content=f"Here are similar questions and their answers:\n\n" +
                        "\n\n".join([f"Q: {doc.metadata['question']}\nA: {doc.metadata['answer']}" 
                                   for doc in similar])
                    )
                    return {"messages": [sys_msg] + state["messages"] + [example_msg]}
                return {"messages": [sys_msg] + state["messages"]}
            except Exception as e:
                print(f"Retriever error: {e}")
                return {"messages": [sys_msg] + state["messages"]}

        builder = StateGraph(MessagesState)
        builder.add_node("retriever", retriever)
        builder.add_node("assistant", assistant)
        builder.add_node("tools", ToolNode(tools))
        
        builder.add_edge(START, "retriever")
        builder.add_edge("retriever", "assistant")
        builder.add_conditional_edges("assistant", tools_condition)
        builder.add_edge("tools", "assistant")

        return builder.compile()
    
    except Exception as e:
        print(f"Error building graph: {e}")
        raise