wt002 commited on
Commit
cdbcd7d
·
verified ·
1 Parent(s): 9f6f7ea

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +124 -1140
agent.py CHANGED
@@ -4,177 +4,103 @@ import os
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
 
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint#, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.utilities import WikipediaAPIWrapper
13
  from langchain_community.document_loaders import ArxivLoader
 
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
- from sentence_transformers import SentenceTransformer
17
- from langchain.embeddings.base import Embeddings
18
- from typing import List
19
- import numpy as np
20
- import yaml
21
-
22
- import pandas as pd
23
- import uuid
24
- import requests
25
- import json
26
- from langchain_core.documents import Document
27
- from youtube_transcript_api import YouTubeTranscriptApi
28
- from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable
29
- import re
30
-
31
- from langchain_community.document_loaders import TextLoader, PyMuPDFLoader
32
- from docx import Document as DocxDocument
33
- import openpyxl
34
- from io import StringIO
35
-
36
- from transformers import BertTokenizer, BertModel
37
- import torch
38
- import torch.nn.functional as F
39
- from langchain_community.chat_models import ChatOpenAI
40
- from langchain_community.tools import Tool
41
- import time
42
- from huggingface_hub import InferenceClient
43
- from langchain_community.llms import HuggingFaceHub
44
- from langchain.prompts import PromptTemplate
45
- from langchain.chains import LLMChain
46
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
47
- from huggingface_hub import login
48
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
49
- from langchain_huggingface import HuggingFaceEndpoint
50
- #from langchain.agents import initialize_agent
51
- #from langchain.agents import AgentType
52
- from typing import Union
53
- from functools import reduce
54
- import operator
55
- from typing import Union
56
- from functools import reduce
57
- from youtube_transcript_api import YouTubeTranscriptApi
58
- from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable
59
- from langchain.schema import Document
60
-
61
- from langchain_community.vectorstores import FAISS
62
- from langchain_huggingface import HuggingFaceEmbeddings
63
  from langchain.tools.retriever import create_retriever_tool
64
- #from langchain_community.tools import create_retriever_tool
65
- from typing import TypedDict, Annotated, List
66
- import gradio as gr
67
- from langchain.schema import Document
68
-
69
- load_dotenv()
70
-
71
 
72
 
 
73
 
74
  @tool
75
- def calculator(inputs: Union[str, dict]):
76
- """
77
- Perform mathematical operations based on the operation provided.
78
- Supports both binary (a, b) operations and list operations.
 
79
  """
 
80
 
81
- # If input is a JSON string, parse it
82
- if isinstance(inputs, str):
83
- try:
84
- import json
85
- inputs = json.loads(inputs)
86
- except Exception as e:
87
- return f"Invalid input format: {e}"
88
-
89
- # Handle list-based operations like SUM
90
- if "list" in inputs:
91
- nums = inputs.get("list", [])
92
- op = inputs.get("operation", "").lower()
93
-
94
- if not isinstance(nums, list) or not all(isinstance(n, (int, float)) for n in nums):
95
- return "Invalid list input. Must be a list of numbers."
96
-
97
- if op == "sum":
98
- return sum(nums)
99
- elif op == "multiply":
100
- return reduce(operator.mul, nums, 1)
101
- else:
102
- return f"Unsupported list operation: {op}"
103
-
104
- # Handle basic two-number operations
105
- a = inputs.get("a")
106
- b = inputs.get("b")
107
- operation = inputs.get("operation", "").lower()
108
 
109
- if a is None or b is None or not isinstance(a, (int, float)) or not isinstance(b, (int, float)):
110
- return "Both 'a' and 'b' must be numbers."
 
 
 
 
 
 
 
111
 
112
- if operation == "add":
113
- return a + b
114
- elif operation == "subtract":
115
- return a - b
116
- elif operation == "multiply":
117
- return a * b
118
- elif operation == "divide":
119
- if b == 0:
120
- return "Error: Division by zero"
121
- return a / b
122
- elif operation == "modulus":
123
- return a % b
124
- else:
125
- return f"Unknown operation: {operation}"
126
 
 
 
 
 
 
 
 
 
 
127
 
128
  @tool
129
  def wiki_search(query: str) -> str:
130
- """Search Wikipedia for a query and return up to 2 results."""
 
 
 
131
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
132
-
133
  formatted_search_docs = "\n\n---\n\n".join(
134
  [
135
- f'<Document source="{doc.metadata.get("source", "Wikipedia")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
136
  for doc in search_docs
137
- ]
138
- )
139
- return formatted_search_docs
140
-
141
-
142
- @tool
143
- def wikidata_query(query: str) -> str:
144
- """
145
- Run a SPARQL query on Wikidata and return results.
146
- """
147
- endpoint_url = "https://query.wikidata.org/sparql"
148
- headers = {
149
- "Accept": "application/sparql-results+json"
150
- }
151
- response = requests.get(endpoint_url, headers=headers, params={"query": query})
152
- data = response.json()
153
- return json.dumps(data, indent=2)
154
-
155
 
156
  @tool
157
  def web_search(query: str) -> str:
158
- """Search Tavily for a query and return up to 3 results."""
159
- tavily_key = os.getenv("TAVILY_API_KEY")
160
-
161
- if not tavily_key:
162
- return "Error: Tavily API key not set."
163
-
164
- search_tool = TavilySearchResults(tavily_api_key=tavily_key, max_results=3)
165
- search_docs = search_tool.invoke(query=query)
166
 
 
 
 
167
  formatted_search_docs = "\n\n---\n\n".join(
168
  [
169
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
170
  for doc in search_docs
171
  ])
172
-
173
- return formatted_search_docs
174
-
175
 
176
  @tool
