Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
@@ -1,147 +1,1218 @@
|
|
|
|
|
|
1 |
import os
|
2 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
|
|
|
|
4 |
import requests
|
5 |
-
|
6 |
-
|
7 |
-
from
|
8 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
@tool
|
12 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
"""
|
14 |
-
|
|
|
15 |
Args:
|
16 |
-
|
17 |
-
file_name (str): The name to assign to the downloaded file.
|
18 |
-
Returns:
|
19 |
-
str: Path to the downloaded file or an error message if the download fails.
|
20 |
"""
|
|
|
|
|
21 |
|
22 |
try:
|
23 |
-
|
24 |
-
temp_dir = tempfile.gettempdir()
|
25 |
-
filepath = os.path.join(temp_dir, file_name)
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
return filepath
|
38 |
except Exception as e:
|
39 |
-
return f"
|
40 |
|
41 |
|
42 |
@tool
|
43 |
-
def
|
44 |
"""
|
45 |
-
|
|
|
46 |
Args:
|
47 |
-
|
48 |
-
|
49 |
Returns:
|
50 |
-
str:
|
|
|
|
|
|
|
51 |
"""
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
)
|
62 |
|
63 |
-
return response.text
|
64 |
|
65 |
|
66 |
@tool
|
67 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
"""
|
69 |
-
|
|
|
|
|
70 |
Args:
|
71 |
-
|
72 |
-
|
|
|
73 |
Returns:
|
74 |
-
|
75 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
)
|
89 |
-
)
|
90 |
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
|
|
|
|
93 |
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
96 |
"""
|
97 |
-
|
98 |
-
|
99 |
-
path_file_image (str): Path to the image file to be analyzed.
|
100 |
-
query (str): Question or query to analyze the content of the image file.
|
101 |
-
Returns:
|
102 |
-
str: The result of the analysis of audio.
|
103 |
"""
|
|
|
104 |
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
myfile = client.files.upload(file=path_file_image)
|
108 |
|
109 |
-
response = client.models.generate_content(
|
110 |
-
model=os.getenv('GOOGLE_MODEL_ID'),
|
111 |
-
contents=[myfile,
|
112 |
-
f"Carefully analyze the image file and think to answer the question correctly.\n\n The question is {query}"]
|
113 |
-
)
|
114 |
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
"""
|
121 |
-
|
122 |
-
Args:
|
123 |
-
file_path: Path to the Excel file
|
124 |
-
query: Question about the data
|
125 |
-
Returns:
|
126 |
-
Analysis result or error message
|
127 |
"""
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
try:
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
# Read the Excel file
|
133 |
-
df = pd.read_excel(file_path)
|
134 |
|
135 |
-
|
136 |
-
result = f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
|
137 |
-
result += f"Columns: {', '.join(df.columns)}\n\n"
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
-
return result
|
144 |
-
except ImportError:
|
145 |
-
return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
|
146 |
except Exception as e:
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# agent.py
|
2 |
+
|
3 |
import os
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from langgraph.graph import START, StateGraph, MessagesState
|
6 |
+
from langgraph.prebuilt import tools_condition
|
7 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
8 |
+
from langchain_groq import ChatGroq
|
9 |
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint#, HuggingFaceEmbeddings
|
10 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
11 |
+
from langchain_community.document_loaders import WikipediaLoader
|
12 |
+
from langchain_community.utilities import WikipediaAPIWrapper
|
13 |
+
from langchain_community.document_loaders import ArxivLoader
|
14 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
15 |
+
from langchain_core.tools import tool
|
16 |
+
from sentence_transformers import SentenceTransformer
|
17 |
+
from langchain.embeddings.base import Embeddings
|
18 |
+
from typing import List
|
19 |
+
import numpy as np
|
20 |
+
import yaml
|
21 |
|
22 |
+
import pandas as pd
|
23 |
+
import uuid
|
24 |
import requests
|
25 |
+
import json
|
26 |
+
from langchain_core.documents import Document
|
27 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
28 |
+
from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable
|
29 |
+
import re
|
30 |
+
|
31 |
+
from langchain_community.document_loaders import TextLoader, PyMuPDFLoader
|
32 |
+
from docx import Document as DocxDocument
|
33 |
+
import openpyxl
|
34 |
+
from io import StringIO
|
35 |
+
|
36 |
+
from transformers import BertTokenizer, BertModel
|
37 |
+
import torch
|
38 |
+
import torch.nn.functional as F
|
39 |
+
from langchain_community.chat_models import ChatOpenAI
|
40 |
+
from langchain_community.tools import Tool
|
41 |
+
import time
|
42 |
+
from huggingface_hub import InferenceClient
|
43 |
+
from langchain_community.llms import HuggingFaceHub
|
44 |
+
from langchain.prompts import PromptTemplate
|
45 |
+
from langchain.chains import LLMChain
|
46 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
47 |
+
from huggingface_hub import login
|
48 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig
|
49 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
50 |
+
#from langchain.agents import initialize_agent
|
51 |
+
#from langchain.agents import AgentType
|
52 |
+
from typing import Union
|
53 |
+
from functools import reduce
|
54 |
+
import operator
|
55 |
+
from typing import Union
|
56 |
+
from functools import reduce
|
57 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
58 |
+
from youtube_transcript_api._errors import TranscriptsDisabled, VideoUnavailable
|
59 |
+
from langchain.schema import Document
|
60 |
+
|
61 |
+
from langchain_community.vectorstores import FAISS
|
62 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
63 |
+
from langchain.tools.retriever import create_retriever_tool
|
64 |
+
#from langchain_community.tools import create_retriever_tool
|
65 |
+
from typing import TypedDict, Annotated, List
|
66 |
+
import gradio as gr
|
67 |
+
from langchain.schema import Document
|
68 |
+
|
69 |
+
load_dotenv()
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
@tool
|
75 |
+
def calculator(inputs: Union[str, dict]):
|
76 |
+
"""
|
77 |
+
Perform mathematical operations based on the operation provided.
|
78 |
+
Supports both binary (a, b) operations and list operations.
|
79 |
+
"""
|
80 |
+
|
81 |
+
# If input is a JSON string, parse it
|
82 |
+
if isinstance(inputs, str):
|
83 |
+
try:
|
84 |
+
import json
|
85 |
+
inputs = json.loads(inputs)
|
86 |
+
except Exception as e:
|
87 |
+
return f"Invalid input format: {e}"
|
88 |
+
|
89 |
+
# Handle list-based operations like SUM
|
90 |
+
if "list" in inputs:
|
91 |
+
nums = inputs.get("list", [])
|
92 |
+
op = inputs.get("operation", "").lower()
|
93 |
+
|
94 |
+
if not isinstance(nums, list) or not all(isinstance(n, (int, float)) for n in nums):
|
95 |
+
return "Invalid list input. Must be a list of numbers."
|
96 |
+
|
97 |
+
if op == "sum":
|
98 |
+
return sum(nums)
|
99 |
+
elif op == "multiply":
|
100 |
+
return reduce(operator.mul, nums, 1)
|
101 |
+
else:
|
102 |
+
return f"Unsupported list operation: {op}"
|
103 |
+
|
104 |
+
# Handle basic two-number operations
|
105 |
+
a = inputs.get("a")
|
106 |
+
b = inputs.get("b")
|
107 |
+
operation = inputs.get("operation", "").lower()
|
108 |
+
|
109 |
+
if a is None or b is None or not isinstance(a, (int, float)) or not isinstance(b, (int, float)):
|
110 |
+
return "Both 'a' and 'b' must be numbers."
|
111 |
+
|
112 |
+
if operation == "add":
|
113 |
+
return a + b
|
114 |
+
elif operation == "subtract":
|
115 |
+
return a - b
|
116 |
+
elif operation == "multiply":
|
117 |
+
return a * b
|
118 |
+
elif operation == "divide":
|
119 |
+
if b == 0:
|
120 |
+
return "Error: Division by zero"
|
121 |
+
return a / b
|
122 |
+
elif operation == "modulus":
|
123 |
+
return a % b
|
124 |
+
else:
|
125 |
+
return f"Unknown operation: {operation}"
|
126 |
+
|
127 |
+
|
128 |
+
@tool
|
129 |
+
def wiki_search(query: str) -> str:
|
130 |
+
"""Search Wikipedia for a query and return up to 2 results."""
|
131 |
+
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
132 |
+
|
133 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
134 |
+
[
|
135 |
+
f'<Document source="{doc.metadata.get("source", "Wikipedia")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
136 |
+
for doc in search_docs
|
137 |
+
]
|
138 |
+
)
|
139 |
+
return formatted_search_docs
|
140 |
+
|
141 |
+
|
142 |
+
@tool
|
143 |
+
def wikidata_query(query: str) -> str:
|
144 |
+
"""
|
145 |
+
Run a SPARQL query on Wikidata and return results.
|
146 |
+
"""
|
147 |
+
endpoint_url = "https://query.wikidata.org/sparql"
|
148 |
+
headers = {
|
149 |
+
"Accept": "application/sparql-results+json"
|
150 |
+
}
|
151 |
+
response = requests.get(endpoint_url, headers=headers, params={"query": query})
|
152 |
+
data = response.json()
|
153 |
+
return json.dumps(data, indent=2)
|
154 |
+
|
155 |
+
|
156 |
+
@tool
|
157 |
+
def web_search(query: str) -> str:
|
158 |
+
"""Search Tavily for a query and return up to 3 results."""
|
159 |
+
tavily_key = os.getenv("TAVILY_API_KEY")
|
160 |
+
|
161 |
+
if not tavily_key:
|
162 |
+
return "Error: Tavily API key not set."
|
163 |
+
|
164 |
+
search_tool = TavilySearchResults(tavily_api_key=tavily_key, max_results=3)
|
165 |
+
search_docs = search_tool.invoke(query=query)
|
166 |
+
|
167 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
168 |
+
[
|
169 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
|
170 |
+
for doc in search_docs
|
171 |
+
])
|
172 |
+
|
173 |
+
return formatted_search_docs
|
174 |
|
175 |
|
176 |
@tool
|
177 |
+
def arxiv_search(query: str) -> str:
|
178 |
+
"""Search Arxiv for a query and return maximum 3 result.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
query: The search query."""
|
182 |
+
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
|
183 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
184 |
+
[
|
185 |
+
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
|
186 |
+
for doc in search_docs
|
187 |
+
])
|
188 |
+
return formatted_search_docs
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
@tool
|
193 |
+
def analyze_attachment(file_path: str) -> str:
|
194 |
"""
|
195 |
+
Analyzes attachments including PY, PDF, TXT, DOCX, and XLSX files and returns text content.
|
196 |
+
|
197 |
Args:
|
198 |
+
file_path: Local path to the attachment.
|
|
|
|
|
|
|
199 |
"""
|
200 |
+
if not os.path.exists(file_path):
|
201 |
+
return f"File not found: {file_path}"
|
202 |
|
203 |
try:
|
204 |
+
ext = file_path.lower()
|
|
|
|
|
205 |
|
206 |
+
if ext.endswith(".pdf"):
|
207 |
+
loader = PyMuPDFLoader(file_path)
|
208 |
+
documents = loader.load()
|
209 |
+
content = "\n\n".join([doc.page_content for doc in documents])
|
210 |
|
211 |
+
elif ext.endswith(".txt") or ext.endswith(".py"):
|
212 |
+
# Both .txt and .py are plain text files
|
213 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
214 |
+
content = file.read()
|
215 |
+
|
216 |
+
elif ext.endswith(".docx"):
|
217 |
+
doc = DocxDocument(file_path)
|
218 |
+
content = "\n".join([para.text for para in doc.paragraphs])
|
219 |
+
|
220 |
+
elif ext.endswith(".xlsx"):
|
221 |
+
wb = openpyxl.load_workbook(file_path, data_only=True)
|
222 |
+
content = ""
|
223 |
+
for sheet in wb:
|
224 |
+
content += f"Sheet: {sheet.title}\n"
|
225 |
+
for row in sheet.iter_rows(values_only=True):
|
226 |
+
content += "\t".join([str(cell) if cell is not None else "" for cell in row]) + "\n"
|
227 |
+
|
228 |
+
else:
|
229 |
+
return "Unsupported file format. Please use PY, PDF, TXT, DOCX, or XLSX."
|
230 |
+
|
231 |
+
return content[:3000] # Limit output size for readability
|
232 |
|
|
|
233 |
except Exception as e:
|
234 |
+
return f"An error occurred while processing the file: {str(e)}"
|
235 |
|
236 |
|
237 |
@tool
|
238 |
+
def get_youtube_transcript(url: str) -> str:
|
239 |
"""
|
240 |
+
Fetch transcript text from a YouTube video.
|
241 |
+
|
242 |
Args:
|
243 |
+
url (str): Full YouTube video URL.
|
244 |
+
|
245 |
Returns:
|
246 |
+
str: Transcript text as a single string.
|
247 |
+
|
248 |
+
Raises:
|
249 |
+
ValueError: If no transcript is available or URL is invalid.
|
250 |
"""
|
251 |
+
try:
|
252 |
+
# Extract video ID
|
253 |
+
video_id = extract_video_id(url)
|
254 |
+
transcript = YouTubeTranscriptApi.get_transcript(video_id)
|
255 |
|
256 |
+
# Combine all transcript text
|
257 |
+
full_text = " ".join([entry['text'] for entry in transcript])
|
258 |
+
return full_text
|
259 |
|
260 |
+
except (TranscriptsDisabled, VideoUnavailable) as e:
|
261 |
+
raise ValueError(f"Transcript not available: {e}")
|
262 |
+
except Exception as e:
|
263 |
+
raise ValueError(f"Failed to fetch transcript: {e}")
|
|
|
264 |
|
|
|
265 |
|
266 |
|
267 |
@tool
|
268 |
+
def extract_video_id(url: str) -> str:
|
269 |
+
"""
|
270 |
+
Extract the video ID from a YouTube URL.
|
271 |
+
"""
|
272 |
+
match = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_-]{11})", url)
|
273 |
+
if not match:
|
274 |
+
raise ValueError("Invalid YouTube URL")
|
275 |
+
return match.group(1)
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
|
280 |
+
# -----------------------------
|
281 |
+
# Load configuration from YAML
|
282 |
+
# -----------------------------
|
283 |
+
with open("config.yaml", "r") as f:
|
284 |
+
config = yaml.safe_load(f)
|
285 |
+
|
286 |
+
provider = config["provider"]
|
287 |
+
model_config = config["models"][provider]
|
288 |
+
|
289 |
+
#prompt_path = config["system_prompt_path"]
|
290 |
+
enabled_tool_names = config["tools"]
|
291 |
+
|
292 |
+
|
293 |
+
# -----------------------------
|
294 |
+
# Load system prompt
|
295 |
+
# -----------------------------
|
296 |
+
# load the system prompt from the file
|
297 |
+
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
298 |
+
system_prompt = f.read()
|
299 |
+
|
300 |
+
# System message
|
301 |
+
sys_msg = SystemMessage(content=system_prompt)
|
302 |
+
|
303 |
+
|
304 |
+
# -----------------------------
|
305 |
+
# Map tool names to functions
|
306 |
+
# -----------------------------
|
307 |
+
tool_map = {
|
308 |
+
"math": calculator,
|
309 |
+
"wiki_search": wiki_search,
|
310 |
+
"web_search": web_search,
|
311 |
+
"arxiv_search": arxiv_search,
|
312 |
+
"get_youtube_transcript": get_youtube_transcript,
|
313 |
+
"extract_video_id": extract_video_id,
|
314 |
+
"analyze_attachment": analyze_attachment,
|
315 |
+
"wikidata_query": wikidata_query
|
316 |
+
}
|
317 |
+
|
318 |
+
# Then define which tools you want enabled
|
319 |
+
enabled_tool_names = [
|
320 |
+
"math",
|
321 |
+
"wiki_search",
|
322 |
+
"web_search",
|
323 |
+
"arxiv_search",
|
324 |
+
"get_youtube_transcript",
|
325 |
+
"extract_video_id",
|
326 |
+
"analyze_attachment",
|
327 |
+
"wikidata_query"
|
328 |
+
]
|
329 |
+
|
330 |
+
|
331 |
+
tools = [tool_map[name] for name in enabled_tool_names]
|
332 |
+
|
333 |
+
|
334 |
+
# Safe version
|
335 |
+
tools = []
|
336 |
+
for name in enabled_tool_names:
|
337 |
+
if name not in tool_map:
|
338 |
+
print(f"โ Tool not found: {name}")
|
339 |
+
continue
|
340 |
+
tools.append(tool_map[name])
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
# -----------------------------
|
345 |
+
# Prepare Documents
|
346 |
+
# -----------------------------
|
347 |
+
# Define the URL where the JSON file is hosted
|
348 |
+
|
349 |
+
import faiss
|
350 |
+
|
351 |
+
# 1. Type-Checked State for Gradio
|
352 |
+
class ChatState(TypedDict):
|
353 |
+
messages: Annotated[
|
354 |
+
List[str],
|
355 |
+
gr.State(render=False),
|
356 |
+
"Stores chat history as list of strings"
|
357 |
+
]
|
358 |
+
|
359 |
+
# 2. Content Processing Utilities
|
360 |
+
def process_content(raw_content) -> str:
|
361 |
+
"""Convert any input to a clean string"""
|
362 |
+
if isinstance(raw_content, list):
|
363 |
+
return " ".join(str(item) for item in raw_content)
|
364 |
+
return str(raw_content)
|
365 |
+
|
366 |
+
def reverse_text(text: str) -> str:
|
367 |
+
"""Fix reversed text patterns"""
|
368 |
+
return text[::-1].replace("\\", "").strip() if text.startswith(('.', ',')) else text
|
369 |
+
|
370 |
+
|
371 |
+
# 3. Unified Document Creation
|
372 |
+
|
373 |
+
def create_documents(data_source: str, data: list) -> list:
|
374 |
+
"""Handle both Gradio chat and JSON questions"""
|
375 |
+
docs = []
|
376 |
+
|
377 |
+
for item in data:
|
378 |
+
content = ""
|
379 |
+
# Process different data sources
|
380 |
+
if data_source == "json":
|
381 |
+
raw_question = item.get("question", "")
|
382 |
+
content = raw_question # Adjust as per your content processing logic
|
383 |
+
else:
|
384 |
+
print(f"Skipping invalid data source: {data_source}")
|
385 |
+
continue
|
386 |
+
|
387 |
+
# Ensure metadata type safety
|
388 |
+
metadata = {
|
389 |
+
"task_id": str(item.get("task_id", "")),
|
390 |
+
"level": str(item.get("Level", "")),
|
391 |
+
"file_name": str(item.get("file_name", ""))
|
392 |
+
}
|
393 |
+
|
394 |
+
# Check if content is non-empty
|
395 |
+
if content.strip(): # Only append non-empty content
|
396 |
+
docs.append(Document(page_content=content, metadata=metadata))
|
397 |
+
else:
|
398 |
+
print(f"Skipping invalid entry with empty content: {item}")
|
399 |
+
|
400 |
+
return docs
|
401 |
+
|
402 |
+
# Path to your data.json
|
403 |
+
file_path = "/home/wendy/Downloads/data.json"
|
404 |
+
|
405 |
+
def load_data(file_path: str) -> list[dict]:
|
406 |
+
"""Safe JSON data loading with error handling"""
|
407 |
+
if not os.path.exists(file_path):
|
408 |
+
raise FileNotFoundError(f"Data file not found: {file_path}")
|
409 |
+
|
410 |
+
if not file_path.endswith('.json'):
|
411 |
+
raise ValueError("Invalid file format. Only JSON files supported")
|
412 |
+
|
413 |
+
try:
|
414 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
415 |
+
return json.load(f)
|
416 |
+
except json.JSONDecodeError:
|
417 |
+
raise ValueError("Invalid JSON format in data file")
|
418 |
+
except Exception as e:
|
419 |
+
raise RuntimeError(f"Error loading data: {str(e)}")
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
# 4. Vector Store Integration
|
424 |
+
|
425 |
+
import faiss
|
426 |
+
|
427 |
+
# Custom FAISS wrapper (optional, if you still want it)
|
428 |
+
class MyVector_Store:
|
429 |
+
def __init__(self, index: faiss.Index):
|
430 |
+
self.index = index
|
431 |
+
|
432 |
+
def save_local(self, path: str):
|
433 |
+
faiss.write_index(self.index, path)
|
434 |
+
|
435 |
+
@classmethod
|
436 |
+
def load_local(cls, path: str):
|
437 |
+
index = faiss.read_index(path)
|
438 |
+
return cls(index)
|
439 |
+
|
440 |
+
# -----------------------------
|
441 |
+
# Process JSON data and create documents
|
442 |
+
# -----------------------------
|
443 |
+
|
444 |
+
file_path = "/home/wendy/Downloads/data.json"
|
445 |
+
|
446 |
+
try:
|
447 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
448 |
+
data = json.load(f)
|
449 |
+
print(data)
|
450 |
+
except FileNotFoundError as e:
|
451 |
+
print(f"Error: {e}")
|
452 |
+
except json.JSONDecodeError as e:
|
453 |
+
print(f"Error decoding JSON: {e}")
|
454 |
+
|
455 |
+
docs = create_documents("json", data)
|
456 |
+
texts = [doc.page_content for doc in docs]
|
457 |
+
|
458 |
+
|
459 |
+
# -----------------------------
|
460 |
+
# Initialize embedding model
|
461 |
+
# -----------------------------
|
462 |
+
embedding_model = HuggingFaceEmbeddings(
|
463 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2"
|
464 |
+
)
|
465 |
+
|
466 |
+
# -----------------------------
|
467 |
+
# Create FAISS index and save it
|
468 |
+
# -----------------------------
|
469 |
+
class ChatState(TypedDict):
|
470 |
+
messages: Annotated[
|
471 |
+
List[str],
|
472 |
+
gr.State(render=False),
|
473 |
+
"Stores chat history"
|
474 |
+
]
|
475 |
+
|
476 |
+
def initialize_vector_store():
|
477 |
+
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
478 |
+
index_path = "/home/wendy/my_hf_agent_course_projects/faiss_index"
|
479 |
+
|
480 |
+
if os.path.exists(os.path.join(index_path, "index.faiss")):
|
481 |
+
try:
|
482 |
+
return FAISS.load_local(
|
483 |
+
index_path,
|
484 |
+
embedding_model,
|
485 |
+
allow_dangerous_deserialization=True
|
486 |
+
)
|
487 |
+
except Exception as e:
|
488 |
+
print(f"Error loading index: {e}")
|
489 |
+
|
490 |
+
# Fallback: Create new index
|
491 |
+
print("Building new vector store...")
|
492 |
+
docs = [...] # Your document loading logic here
|
493 |
+
vector_store = FAISS.from_documents(docs, embedding_model)
|
494 |
+
vector_store.save_local(index_path)
|
495 |
+
return vector_store
|
496 |
+
|
497 |
+
# Initialize at module level
|
498 |
+
loaded_store = initialize_vector_store()
|
499 |
+
retriever = loaded_store.as_retriever()
|
500 |
+
|
501 |
+
# -----------------------------
|
502 |
+
# Create LangChain Retriever Tool
|
503 |
+
# -----------------------------
|
504 |
+
#retriever = loaded_store.as_retriever()
|
505 |
+
|
506 |
+
question_retriever_tool = create_retriever_tool(
|
507 |
+
retriever=retriever,
|
508 |
+
name="Question_Search",
|
509 |
+
description="A tool to retrieve documents related to a user's question."
|
510 |
+
)
|
511 |
+
|
512 |
+
# -----------------------------
|
513 |
+
# Load HuggingFace LLM
|
514 |
+
# -----------------------------
|
515 |
+
llm = HuggingFaceEndpoint(
|
516 |
+
repo_id="HuggingFaceH4/zephyr-7b-beta",
|
517 |
+
task="text-generation",
|
518 |
+
huggingfacehub_api_token=os.getenv("HF_TOKEN"),
|
519 |
+
temperature=0.7,
|
520 |
+
max_new_tokens=512
|
521 |
+
)
|
522 |
+
|
523 |
+
|
524 |
+
|
525 |
+
# -------------------------------
|
526 |
+
# Step 8: Use the Planner, Classifier, and Decision Logic
|
527 |
+
# -------------------------------
|
528 |
+
|
529 |
+
def process_question(question):
|
530 |
+
# Step 1: Planner generates the task sequence
|
531 |
+
tasks = planner(question)
|
532 |
+
print(f"Tasks to perform: {tasks}")
|
533 |
+
|
534 |
+
# Step 2: Classify the task (based on question)
|
535 |
+
task_type = task_classifier(question)
|
536 |
+
print(f"Task type: {task_type}")
|
537 |
+
|
538 |
+
# Step 3: Use the classifier and planner to decide on the next task or node
|
539 |
+
state = {"question": question, "last_response": ""}
|
540 |
+
next_task = decide_task(state)
|
541 |
+
print(f"Next task: {next_task}")
|
542 |
+
|
543 |
+
# Step 4: Use node skipper logic (skip if needed)
|
544 |
+
skip = node_skipper(state)
|
545 |
+
if skip:
|
546 |
+
print(f"Skipping to {skip}")
|
547 |
+
return skip # Or move directly to generating answer
|
548 |
+
|
549 |
+
# Step 5: Execute task (with error handling)
|
550 |
+
try:
|
551 |
+
if task_type == "wiki_search":
|
552 |
+
response = wiki_search(question)
|
553 |
+
elif task_type == "math":
|
554 |
+
response = calculator(question)
|
555 |
+
else:
|
556 |
+
response = "Default answer logic"
|
557 |
+
|
558 |
+
# Step 6: Final response formatting
|
559 |
+
final_response = final_answer_tool(state, {'wiki_search': response})
|
560 |
+
return final_response
|
561 |
+
|
562 |
+
except Exception as e:
|
563 |
+
print(f"Error executing task: {e}")
|
564 |
+
return "Sorry, I encountered an error processing your request."
|
565 |
+
|
566 |
+
|
567 |
+
# Run the process
|
568 |
+
#question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
569 |
+
#response = agent.invoke(question)
|
570 |
+
#print("Final Response:", response)
|
571 |
+
|
572 |
+
|
573 |
+
|
574 |
+
|
575 |
+
from langchain.schema import HumanMessage
|
576 |
+
|
577 |
+
def retriever(state: MessagesState, k: int = 4):
|
578 |
"""
|
579 |
+
Retrieves documents from the vector store using similarity scores,
|
580 |
+
applies a dynamic threshold filter, and returns updated message state.
|
581 |
+
|
582 |
Args:
|
583 |
+
state (MessagesState): Current message state including the user's query.
|
584 |
+
k (int): Number of top results to retrieve from the vector store.
|
585 |
+
|
586 |
Returns:
|
587 |
+
dict: Updated messages state including relevant documents or fallback message.
|
588 |
"""
|
589 |
+
query = state["messages"][0].content.strip()
|
590 |
+
results = vector_store.similarity_search_with_score(query, k=k)
|
591 |
+
|
592 |
+
# Determine dynamic similarity threshold
|
593 |
+
if any(keyword in query.lower() for keyword in ["who", "what", "where", "when", "why", "how"]):
|
594 |
+
threshold = 0.75
|
595 |
+
else:
|
596 |
+
threshold = 0.8
|
597 |
|
598 |
+
filtered = [doc for doc, score in results if score < threshold]
|
599 |
|
600 |
+
if not filtered:
|
601 |
+
response_msg = HumanMessage(content="No relevant documents found.")
|
602 |
+
else:
|
603 |
+
content = "\n\n".join(doc.page_content for doc in filtered)
|
604 |
+
response_msg = HumanMessage(content=f"Here are relevant reference documents:\n\n{content}")
|
605 |
+
|
606 |
+
return {"messages": [sys_msg] + state["messages"] + [response_msg]}
|
607 |
+
|
608 |
+
|
609 |
+
|
610 |
+
|
611 |
+
# ----------------------------------------------------------------
|
612 |
+
# LLM Loader
|
613 |
+
# ----------------------------------------------------------------
|
614 |
+
def get_llm(provider: str, config: dict):
|
615 |
+
if provider == "google":
|
616 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
617 |
+
return ChatGoogleGenerativeAI(
|
618 |
+
model=config.get("model"),
|
619 |
+
temperature=config.get("temperature", 0.7),
|
620 |
+
google_api_key=config.get("api_key") # Optional: if needed
|
621 |
+
)
|
622 |
+
|
623 |
+
elif provider == "groq":
|
624 |
+
from langchain_groq import ChatGroq
|
625 |
+
return ChatGroq(
|
626 |
+
model=config.get("model"),
|
627 |
+
temperature=config.get("temperature", 0.7),
|
628 |
+
groq_api_key=config.get("api_key") # Optional: if needed
|
629 |
)
|
|
|
630 |
|
631 |
+
elif provider == "huggingface":
|
632 |
+
from langchain_huggingface import ChatHuggingFace
|
633 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
634 |
+
return ChatHuggingFace(
|
635 |
+
llm=HuggingFaceEndpoint(
|
636 |
+
endpoint_url=config.get("url"),
|
637 |
+
temperature=config.get("temperature", 0.7),
|
638 |
+
huggingfacehub_api_token=config.get("api_key") # Optional
|
639 |
+
)
|
640 |
+
)
|
641 |
|
642 |
+
else:
|
643 |
+
raise ValueError(f"Invalid provider: {provider}")
|
644 |
|
645 |
+
|
646 |
+
|
647 |
+
# ----------------------------------------------------------------
|
648 |
+
# Planning & Execution Logic
|
649 |
+
# ----------------------------------------------------------------
|
650 |
+
def planner(question: str, tools: list) -> tuple:
|
651 |
"""
|
652 |
+
Select the best-matching tool(s) for a question based on keyword-based intent detection and tool metadata.
|
653 |
+
Returns the detected intent and matched tools.
|
|
|
|
|
|
|
|
|
654 |
"""
|
655 |
+
question = question.lower().strip()
|
656 |
|
657 |
+
# Define intent-based keywords
|
658 |
+
intent_keywords = {
|
659 |
+
"math": ["calculate", "evaluate", "add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"],
|
660 |
+
"wiki_search": ["who is", "what is", "define", "explain", "tell me about", "overview of"],
|
661 |
+
"web_search": ["search", "find", "look up", "google", "latest news", "current info"],
|
662 |
+
"arxiv_search": ["arxiv", "research paper", "scientific paper", "preprint"],
|
663 |
+
"get_youtube_transcript": ["youtube", "watch", "play video", "show me a video"],
|
664 |
+
"extract_video_id": ["analyze video", "summarize video", "video content"],
|
665 |
+
"data_analysis": ["analyze", "plot", "graph", "data", "visualize"],
|
666 |
+
"wikidata_query": ["wikidata", "sparql", "run sparql", "query wikidata"],
|
667 |
+
"default": ["why", "how", "difference between", "compare", "what happens", "reason for", "cause of", "effect of"]
|
668 |
+
}
|
669 |
|
|
|
670 |
|
|
|
|
|
|
|
|
|
|
|
671 |
|
672 |
+
# Step 1: Identify intent
|
673 |
+
detected_intent = None
|
674 |
+
for intent, keywords in intent_keywords.items():
|
675 |
+
if any(keyword in question for keyword in keywords):
|
676 |
+
detected_intent = intent
|
677 |
+
break
|
678 |
|
679 |
+
# Step 2: Match tools by intent
|
680 |
+
matched_tools = []
|
681 |
+
if detected_intent:
|
682 |
+
for tool in tools:
|
683 |
+
name = getattr(tool, "name", "").lower()
|
684 |
+
description = getattr(tool, "description", "").lower()
|
685 |
+
if detected_intent in name or detected_intent in description:
|
686 |
+
matched_tools.append(tool)
|
687 |
|
688 |
+
# Step 3: Fallback to general-purpose/default tools if no match found
|
689 |
+
if not matched_tools:
|
690 |
+
matched_tools = [
|
691 |
+
tool for tool in tools
|
692 |
+
if "default" in getattr(tool, "name", "").lower()
|
693 |
+
or "qa" in getattr(tool, "description", "").lower()
|
694 |
+
]
|
695 |
+
|
696 |
+
return detected_intent, matched_tools if matched_tools else [tools[0]]
|
697 |
+
|
698 |
+
|
699 |
+
|
700 |
+
|
701 |
+
def task_classifier(question: str) -> str:
|
702 |
"""
|
703 |
+
Classifies the question into one of the predefined task categories.
|
|
|
|
|
|
|
|
|
|
|
704 |
"""
|
705 |
+
question = question.lower().strip()
|
706 |
+
|
707 |
+
# Context-aware intent patterns
|
708 |
+
if any(phrase in question for phrase in [
|
709 |
+
"calculate", "how much is", "what is the result of", "evaluate", "solve"
|
710 |
+
]) or any(op in question for op in ["add", "subtract", "multiply", "divide", "modulus", "plus", "minus", "times"]):
|
711 |
+
return "math"
|
712 |
+
|
713 |
+
elif any(phrase in question for phrase in [
|
714 |
+
"who is", "what is", "define", "explain", "tell me about", "give me an overview of"
|
715 |
+
]):
|
716 |
+
return "wiki_search"
|
717 |
+
|
718 |
+
elif any(phrase in question for phrase in [
|
719 |
+
"search", "find", "look up", "google", "get the latest", "current news", "trending"
|
720 |
+
]):
|
721 |
+
return "web_search"
|
722 |
+
|
723 |
+
elif any(phrase in question for phrase in [
|
724 |
+
"arxiv", "latest research", "scientific paper", "research paper", "preprint"
|
725 |
+
]):
|
726 |
+
return "arxiv_search"
|
727 |
+
|
728 |
+
elif any(phrase in question for phrase in [
|
729 |
+
"youtube", "watch", "play the video", "show me a video"
|
730 |
+
]):
|
731 |
+
return "get_youtube_transcript"
|
732 |
+
|
733 |
+
elif any(phrase in question for phrase in [
|
734 |
+
"analyze video", "summarize video", "what happens in the video", "video content"
|
735 |
+
]):
|
736 |
+
return "video_analysis"
|
737 |
+
|
738 |
+
elif any(phrase in question for phrase in [
|
739 |
+
"analyze", "visualize", "plot", "graph", "inspect data", "explore dataset"
|
740 |
+
]):
|
741 |
+
return "data_analysis"
|
742 |
+
|
743 |
+
elif any(phrase in question for phrase in [
|
744 |
+
"sparql", "wikidata", "query wikidata", "run sparql", "wikidata query"
|
745 |
+
]):
|
746 |
+
return "wikidata_query"
|
747 |
+
|
748 |
+
return "default"
|
749 |
+
|
750 |
+
|
751 |
+
def select_tool_and_run(question: str, tools: dict):
|
752 |
+
# Step 1: Classify intent
|
753 |
+
intent = task_classifier(question) # assuming task_classifier maps the question to intent
|
754 |
+
|
755 |
+
# Map intent to tool names
|
756 |
+
intent_tool_map = {
|
757 |
+
"math": "calculator", # maps to tools["math"] โ calculator
|
758 |
+
"wiki_search": "wiki_search", # โ wiki_search
|
759 |
+
"web_search": "web_search", # โ web_search
|
760 |
+
"arxiv_search": "arxiv_search", # โ arxiv_search (spelling fixed)
|
761 |
+
"get_youtube_transcript": "get_youtube_transcript", # โ get_youtube_transcript
|
762 |
+
"extract_video_id": "extract_video_id", # adjust based on your tools
|
763 |
+
"analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
|
764 |
+
"wikidata_query": "wikidata_query", # โ wikidata_query
|
765 |
+
"default": "default" # โ default_tool
|
766 |
+
}
|
767 |
+
|
768 |
+
# Get the corresponding tool name
|
769 |
+
tool_name = intent_tool_map.get(intent, "default") # Default to "default" if no match
|
770 |
+
|
771 |
+
# Retrieve the tool from the tools dictionary
|
772 |
+
tool_func = tools.get(tool_name)
|
773 |
+
|
774 |
+
if not tool_func:
|
775 |
+
return f"Tool not found for intent '{intent}'"
|
776 |
+
|
777 |
+
# Step 2: Run the tool
|
778 |
try:
|
779 |
+
# If the tool needs JSON or structured data
|
780 |
+
try:
|
781 |
+
parsed_input = json.loads(question)
|
782 |
+
except json.JSONDecodeError:
|
783 |
+
parsed_input = question # fallback to raw input if not JSON
|
784 |
+
|
785 |
+
# Run the selected tool
|
786 |
+
print(f"Running tool: {tool_name} with input: {parsed_input}") # log the tool name and input
|
787 |
+
return tool_func(parsed_input)
|
788 |
+
except Exception as e:
|
789 |
+
return f"Error while running tool '{tool_name}': {str(e)}"
|
790 |
+
|
791 |
|
|
|
|
|
792 |
|
793 |
+
# Function to extract math operation from the question
|
|
|
|
|
794 |
|
795 |
+
def extract_math_from_question(question: str):
|
796 |
+
question = question.lower()
|
797 |
+
|
798 |
+
# Map natural language to symbols
|
799 |
+
ops = {
|
800 |
+
"add": "+", "plus": "+",
|
801 |
+
"subtract": "-", "minus": "-",
|
802 |
+
"multiply": "*", "times": "*",
|
803 |
+
"divide": "/", "divided by": "/",
|
804 |
+
"modulus": "%", "mod": "%"
|
805 |
+
}
|
806 |
+
|
807 |
+
for word, symbol in ops.items():
|
808 |
+
question = re.sub(rf"\b{word}\b", symbol, question)
|
809 |
+
|
810 |
+
# Extract math expression like "12 + 5"
|
811 |
+
match = re.search(r'(\d+)\s*([\+\-\*/%])\s*(\d+)', question)
|
812 |
+
if match:
|
813 |
+
num1 = int(match.group(1))
|
814 |
+
operator = match.group(2)
|
815 |
+
num2 = int(match.group(3))
|
816 |
+
return {
|
817 |
+
"a": num1,
|
818 |
+
"b": num2,
|
819 |
+
"operation": {
|
820 |
+
"+": "add",
|
821 |
+
"-": "subtract",
|
822 |
+
"*": "multiply",
|
823 |
+
"/": "divide",
|
824 |
+
"%": "modulus"
|
825 |
+
}[operator]
|
826 |
+
}
|
827 |
+
return None
|
828 |
+
|
829 |
+
|
830 |
+
|
831 |
+
# Example tool set (adjust these to match your actual tool names)
|
832 |
+
intent_tool_map = {
|
833 |
+
"math": "math", # maps to tools["math"] โ calculator
|
834 |
+
"wiki_search": "wiki_search", # โ wiki_search
|
835 |
+
"web_search": "web_search", # โ web_search
|
836 |
+
"arxiv_search": "arxiv_search", # โ arxiv_search (spelling fixed)
|
837 |
+
"get_youtube_transcript": "get_youtube_transcript", # โ get_youtube_transcript
|
838 |
+
"extract_video_id": "extract_video_id", # adjust based on your tools
|
839 |
+
"analyze_attachment": "analyze_attachment", # assuming analyze_attachment handles this
|
840 |
+
"wikidata_query": "wikidata_query", # โ wikidata_query
|
841 |
+
"default": "default" # โ default_tool
|
842 |
+
}
|
843 |
+
|
844 |
+
|
845 |
+
|
846 |
+
# The task order can also include the tools for each task
|
847 |
+
priority_order = [
|
848 |
+
{"task": "math", "tool": "math"},
|
849 |
+
{"task": "wiki_search", "tool": "wiki_search"},
|
850 |
+
{"task": "web_search", "tool": "web_search"},
|
851 |
+
{"task": "arxiv_search", "tool": "arxiv_search"},
|
852 |
+
{"task": "wikidata_query", "tool": "wikidata_query"},
|
853 |
+
{"task": "retriever", "tool": "retriever"},
|
854 |
+
{"task": "get_youtube_transcript", "tool": "get_youtube_transcript"},
|
855 |
+
{"task": "extract_video_id", "tool": "extract_video_id"},
|
856 |
+
{"task": "analyze_attachment", "tool": "analyze_attachment"},
|
857 |
+
{"task": "default", "tool": "default"} # Fallback
|
858 |
+
]
|
859 |
+
|
860 |
+
def decide_task(state: dict) -> str:
|
861 |
+
"""Decides which task to perform based on the current state."""
|
862 |
+
|
863 |
+
# Get the list of tasks from the planner
|
864 |
+
tasks = planner(state["question"])
|
865 |
+
print(f"Available tasks: {tasks}") # Debugging: show all possible tasks
|
866 |
+
|
867 |
+
# Check if the tasks list is empty or invalid
|
868 |
+
if not tasks:
|
869 |
+
print("โ No valid tasks were returned from the planner.")
|
870 |
+
return "default" # Return a default task if no tasks were generated
|
871 |
+
|
872 |
+
# If there are multiple tasks, we can prioritize based on certain conditions
|
873 |
+
task = tasks[0] # Default to the first task in the list
|
874 |
+
if len(tasks) > 1:
|
875 |
+
print(f"โ ๏ธ Multiple tasks found. Deciding based on priority.")
|
876 |
+
# Example logic to prioritize tasks, adjust based on your use case
|
877 |
+
task = prioritize_tasks(tasks)
|
878 |
+
|
879 |
+
print(f"Decided on task: {task}") # Debugging: show the final task
|
880 |
+
return task
|
881 |
+
|
882 |
+
|
883 |
+
def prioritize_tasks(tasks: list) -> str:
|
884 |
+
"""Prioritize tasks based on certain conditions or criteria, including tools."""
|
885 |
+
# Sort tasks based on priority_order mapping
|
886 |
+
for priority in priority_order:
|
887 |
+
# Check if any task matches the priority task type
|
888 |
+
for task in tasks:
|
889 |
+
if priority["task"] in task:
|
890 |
+
print(f"โ
Prioritizing task: {task} with tool: {priority['tool']}") # Debugging: show the chosen task and tool
|
891 |
+
# Assign the correct tool based on the task
|
892 |
+
tool = tools.get(priority["tool"], tools["default"]) # Default to 'default_tool' if not found
|
893 |
+
return task, tool
|
894 |
+
|
895 |
+
# If no priority task is found, return the first task with its default tool
|
896 |
+
return tasks[0], tools["default"]
|
897 |
+
|
898 |
+
|
899 |
+
def process_question(question: str):
|
900 |
+
"""Process the question and route it to the appropriate tool."""
|
901 |
+
# Get the tasks from the planner
|
902 |
+
tasks = planner(question)
|
903 |
+
print(f"Tasks to perform: {tasks}")
|
904 |
+
|
905 |
+
task_type, tool = decide_task({"question": question})
|
906 |
+
print(f"Next task: {task_type} with tool: {tool}")
|
907 |
+
|
908 |
+
if node_skipper({"question": question}):
|
909 |
+
print(f"Skipping task: {task_type}")
|
910 |
+
return "Task skipped."
|
911 |
+
|
912 |
+
try:
|
913 |
+
# Execute the corresponding tool for the task type
|
914 |
+
if task_type == "wiki_search":
|
915 |
+
response = tool.run(question) # Assuming tool is wiki_tool
|
916 |
+
elif task_type == "math":
|
917 |
+
response = tool.run(question) # Assuming tool is calc_tool
|
918 |
+
elif task_type == "retriever":
|
919 |
+
response = tool.run(question) # Assuming tool is retriever_tool
|
920 |
+
else:
|
921 |
+
response = tool.run(question) # Default tool
|
922 |
+
|
923 |
+
return generate_final_answer({"question": question}, {task_type: response})
|
924 |
|
|
|
|
|
|
|
925 |
except Exception as e:
|
926 |
+
print(f"โ Error: {e}")
|
927 |
+
return f"Sorry, I encountered an error: {str(e)}"
|
928 |
+
|
929 |
+
|
930 |
+
|
931 |
+
|
932 |
+
def call_llm(state):
|
933 |
+
messages = state["messages"]
|
934 |
+
response = llm.invoke(messages)
|
935 |
+
return {"messages": messages + [response]}
|
936 |
+
|
937 |
+
|
938 |
+
|
939 |
+
|
940 |
+
from langchain.schema import AIMessage
|
941 |
+
from typing import TypedDict, List, Optional
|
942 |
+
from langchain_core.messages import BaseMessage
|
943 |
+
|
944 |
+
class AgentState(TypedDict):
|
945 |
+
messages: List[BaseMessage] # Chat history
|
946 |
+
input: str # Original input
|
947 |
+
intent: str # Derived or predicted intent
|
948 |
+
result: Optional[str] # Optional result
|
949 |
+
|
950 |
+
|
951 |
+
def tool_dispatcher(state: AgentState) -> AgentState:
|
952 |
+
last_msg = state["messages"][-1]
|
953 |
+
|
954 |
+
# Make sure it's an AI message with tool_calls
|
955 |
+
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
956 |
+
tool_call = last_msg.tool_calls[0]
|
957 |
+
tool_name = tool_call["name"]
|
958 |
+
tool_input = tool_call["args"] # Adjust based on your actual schema
|
959 |
+
|
960 |
+
tool_func = tool_map.get(tool_name, default_tool)
|
961 |
+
|
962 |
+
# If args is a dict and your tool expects unpacked values:
|
963 |
+
if isinstance(tool_input, dict):
|
964 |
+
result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(**tool_input)
|
965 |
+
else:
|
966 |
+
result = tool_func.invoke(tool_input) if hasattr(tool_func, "invoke") else tool_func(tool_input)
|
967 |
+
|
968 |
+
# You can choose to append this to messages, or just save result
|
969 |
+
return {
|
970 |
+
**state,
|
971 |
+
"result": result,
|
972 |
+
# Optionally add: "messages": state["messages"] + [ToolMessage(...)]
|
973 |
+
}
|
974 |
+
|
975 |
+
# No tool call detected, return state unchanged
|
976 |
+
return state
|
977 |
+
|
978 |
+
|
979 |
+
|
980 |
+
|
981 |
+
# Decide what to do next: if tool call โ call_tool, else โ end
|
982 |
+
def should_call_tool(state):
|
983 |
+
last_msg = state["messages"][-1]
|
984 |
+
if isinstance(last_msg, AIMessage) and last_msg.tool_calls:
|
985 |
+
return "call_tool"
|
986 |
+
return "end"
|
987 |
+
|
988 |
+
|
989 |
+
from typing import TypedDict, List, Optional, Union
|
990 |
+
from langchain.schema import BaseMessage
|
991 |
+
|
992 |
+
class AgentState(TypedDict):
|
993 |
+
messages: List[BaseMessage] # Chat history
|
994 |
+
input: str # Original input
|
995 |
+
intent: str # Derived or predicted intent
|
996 |
+
result: Optional[str] # Final or intermediate result
|
997 |
+
|
998 |
+
|
999 |
+
|
1000 |
+
|
1001 |
+
|
1002 |
+
|
1003 |
+
# To store previously asked questions and timestamps (simulating state persistence)
|
1004 |
+
recent_questions = {}
|
1005 |
+
|
1006 |
+
def node_skipper(state: dict) -> bool:
|
1007 |
+
"""
|
1008 |
+
Determines whether to skip the task based on the state.
|
1009 |
+
This could include:
|
1010 |
+
1. Repeated or similar questions
|
1011 |
+
2. Irrelevant or empty questions
|
1012 |
+
3. Tasks that have already been processed recently
|
1013 |
+
"""
|
1014 |
+
question = state.get("question", "").strip()
|
1015 |
+
|
1016 |
+
if not question:
|
1017 |
+
print("โ Skipping: Empty or invalid question.")
|
1018 |
+
return True # Skip if no valid question
|
1019 |
+
|
1020 |
+
# 1. Skip if the question has already been asked recently (within a given time window)
|
1021 |
+
# Here, we're using a simple example with a 5-minute window (300 seconds).
|
1022 |
+
if question in recent_questions:
|
1023 |
+
last_asked_time = recent_questions[question]
|
1024 |
+
time_since_last_ask = time.time() - last_asked_time
|
1025 |
+
if time_since_last_ask < 300: # 5-minute threshold
|
1026 |
+
print(f"โ Skipping: The question has been asked recently. Time since last ask: {time_since_last_ask:.2f} seconds.")
|
1027 |
+
return True # Skip if the question was asked within the last 5 minutes
|
1028 |
+
|
1029 |
+
# 2. Skip if the question is irrelevant or not meaningful enough
|
1030 |
+
irrelevant_keywords = ["blah", "nothing", "invalid", "nonsense"]
|
1031 |
+
if any(keyword in question.lower() for keyword in irrelevant_keywords):
|
1032 |
+
print("โ Skipping: Irrelevant or nonsense question.")
|
1033 |
+
return True # Skip if the question contains irrelevant keywords
|
1034 |
+
|
1035 |
+
# 3. Skip if the task has already been completed for this question (based on a unique task identifier)
|
1036 |
+
if "last_response" in state and state["last_response"]:
|
1037 |
+
print("โ Skipping: Task has already been processed recently.")
|
1038 |
+
return True # Skip if a response has already been given
|
1039 |
+
|
1040 |
+
# 4. Skip based on a condition related to the task itself
|
1041 |
+
# Example: Skip math-related tasks if the result is already known or trivial
|
1042 |
+
if "math" in state.get("question", "").lower():
|
1043 |
+
# If math is trivial (like "What is 2+2?")
|
1044 |
+
trivial_math = ["2 + 2", "1 + 1", "3 + 3"]
|
1045 |
+
if any(trivial_question in question for trivial_question in trivial_math):
|
1046 |
+
print(f"โ Skipping trivial math question: {question}")
|
1047 |
+
return True # Skip if the math question is trivial
|
1048 |
+
|
1049 |
+
# 5. Skip based on external factors (e.g., current time, system load, etc.)
|
1050 |
+
# Example: Avoid processing tasks at night if that's part of the business logic
|
1051 |
+
current_hour = time.localtime().tm_hour
|
1052 |
+
if current_hour >= 22 or current_hour < 6:
|
1053 |
+
print("โ Skipping: It's night time, not processing tasks.")
|
1054 |
+
return True # Skip tasks during night time (e.g., between 10 PM and 6 AM)
|
1055 |
+
|
1056 |
+
# If none of the conditions matched, don't skip the task
|
1057 |
+
return False
|
1058 |
+
|
1059 |
+
# Update recent questions (for simulating repeated question check)
|
1060 |
+
def update_recent_questions(question: str):
|
1061 |
+
"""Update the recent questions dictionary with the current timestamp."""
|
1062 |
+
recent_questions[question] = time.time()
|
1063 |
+
|
1064 |
+
|
1065 |
+
|
1066 |
+
def generate_final_answer(state: dict, task_results: dict) -> str:
|
1067 |
+
"""Generate a final answer based on the results of the task."""
|
1068 |
+
if "wiki_search" in task_results:
|
1069 |
+
return f"๐ Wiki Summary:\n{task_results['wiki_search']}"
|
1070 |
+
elif "math" in task_results:
|
1071 |
+
return f"๐งฎ Math Result: {task_results['math']}"
|
1072 |
+
elif "retriever" in task_results:
|
1073 |
+
return f"๐ Retrieved Info: {task_results['retriever']}"
|
1074 |
+
else:
|
1075 |
+
return "๐ค Unable to generate a specific answer."
|
1076 |
+
|
1077 |
+
|
1078 |
+
def answer_question(question: str) -> str:
|
1079 |
+
"""Process a single question and return the answer."""
|
1080 |
+
print(f"Processing question: {question[:50]}...") # Debugging: show first 50 chars
|
1081 |
+
|
1082 |
+
# Wrap the question in a HumanMessage from langchain_core (assuming langchain is used)
|
1083 |
+
messages = [HumanMessage(content=question)]
|
1084 |
+
response = graph.invoke({"messages": messages}) # Assuming `graph` is defined elsewhere
|
1085 |
+
|
1086 |
+
# Extract the answer from the response
|
1087 |
+
answer = response['messages'][-1].content
|
1088 |
+
return answer[14:] # Assuming 'answer[14:]' is correct based on your example
|
1089 |
+
|
1090 |
+
|
1091 |
+
def process_all_tasks(tasks: list):
|
1092 |
+
"""Process a list of tasks."""
|
1093 |
+
results = {}
|
1094 |
+
|
1095 |
+
for task in tasks:
|
1096 |
+
question = task.get("question", "").strip()
|
1097 |
+
if not question:
|
1098 |
+
print(f"Skipping task with missing or empty 'question': {task}")
|
1099 |
+
continue
|
1100 |
+
|
1101 |
+
print(f"\n๐ข Processing Task: {task['task_id']} - Question: {question}")
|
1102 |
+
|
1103 |
+
# Call the existing process_question logic
|
1104 |
+
response = process_question(question)
|
1105 |
+
|
1106 |
+
print(f"โ
Response: {response}")
|
1107 |
+
results[task['task_id']] = response
|
1108 |
+
|
1109 |
+
return results
|
1110 |
+
|
1111 |
+
|
1112 |
+
|
1113 |
+
|
1114 |
+
|
1115 |
+
## Langgraph
|
1116 |
+
|
1117 |
+
# Build graph function
|
1118 |
+
vector_store = vector_store.save_local("faiss_index")
|
1119 |
+
|
1120 |
+
provider = "huggingface"
|
1121 |
+
|
1122 |
+
model_config = {
|
1123 |
+
"repo_id": "HuggingFaceH4/zephyr-7b-beta",
|
1124 |
+
"task": "text-generation",
|
1125 |
+
"temperature": 0.7,
|
1126 |
+
"max_new_tokens": 512,
|
1127 |
+
"huggingfacehub_api_token": os.getenv("HF_TOKEN")
|
1128 |
+
}
|
1129 |
+
|
1130 |
+
# Get LLM
|
1131 |
+
def get_llm(provider: str, config: dict):
|
1132 |
+
if provider == "huggingface":
|
1133 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
1134 |
+
return HuggingFaceEndpoint(
|
1135 |
+
repo_id=config["repo_id"],
|
1136 |
+
task=config["task"],
|
1137 |
+
huggingfacehub_api_token=config["huggingfacehub_api_token"],
|
1138 |
+
temperature=config["temperature"],
|
1139 |
+
max_new_tokens=config["max_new_tokens"]
|
1140 |
+
)
|
1141 |
+
else:
|
1142 |
+
raise ValueError(f"Unsupported provider: {provider}")
|
1143 |
+
|
1144 |
+
|
1145 |
+
def assistant(state: dict):
|
1146 |
+
return {
|
1147 |
+
"messages": [llm_with_tools.invoke(state["messages"])]
|
1148 |
+
}
|
1149 |
+
|
1150 |
+
|
1151 |
+
def tools_condition(state: dict) -> str:
|
1152 |
+
if "use tool" in state["messages"][-1].content.lower():
|
1153 |
+
return "tools"
|
1154 |
+
else:
|
1155 |
+
return "END"
|
1156 |
+
|
1157 |
+
|
1158 |
+
|
1159 |
+
from langgraph.graph import StateGraph
|
1160 |
+
from langchain_core.messages import SystemMessage
|
1161 |
+
from langchain_core.runnables import RunnableLambda
|
1162 |
+
def build_graph(vector_store, provider: str, model_config: dict) -> StateGraph:
|
1163 |
+
# Get LLM
|
1164 |
+
llm = get_llm(provider, model_config)
|
1165 |
+
|
1166 |
+
# Define available tools
|
1167 |
+
tools = [
|
1168 |
+
wiki_search, calculator, web_search, arxiv_search,
|
1169 |
+
get_youtube_transcript, extract_video_id, analyze_attachment, wikidata_query
|
1170 |
+
]
|
1171 |
+
|
1172 |
+
# Tool mapping (global if needed elsewhere)
|
1173 |
+
global tool_map
|
1174 |
+
tool_map = {t.name: t for t in tools}
|
1175 |
+
|
1176 |
+
# Bind tools only if LLM supports it
|
1177 |
+
if hasattr(llm, "bind_tools"):
|
1178 |
+
llm_with_tools = llm.bind_tools(tools)
|
1179 |
+
else:
|
1180 |
+
llm_with_tools = llm # fallback for non-tool-aware models
|
1181 |
+
|
1182 |
+
sys_msg = SystemMessage(content="You are a helpful assistant.")
|
1183 |
+
|
1184 |
+
# Define nodes as runnables
|
1185 |
+
retriever = RunnableLambda(lambda state: {
|
1186 |
+
**state,
|
1187 |
+
"retrieved_docs": vector_store.similarity_search(state["input"])
|
1188 |
+
})
|
1189 |
+
|
1190 |
+
assistant = RunnableLambda(lambda state: {
|
1191 |
+
**state,
|
1192 |
+
"messages": [sys_msg] + state["messages"]
|
1193 |
+
})
|
1194 |
+
|
1195 |
+
call_llm = llm_with_tools # already configured
|
1196 |
+
|
1197 |
+
# Start building the graph
|
1198 |
+
builder = StateGraph(AgentState)
|
1199 |
+
builder.add_node("retriever", retriever)
|
1200 |
+
builder.add_node("assistant", assistant)
|
1201 |
+
builder.add_node("call_llm", call_llm)
|
1202 |
+
builder.add_node("call_tool", tool_dispatcher)
|
1203 |
+
builder.add_node("end", lambda state: state) # Add explicit end node
|
1204 |
+
|
1205 |
+
# Define graph flow
|
1206 |
+
builder.set_entry_point("retriever")
|
1207 |
+
builder.add_edge("retriever", "assistant")
|
1208 |
+
builder.add_edge("assistant", "call_llm")
|
1209 |
+
|
1210 |
+
builder.add_conditional_edges("call_llm", should_call_tool, {
|
1211 |
+
"call_tool": "call_tool",
|
1212 |
+
"end": "end" # โ
fixed: must point to actual "end" node
|
1213 |
+
})
|
1214 |
+
|
1215 |
+
builder.add_edge("call_tool", "call_llm") # loop back after tool call
|
1216 |
+
|
1217 |
+
return builder.compile()
|
1218 |
+
|