Spaces:
Runtime error
Runtime error
Update agent.py
Browse files
agent.py
CHANGED
@@ -42,7 +42,7 @@ from io import StringIO
|
|
42 |
from transformers import BertTokenizer, BertModel
|
43 |
import torch
|
44 |
import torch.nn.functional as F
|
45 |
-
|
46 |
|
47 |
load_dotenv()
|
48 |
|
@@ -333,81 +333,8 @@ for task in tasks:
|
|
333 |
|
334 |
|
335 |
# -------------------------------
|
336 |
-
# Step
|
337 |
-
# -------------------------------
|
338 |
-
def planner(question: str):
|
339 |
-
"""Break down the question into smaller tasks"""
|
340 |
-
if "how many" in question and "albums" in question:
|
341 |
-
return ["Retrieve album list", "Filter by date", "Count albums"]
|
342 |
-
elif "who" in question:
|
343 |
-
return ["Retrieve biography", "Find related works"]
|
344 |
-
return ["Default task"]
|
345 |
-
|
346 |
-
# -------------------------------
|
347 |
-
# Step 2: Task Classifier (decides the best tool to use)
|
348 |
-
# -------------------------------
|
349 |
-
def task_classifier(question: str):
|
350 |
-
"""Classify the question to select the best tool"""
|
351 |
-
if "calculate" in question or any(op in question for op in ["+", "-", "*", "/"]):
|
352 |
-
return "math"
|
353 |
-
elif "album" in question or "music" in question:
|
354 |
-
return "wiki_search"
|
355 |
-
elif "file" in question or "attachment" in question:
|
356 |
-
return "file_analysis"
|
357 |
-
return "default_tool"
|
358 |
-
|
359 |
-
# -------------------------------
|
360 |
-
# Step 3: Decide Task Function
|
361 |
-
# -------------------------------
|
362 |
-
def decide_task(state):
|
363 |
-
"""Logic to decide what to do based on prior actions and results"""
|
364 |
-
if "no relevant documents" in state.get("last_response", ""):
|
365 |
-
return "web_search"
|
366 |
-
if "not found" in state.get("last_response", "").lower():
|
367 |
-
return "wiki_search"
|
368 |
-
return "final_answer"
|
369 |
-
|
370 |
-
# -------------------------------
|
371 |
-
# Step 4: Node Skipper (Skip unnecessary nodes)
|
372 |
# -------------------------------
|
373 |
-
def node_skipper(state):
|
374 |
-
"""Skip unnecessary nodes based on context"""
|
375 |
-
if "just generate" in state.get("question", "").lower():
|
376 |
-
return "answer_generation" # Skip all tools and just generate the answer.
|
377 |
-
return None # Continue to the next tool or node
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
# -------------------------------
|
382 |
-
# Step 4: Set up HuggingFace Embeddings and FAISS VectorStore
|
383 |
-
# -------------------------------
|
384 |
-
# Initialize HuggingFace Embedding model
|
385 |
-
#embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
386 |
-
#embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
class BERTEmbeddings(Embeddings):
|
391 |
-
def __init__(self, model_name='bert-base-uncased'):
|
392 |
-
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
393 |
-
self.model = BertModel.from_pretrained(model_name)
|
394 |
-
self.model.eval() # Set to evaluation mode
|
395 |
-
|
396 |
-
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
397 |
-
return self._embed(texts)
|
398 |
-
|
399 |
-
def embed_query(self, text: str) -> List[float]:
|
400 |
-
return self._embed([text])[0]
|
401 |
-
|
402 |
-
def _embed(self, texts: List[str]) -> List[List[float]]:
|
403 |
-
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
|
404 |
-
with torch.no_grad():
|
405 |
-
outputs = self.model(**inputs)
|
406 |
-
embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling
|
407 |
-
return embeddings.cpu().numpy().tolist()
|
408 |
-
|
409 |
-
# Example usage of BERTEmbedding with LangChain
|
410 |
-
|
411 |
|
412 |
# -----------------------------
|
413 |
# 1. Define Custom BERT Embedding Model
|
@@ -517,6 +444,7 @@ agent = initialize_agent(
|
|
517 |
# -------------------------------
|
518 |
# Step 8: Use the Planner, Classifier, and Decision Logic
|
519 |
# -------------------------------
|
|
|
520 |
def process_question(question):
|
521 |
# Step 1: Planner generates the task sequence
|
522 |
tasks = planner(question)
|
@@ -537,18 +465,31 @@ def process_question(question):
|
|
537 |
print(f"Skipping to {skip}")
|
538 |
return skip # Or move directly to generating answer
|
539 |
|
540 |
-
# Execute
|
541 |
-
|
542 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
|
544 |
-
|
545 |
-
#
|
546 |
-
# -------------------------------
|
547 |
question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
548 |
response = process_question(question)
|
549 |
print("Final Response:", response)
|
550 |
|
551 |
|
|
|
552 |
question_retriever_tool = create_retriever_tool(
|
553 |
retriever=retriever,
|
554 |
name="Question_Search",
|
@@ -561,12 +502,12 @@ def retriever(state: MessagesState):
|
|
561 |
"""Retriever node using similarity scores for filtering"""
|
562 |
query = state["messages"][0].content
|
563 |
results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
|
564 |
-
|
565 |
# Dynamically adjust threshold based on query complexity
|
566 |
threshold = 0.75 if "who" in query else 0.8
|
567 |
-
|
568 |
filtered = [doc for doc, score in results if score < threshold]
|
569 |
-
|
|
|
570 |
if not filtered:
|
571 |
example_msg = HumanMessage(content="No relevant documents found.")
|
572 |
else:
|
@@ -605,9 +546,12 @@ def get_llm(provider: str, config: dict):
|
|
605 |
|
606 |
|
607 |
def generate_final_answer(state, tools_results):
|
608 |
-
|
609 |
-
|
610 |
-
|
|
|
|
|
|
|
611 |
return final_answer
|
612 |
|
613 |
|
|
|
42 |
from transformers import BertTokenizer, BertModel
|
43 |
import torch
|
44 |
import torch.nn.functional as F
|
45 |
+
import time
|
46 |
|
47 |
load_dotenv()
|
48 |
|
|
|
333 |
|
334 |
|
335 |
# -------------------------------
|
336 |
+
# Step 4: Set up BERT Embeddings and FAISS VectorStore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
# -------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
# -----------------------------
|
340 |
# 1. Define Custom BERT Embedding Model
|
|
|
444 |
# -------------------------------
|
445 |
# Step 8: Use the Planner, Classifier, and Decision Logic
|
446 |
# -------------------------------
|
447 |
+
|
448 |
def process_question(question):
|
449 |
# Step 1: Planner generates the task sequence
|
450 |
tasks = planner(question)
|
|
|
465 |
print(f"Skipping to {skip}")
|
466 |
return skip # Or move directly to generating answer
|
467 |
|
468 |
+
# Step 5: Execute task (with error handling)
|
469 |
+
try:
|
470 |
+
if task_type == "wiki_search":
|
471 |
+
response = wiki_search_tool(question)
|
472 |
+
elif task_type == "math":
|
473 |
+
response = calculator_tool(question)
|
474 |
+
else:
|
475 |
+
response = "Default answer logic"
|
476 |
+
|
477 |
+
# Step 6: Final response formatting
|
478 |
+
final_response = generate_final_answer(state, {'wiki_search': response})
|
479 |
+
return final_response
|
480 |
+
|
481 |
+
except Exception as e:
|
482 |
+
print(f"Error executing task: {e}")
|
483 |
+
return "Sorry, I encountered an error processing your request."
|
484 |
|
485 |
+
|
486 |
+
# Run the process
|
|
|
487 |
question = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
488 |
response = process_question(question)
|
489 |
print("Final Response:", response)
|
490 |
|
491 |
|
492 |
+
|
493 |
question_retriever_tool = create_retriever_tool(
|
494 |
retriever=retriever,
|
495 |
name="Question_Search",
|
|
|
502 |
"""Retriever node using similarity scores for filtering"""
|
503 |
query = state["messages"][0].content
|
504 |
results = vector_store.similarity_search_with_score(query, k=4) # top 4 matches
|
505 |
+
|
506 |
# Dynamically adjust threshold based on query complexity
|
507 |
threshold = 0.75 if "who" in query else 0.8
|
|
|
508 |
filtered = [doc for doc, score in results if score < threshold]
|
509 |
+
|
510 |
+
# Provide a default message if no documents found
|
511 |
if not filtered:
|
512 |
example_msg = HumanMessage(content="No relevant documents found.")
|
513 |
else:
|
|
|
546 |
|
547 |
|
548 |
def generate_final_answer(state, tools_results):
|
549 |
+
final_answer = ""
|
550 |
+
|
551 |
+
# Concatenate results from each tool (wiki_search, calculator, etc.)
|
552 |
+
for tool_name, result in tools_results.items():
|
553 |
+
final_answer += f"{tool_name} result: {result}\n"
|
554 |
+
|
555 |
return final_answer
|
556 |
|
557 |
|