Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
@@ -4,177 +4,103 @@ import os
|
|
4 |
from dotenv import load_dotenv
|
5 |
from langgraph.graph import START, StateGraph, MessagesState
|
6 |
from langgraph.prebuilt import tools_condition
|
|
|
7 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
8 |
from langchain_groq import ChatGroq
|
9 |
-
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
10 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
11 |
from langchain_community.document_loaders import WikipediaLoader
|
12 |
-
from langchain_community.utilities import WikipediaAPIWrapper
|
13 |
from langchain_community.document_loaders import ArxivLoader
|
|
|
14 |
from langchain_core.messages import SystemMessage, HumanMessage
|
15 |
from langchain_core.tools import tool
|
16 |
-
from sentence_transformers import SentenceTransformer
|
17 |
-
from langchain.embeddings.base import Embeddings
|
18 |
-
from typing import List
|
19 |
-
import numpy as np
|
20 |
-
import yaml
|
21 |
-
|
22 |
-
import pandas as pd
|
23 |
-
import uuid
|
24 |
-
import requests
|
25 |
-
import json
|
26 |
-
from langchain_core.documents import Document
|
27 |
-
from youtube_transcript_api import YouTubeTranscriptApi
|
28 |
-
from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable
|
29 |
-
import re
|
30 |
-
|
31 |
-
from langchain_community.document_loaders import TextLoader, PyMuPDFLoader
|
32 |
-
from docx import Document as DocxDocument
|
33 |
-
import openpyxl
|
34 |
-
from io import StringIO
|
35 |
-
|
36 |
-
from transformers import BertTokenizer, BertModel
|
37 |
-
import torch
|
38 |
-
import torch.nn.functional as F
|
39 |
-
from langchain_community.chat_models import ChatOpenAI
|
40 |
-
from langchain_community.tools import Tool
|
41 |
-
import time
|
42 |
-
from huggingface_hub import InferenceClient
|
43 |
-
from langchain_community.llms import HuggingFaceHub
|
44 |
-
from langchain.prompts import PromptTemplate
|
45 |
-
from langchain.chains import LLMChain
|
46 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
47 |
-
from huggingface_hub import login
|
48 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
|
49 |
-
from langchain_huggingface import HuggingFaceEndpoint
|
50 |
-
#from langchain.agents import initialize_agent
|
51 |
-
#from langchain.agents import AgentType
|
52 |
-
from typing import Union
|
53 |
-
from functools import reduce
|
54 |
-
import operator
|
55 |
-
from typing import Union
|
56 |
-
from functools import reduce
|
57 |
-
from youtube_transcript_api import YouTubeTranscriptApi
|
58 |
-
from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable
|
59 |
-
from langchain.schema import Document
|
60 |
-
|
61 |
-
from langchain_community.vectorstores import FAISS
|
62 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
63 |
from langchain.tools.retriever import create_retriever_tool
|
64 |
-
|
65 |
-
from typing import TypedDict, Annotated, List
|
66 |
-
import gradio as gr
|
67 |
-
from langchain.schema import Document
|
68 |
-
|
69 |
-
load_dotenv()
|
70 |
-
|
71 |
|
72 |
|
|
|
73 |
|
74 |
@tool
|
75 |
-
def
|
76 |
-
"""
|
77 |
-
|
78 |
-
|
|
|
79 |
"""
|
|
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
if "list" in inputs:
|
91 |
-
nums = inputs.get("list", [])
|
92 |
-
op = inputs.get("operation", "").lower()
|
93 |
-
|
94 |
-
if not isinstance(nums, list) or not all(isinstance(n, (int, float)) for n in nums):
|
95 |
-
return "Invalid list input. Must be a list of numbers."
|
96 |
-
|
97 |
-
if op == "sum":
|
98 |
-
return sum(nums)
|
99 |
-
elif op == "multiply":
|
100 |
-
return reduce(operator.mul, nums, 1)
|
101 |
-
else:
|
102 |
-
return f"Unsupported list operation: {op}"
|
103 |
-
|
104 |
-
# Handle basic two-number operations
|
105 |
-
a = inputs.get("a")
|
106 |
-
b = inputs.get("b")
|
107 |
-
operation = inputs.get("operation", "").lower()
|
108 |
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
return a % b
|
124 |
-
else:
|
125 |
-
return f"Unknown operation: {operation}"
|
126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
@tool
|
129 |
def wiki_search(query: str) -> str:
|
130 |
-
"""Search Wikipedia for a query and return
|
|
|
|
|
|
|
131 |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
132 |
-
|
133 |
formatted_search_docs = "\n\n---\n\n".join(
|
134 |
[
|
135 |
-
f'<Document source="{doc.metadata
|
136 |
for doc in search_docs
|
137 |
-
]
|
138 |
-
|
139 |
-
return formatted_search_docs
|
140 |
-
|
141 |
-
|
142 |
-
@tool
|
143 |
-
def wikidata_query(query: str) -> str:
|
144 |
-
"""
|
145 |
-
Run a SPARQL query on Wikidata and return results.
|
146 |
-
"""
|
147 |
-
endpoint_url = "https://query.wikidata.org/sparql"
|
148 |
-
headers = {
|
149 |
-
"Accept": "application/sparql-results+json"
|
150 |
-
}
|
151 |
-
response = requests.get(endpoint_url, headers=headers, params={"query": query})
|
152 |
-
data = response.json()
|
153 |
-
return json.dumps(data, indent=2)
|
154 |
-
|
155 |
|
156 |
@tool
|
157 |
def web_search(query: str) -> str:
|
158 |
-
"""Search Tavily for a query and return
|
159 |
-
tavily_key = os.getenv("TAVILY_API_KEY")
|
160 |
-
|
161 |
-
if not tavily_key:
|
162 |
-
return "Error: Tavily API key not set."
|
163 |
-
|
164 |
-
search_tool = TavilySearchResults(tavily_api_key=tavily_key, max_results=3)
|
165 |
-
search_docs = search_tool.invoke(query=query)
|
166 |
|
|
|
|
|
|
|
167 |
formatted_search_docs = "\n\n---\n\n".join(
|
168 |
[
|
169 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
170 |
for doc in search_docs
|
171 |
])
|
172 |
-
|
173 |
-
return formatted_search_docs
|
174 |
-
|
175 |
|
176 |
@tool
|
177 |
-
def
|
178 |
"""Search Arxiv for a query and return maximum 3 result.
|
179 |
|
180 |
Args:
|
@@ -185,114 +111,10 @@ def arxiv_search(query: str) -> str:
|
|
185 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
186 |
for doc in search_docs
|
187 |
])
|
188 |
-
return formatted_search_docs
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
@tool
|
193 |
-
def analyze_attachment(file_path: str) -> str:
|
194 |
-
"""
|
195 |
-
Analyzes attachments including PY, PDF, TXT, DOCX, and XLSX files and returns text content.
|
196 |
-
|
197 |
-
Args:
|
198 |
-
file_path: Local path to the attachment.
|
199 |
-
"""
|
200 |
-
if not os.path.exists(file_path):
|
201 |
-
return f"File not found: {file_path}"
|
202 |
-
|
203 |
-
try:
|
204 |
-
ext = file_path.lower()
|
205 |
-
|
206 |
-
if ext.endswith(".pdf"):
|
207 |
-
loader = PyMuPDFLoader(file_path)
|
208 |
-
documents = loader.load()
|
209 |
-
content = "\n\n".join([doc.page_content for doc in documents])
|
210 |
-
|
211 |
-
elif ext.endswith(".txt") or ext.endswith(".py"):
|
212 |
-
# Both .txt and .py are plain text files
|
213 |
-
with open(file_path, "r", encoding="utf-8") as file:
|
214 |
-
content = file.read()
|
215 |
-
|
216 |
-
elif ext.endswith(".docx"):
|
217 |
-
doc = DocxDocument(file_path)
|
218 |
-
content = "\n".join([para.text for para in doc.paragraphs])
|
219 |
-
|
220 |
-
elif ext.endswith(".xlsx"):
|
221 |
-
wb = openpyxl.load_workbook(file_path, data_only=True)
|
222 |
-
content = ""
|
223 |
-
for sheet in wb:
|
224 |
-
content += f"Sheet: {sheet.title}\n"
|
225 |
-
for row in sheet.iter_rows(values_only=True):
|
226 |
-
content += "\t".join([str(cell) if cell is not None else "" for cell in row]) + "\n"
|
227 |
-
|
228 |
-
else:
|
229 |
-
return "Unsupported file format. Please use PY, PDF, TXT, DOCX, or XLSX."
|
230 |
-
|
231 |
-
return content[:3000] # Limit output size for readability
|
232 |
-
|
233 |
-
except Exception as e:
|
234 |
-
return f"An error occurred while processing the file: {str(e)}"
|
235 |
-
|
236 |
-
|
237 |
-
@tool
|
238 |
-
def get_youtube_transcript(url: str) -> str:
|
239 |
-
"""
|
240 |
-
Fetch transcript text from a YouTube video.
|
241 |
-
|
242 |
-
Args:
|
243 |
-
url (str): Full YouTube video URL.
|
244 |
-
|
245 |
-
Returns:
|
246 |
-
str: Transcript text as a single string.
|
247 |
-
|
248 |
-
Raises:
|
249 |
-
ValueError: If no transcript is available or URL is invalid.
|
250 |
-
"""
|
251 |
-
try:
|
252 |
-
# Extract video ID
|
253 |
-
video_id = extract_video_id(url)
|
254 |
-
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
255 |
-
|
256 |
-
# Combine all transcript text
|
257 |
-
full_text = " ".join([entry['text'] for entry in transcript])
|
258 |
-
return full_text
|
259 |
-
|
260 |
-
except (TranscriptsDisabled, VideoUnavailable) as e:
|
261 |
-
raise ValueError(f"Transcript not available: {e}")
|
262 |
-
except Exception as e:
|
263 |
-
raise ValueError(f"Failed to fetch transcript: {e}")
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
@tool
|
268 |
-
def extract_video_id(url: str) -> str:
|
269 |
-
"""
|
270 |
-
Extract the video ID from a YouTube URL.
|
271 |
-
"""
|
272 |
-
match = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_-]{11})", url)
|
273 |
-
if not match:
|
274 |
-
raise ValueError("Invalid YouTube URL")
|
275 |
-
return match.group(1)
|
276 |
-
|
277 |
|
278 |
|
279 |
|
280 |
-
# -----------------------------
|
281 |
-
# Load configuration from YAML
|
282 |
-
# -----------------------------
|
283 |
-
with open("config.yaml", "r") as f:
|
284 |
-
config = yaml.safe_load(f)
|
285 |
-
|
286 |
-
provider = config["provider"]
|
287 |
-
model_config = config["models"][provider]
|
288 |
-
|
289 |
-
#prompt_path = config["system_prompt_path"]
|
290 |
-
enabled_tool_names = config["tools"]
|
291 |
-
|
292 |
-
|
293 |
-
# -----------------------------
|
294 |
-
# Load system prompt
|
295 |
-
# -----------------------------
|
296 |
# load the system prompt from the file
|
297 |
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
298 |
system_prompt = f.read()
|
@@ -300,919 +122,81 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
|
300 |
# System message
|
301 |
sys_msg = SystemMessage(content=system_prompt)
|
302 |
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
"
|
311 |
-
"
|
312 |
-
"get_youtube_transcript": get_youtube_transcript,
|
313 |
-
"extract_video_id": extract_video_id,
|
314 |
-
"analyze_attachment": analyze_attachment,
|
315 |
-
"wikidata_query": wikidata_query
|
316 |
-
}
|
317 |
-
|
318 |
-
# Then define which tools you want enabled
|
319 |
-
enabled_tool_names = [
|
320 |
-
"math",
|
321 |
-
"wiki_search",
|
322 |
-
"web_search",
|
323 |
-
"arxiv_search",
|
324 |
-
"get_youtube_transcript",
|
325 |
-
"extract_video_id",
|
326 |
-
"analyze_attachment",
|
327 |
-
"wikidata_query"
|
328 |
-
]
|
329 |
-
|
330 |
-
|
331 |
-
tools = [tool_map[name] for name in enabled_tool_names]
|
332 |
-
|
333 |
-
|
334 |
-
# Safe version
|
335 |
-
tools = []
|
336 |
-
for name in enabled_tool_names:
|
337 |
-
if name not in tool_map:
|
338 |
-
print(f"❌ Tool not found: {name}")
|
339 |
-
continue
|
340 |
-
tools.append(tool_map[name])
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
# -----------------------------
|
345 |
-
# Prepare Documents
|
346 |
-
# -----------------------------
|
347 |
-
# Define the URL where the JSON file is hosted
|
348 |
-
|
349 |
-
import faiss
|
350 |
-
|
351 |
-
# 1. Type-Checked State for Gradio
|
352 |
-
class ChatState(TypedDict):
|
353 |
-
messages: Annotated[
|
354 |
-
List[str],
|
355 |
-
gr.State(render=False),
|
356 |
-
"Stores chat history as list of strings"
|
357 |
-
]
|
358 |
-
|
359 |
-
# 2. Content Processing Utilities
|
360 |
-
def process_content(raw_content) -> str:
|
361 |
-
"""Convert any input to a clean string"""
|
362 |
-
if isinstance(raw_content, list):
|
363 |
-
return " ".join(str(item) for item in raw_content)
|
364 |
-
return str(raw_content)
|
365 |
-
|
366 |
-
def reverse_text(text: str) -> str:
|
367 |
-
"""Fix reversed text patterns"""
|
368 |
-
return text[::-1].replace("\\", "").strip() if text.startswith(('.', ',')) else text
|
369 |
-
|
370 |
-
|
371 |
-
# 3. Unified Document Creation
|
372 |
-
|
373 |
-
def create_documents(data_source: str, data: list) -> list:
|
374 |
-
"""Handle both Gradio chat and JSON questions"""
|
375 |
-
docs = []
|
376 |
-
|
377 |
-
for item in data:
|
378 |
-
content = ""
|
379 |
-
# Process different data sources
|
380 |
-
if data_source == "json":
|
381 |
-
raw_question = item.get("question", "")
|
382 |
-
content = raw_question # Adjust as per your content processing logic
|
383 |
-
else:
|
384 |
-
print(f"Skipping invalid data source: {data_source}")
|
385 |
-
continue
|
386 |
-
|
387 |
-
# Ensure metadata type safety
|
388 |
-
metadata = {
|
389 |
-
"task_id": str(item.get("task_id", "")),
|
390 |
-
"level": str(item.get("Level", "")),
|
391 |
-
"file_name": str(item.get("file_name", ""))
|
392 |
-
}
|
393 |
-
|
394 |
-
# Check if content is non-empty
|
395 |
-
if content.strip(): # Only append non-empty content
|
396 |
-
docs.append(Document(page_content=content, metadata=metadata))
|
397 |
-
else:
|
398 |
-
print(f"Skipping invalid entry with empty content: {item}")
|
399 |
-
|
400 |
-
return docs
|
401 |
-
|
402 |
-
# Path to your data.json
|
403 |
-
file_path = "/home/wendy/Downloads/data.json"
|
404 |
-
|
405 |
-
def load_data(file_path: str) -> list[dict]:
|
406 |
-
"""Safe JSON data loading with error handling"""
|
407 |
-
if not os.path.exists(file_path):
|
408 |
-
raise FileNotFoundError(f"Data file not found: {file_path}")
|
409 |
-
|
410 |
-
if not file_path.endswith('.json'):
|
411 |
-
raise ValueError("Invalid file format. Only JSON files supported")
|
412 |
-
|
413 |
-
try:
|
414 |
-
with open(file_path, "r", encoding="utf-8") as f:
|
415 |
-
return json.load(f)
|
416 |
-
except json.JSONDecodeError:
|
417 |
-
raise ValueError("Invalid JSON format in data file")
|
418 |
-
except Exception as e:
|
419 |
-
raise RuntimeError(f"Error loading data: {str(e)}")
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
# 4. Vector Store Integration
|
424 |
-
|
425 |
-
import faiss
|
426 |
-
|
427 |
-
# Custom FAISS wrapper (optional, if you still want it)
|
428 |
-
class MyVector_Store:
|
429 |
-
def __init__(self, index: faiss.Index):
|
430 |
-
self.index = index
|
431 |
-
|
432 |
-
def save_local(self, path: str):
|
433 |
-
faiss.write_index(self.index, path)
|
434 |
-
|
435 |
-
@classmethod
|
436 |
-
def load_local(cls, path: str):
|
437 |
-
index = faiss.read_index(path)
|
438 |
-
return cls(index)
|
439 |
-
|
440 |
-
# -----------------------------
|
441 |
-
# Process JSON data and create documents
|
442 |
-
# -----------------------------
|
443 |
-
|
444 |
-
file_path = "/home/wendy/Downloads/data.json"
|
445 |
-
|
446 |
-
try:
|
447 |
-
with open(file_path, "r", encoding="utf-8") as f:
|
448 |
-
data = json.load(f)
|
449 |
-
print(data)
|
450 |
-
except FileNotFoundError as e:
|
451 |
-
print(f"Error: {e}")
|
452 |
-
except json.JSONDecodeError as e:
|
453 |
-
print(f"Error decoding JSON: {e}")
|
454 |
-
|
455 |
-
docs = create_documents("json", data)
|
456 |
-
texts = [doc.page_content for doc in docs]
|
457 |
-
|
458 |
-
|
459 |
-
# -----------------------------
|
460 |
-
# Initialize embedding model
|
461 |
-
# -----------------------------
|
462 |
-
embedding_model = HuggingFaceEmbeddings(
|
463 |
-
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
464 |
-
)
|
465 |
-
|
466 |
-
# -----------------------------
|
467 |
-
# Create FAISS index and save it
|
468 |
-
# -----------------------------
|
469 |
-
class ChatState(TypedDict):
|
470 |
-
messages: Annotated[
|
471 |
-
List[str],
|
472 |
-
gr.State(render=False),
|
473 |
-
"Stores chat history"
|
474 |
-
]
|
475 |
-
|
476 |
-
def initialize_vector_store():
|
477 |
-
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
478 |
-
index_path = "/home/wendy/my_hf_agent_course_projects/faiss_index"
|
479 |
-
|
480 |
-
if os.path.exists(os.path.join(index_path, "index.faiss")):
|
481 |
-
try:
|
482 |
-
return FAISS.load_local(
|
483 |
-
index_path,
|
484 |
-
embedding_model,
|
485 |
-
allow_dangerous_deserialization=True
|
486 |
-
)
|
487 |
-
except Exception as e:
|
488 |
-
print(f"Error loading index: {e}")
|
489 |
-
|
490 |
-
# Fallback: Create new index
|
491 |
-
print("Building new vector store...")
|
492 |
-
docs = [...] # Your document loading logic here
|
493 |
-
vector_store = FAISS.from_documents(docs, embedding_model)
|
494 |
-
vector_store.save_local(index_path)
|
495 |
-
return vector_store
|
496 |
-
|
497 |
-
# Initialize at module level
|
498 |
-
loaded_store = initialize_vector_store()
|
499 |
-
retriever = loaded_store.as_retriever()
|
500 |
-
|
501 |
-
# -----------------------------
|
502 |
-
# Create LangChain Retriever Tool
|
503 |
-
# -----------------------------
|
504 |
-
#retriever = loaded_store.as_retriever()
|
505 |
-
|
506 |
-
question_retriever_tool = create_retriever_tool(
|
507 |
-
retriever=retriever,
|
508 |
-
name="Question_Search",
|
509 |
-
description="A tool to retrieve documents related to a user's question."
|
510 |
)
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
llm = HuggingFaceEndpoint(
|
516 |
-
repo_id="HuggingFaceH4/zephyr-7b-beta",
|
517 |
-
task="text-generation",
|
518 |
-
huggingfacehub_api_token=os.getenv("HF_TOKEN"),
|
519 |
-
temperature=0.7,
|
520 |
-
max_new_tokens=512
|
521 |
)
|
522 |
|
523 |
-
|
524 |
-
|
525 |
-
# -------------------------------
|
526 |
-
# Step 8: Use the Planner, Classifier, and Decision Logic
|
527 |
-
# -------------------------------
|
528 |
-
|
529 |
-
def process_question(question):
|
530 |
-
# Step 1: Planner generates the task sequence
|
531 |
-
tasks = planner(question)
|
532 |
-
print(f"Tasks to perform: {tasks}")
|
533 |
-
|
534 |
-
# Step 2: Classify the task (based on question)
|
535 |
-
task_type = task_classifier(question)
|
536 |
-
print(f"Task type: {task_type}")
|
537 |
-
|
538 |
-
# Step 3: Use the classifier and planner to decide on the next task or node
|
539 |
-
state = {"question": question, "last_response": ""}
|
540 |
-
next_task = decide_task(state)
|
541 |
-
print(f"Next task: {next_task}")
|
542 |
-
|
543 |
-
# Step 4: Use node skipper logic (skip if needed)
|
544 |
-
skip = node_skipper(state)
|
545 |
-
if skip:
|
546 |
-
print(f"Skipping to {skip}")
|
547 |
-
return skip # Or move directly to generating answer
|
548 |
-
|
549 |
-
# Step 5: Execute task (with error handling)
|
550 |
-
try:
|
551 |
-
if task_type == "wiki_search":
|
552 |
-
response = wiki_search(question)
|
553 |
-
elif task_type == "math":
|
554 |
-
response = calculator(question)
|
555 |
-
else:
|
556 |
-
response = "Default answer logic"
|
557 |
-
|
558 |
-
# Step 6: Final response formatting
|
559 |
-
final_response = final_answer_tool(state, {'wiki_search': response})
|
560 |
-
return final_response
|
561 |
-
|
562 |
-
except Exception as e:
|
563 |
-
print(f"Error executing task: {e}")
|
564 |
-
return "Sorry, I encountered an error processing your request."
|
565 |
-
|
566 |
-
|
567 |
-
# Run the process
|
568 |
-
question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
569 |
-
response = agent.invoke(question)
|
570 |
-
print("Final Response:", response)
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
from langchain.schema import HumanMessage
|
576 |
-
|
577 |
-
def retriever(state: MessagesState, k: int = 4):
|
578 |
-
"""
|
579 |
-
Retrieves documents from the vector store using similarity scores,
|
580 |
-
applies a dynamic threshold filter, and returns updated message state.
|
581 |
-
|
582 |
-
Args:
|
583 |
-
state (MessagesState): Current message state including the user's query.
|
584 |
-
k (int): Number of top results to retrieve from the vector store.
|
585 |
-
|
586 |
-
Returns:
|
587 |
-
dict: Updated messages state including relevant documents or fallback message.
|
588 |
-
"""
|
589 |
-
query = state["messages"][0].content.strip()
|
590 |
-
results = vector_store.similarity_search_with_score(query, k=k)
|
591 |
-
|
592 |
-
# Determine dynamic similarity threshold
|
593 |
-
if any(keyword in query.lower() for keyword in ["who", "what", "where", "when", "why", "how"]):
|
594 |
-
threshold = 0.75
|
595 |
-
else:
|
596 |
-
threshold = 0.8
|
597 |
-
|
598 |
-
filtered = [doc for doc, score in results if score < threshold]
|
599 |
-
|
600 |
-
if not filtered:
|
601 |
-
response_msg = HumanMessage(content="No relevant documents found.")
|
602 |
-
else:
|
603 |
-
content = "\n\n".join(doc.page_content for doc in filtered)
|
604 |
-
response_msg = HumanMessage(content=f"Here are relevant reference documents:\n\n{content}")
|
605 |
-
|
606 |
-
return {"messages": [sys_msg] + state["messages"] + [response_msg]}
|
607 |
-
|
608 |
-
|
609 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
610 |
|
611 |
-
#
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
if provider == "google":
|
616 |
-
|
617 |
-
|
618 |
-
model=config.get("model"),
|
619 |
-
temperature=config.get("temperature", 0.7),
|
620 |
-
google_api_key=config.get("api_key") # Optional: if needed
|
621 |
-
)
|
622 |
-
|
623 |
elif provider == "groq":
|
624 |
-
|
625 |
-
|
626 |
-
model=config.get("model"),
|
627 |
-
temperature=config.get("temperature", 0.7),
|
628 |
-
groq_api_key=config.get("api_key") # Optional: if needed
|
629 |
-
)
|
630 |
-
|
631 |
elif provider == "huggingface":
|
632 |
-
|
633 |
-
|
634 |
-
return ChatHuggingFace(
|
635 |
llm=HuggingFaceEndpoint(
|
636 |
-
|
637 |
-
temperature=
|
638 |
-
|
639 |
-
)
|
640 |
)
|
641 |
-
|
642 |
else:
|
643 |
-
raise ValueError(
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
#
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
# Define intent-based keywords
|
658 |
-
intent_keywords = {
|
659 |
-
"math": ["calculate", "evaluate", "add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"],
|
660 |
-
"wiki_search": ["who is", "what is", "define", "explain", "tell me about", "overview of"],
|
661 |
-
"web_search": ["search", "find", "look up", "google", "latest news", "current info"],
|
662 |
-
"arxiv_search": ["arxiv", "research paper", "scientific paper", "preprint"],
|
663 |
-
"get_youtube_transcript": ["youtube", "watch", "play video", "show me a video"],
|
664 |
-
"extract_video_id": ["analyze video", "summarize video", "video content"],
|
665 |
-
"data_analysis": ["analyze", "plot", "graph", "data", "visualize"],
|
666 |
-
"wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"],
|
667 |
-
"default": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"]
|
668 |
-
}
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
# Step 1: Identify intent
|
673 |
-
detected_intent = None
|
674 |
-
for intent, keywords in intent_keywords.items():
|
675 |
-
if any(keyword in question for keyword in keywords):
|
676 |
-
detected_intent = intent
|
677 |
-
break
|
678 |
-
|
679 |
-
# Step 2: Match tools by intent
|
680 |
-
matched_tools = []
|
681 |
-
if detected_intent:
|
682 |
-
for tool in tools:
|
683 |
-
name = getattr(tool, "name", "").lower()
|
684 |
-
description = getattr(tool, "description", "").lower()
|
685 |
-
if detected_intent in name or detected_intent in description:
|
686 |
-
matched_tools.append(tool)
|
687 |
-
|
688 |
-
# Step 3: Fallback to general-purpose/default tools if no match found
|
689 |
-
if not matched_tools:
|
690 |
-
matched_tools = [
|
691 |
-
tool for tool in tools
|
692 |
-
if "default" in getattr(tool, "name", "").lower()
|
693 |
-
or "qa" in getattr(tool, "description", "").lower()
|
694 |
-
]
|
695 |
-
|
696 |
-
return detected_intent, matched_tools if matched_tools else [tools[0]]
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
def task_classifier(question: str) -> str:
|
702 |
-
"""
|
703 |
-
Classifies the question into one of the predefined task categories.
|
704 |
-
"""
|
705 |
-
question = question.lower().strip()
|
706 |
-
|
707 |
-
# Context-aware intent patterns
|
708 |
-
if any(phrase in question for phrase in [
|
709 |
-
"calculate", "how much is", "what is the result of", "evaluate", "solve"
|
710 |
-
]) or any(op in question for op in ["add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"]):
|
711 |
-
return "math"
|
712 |
-
|
713 |
-
elif any(phrase in question for phrase in [
|
714 |
-
"who is", "what is", "define", "explain", "tell me about", "give me an overview of"
|
715 |
-
]):
|
716 |
-
return "wiki_search"
|
717 |
-
|
718 |
-
elif any(phrase in question for phrase in [
|
719 |
-
"search", "find", "look up", "google", "get the latest", "current news", "trending"
|
720 |
-
]):
|
721 |
-
return "web_search"
|
722 |
-
|
723 |
-
elif any(phrase in question for phrase in [
|
724 |
-
"arxiv", "latest research", "scientific paper", "research paper", "preprint"
|
725 |
-
]):
|
726 |
-
return "arxiv_search"
|
727 |
-
|
728 |
-
elif any(phrase in question for phrase in [
|
729 |
-
"youtube", "watch", "play the video", "show me a video"
|
730 |
-
]):
|
731 |
-
return "get_youtube_transcript"
|
732 |
-
|
733 |
-
elif any(phrase in question for phrase in [
|
734 |
-
"analyze video", "summarize video", "what happens in the video", "video content"
|
735 |
-
]):
|
736 |
-
return "video_analysis"
|
737 |
-
|
738 |
-
elif any(phrase in question for phrase in [
|
739 |
-
"analyze", "visualize", "plot", "graph", "inspect data", "explore dataset"
|
740 |
-
]):
|
741 |
-
return "data_analysis"
|
742 |
-
|
743 |
-
elif any(phrase in question for phrase in [
|
744 |
-
"sparql", "wikidata", "query wikidata", "run sparql", "wikidata query"
|
745 |
-
]):
|
746 |
-
return "wikidata_query"
|
747 |
-
|
748 |
-
return "default"
|
749 |
-
|
750 |
-
|
751 |
-
def select_tool_and_run(question: str, tools: dict):
|
752 |
-
# Step 1: Classify intent
|
753 |
-
intent = task_classifier(question) # assuming task_classifier maps the question to intent
|
754 |
-
|
755 |
-
# Map intent to tool names
|
756 |
-
intent_tool_map = {
|
757 |
-
"math": "calculator", # maps to tools["math"] → calculator
|
758 |
-
"wiki_search": "wiki_search", # → wiki_search
|
759 |
-
"web_search": "web_search", # → web_search
|
760 |
-
"arxiv_search": "arxiv_search", # → arxiv_search (spelling fixed)
|
761 |
-
"get_youtube_transcript": "get_youtube_transcript", # → get_youtube_transcript
|
762 |
-
"extract_video_id": "extract_video_id", # adjust based on your tools
|
763 |
-
"analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
|
764 |
-
"wikidata_query": "wikidata_query", # → wikidata_query
|
765 |
-
"default": "default" # → default_tool
|
766 |
-
}
|
767 |
-
|
768 |
-
# Get the corresponding tool name
|
769 |
-
tool_name = intent_tool_map.get(intent, "default") # Default to "default" if no match
|
770 |
-
|
771 |
-
# Retrieve the tool from the tools dictionary
|
772 |
-
tool_func = tools.get(tool_name)
|
773 |
-
|
774 |
-
if not tool_func:
|
775 |
-
return f"Tool not found for intent '{intent}'"
|
776 |
-
|
777 |
-
# Step 2: Run the tool
|
778 |
-
try:
|
779 |
-
# If the tool needs JSON or structured data
|
780 |
-
try:
|
781 |
-
parsed_input = json.loads(question)
|
782 |
-
except json.JSONDecodeError:
|
783 |
-
parsed_input = question # fallback to raw input if not JSON
|
784 |
-
|
785 |
-
# Run the selected tool
|
786 |
-
print(f"Running tool: {tool_name} with input: {parsed_input}") # log the tool name and input
|
787 |
-
return tool_func(parsed_input)
|
788 |
-
except Exception as e:
|
789 |
-
return f"Error while running tool '{tool_name}': {str(e)}"
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
# Function to extract math operation from the question
|
794 |
-
|
795 |
-
def extract_math_from_question(question: str):
|
796 |
-
question = question.lower()
|
797 |
-
|
798 |
-
# Map natural language to symbols
|
799 |
-
ops = {
|
800 |
-
"add": "+", "plus": "+",
|
801 |
-
"subtract": "-", "minus": "-",
|
802 |
-
"multiply": "*", "times": "*",
|
803 |
-
"divide": "/", "divided by": "/",
|
804 |
-
"modulus": "%", "mod": "%"
|
805 |
-
}
|
806 |
-
|
807 |
-
for word, symbol in ops.items():
|
808 |
-
question = re.sub(rf"\b{word}\b", symbol, question)
|
809 |
-
|
810 |
-
# Extract math expression like "12 + 5"
|
811 |
-
match = re.search(r'(\d+)\s*([\+\-\*/%])\s*(\d+)', question)
|
812 |
-
if match:
|
813 |
-
num1 = int(match.group(1))
|
814 |
-
operator = match.group(2)
|
815 |
-
num2 = int(match.group(3))
|
816 |
-
return {
|
817 |
-
"a": num1,
|
818 |
-
"b": num2,
|
819 |
-
"operation": {
|
820 |
-
"+": "add",
|
821 |
-
"-": "subtract",
|
822 |
-
"*": "multiply",
|
823 |
-
"/": "divide",
|
824 |
-
"%": "modulus"
|
825 |
-
}[operator]
|
826 |
-
}
|
827 |
-
return None
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
# Example tool set (adjust these to match your actual tool names)
|
832 |
-
intent_tool_map = {
|
833 |
-
"math": "math", # maps to tools["math"] → calculator
|
834 |
-
"wiki_search": "wiki_search", # → wiki_search
|
835 |
-
"web_search": "web_search", # → web_search
|
836 |
-
"arxiv_search": "arxiv_search", # → arxiv_search (spelling fixed)
|
837 |
-
"get_youtube_transcript": "get_youtube_transcript", # → get_youtube_transcript
|
838 |
-
"extract_video_id": "extract_video_id", # adjust based on your tools
|
839 |
-
"analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
|
840 |
-
"wikidata_query": "wikidata_query", # → wikidata_query
|
841 |
-
"default": "default" # → default_tool
|
842 |
-
}
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
# The task order can also include the tools for each task
|
847 |
-
priority_order = [
|
848 |
-
{"task": "math", "tool": "math"},
|
849 |
-
{"task": "wiki_search", "tool": "wiki_search"},
|
850 |
-
{"task": "web_search", "tool": "web_search"},
|
851 |
-
{"task": "arxiv_search", "tool": "arxiv_search"},
|
852 |
-
{"task": "wikidata_query", "tool": "wikidata_query"},
|
853 |
-
{"task": "retriever", "tool": "retriever"},
|
854 |
-
{"task": "get_youtube_transcript", "tool": "get_youtube_transcript"},
|
855 |
-
{"task": "extract_video_id", "tool": "extract_video_id"},
|
856 |
-
{"task": "analyze_attachment", "tool": "analyze_attachment"},
|
857 |
-
{"task": "default", "tool": "default"} # Fallback
|
858 |
-
]
|
859 |
-
|
860 |
-
def decide_task(state: dict) -> str:
|
861 |
-
"""Decides which task to perform based on the current state."""
|
862 |
-
|
863 |
-
# Get the list of tasks from the planner
|
864 |
-
tasks = planner(state["question"])
|
865 |
-
print(f"Available tasks: {tasks}") # Debugging: show all possible tasks
|
866 |
-
|
867 |
-
# Check if the tasks list is empty or invalid
|
868 |
-
if not tasks:
|
869 |
-
print("❌ No valid tasks were returned from the planner.")
|
870 |
-
return "default" # Return a default task if no tasks were generated
|
871 |
-
|
872 |
-
# If there are multiple tasks, we can prioritize based on certain conditions
|
873 |
-
task = tasks[0] # Default to the first task in the list
|
874 |
-
if len(tasks) > 1:
|
875 |
-
print(f"⚠️ Multiple tasks found. Deciding based on priority.")
|
876 |
-
# Example logic to prioritize tasks, adjust based on your use case
|
877 |
-
task = prioritize_tasks(tasks)
|
878 |
-
|
879 |
-
print(f"Decided on task: {task}") # Debugging: show the final task
|
880 |
-
return task
|
881 |
-
|
882 |
-
|
883 |
-
def prioritize_tasks(tasks: list) -> str:
|
884 |
-
"""Prioritize tasks based on certain conditions or criteria, including tools."""
|
885 |
-
# Sort tasks based on priority_order mapping
|
886 |
-
for priority in priority_order:
|
887 |
-
# Check if any task matches the priority task type
|
888 |
-
for task in tasks:
|
889 |
-
if priority["task"] in task:
|
890 |
-
print(f"✅ Prioritizing task: {task} with tool: {priority['tool']}") # Debugging: show the chosen task and tool
|
891 |
-
# Assign the correct tool based on the task
|
892 |
-
tool = tools.get(priority["tool"], tools["default"]) # Default to 'default_tool' if not found
|
893 |
-
return task, tool
|
894 |
-
|
895 |
-
# If no priority task is found, return the first task with its default tool
|
896 |
-
return tasks[0], tools["default"]
|
897 |
-
|
898 |
-
|
899 |
-
def process_question(question: str):
|
900 |
-
"""Process the question and route it to the appropriate tool."""
|
901 |
-
# Get the tasks from the planner
|
902 |
-
tasks = planner(question)
|
903 |
-
print(f"Tasks to perform: {tasks}")
|
904 |
-
|
905 |
-
task_type, tool = decide_task({"question": question})
|
906 |
-
print(f"Next task: {task_type} with tool: {tool}")
|
907 |
-
|
908 |
-
if node_skipper({"question": question}):
|
909 |
-
print(f"Skipping task: {task_type}")
|
910 |
-
return "Task skipped."
|
911 |
-
|
912 |
-
try:
|
913 |
-
# Execute the corresponding tool for the task type
|
914 |
-
if task_type == "wiki_search":
|
915 |
-
response = tool.run(question) # Assuming tool is wiki_tool
|
916 |
-
elif task_type == "math":
|
917 |
-
response = tool.run(question) # Assuming tool is calc_tool
|
918 |
-
elif task_type == "retriever":
|
919 |
-
response = tool.run(question) # Assuming tool is retriever_tool
|
920 |
-
else:
|
921 |
-
response = tool.run(question) # Default tool
|
922 |
-
|
923 |
-
return generate_final_answer({"question": question}, {task_type: response})
|
924 |
-
|
925 |
-
except Exception as e:
|
926 |
-
print(f"❌ Error: {e}")
|
927 |
-
return f"Sorry, I encountered an error: {str(e)}"
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
def call_llm(state):
|
933 |
-
messages = state["messages"]
|
934 |
-
response = llm.invoke(messages)
|
935 |
-
return {"messages": messages + [response]}
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
from langchain.schema import AIMessage
|
941 |
-
from typing import TypedDict, List, Optional
|
942 |
-
from langchain_core.messages import BaseMessage
|
943 |
-
|
944 |
-
class AgentState(TypedDict):
|
945 |
-
messages: List[BaseMessage] # Chat history
|
946 |
-
input: str # Original input
|
947 |
-
intent: str # Derived or predicted intent
|
948 |
-
result: Optional[str] # Optional result
|
949 |
-
|
950 |
-
|
951 |
-
def tool_dispatcher(state: AgentState) -> AgentState:
|
952 |
-
last_msg = state["messages"][-1]
|
953 |
-
|
954 |
-
# Make sure it's an AI message with tool_calls
|
955 |
-
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
956 |
-
tool_call = last_msg.tool_calls[0]
|
957 |
-
tool_name = tool_call["name"]
|
958 |
-
tool_input = tool_call["args"] # Adjust based on your actual schema
|
959 |
-
|
960 |
-
tool_func = tool_map.get(tool_name, default_tool)
|
961 |
-
|
962 |
-
# If args is a dict and your tool expects unpacked values:
|
963 |
-
if isinstance(tool_input, dict):
|
964 |
-
result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(**tool_input)
|
965 |
-
else:
|
966 |
-
result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(tool_input)
|
967 |
-
|
968 |
-
# You can choose to append this to messages, or just save result
|
969 |
-
return {
|
970 |
-
**state,
|
971 |
-
"result": result,
|
972 |
-
# Optionally add: "messages": state["messages"] + [ToolMessage(...)]
|
973 |
-
}
|
974 |
-
|
975 |
-
# No tool call detected, return state unchanged
|
976 |
-
return state
|
977 |
-
|
978 |
-
|
979 |
-
|
980 |
-
|
981 |
-
# Decide what to do next: if tool call → call_tool, else → end
|
982 |
-
def should_call_tool(state):
|
983 |
-
last_msg = state["messages"][-1]
|
984 |
-
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
985 |
-
return "call_tool"
|
986 |
-
return "end"
|
987 |
-
|
988 |
-
|
989 |
-
from typing import TypedDict, List, Optional, Union
|
990 |
-
from langchain.schema import BaseMessage
|
991 |
-
|
992 |
-
class AgentState(TypedDict):
|
993 |
-
messages: List[BaseMessage] # Chat history
|
994 |
-
input: str # Original input
|
995 |
-
intent: str # Derived or predicted intent
|
996 |
-
result: Optional[str] # Final or intermediate result
|
997 |
-
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
# To store previously asked questions and timestamps (simulating state persistence)
|
1004 |
-
recent_questions = {}
|
1005 |
-
|
1006 |
-
def node_skipper(state: dict) -> bool:
|
1007 |
-
"""
|
1008 |
-
Determines whether to skip the task based on the state.
|
1009 |
-
This could include:
|
1010 |
-
1. Repeated or similar questions
|
1011 |
-
2. Irrelevant or empty questions
|
1012 |
-
3. Tasks that have already been processed recently
|
1013 |
-
"""
|
1014 |
-
question = state.get("question", "").strip()
|
1015 |
-
|
1016 |
-
if not question:
|
1017 |
-
print("❌ Skipping: Empty or invalid question.")
|
1018 |
-
return True # Skip if no valid question
|
1019 |
-
|
1020 |
-
# 1. Skip if the question has already been asked recently (within a given time window)
|
1021 |
-
# Here, we're using a simple example with a 5-minute window (300 seconds).
|
1022 |
-
if question in recent_questions:
|
1023 |
-
last_asked_time = recent_questions[question]
|
1024 |
-
time_since_last_ask = time.time() - last_asked_time
|
1025 |
-
if time_since_last_ask < 300: # 5-minute threshold
|
1026 |
-
print(f"❌ Skipping: The question has been asked recently. Time since last ask: {time_since_last_ask:.2f} seconds.")
|
1027 |
-
return True # Skip if the question was asked within the last 5 minutes
|
1028 |
-
|
1029 |
-
# 2. Skip if the question is irrelevant or not meaningful enough
|
1030 |
-
irrelevant_keywords = ["blah", "nothing", "invalid", "nonsense"]
|
1031 |
-
if any(keyword in question.lower() for keyword in irrelevant_keywords):
|
1032 |
-
print("❌ Skipping: Irrelevant or nonsense question.")
|
1033 |
-
return True # Skip if the question contains irrelevant keywords
|
1034 |
-
|
1035 |
-
# 3. Skip if the task has already been completed for this question (based on a unique task identifier)
|
1036 |
-
if "last_response" in state and state["last_response"]:
|
1037 |
-
print("❌ Skipping: Task has already been processed recently.")
|
1038 |
-
return True # Skip if a response has already been given
|
1039 |
-
|
1040 |
-
# 4. Skip based on a condition related to the task itself
|
1041 |
-
# Example: Skip math-related tasks if the result is already known or trivial
|
1042 |
-
if "math" in state.get("question", "").lower():
|
1043 |
-
# If math is trivial (like "What is 2+2?")
|
1044 |
-
trivial_math = ["2 + 2", "1 + 1", "3 + 3"]
|
1045 |
-
if any(trivial_question in question for trivial_question in trivial_math):
|
1046 |
-
print(f"❌ Skipping trivial math question: {question}")
|
1047 |
-
return True # Skip if the math question is trivial
|
1048 |
-
|
1049 |
-
# 5. Skip based on external factors (e.g., current time, system load, etc.)
|
1050 |
-
# Example: Avoid processing tasks at night if that's part of the business logic
|
1051 |
-
current_hour = time.localtime().tm_hour
|
1052 |
-
if current_hour >= 22 or current_hour < 6:
|
1053 |
-
print("❌ Skipping: It's night time, not processing tasks.")
|
1054 |
-
return True # Skip tasks during night time (e.g., between 10 PM and 6 AM)
|
1055 |
-
|
1056 |
-
# If none of the conditions matched, don't skip the task
|
1057 |
-
return False
|
1058 |
-
|
1059 |
-
# Update recent questions (for simulating repeated question check)
|
1060 |
-
def update_recent_questions(question: str):
|
1061 |
-
"""Update the recent questions dictionary with the current timestamp."""
|
1062 |
-
recent_questions[question] = time.time()
|
1063 |
-
|
1064 |
-
|
1065 |
-
|
1066 |
-
def generate_final_answer(state: dict, task_results: dict) -> str:
|
1067 |
-
"""Generate a final answer based on the results of the task."""
|
1068 |
-
if "wiki_search" in task_results:
|
1069 |
-
return f"📚 Wiki Summary:\n{task_results['wiki_search']}"
|
1070 |
-
elif "math" in task_results:
|
1071 |
-
return f"🧮 Math Result: {task_results['math']}"
|
1072 |
-
elif "retriever" in task_results:
|
1073 |
-
return f"🔍 Retrieved Info: {task_results['retriever']}"
|
1074 |
-
else:
|
1075 |
-
return "🤖 Unable to generate a specific answer."
|
1076 |
-
|
1077 |
-
|
1078 |
-
def answer_question(question: str) -> str:
|
1079 |
-
"""Process a single question and return the answer."""
|
1080 |
-
print(f"Processing question: {question[:50]}...") # Debugging: show first 50 chars
|
1081 |
-
|
1082 |
-
# Wrap the question in a HumanMessage from langchain_core (assuming langchain is used)
|
1083 |
-
messages = [HumanMessage(content=question)]
|
1084 |
-
response = graph.invoke({"messages": messages}) # Assuming `graph` is defined elsewhere
|
1085 |
-
|
1086 |
-
# Extract the answer from the response
|
1087 |
-
answer = response['messages'][-1].content
|
1088 |
-
return answer[14:] # Assuming 'answer[14:]' is correct based on your example
|
1089 |
-
|
1090 |
-
|
1091 |
-
def process_all_tasks(tasks: list):
|
1092 |
-
"""Process a list of tasks."""
|
1093 |
-
results = {}
|
1094 |
-
|
1095 |
-
for task in tasks:
|
1096 |
-
question = task.get("question", "").strip()
|
1097 |
-
if not question:
|
1098 |
-
print(f"Skipping task with missing or empty 'question': {task}")
|
1099 |
-
continue
|
1100 |
-
|
1101 |
-
print(f"\n🟢 Processing Task: {task['task_id']} - Question: {question}")
|
1102 |
-
|
1103 |
-
# Call the existing process_question logic
|
1104 |
-
response = process_question(question)
|
1105 |
-
|
1106 |
-
print(f"✅ Response: {response}")
|
1107 |
-
results[task['task_id']] = response
|
1108 |
-
|
1109 |
-
return results
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
## Langgraph
|
1116 |
-
|
1117 |
-
# Build graph function
|
1118 |
-
vector_store = vector_store.save_local("faiss_index")
|
1119 |
-
|
1120 |
-
provider = "huggingface"
|
1121 |
-
|
1122 |
-
model_config = {
|
1123 |
-
"repo_id": "HuggingFaceH4/zephyr-7b-beta",
|
1124 |
-
"task": "text-generation",
|
1125 |
-
"temperature": 0.7,
|
1126 |
-
"max_new_tokens": 512,
|
1127 |
-
"huggingfacehub_api_token": os.getenv("HF_TOKEN")
|
1128 |
-
}
|
1129 |
-
|
1130 |
-
# Get LLM
|
1131 |
-
def get_llm(provider: str, config: dict):
|
1132 |
-
if provider == "huggingface":
|
1133 |
-
from langchain_huggingface import HuggingFaceEndpoint
|
1134 |
-
return HuggingFaceEndpoint(
|
1135 |
-
repo_id=config["repo_id"],
|
1136 |
-
task=config["task"],
|
1137 |
-
huggingfacehub_api_token=config["huggingfacehub_api_token"],
|
1138 |
-
temperature=config["temperature"],
|
1139 |
-
max_new_tokens=config["max_new_tokens"]
|
1140 |
)
|
1141 |
-
|
1142 |
-
raise ValueError(f"Unsupported provider: {provider}")
|
1143 |
|
1144 |
-
|
1145 |
-
def assistant(state: dict):
|
1146 |
-
return {
|
1147 |
-
"messages": [llm_with_tools.invoke(state["messages"])]
|
1148 |
-
}
|
1149 |
-
|
1150 |
-
|
1151 |
-
def tools_condition(state: dict) -> str:
|
1152 |
-
if "use tool" in state["messages"][-1].content.lower():
|
1153 |
-
return "tools"
|
1154 |
-
else:
|
1155 |
-
return "END"
|
1156 |
-
|
1157 |
-
|
1158 |
-
|
1159 |
-
from langgraph.graph import StateGraph
|
1160 |
-
from langchain_core.messages import SystemMessage
|
1161 |
-
from langchain_core.runnables import RunnableLambda
|
1162 |
-
def build_graph(vector_store, provider: str, model_config: dict) -> StateGraph:
|
1163 |
-
# Get LLM
|
1164 |
-
llm = get_llm(provider, model_config)
|
1165 |
-
|
1166 |
-
# Define available tools
|
1167 |
-
tools = [
|
1168 |
-
wiki_search, calculator, web_search, arxiv_search,
|
1169 |
-
get_youtube_transcript, extract_video_id, analyze_attachment, wikidata_query
|
1170 |
-
]
|
1171 |
-
|
1172 |
-
# Tool mapping (global if needed elsewhere)
|
1173 |
-
global tool_map
|
1174 |
-
tool_map = {t.name: t for t in tools}
|
1175 |
-
|
1176 |
-
# Bind tools only if LLM supports it
|
1177 |
-
if hasattr(llm, "bind_tools"):
|
1178 |
-
llm_with_tools = llm.bind_tools(tools)
|
1179 |
-
else:
|
1180 |
-
llm_with_tools = llm # fallback for non-tool-aware models
|
1181 |
-
|
1182 |
-
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
1183 |
-
|
1184 |
-
# Define nodes as runnables
|
1185 |
-
retriever = RunnableLambda(lambda state: {
|
1186 |
-
**state,
|
1187 |
-
"retrieved_docs": vector_store.similarity_search(state["input"])
|
1188 |
-
})
|
1189 |
-
|
1190 |
-
assistant = RunnableLambda(lambda state: {
|
1191 |
-
**state,
|
1192 |
-
"messages": [sys_msg] + state["messages"]
|
1193 |
-
})
|
1194 |
-
|
1195 |
-
call_llm = llm_with_tools # already configured
|
1196 |
-
|
1197 |
-
# Start building the graph
|
1198 |
-
builder = StateGraph(AgentState)
|
1199 |
builder.add_node("retriever", retriever)
|
1200 |
builder.add_node("assistant", assistant)
|
1201 |
-
builder.add_node("
|
1202 |
-
builder.
|
1203 |
-
builder.add_node("end", lambda state: state) # Add explicit end node
|
1204 |
-
|
1205 |
-
# Define graph flow
|
1206 |
-
builder.set_entry_point("retriever")
|
1207 |
builder.add_edge("retriever", "assistant")
|
1208 |
-
builder.
|
1209 |
-
|
1210 |
-
|
1211 |
-
|
1212 |
-
|
1213 |
-
})
|
1214 |
-
|
1215 |
-
builder.add_edge("call_tool", "call_llm") # loop back after tool call
|
1216 |
|
1217 |
-
|
1218 |
-
|
|
|
4 |
from dotenv import load_dotenv
|
5 |
from langgraph.graph import START, StateGraph, MessagesState
|
6 |
from langgraph.prebuilt import tools_condition
|
7 |
+
from langgraph.prebuilt import ToolNode
|
8 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
9 |
from langchain_groq import ChatGroq
|
10 |
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
11 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
12 |
from langchain_community.document_loaders import WikipediaLoader
|
|
|
13 |
from langchain_community.document_loaders import ArxivLoader
|
14 |
+
from langchain_community.vectorstores import SupabaseVectorStore
|
15 |
from langchain_core.messages import SystemMessage, HumanMessage
|
16 |
from langchain_core.tools import tool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from langchain.tools.retriever import create_retriever_tool
|
18 |
+
from supabase.client import Client, create_client
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
+
load_dotenv()
|
22 |
|
23 |
@tool
|
24 |
+
def multiply(a: int, b: int) -> int:
|
25 |
+
"""Multiply two numbers.
|
26 |
+
Args:
|
27 |
+
a: first int
|
28 |
+
b: second int
|
29 |
"""
|
30 |
+
return a * b
|
31 |
|
32 |
+
@tool
|
33 |
+
def add(a: int, b: int) -> int:
|
34 |
+
"""Add two numbers.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
a: first int
|
38 |
+
b: second int
|
39 |
+
"""
|
40 |
+
return a + b
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
+
@tool
|
43 |
+
def subtract(a: int, b: int) -> int:
|
44 |
+
"""Subtract two numbers.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
a: first int
|
48 |
+
b: second int
|
49 |
+
"""
|
50 |
+
return a - b
|
51 |
|
52 |
+
@tool
|
53 |
+
def divide(a: int, b: int) -> int:
|
54 |
+
"""Divide two numbers.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
a: first int
|
58 |
+
b: second int
|
59 |
+
"""
|
60 |
+
if b == 0:
|
61 |
+
raise ValueError("Cannot divide by zero.")
|
62 |
+
return a / b
|
|
|
|
|
|
|
63 |
|
64 |
+
@tool
|
65 |
+
def modulus(a: int, b: int) -> int:
|
66 |
+
"""Get the modulus of two numbers.
|
67 |
+
|
68 |
+
Args:
|
69 |
+
a: first int
|
70 |
+
b: second int
|
71 |
+
"""
|
72 |
+
return a % b
|
73 |
|
74 |
@tool
|
75 |
def wiki_search(query: str) -> str:
|
76 |
+
"""Search Wikipedia for a query and return maximum 2 results.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
query: The search query."""
|
80 |
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
|
|
81 |
formatted_search_docs = "\n\n---\n\n".join(
|
82 |
[
|
83 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
84 |
for doc in search_docs
|
85 |
+
])
|
86 |
+
return {"wiki_results": formatted_search_docs}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
@tool
|
89 |
def web_search(query: str) -> str:
|
90 |
+
"""Search Tavily for a query and return maximum 3 results.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
+
Args:
|
93 |
+
query: The search query."""
|
94 |
+
search_docs = TavilySearchResults(max_results=3).invoke(query=query)
|
95 |
formatted_search_docs = "\n\n---\n\n".join(
|
96 |
[
|
97 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
98 |
for doc in search_docs
|
99 |
])
|
100 |
+
return {"web_results": formatted_search_docs}
|
|
|
|
|
101 |
|
102 |
@tool
|
103 |
+
def arvix_search(query: str) -> str:
|
104 |
"""Search Arxiv for a query and return maximum 3 result.
|
105 |
|
106 |
Args:
|
|
|
111 |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
112 |
for doc in search_docs
|
113 |
])
|
114 |
+
return {"arvix_results": formatted_search_docs}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
# load the system prompt from the file
|
119 |
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
120 |
system_prompt = f.read()
|
|
|
122 |
# System message
|
123 |
sys_msg = SystemMessage(content=system_prompt)
|
124 |
|
125 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
|
126 |
+
supabase: Client = create_client(
|
127 |
+
os.environ.get("SUPABASE_URL"),
|
128 |
+
os.environ.get("SUPABASE_SERVICE_KEY"))
|
129 |
+
vector_store = SupabaseVectorStore(
|
130 |
+
client=supabase,
|
131 |
+
embedding= embeddings,
|
132 |
+
table_name="documents",
|
133 |
+
query_name="match_documents_langchain",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
)
|
135 |
+
create_retriever_tool = create_retriever_tool(
|
136 |
+
retriever=vector_store.as_retriever(),
|
137 |
+
name="Question Search",
|
138 |
+
description="A tool to retrieve similar questions from a vector store.",
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
)
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
+
tools = [
|
143 |
+
multiply,
|
144 |
+
add,
|
145 |
+
subtract,
|
146 |
+
divide,
|
147 |
+
modulus,
|
148 |
+
wiki_search,
|
149 |
+
web_search,
|
150 |
+
arvix_search,
|
151 |
+
]
|
152 |
|
153 |
+
# Build graph function
|
154 |
+
def build_graph(provider: str = "google"):
|
155 |
+
"""Build the graph"""
|
156 |
+
# Load environment variables from .env file
|
157 |
if provider == "google":
|
158 |
+
# Google Gemini
|
159 |
+
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
|
|
|
|
|
|
|
|
|
|
|
160 |
elif provider == "groq":
|
161 |
+
# Groq https://console.groq.com/docs/models
|
162 |
+
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
|
|
|
|
|
|
|
|
|
|
|
163 |
elif provider == "huggingface":
|
164 |
+
# TODO: Add huggingface endpoint
|
165 |
+
llm = ChatHuggingFace(
|
|
|
166 |
llm=HuggingFaceEndpoint(
|
167 |
+
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
|
168 |
+
temperature=0,
|
169 |
+
),
|
|
|
170 |
)
|
|
|
171 |
else:
|
172 |
+
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
|
173 |
+
# Bind tools to LLM
|
174 |
+
llm_with_tools = llm.bind_tools(tools)
|
175 |
+
|
176 |
+
# Node
|
177 |
+
def assistant(state: MessagesState):
|
178 |
+
"""Assistant node"""
|
179 |
+
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
180 |
+
|
181 |
+
def retriever(state: MessagesState):
|
182 |
+
"""Retriever node"""
|
183 |
+
similar_question = vector_store.similarity_search(state["messages"][0].content)
|
184 |
+
example_msg = HumanMessage(
|
185 |
+
content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
)
|
187 |
+
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
|
|
|
188 |
|
189 |
+
builder = StateGraph(MessagesState)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
builder.add_node("retriever", retriever)
|
191 |
builder.add_node("assistant", assistant)
|
192 |
+
builder.add_node("tools", ToolNode(tools))
|
193 |
+
builder.add_edge(START, "retriever")
|
|
|
|
|
|
|
|
|
194 |
builder.add_edge("retriever", "assistant")
|
195 |
+
builder.add_conditional_edges(
|
196 |
+
"assistant",
|
197 |
+
tools_condition,
|
198 |
+
)
|
199 |
+
builder.add_edge("tools", "assistant")
|
|
|
|
|
|
|
200 |
|
201 |
+
# Compile graph
|
202 |
+
return builder.compile()
|