177
- def arxiv_search(query: str) -> str:
178
  """Search Arxiv for a query and return maximum 3 result.
179
 
180
  Args:
@@ -185,114 +111,10 @@ def arxiv_search(query: str) -> str:
185
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
186
  for doc in search_docs
187
  ])
188
- return formatted_search_docs
189
-
190
-
191
-
192
- @tool
193
- def analyze_attachment(file_path: str) -> str:
194
- """
195
- Analyzes attachments including PY, PDF, TXT, DOCX, and XLSX files and returns text content.
196
-
197
- Args:
198
- file_path: Local path to the attachment.
199
- """
200
- if not os.path.exists(file_path):
201
- return f"File not found: {file_path}"
202
-
203
- try:
204
- ext = file_path.lower()
205
-
206
- if ext.endswith(".pdf"):
207
- loader = PyMuPDFLoader(file_path)
208
- documents = loader.load()
209
- content = "\n\n".join([doc.page_content for doc in documents])
210
-
211
- elif ext.endswith(".txt") or ext.endswith(".py"):
212
- # Both .txt and .py are plain text files
213
- with open(file_path, "r", encoding="utf-8") as file:
214
- content = file.read()
215
-
216
- elif ext.endswith(".docx"):
217
- doc = DocxDocument(file_path)
218
- content = "\n".join([para.text for para in doc.paragraphs])
219
-
220
- elif ext.endswith(".xlsx"):
221
- wb = openpyxl.load_workbook(file_path, data_only=True)
222
- content = ""
223
- for sheet in wb:
224
- content += f"Sheet: {sheet.title}\n"
225
- for row in sheet.iter_rows(values_only=True):
226
- content += "\t".join([str(cell) if cell is not None else "" for cell in row]) + "\n"
227
-
228
- else:
229
- return "Unsupported file format. Please use PY, PDF, TXT, DOCX, or XLSX."
230
-
231
- return content[:3000] # Limit output size for readability
232
-
233
- except Exception as e:
234
- return f"An error occurred while processing the file: {str(e)}"
235
-
236
-
237
- @tool
238
- def get_youtube_transcript(url: str) -> str:
239
- """
240
- Fetch transcript text from a YouTube video.
241
-
242
- Args:
243
- url (str): Full YouTube video URL.
244
-
245
- Returns:
246
- str: Transcript text as a single string.
247
-
248
- Raises:
249
- ValueError: If no transcript is available or URL is invalid.
250
- """
251
- try:
252
- # Extract video ID
253
- video_id = extract_video_id(url)
254
- transcript = YouTubeTranscriptApi.get_transcript(video_id)
255
-
256
- # Combine all transcript text
257
- full_text = " ".join([entry['text'] for entry in transcript])
258
- return full_text
259
-
260
- except (TranscriptsDisabled, VideoUnavailable) as e:
261
- raise ValueError(f"Transcript not available: {e}")
262
- except Exception as e:
263
- raise ValueError(f"Failed to fetch transcript: {e}")
264
-
265
-
266
-
267
- @tool
268
- def extract_video_id(url: str) -> str:
269
- """
270
- Extract the video ID from a YouTube URL.
271
- """
272
- match = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_-]{11})", url)
273
- if not match:
274
- raise ValueError("Invalid YouTube URL")
275
- return match.group(1)
276
-
277
 
278
 
279
 
280
- # -----------------------------
281
- # Load configuration from YAML
282
- # -----------------------------
283
- with open("config.yaml", "r") as f:
284
- config = yaml.safe_load(f)
285
-
286
- provider = config["provider"]
287
- model_config = config["models"][provider]
288
-
289
- #prompt_path = config["system_prompt_path"]
290
- enabled_tool_names = config["tools"]
291
-
292
-
293
- # -----------------------------
294
- # Load system prompt
295
- # -----------------------------
296
  # load the system prompt from the file
297
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
298
  system_prompt = f.read()
@@ -300,919 +122,81 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
300
  # System message
301
  sys_msg = SystemMessage(content=system_prompt)
302
 
303
-
304
- # -----------------------------
305
- # Map tool names to functions
306
- # -----------------------------
307
- tool_map = {
308
- "math": calculator,
309
- "wiki_search": wiki_search,
310
- "web_search": web_search,
311
- "arxiv_search": arxiv_search,
312
- "get_youtube_transcript": get_youtube_transcript,
313
- "extract_video_id": extract_video_id,
314
- "analyze_attachment": analyze_attachment,
315
- "wikidata_query": wikidata_query
316
- }
317
-
318
- # Then define which tools you want enabled
319
- enabled_tool_names = [
320
- "math",
321
- "wiki_search",
322
- "web_search",
323
- "arxiv_search",
324
- "get_youtube_transcript",
325
- "extract_video_id",
326
- "analyze_attachment",
327
- "wikidata_query"
328
- ]
329
-
330
-
331
- tools = [tool_map[name] for name in enabled_tool_names]
332
-
333
-
334
- # Safe version
335
- tools = []
336
- for name in enabled_tool_names:
337
- if name not in tool_map:
338
- print(f"❌ Tool not found: {name}")
339
- continue
340
- tools.append(tool_map[name])
341
-
342
-
343
-
344
- # -----------------------------
345
- # Prepare Documents
346
- # -----------------------------
347
- # Define the URL where the JSON file is hosted
348
-
349
- import faiss
350
-
351
- # 1. Type-Checked State for Gradio
352
- class ChatState(TypedDict):
353
- messages: Annotated[
354
- List[str],
355
- gr.State(render=False),
356
- "Stores chat history as list of strings"
357
- ]
358
-
359
- # 2. Content Processing Utilities
360
- def process_content(raw_content) -> str:
361
- """Convert any input to a clean string"""
362
- if isinstance(raw_content, list):
363
- return " ".join(str(item) for item in raw_content)
364
- return str(raw_content)
365
-
366
- def reverse_text(text: str) -> str:
367
- """Fix reversed text patterns"""
368
- return text[::-1].replace("\\", "").strip() if text.startswith(('.', ',')) else text
369
-
370
-
371
- # 3. Unified Document Creation
372
-
373
- def create_documents(data_source: str, data: list) -> list:
374
- """Handle both Gradio chat and JSON questions"""
375
- docs = []
376
-
377
- for item in data:
378
- content = ""
379
- # Process different data sources
380
- if data_source == "json":
381
- raw_question = item.get("question", "")
382
- content = raw_question # Adjust as per your content processing logic
383
- else:
384
- print(f"Skipping invalid data source: {data_source}")
385
- continue
386
-
387
- # Ensure metadata type safety
388
- metadata = {
389
- "task_id": str(item.get("task_id", "")),
390
- "level": str(item.get("Level", "")),
391
- "file_name": str(item.get("file_name", ""))
392
- }
393
-
394
- # Check if content is non-empty
395
- if content.strip(): # Only append non-empty content
396
- docs.append(Document(page_content=content, metadata=metadata))
397
- else:
398
- print(f"Skipping invalid entry with empty content: {item}")
399
-
400
- return docs
401
-
402
- # Path to your data.json
403
- file_path = "/home/wendy/Downloads/data.json"
404
-
405
- def load_data(file_path: str) -> list[dict]:
406
- """Safe JSON data loading with error handling"""
407
- if not os.path.exists(file_path):
408
- raise FileNotFoundError(f"Data file not found: {file_path}")
409
-
410
- if not file_path.endswith('.json'):
411
- raise ValueError("Invalid file format. Only JSON files supported")
412
-
413
- try:
414
- with open(file_path, "r", encoding="utf-8") as f:
415
- return json.load(f)
416
- except json.JSONDecodeError:
417
- raise ValueError("Invalid JSON format in data file")
418
- except Exception as e:
419
- raise RuntimeError(f"Error loading data: {str(e)}")
420
-
421
-
422
-
423
- # 4. Vector Store Integration
424
-
425
- import faiss
426
-
427
- # Custom FAISS wrapper (optional, if you still want it)
428
- class MyVector_Store:
429
- def __init__(self, index: faiss.Index):
430
- self.index = index
431
-
432
- def save_local(self, path: str):
433
- faiss.write_index(self.index, path)
434
-
435
- @classmethod
436
- def load_local(cls, path: str):
437
- index = faiss.read_index(path)
438
- return cls(index)
439
-
440
- # -----------------------------
441
- # Process JSON data and create documents
442
- # -----------------------------
443
-
444
- file_path = "/home/wendy/Downloads/data.json"
445
-
446
- try:
447
- with open(file_path, "r", encoding="utf-8") as f:
448
- data = json.load(f)
449
- print(data)
450
- except FileNotFoundError as e:
451
- print(f"Error: {e}")
452
- except json.JSONDecodeError as e:
453
- print(f"Error decoding JSON: {e}")
454
-
455
- docs = create_documents("json", data)
456
- texts = [doc.page_content for doc in docs]
457
-
458
-
459
- # -----------------------------
460
- # Initialize embedding model
461
- # -----------------------------
462
- embedding_model = HuggingFaceEmbeddings(
463
- model_name="sentence-transformers/all-MiniLM-L6-v2"
464
- )
465
-
466
- # -----------------------------
467
- # Create FAISS index and save it
468
- # -----------------------------
469
- class ChatState(TypedDict):
470
- messages: Annotated[
471
- List[str],
472
- gr.State(render=False),
473
- "Stores chat history"
474
- ]
475
-
476
- def initialize_vector_store():
477
- embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
478
- index_path = "/home/wendy/my_hf_agent_course_projects/faiss_index"
479
-
480
- if os.path.exists(os.path.join(index_path, "index.faiss")):
481
- try:
482
- return FAISS.load_local(
483
- index_path,
484
- embedding_model,
485
- allow_dangerous_deserialization=True
486
- )
487
- except Exception as e:
488
- print(f"Error loading index: {e}")
489
-
490
- # Fallback: Create new index
491
- print("Building new vector store...")
492
- docs = [...] # Your document loading logic here
493
- vector_store = FAISS.from_documents(docs, embedding_model)
494
- vector_store.save_local(index_path)
495
- return vector_store
496
-
497
- # Initialize at module level
498
- loaded_store = initialize_vector_store()
499
- retriever = loaded_store.as_retriever()
500
-
501
- # -----------------------------
502
- # Create LangChain Retriever Tool
503
- # -----------------------------
504
- #retriever = loaded_store.as_retriever()
505
-
506
- question_retriever_tool = create_retriever_tool(
507
- retriever=retriever,
508
- name="Question_Search",
509
- description="A tool to retrieve documents related to a user's question."
510
  )
511
-
512
- # -----------------------------
513
- # Load HuggingFace LLM
514
- # -----------------------------
515
- llm = HuggingFaceEndpoint(
516
- repo_id="HuggingFaceH4/zephyr-7b-beta",
517
- task="text-generation",
518
- huggingfacehub_api_token=os.getenv("HF_TOKEN"),
519
- temperature=0.7,
520
- max_new_tokens=512
521
  )
522
 
523
-
524
-
525
- # -------------------------------
526
- # Step 8: Use the Planner, Classifier, and Decision Logic
527
- # -------------------------------
528
-
529
- def process_question(question):
530
- # Step 1: Planner generates the task sequence
531
- tasks = planner(question)
532
- print(f"Tasks to perform: {tasks}")
533
-
534
- # Step 2: Classify the task (based on question)
535
- task_type = task_classifier(question)
536
- print(f"Task type: {task_type}")
537
-
538
- # Step 3: Use the classifier and planner to decide on the next task or node
539
- state = {"question": question, "last_response": ""}
540
- next_task = decide_task(state)
541
- print(f"Next task: {next_task}")
542
-
543
- # Step 4: Use node skipper logic (skip if needed)
544
- skip = node_skipper(state)
545
- if skip:
546
- print(f"Skipping to {skip}")
547
- return skip # Or move directly to generating answer
548
-
549
- # Step 5: Execute task (with error handling)
550
- try:
551
- if task_type == "wiki_search":
552
- response = wiki_search(question)
553
- elif task_type == "math":
554
- response = calculator(question)
555
- else:
556
- response = "Default answer logic"
557
-
558
- # Step 6: Final response formatting
559
- final_response = final_answer_tool(state, {'wiki_search': response})
560
- return final_response
561
-
562
- except Exception as e:
563
- print(f"Error executing task: {e}")
564
- return "Sorry, I encountered an error processing your request."
565
-
566
-
567
- # Run the process
568
- question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
569
- response = agent.invoke(question)
570
- print("Final Response:", response)
571
-
572
-
573
-
574
-
575
- from langchain.schema import HumanMessage
576
-
577
- def retriever(state: MessagesState, k: int = 4):
578
- """
579
- Retrieves documents from the vector store using similarity scores,
580
- applies a dynamic threshold filter, and returns updated message state.
581
-
582
- Args:
583
- state (MessagesState): Current message state including the user's query.
584
- k (int): Number of top results to retrieve from the vector store.
585
-
586
- Returns:
587
- dict: Updated messages state including relevant documents or fallback message.
588
- """
589
- query = state["messages"][0].content.strip()
590
- results = vector_store.similarity_search_with_score(query, k=k)
591
-
592
- # Determine dynamic similarity threshold
593
- if any(keyword in query.lower() for keyword in ["who", "what", "where", "when", "why", "how"]):
594
- threshold = 0.75
595
- else:
596
- threshold = 0.8
597
-
598
- filtered = [doc for doc, score in results if score < threshold]
599
-
600
- if not filtered:
601
- response_msg = HumanMessage(content="No relevant documents found.")
602
- else:
603
- content = "\n\n".join(doc.page_content for doc in filtered)
604
- response_msg = HumanMessage(content=f"Here are relevant reference documents:\n\n{content}")
605
-
606
- return {"messages": [sys_msg] + state["messages"] + [response_msg]}
607
-
608
-
609
 
 
 
 
 
 
 
 
 
 
 
610
 
611
- # ----------------------------------------------------------------
612
- # LLM Loader
613
- # ----------------------------------------------------------------
614
- def get_llm(provider: str, config: dict):
615
  if provider == "google":
616
- from langchain_google_genai import ChatGoogleGenerativeAI
617
- return ChatGoogleGenerativeAI(
618
- model=config.get("model"),
619
- temperature=config.get("temperature", 0.7),
620
- google_api_key=config.get("api_key") # Optional: if needed
621
- )
622
-
623
  elif provider == "groq":
624
- from langchain_groq import ChatGroq
625
- return ChatGroq(
626
- model=config.get("model"),
627
- temperature=config.get("temperature", 0.7),
628
- groq_api_key=config.get("api_key") # Optional: if needed
629
- )
630
-
631
  elif provider == "huggingface":
632
- from langchain_huggingface import ChatHuggingFace
633
- from langchain_huggingface import HuggingFaceEndpoint
634
- return ChatHuggingFace(
635
  llm=HuggingFaceEndpoint(
636
- endpoint_url=config.get("url"),
637
- temperature=config.get("temperature", 0.7),
638
- huggingfacehub_api_token=config.get("api_key") # Optional
639
- )
640
  )
641
-
642
  else:
643
- raise ValueError(f"Invalid provider: {provider}")
644
-
645
-
646
-
647
- # ----------------------------------------------------------------
648
- # Planning & Execution Logic
649
- # ----------------------------------------------------------------
650
- def planner(question: str, tools: list) -> tuple:
651
- """
652
- Select the best-matching tool(s) for a question based on keyword-based intent detection and tool metadata.
653
- Returns the detected intent and matched tools.
654
- """
655
- question = question.lower().strip()
656
-
657
- # Define intent-based keywords
658
- intent_keywords = {
659
- "math": ["calculate", "evaluate", "add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"],
660
- "wiki_search": ["who is", "what is", "define", "explain", "tell me about", "overview of"],
661
- "web_search": ["search", "find", "look up", "google", "latest news", "current info"],
662
- "arxiv_search": ["arxiv", "research paper", "scientific paper", "preprint"],
663
- "get_youtube_transcript": ["youtube", "watch", "play video", "show me a video"],
664
- "extract_video_id": ["analyze video", "summarize video", "video content"],
665
- "data_analysis": ["analyze", "plot", "graph", "data", "visualize"],
666
- "wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"],
667
- "default": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"]
668
- }
669
-
670
-
671
-
672
- # Step 1: Identify intent
673
- detected_intent = None
674
- for intent, keywords in intent_keywords.items():
675
- if any(keyword in question for keyword in keywords):
676
- detected_intent = intent
677
- break
678
-
679
- # Step 2: Match tools by intent
680
- matched_tools = []
681
- if detected_intent:
682
- for tool in tools:
683
- name = getattr(tool, "name", "").lower()
684
- description = getattr(tool, "description", "").lower()
685
- if detected_intent in name or detected_intent in description:
686
- matched_tools.append(tool)
687
-
688
- # Step 3: Fallback to general-purpose/default tools if no match found
689
- if not matched_tools:
690
- matched_tools = [
691
- tool for tool in tools
692
- if "default" in getattr(tool, "name", "").lower()
693
- or "qa" in getattr(tool, "description", "").lower()
694
- ]
695
-
696
- return detected_intent, matched_tools if matched_tools else [tools[0]]
697
-
698
-
699
-
700
-
701
- def task_classifier(question: str) -> str:
702
- """
703
- Classifies the question into one of the predefined task categories.
704
- """
705
- question = question.lower().strip()
706
-
707
- # Context-aware intent patterns
708
- if any(phrase in question for phrase in [
709
- "calculate", "how much is", "what is the result of", "evaluate", "solve"
710
- ]) or any(op in question for op in ["add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"]):
711
- return "math"
712
-
713
- elif any(phrase in question for phrase in [
714
- "who is", "what is", "define", "explain", "tell me about", "give me an overview of"
715
- ]):
716
- return "wiki_search"
717
-
718
- elif any(phrase in question for phrase in [
719
- "search", "find", "look up", "google", "get the latest", "current news", "trending"
720
- ]):
721
- return "web_search"
722
-
723
- elif any(phrase in question for phrase in [
724
- "arxiv", "latest research", "scientific paper", "research paper", "preprint"
725
- ]):
726
- return "arxiv_search"
727
-
728
- elif any(phrase in question for phrase in [
729
- "youtube", "watch", "play the video", "show me a video"
730
- ]):
731
- return "get_youtube_transcript"
732
-
733
- elif any(phrase in question for phrase in [
734
- "analyze video", "summarize video", "what happens in the video", "video content"
735
- ]):
736
- return "video_analysis"
737
-
738
- elif any(phrase in question for phrase in [
739
- "analyze", "visualize", "plot", "graph", "inspect data", "explore dataset"
740
- ]):
741
- return "data_analysis"
742
-
743
- elif any(phrase in question for phrase in [
744
- "sparql", "wikidata", "query wikidata", "run sparql", "wikidata query"
745
- ]):
746
- return "wikidata_query"
747
-
748
- return "default"
749
-
750
-
751
- def select_tool_and_run(question: str, tools: dict):
752
- # Step 1: Classify intent
753
- intent = task_classifier(question) # assuming task_classifier maps the question to intent
754
-
755
- # Map intent to tool names
756
- intent_tool_map = {
757
- "math": "calculator", # maps to tools["math"] → calculator
758
- "wiki_search": "wiki_search", # → wiki_search
759
- "web_search": "web_search", # → web_search
760
- "arxiv_search": "arxiv_search", # → arxiv_search (spelling fixed)
761
- "get_youtube_transcript": "get_youtube_transcript", # → get_youtube_transcript
762
- "extract_video_id": "extract_video_id", # adjust based on your tools
763
- "analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
764
- "wikidata_query": "wikidata_query", # → wikidata_query
765
- "default": "default" # → default_tool
766
- }
767
-
768
- # Get the corresponding tool name
769
- tool_name = intent_tool_map.get(intent, "default") # Default to "default" if no match
770
-
771
- # Retrieve the tool from the tools dictionary
772
- tool_func = tools.get(tool_name)
773
-
774
- if not tool_func:
775
- return f"Tool not found for intent '{intent}'"
776
-
777
- # Step 2: Run the tool
778
- try:
779
- # If the tool needs JSON or structured data
780
- try:
781
- parsed_input = json.loads(question)
782
- except json.JSONDecodeError:
783
- parsed_input = question # fallback to raw input if not JSON
784
-
785
- # Run the selected tool
786
- print(f"Running tool: {tool_name} with input: {parsed_input}") # log the tool name and input
787
- return tool_func(parsed_input)
788
- except Exception as e:
789
- return f"Error while running tool '{tool_name}': {str(e)}"
790
-
791
-
792
-
793
- # Function to extract math operation from the question
794
-
795
- def extract_math_from_question(question: str):
796
- question = question.lower()
797
-
798
- # Map natural language to symbols
799
- ops = {
800
- "add": "+", "plus": "+",
801
- "subtract": "-", "minus": "-",
802
- "multiply": "*", "times": "*",
803
- "divide": "/", "divided by": "/",
804
- "modulus": "%", "mod": "%"
805
- }
806
-
807
- for word, symbol in ops.items():
808
- question = re.sub(rf"\b{word}\b", symbol, question)
809
-
810
- # Extract math expression like "12 + 5"
811
- match = re.search(r'(\d+)\s*([\+\-\*/%])\s*(\d+)', question)
812
- if match:
813
- num1 = int(match.group(1))
814
- operator = match.group(2)
815
- num2 = int(match.group(3))
816
- return {
817
- "a": num1,
818
- "b": num2,
819
- "operation": {
820
- "+": "add",
821
- "-": "subtract",
822
- "*": "multiply",
823
- "/": "divide",
824
- "%": "modulus"
825
- }[operator]
826
- }
827
- return None
828
-
829
-
830
-
831
- # Example tool set (adjust these to match your actual tool names)
832
- intent_tool_map = {
833
- "math": "math", # maps to tools["math"] → calculator
834
- "wiki_search": "wiki_search", # → wiki_search
835
- "web_search": "web_search", # → web_search
836
- "arxiv_search": "arxiv_search", # → arxiv_search (spelling fixed)
837
- "get_youtube_transcript": "get_youtube_transcript", # → get_youtube_transcript
838
- "extract_video_id": "extract_video_id", # adjust based on your tools
839
- "analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
840
- "wikidata_query": "wikidata_query", # → wikidata_query
841
- "default": "default" # → default_tool
842
- }
843
-
844
-
845
-
846
- # The task order can also include the tools for each task
847
- priority_order = [
848
- {"task": "math", "tool": "math"},
849
- {"task": "wiki_search", "tool": "wiki_search"},
850
- {"task": "web_search", "tool": "web_search"},
851
- {"task": "arxiv_search", "tool": "arxiv_search"},
852
- {"task": "wikidata_query", "tool": "wikidata_query"},
853
- {"task": "retriever", "tool": "retriever"},
854
- {"task": "get_youtube_transcript", "tool": "get_youtube_transcript"},
855
- {"task": "extract_video_id", "tool": "extract_video_id"},
856
- {"task": "analyze_attachment", "tool": "analyze_attachment"},
857
- {"task": "default", "tool": "default"} # Fallback
858
- ]
859
-
860
- def decide_task(state: dict) -> str:
861
- """Decides which task to perform based on the current state."""
862
-
863
- # Get the list of tasks from the planner
864
- tasks = planner(state["question"])
865
- print(f"Available tasks: {tasks}") # Debugging: show all possible tasks
866
-
867
- # Check if the tasks list is empty or invalid
868
- if not tasks:
869
- print("❌ No valid tasks were returned from the planner.")
870
- return "default" # Return a default task if no tasks were generated
871
-
872
- # If there are multiple tasks, we can prioritize based on certain conditions
873
- task = tasks[0] # Default to the first task in the list
874
- if len(tasks) > 1:
875
- print(f"⚠️ Multiple tasks found. Deciding based on priority.")
876
- # Example logic to prioritize tasks, adjust based on your use case
877
- task = prioritize_tasks(tasks)
878
-
879
- print(f"Decided on task: {task}") # Debugging: show the final task
880
- return task
881
-
882
-
883
- def prioritize_tasks(tasks: list) -> str:
884
- """Prioritize tasks based on certain conditions or criteria, including tools."""
885
- # Sort tasks based on priority_order mapping
886
- for priority in priority_order:
887
- # Check if any task matches the priority task type
888
- for task in tasks:
889
- if priority["task"] in task:
890
- print(f"✅ Prioritizing task: {task} with tool: {priority['tool']}") # Debugging: show the chosen task and tool
891
- # Assign the correct tool based on the task
892
- tool = tools.get(priority["tool"], tools["default"]) # Default to 'default_tool' if not found
893
- return task, tool
894
-
895
- # If no priority task is found, return the first task with its default tool
896
- return tasks[0], tools["default"]
897
-
898
-
899
- def process_question(question: str):
900
- """Process the question and route it to the appropriate tool."""
901
- # Get the tasks from the planner
902
- tasks = planner(question)
903
- print(f"Tasks to perform: {tasks}")
904
-
905
- task_type, tool = decide_task({"question": question})
906
- print(f"Next task: {task_type} with tool: {tool}")
907
-
908
- if node_skipper({"question": question}):
909
- print(f"Skipping task: {task_type}")
910
- return "Task skipped."
911
-
912
- try:
913
- # Execute the corresponding tool for the task type
914
- if task_type == "wiki_search":
915
- response = tool.run(question) # Assuming tool is wiki_tool
916
- elif task_type == "math":
917
- response = tool.run(question) # Assuming tool is calc_tool
918
- elif task_type == "retriever":
919
- response = tool.run(question) # Assuming tool is retriever_tool
920
- else:
921
- response = tool.run(question) # Default tool
922
-
923
- return generate_final_answer({"question": question}, {task_type: response})
924
-
925
- except Exception as e:
926
- print(f"❌ Error: {e}")
927
- return f"Sorry, I encountered an error: {str(e)}"
928
-
929
-
930
-
931
-
932
- def call_llm(state):
933
- messages = state["messages"]
934
- response = llm.invoke(messages)
935
- return {"messages": messages + [response]}
936
-
937
-
938
-
939
-
940
- from langchain.schema import AIMessage
941
- from typing import TypedDict, List, Optional
942
- from langchain_core.messages import BaseMessage
943
-
944
- class AgentState(TypedDict):
945
- messages: List[BaseMessage] # Chat history
946
- input: str # Original input
947
- intent: str # Derived or predicted intent
948
- result: Optional[str] # Optional result
949
-
950
-
951
- def tool_dispatcher(state: AgentState) -> AgentState:
952
- last_msg = state["messages"][-1]
953
-
954
- # Make sure it's an AI message with tool_calls
955
- if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
956
- tool_call = last_msg.tool_calls[0]
957
- tool_name = tool_call["name"]
958
- tool_input = tool_call["args"] # Adjust based on your actual schema
959
-
960
- tool_func = tool_map.get(tool_name, default_tool)
961
-
962
- # If args is a dict and your tool expects unpacked values:
963
- if isinstance(tool_input, dict):
964
- result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(**tool_input)
965
- else:
966
- result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(tool_input)
967
-
968
- # You can choose to append this to messages, or just save result
969
- return {
970
- **state,
971
- "result": result,
972
- # Optionally add: "messages": state["messages"] + [ToolMessage(...)]
973
- }
974
-
975
- # No tool call detected, return state unchanged
976
- return state
977
-
978
-
979
-
980
-
981
- # Decide what to do next: if tool call → call_tool, else → end
982
- def should_call_tool(state):
983
- last_msg = state["messages"][-1]
984
- if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
985
- return "call_tool"
986
- return "end"
987
-
988
-
989
- from typing import TypedDict, List, Optional, Union
990
- from langchain.schema import BaseMessage
991
-
992
- class AgentState(TypedDict):
993
- messages: List[BaseMessage] # Chat history
994
- input: str # Original input
995
- intent: str # Derived or predicted intent
996
- result: Optional[str] # Final or intermediate result
997
-
998
-
999
-
1000
-
1001
-
1002
-
1003
- # To store previously asked questions and timestamps (simulating state persistence)
1004
- recent_questions = {}
1005
-
1006
- def node_skipper(state: dict) -> bool:
1007
- """
1008
- Determines whether to skip the task based on the state.
1009
- This could include:
1010
- 1. Repeated or similar questions
1011
- 2. Irrelevant or empty questions
1012
- 3. Tasks that have already been processed recently
1013
- """
1014
- question = state.get("question", "").strip()
1015
-
1016
- if not question:
1017
- print("❌ Skipping: Empty or invalid question.")
1018
- return True # Skip if no valid question
1019
-
1020
- # 1. Skip if the question has already been asked recently (within a given time window)
1021
- # Here, we're using a simple example with a 5-minute window (300 seconds).
1022
- if question in recent_questions:
1023
- last_asked_time = recent_questions[question]
1024
- time_since_last_ask = time.time() - last_asked_time
1025
- if time_since_last_ask < 300: # 5-minute threshold
1026
- print(f"❌ Skipping: The question has been asked recently. Time since last ask: {time_since_last_ask:.2f} seconds.")
1027
- return True # Skip if the question was asked within the last 5 minutes
1028
-
1029
- # 2. Skip if the question is irrelevant or not meaningful enough
1030
- irrelevant_keywords = ["blah", "nothing", "invalid", "nonsense"]
1031
- if any(keyword in question.lower() for keyword in irrelevant_keywords):
1032
- print("❌ Skipping: Irrelevant or nonsense question.")
1033
- return True # Skip if the question contains irrelevant keywords
1034
-
1035
- # 3. Skip if the task has already been completed for this question (based on a unique task identifier)
1036
- if "last_response" in state and state["last_response"]:
1037
- print("❌ Skipping: Task has already been processed recently.")
1038
- return True # Skip if a response has already been given
1039
-
1040
- # 4. Skip based on a condition related to the task itself
1041
- # Example: Skip math-related tasks if the result is already known or trivial
1042
- if "math" in state.get("question", "").lower():
1043
- # If math is trivial (like "What is 2+2?")
1044
- trivial_math = ["2 + 2", "1 + 1", "3 + 3"]
1045
- if any(trivial_question in question for trivial_question in trivial_math):
1046
- print(f"❌ Skipping trivial math question: {question}")
1047
- return True # Skip if the math question is trivial
1048
-
1049
- # 5. Skip based on external factors (e.g., current time, system load, etc.)
1050
- # Example: Avoid processing tasks at night if that's part of the business logic
1051
- current_hour = time.localtime().tm_hour
1052
- if current_hour >= 22 or current_hour < 6:
1053
- print("❌ Skipping: It's night time, not processing tasks.")
1054
- return True # Skip tasks during night time (e.g., between 10 PM and 6 AM)
1055
-
1056
- # If none of the conditions matched, don't skip the task
1057
- return False
1058
-
1059
- # Update recent questions (for simulating repeated question check)
1060
- def update_recent_questions(question: str):
1061
- """Update the recent questions dictionary with the current timestamp."""
1062
- recent_questions[question] = time.time()
1063
-
1064
-
1065
-
1066
- def generate_final_answer(state: dict, task_results: dict) -> str:
1067
- """Generate a final answer based on the results of the task."""
1068
- if "wiki_search" in task_results:
1069
- return f"📚 Wiki Summary:\n{task_results['wiki_search']}"
1070
- elif "math" in task_results:
1071
- return f"🧮 Math Result: {task_results['math']}"
1072
- elif "retriever" in task_results:
1073
- return f"🔍 Retrieved Info: {task_results['retriever']}"
1074
- else:
1075
- return "🤖 Unable to generate a specific answer."
1076
-
1077
-
1078
- def answer_question(question: str) -> str:
1079
- """Process a single question and return the answer."""
1080
- print(f"Processing question: {question[:50]}...") # Debugging: show first 50 chars
1081
-
1082
- # Wrap the question in a HumanMessage from langchain_core (assuming langchain is used)
1083
- messages = [HumanMessage(content=question)]
1084
- response = graph.invoke({"messages": messages}) # Assuming `graph` is defined elsewhere
1085
-
1086
- # Extract the answer from the response
1087
- answer = response['messages'][-1].content
1088
- return answer[14:] # Assuming 'answer[14:]' is correct based on your example
1089
-
1090
-
1091
- def process_all_tasks(tasks: list):
1092
- """Process a list of tasks."""
1093
- results = {}
1094
-
1095
- for task in tasks:
1096
- question = task.get("question", "").strip()
1097
- if not question:
1098
- print(f"Skipping task with missing or empty 'question': {task}")
1099
- continue
1100
-
1101
- print(f"\n🟢 Processing Task: {task['task_id']} - Question: {question}")
1102
-
1103
- # Call the existing process_question logic
1104
- response = process_question(question)
1105
-
1106
- print(f"✅ Response: {response}")
1107
- results[task['task_id']] = response
1108
-
1109
- return results
1110
-
1111
-
1112
-
1113
-
1114
-
1115
- ## Langgraph
1116
-
1117
- # Build graph function
1118
- vector_store = vector_store.save_local("faiss_index")
1119
-
1120
- provider = "huggingface"
1121
-
1122
- model_config = {
1123
- "repo_id": "HuggingFaceH4/zephyr-7b-beta",
1124
- "task": "text-generation",
1125
- "temperature": 0.7,
1126
- "max_new_tokens": 512,
1127
- "huggingfacehub_api_token": os.getenv("HF_TOKEN")
1128
- }
1129
-
1130
- # Get LLM
1131
- def get_llm(provider: str, config: dict):
1132
- if provider == "huggingface":
1133
- from langchain_huggingface import HuggingFaceEndpoint
1134
- return HuggingFaceEndpoint(
1135
- repo_id=config["repo_id"],
1136
- task=config["task"],
1137
- huggingfacehub_api_token=config["huggingfacehub_api_token"],
1138
- temperature=config["temperature"],
1139
- max_new_tokens=config["max_new_tokens"]
1140
  )
1141
- else:
1142
- raise ValueError(f"Unsupported provider: {provider}")
1143
 
1144
-
1145
- def assistant(state: dict):
1146
- return {
1147
- "messages": [llm_with_tools.invoke(state["messages"])]
1148
- }
1149
-
1150
-
1151
- def tools_condition(state: dict) -> str:
1152
- if "use tool" in state["messages"][-1].content.lower():
1153
- return "tools"
1154
- else:
1155
- return "END"
1156
-
1157
-
1158
-
1159
- from langgraph.graph import StateGraph
1160
- from langchain_core.messages import SystemMessage
1161
- from langchain_core.runnables import RunnableLambda
1162
- def build_graph(vector_store, provider: str, model_config: dict) -> StateGraph:
1163
- # Get LLM
1164
- llm = get_llm(provider, model_config)
1165
-
1166
- # Define available tools
1167
- tools = [
1168
- wiki_search, calculator, web_search, arxiv_search,
1169
- get_youtube_transcript, extract_video_id, analyze_attachment, wikidata_query
1170
- ]
1171
-
1172
- # Tool mapping (global if needed elsewhere)
1173
- global tool_map
1174
- tool_map = {t.name: t for t in tools}
1175
-
1176
- # Bind tools only if LLM supports it
1177
- if hasattr(llm, "bind_tools"):
1178
- llm_with_tools = llm.bind_tools(tools)
1179
- else:
1180
- llm_with_tools = llm # fallback for non-tool-aware models
1181
-
1182
- sys_msg = SystemMessage(content="You are a helpful assistant.")
1183
-
1184
- # Define nodes as runnables
1185
- retriever = RunnableLambda(lambda state: {
1186
- **state,
1187
- "retrieved_docs": vector_store.similarity_search(state["input"])
1188
- })
1189
-
1190
- assistant = RunnableLambda(lambda state: {
1191
- **state,
1192
- "messages": [sys_msg] + state["messages"]
1193
- })
1194
-
1195
- call_llm = llm_with_tools # already configured
1196
-
1197
- # Start building the graph
1198
- builder = StateGraph(AgentState)
1199
  builder.add_node("retriever", retriever)
1200
  builder.add_node("assistant", assistant)
1201
- builder.add_node("call_llm", call_llm)
1202
- builder.add_node("call_tool", tool_dispatcher)
1203
- builder.add_node("end", lambda state: state) # Add explicit end node
1204
-
1205
- # Define graph flow
1206
- builder.set_entry_point("retriever")
1207
  builder.add_edge("retriever", "assistant")
1208
- builder.add_edge("assistant", "call_llm")
1209
-
1210
- builder.add_conditional_edges("call_llm", should_call_tool, {
1211
- "call_tool": "call_tool",
1212
- "end": "end" # ✅ fixed: must point to actual "end" node
1213
- })
1214
-
1215
- builder.add_edge("call_tool", "call_llm") # loop back after tool call
1216
 
1217
- return builder.compile()
1218
-
 
4
  from dotenv import load_dotenv
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import tools_condition
7
+ from langgraph.prebuilt import ToolNode
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_groq import ChatGroq
10
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
11
  from langchain_community.tools.tavily_search import TavilySearchResults
12
  from langchain_community.document_loaders import WikipediaLoader
 
13
  from langchain_community.document_loaders import ArxivLoader
14
+ from langchain_community.vectorstores import SupabaseVectorStore
15
  from langchain_core.messages import SystemMessage, HumanMessage
16
  from langchain_core.tools import tool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from langchain.tools.retriever import create_retriever_tool
18
+ from supabase.client import Client, create_client
 
 
 
 
 
 
19
 
20
 
21
+ load_dotenv()
22
 
23
  @tool
24
+ def multiply(a: int, b: int) -> int:
25
+ """Multiply two numbers.
26
+ Args:
27
+ a: first int
28
+ b: second int
29
  """
30
+ return a * b
31
 
32
+ @tool
33
+ def add(a: int, b: int) -> int:
34
+ """Add two numbers.
35
+
36
+ Args:
37
+ a: first int
38
+ b: second int
39
+ """
40
+ return a + b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ @tool
43
+ def subtract(a: int, b: int) -> int:
44
+ """Subtract two numbers.
45
+
46
+ Args:
47
+ a: first int
48
+ b: second int
49
+ """
50
+ return a - b
51
 
52
+ @tool
53
+ def divide(a: int, b: int) -> int:
54
+ """Divide two numbers.
55
+
56
+ Args:
57
+ a: first int
58
+ b: second int
59
+ """
60
+ if b == 0:
61
+ raise ValueError("Cannot divide by zero.")
62
+ return a / b
 
 
 
63
 
64
+ @tool
65
+ def modulus(a: int, b: int) -> int:
66
+ """Get the modulus of two numbers.
67
+
68
+ Args:
69
+ a: first int
70
+ b: second int
71
+ """
72
+ return a % b
73
 
74
  @tool
75
  def wiki_search(query: str) -> str:
76
+ """Search Wikipedia for a query and return maximum 2 results.
77
+
78
+ Args:
79
+ query: The search query."""
80
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
 
81
  formatted_search_docs = "\n\n---\n\n".join(
82
  [
83
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
84
  for doc in search_docs
85
+ ])
86
+ return {"wiki_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  @tool
89
  def web_search(query: str) -> str:
90
+ """Search Tavily for a query and return maximum 3 results.
 
 
 
 
 
 
 
