Spaces:
Build error
Build error
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -42,7 +42,7 @@ from io import StringIO
|
|
| 42 |
from transformers import BertTokenizer, BertModel
|
| 43 |
import torch
|
| 44 |
import torch.nn.functional as F
|
| 45 |
-
|
| 46 |
|
| 47 |
load_dotenv()
|
| 48 |
|
|
@@ -333,81 +333,8 @@ for task in tasks:
|
|
| 333 |
|
| 334 |
|
| 335 |
# -------------------------------
|
| 336 |
-
# Step
|
| 337 |
-
# -------------------------------
|
| 338 |
-
def planner(question: str):
|
| 339 |
-
"""Break down the question into smaller tasks"""
|
| 340 |
-
if "how many" in question and "albums" in question:
|
| 341 |
-
return ["Retrieve album list", "Filter by date", "Count albums"]
|
| 342 |
-
elif "who" in question:
|
| 343 |
-
return ["Retrieve biography", "Find related works"]
|
| 344 |
-
return ["Default task"]
|
| 345 |
-
|
| 346 |
-
# -------------------------------
|
| 347 |
-
# Step 2: Task Classifier (decides the best tool to use)
|
| 348 |
-
# -------------------------------
|
| 349 |
-
def task_classifier(question: str):
|
| 350 |
-
"""Classify the question to select the best tool"""
|
| 351 |
-
if "calculate" in question or any(op in question for op in ["+", "-", "*", "/"]):
|
| 352 |
-
return "math"
|
| 353 |
-
elif "album" in question or "music" in question:
|
| 354 |
-
return "wiki_search"
|
| 355 |
-
elif "file" in question or "attachment" in question:
|
| 356 |
-
return "file_analysis"
|
| 357 |
-
return "default_tool"
|
| 358 |
-
|
| 359 |
-
# -------------------------------
|
| 360 |
-
# Step 3: Decide Task Function
|
| 361 |
-
# -------------------------------
|
| 362 |
-
def decide_task(state):
|
| 363 |
-
"""Logic to decide what to do based on prior actions and results"""
|
| 364 |
-
if "no relevant documents" in state.get("last_response", ""):
|
| 365 |
-
return "web_search"
|
| 366 |
-
if "not found" in state.get("last_response", "").lower():
|
| 367 |
-
return "wiki_search"
|
| 368 |
-
return "final_answer"
|
| 369 |
-
|
| 370 |
-
# -------------------------------
|
| 371 |
-
# Step 4: Node Skipper (Skip unnecessary nodes)
|
| 372 |
# -------------------------------
|
| 373 |
-
def node_skipper(state):
|
| 374 |
-
"""Skip unnecessary nodes based on context"""
|
| 375 |
-
if "just generate" in state.get("question", "").lower():
|
| 376 |
-
return "answer_generation" # Skip all tools and just generate the answer.
|
| 377 |
-
return None # Continue to the next tool or node
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
# -------------------------------
|
| 382 |
-
# Step 4: Set up HuggingFace Embeddings and FAISS VectorStore
|
| 383 |
-
# -------------------------------
|
| 384 |
-
# Initialize HuggingFace Embedding model
|
| 385 |
-
#embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
| 386 |
-
#embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
class BERTEmbeddings(Embeddings):
|
| 391 |
-
def __init__(self, model_name='bert-base-uncased'):
|
| 392 |
-
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 393 |
-
self.model = BertModel.from_pretrained(model_name)
|
| 394 |
-
self.model.eval() # Set to evaluation mode
|
| 395 |
-
|
| 396 |
-
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
| 397 |
-
return self._embed(texts)
|
| 398 |
-
|
| 399 |
-
def embed_query(self, text: str) -> List[float]:
|
| 400 |
-
return self._embed([text])[0]
|
| 401 |
-
|
| 402 |
-
def _embed(self, texts: List[str]) -> List[List[float]]:
|
| 403 |
-
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
|
| 404 |
-
with torch.no_grad():
|
| 405 |
-
outputs = self.model(**inputs)
|
| 406 |
-
embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling
|
| 407 |
-
return embeddings.cpu().numpy().tolist()
|
| 408 |
-
|
| 409 |
-
# Example usage of BERTEmbedding with LangChain
|
| 410 |
-
|
| 411 |
|
| 412 |
# -----------------------------
|
| 413 |
# 1. Define Custom BERT Embedding Model
|
|
@@ -517,6 +444,7 @@ agent = initialize_agent(
|
|
| 517 |
# -------------------------------
|
| 518 |
# Step 8: Use the Planner, Classifier, and Decision Logic
|
| 519 |
# -------------------------------
|
|
|
|
| 520 |
def process_question(question):
|
| 521 |
# Step 1: Planner generates the task sequence
|
| 522 |
tasks = planner(question)
|
|
@@ -537,18 +465,31 @@ def process_question(question):
|
|
| 537 |
print(f"Skipping to {skip}")
|
| 538 |
return skip # Or move directly to generating answer
|
| 539 |
|
| 540 |
-
# Execute
|
| 541 |
-
|
| 542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
|
| 544 |
-
|
| 545 |
-
#
|
| 546 |
-
# -------------------------------
|
| 547 |
question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
| 548 |
response = process_question(question)
|
| 549 |
print("Final Response:", response)
|
| 550 |
|
| 551 |
|
|
|
|
| 552 |
question_retriever_tool = create_retriever_tool(
|
| 553 |
retriever=retriever,
|
| 554 |
name="Question_Search",
|
|
@@ -561,12 +502,12 @@ def retriever(state: MessagesState):
|
|
| 561 |
"""Retriever node using similarity scores for filtering"""
|
| 562 |
query = state["messages"][0].content
|
| 563 |
results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
|
| 564 |
-
|
| 565 |
# Dynamically adjust threshold based on query complexity
|
| 566 |
threshold = 0.75 if "who" in query else 0.8
|
| 567 |
-
|
| 568 |
filtered = [doc for doc, score in results if score < threshold]
|
| 569 |
-
|
|
|
|
| 570 |
if not filtered:
|
| 571 |
example_msg = HumanMessage(content="No relevant documents found.")
|
| 572 |
else:
|
|
@@ -605,9 +546,12 @@ def get_llm(provider: str, config: dict):
|
|
| 605 |
|
| 606 |
|
| 607 |
def generate_final_answer(state, tools_results):
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
|
|
|
|
|
|
|
|
|
| 611 |
return final_answer
|
| 612 |
|
| 613 |
|
|
|
|
| 42 |
from transformers import BertTokenizer, BertModel
|
| 43 |
import torch
|
| 44 |
import torch.nn.functional as F
|
| 45 |
+
import time
|
| 46 |
|
| 47 |
load_dotenv()
|
| 48 |
|
|
|
|
| 333 |
|
| 334 |
|
| 335 |
# -------------------------------
|
| 336 |
+
# Step 4: Set up BERT Embeddings and FAISS VectorStore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
# -------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
# -----------------------------
|
| 340 |
# 1. Define Custom BERT Embedding Model
|
|
|
|
| 444 |
# -------------------------------
|
| 445 |
# Step 8: Use the Planner, Classifier, and Decision Logic
|
| 446 |
# -------------------------------
|
| 447 |
+
|
| 448 |
def process_question(question):
|
| 449 |
# Step 1: Planner generates the task sequence
|
| 450 |
tasks = planner(question)
|
|
|
|
| 465 |
print(f"Skipping to {skip}")
|
| 466 |
return skip # Or move directly to generating answer
|
| 467 |
|
| 468 |
+
# Step 5: Execute task (with error handling)
|
| 469 |
+
try:
|
| 470 |
+
if task_type == "wiki_search":
|
| 471 |
+
response = wiki_search_tool(question)
|
| 472 |
+
elif task_type == "math":
|
| 473 |
+
response = calculator_tool(question)
|
| 474 |
+
else:
|
| 475 |
+
response = "Default answer logic"
|
| 476 |
+
|
| 477 |
+
# Step 6: Final response formatting
|
| 478 |
+
final_response = generate_final_answer(state, {'wiki_search': response})
|
| 479 |
+
return final_response
|
| 480 |
+
|
| 481 |
+
except Exception as e:
|
| 482 |
+
print(f"Error executing task: {e}")
|
| 483 |
+
return "Sorry, I encountered an error processing your request."
|
| 484 |
|
| 485 |
+
|
| 486 |
+
# Run the process
|
|
|
|
| 487 |
question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
| 488 |
response = process_question(question)
|
| 489 |
print("Final Response:", response)
|
| 490 |
|
| 491 |
|
| 492 |
+
|
| 493 |
question_retriever_tool = create_retriever_tool(
|
| 494 |
retriever=retriever,
|
| 495 |
name="Question_Search",
|
|
|
|
| 502 |
"""Retriever node using similarity scores for filtering"""
|
| 503 |
query = state["messages"][0].content
|
| 504 |
results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
|
| 505 |
+
|
| 506 |
# Dynamically adjust threshold based on query complexity
|
| 507 |
threshold = 0.75 if "who" in query else 0.8
|
|
|
|
| 508 |
filtered = [doc for doc, score in results if score < threshold]
|
| 509 |
+
|
| 510 |
+
# Provide a default message if no documents found
|
| 511 |
if not filtered:
|
| 512 |
example_msg = HumanMessage(content="No relevant documents found.")
|
| 513 |
else:
|
|
|
|
| 546 |
|
| 547 |
|
| 548 |
def generate_final_answer(state, tools_results):
|
| 549 |
+
final_answer = ""
|
| 550 |
+
|
| 551 |
+
# Concatenate results from each tool (wiki_search, calculator, etc.)
|
| 552 |
+
for tool_name, result in tools_results.items():
|
| 553 |
+
final_answer += f"{tool_name} result: {result}\n"
|
| 554 |
+
|
| 555 |
return final_answer
|
| 556 |
|
| 557 |
|