91
 
92
+ Args:
93
+ query: The search query."""
94
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
95
  formatted_search_docs = "\n\n---\n\n".join(
96
  [
97
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
98
  for doc in search_docs
99
  ])
100
+ return {"web_results": formatted_search_docs}
 
 
101
 
102
  @tool
103
+ def arvix_search(query: str) -> str:
104
  """Search Arxiv for a query and return maximum 3 result.
105
 
106
  Args:
 
111
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
112
  for doc in search_docs
113
  ])
114
+ return {"arvix_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # load the system prompt from the file
119
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
120
  system_prompt = f.read()
 
122
  # System message
123
  sys_msg = SystemMessage(content=system_prompt)
124
 
125
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
+ supabase: Client = create_client(
127
+ os.environ.get("SUPABASE_URL"),
128
+ os.environ.get("SUPABASE_SERVICE_KEY"))
129
+ vector_store = SupabaseVectorStore(
130
+ client=supabase,
131
+ embedding= embeddings,
132
+ table_name="documents",
133
+ query_name="match_documents_langchain",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  )
135
+ create_retriever_tool = create_retriever_tool(
136
+ retriever=vector_store.as_retriever(),
137
+ name="Question Search",
138
+ description="A tool to retrieve similar questions from a vector store.",
 
 
 
 
 
 
139
  )
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ tools = [
143
+ multiply,
144
+ add,
145
+ subtract,
146
+ divide,
147
+ modulus,
148
+ wiki_search,
149
+ web_search,
150
+ arvix_search,
151
+ ]
152
 
153
+ # Build graph function
154
+ def build_graph(provider: str = "google"):
155
+ """Build the graph"""
156
+ # Load environment variables from .env file
157
  if provider == "google":
158
+ # Google Gemini
159
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
 
 
 
 
 
160
  elif provider == "groq":
161
+ # Groq https://console.groq.com/docs/models
162
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
 
 
 
 
 
163
  elif provider == "huggingface":
164
+ # TODO: Add huggingface endpoint
165
+ llm = ChatHuggingFace(
 
166
  llm=HuggingFaceEndpoint(
167
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
168
+ temperature=0,
169
+ ),
 
170
  )
 
171
  else:
172
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
+ # Bind tools to LLM
174
+ llm_with_tools = llm.bind_tools(tools)
175
+
176
+ # Node
177
+ def assistant(state: MessagesState):
178
+ """Assistant node"""
179
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
+
181
+ def retriever(state: MessagesState):
182
+ """Retriever node"""
183
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
184
+ example_msg = HumanMessage(
185
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  )
187
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
188
 
189
+ builder = StateGraph(MessagesState)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  builder.add_node("retriever", retriever)
191
  builder.add_node("assistant", assistant)
192
+ builder.add_node("tools", ToolNode(tools))
193
+ builder.add_edge(START, "retriever")
 
 
 
 
194
  builder.add_edge("retriever", "assistant")
195
+ builder.add_conditional_edges(
196
+ "assistant",
197
+ tools_condition,
198
+ )
199
+ builder.add_edge("tools", "assistant")
 
 
 
200
 
201
+ # Compile graph
202
+ return builder.compile()