Spaces:
Sleeping
Sleeping
""" | |
Enhanced Multi-LLM Agent System - CORRECTED VERSION | |
Fixes the issue where questions are returned as answers | |
""" | |
import os | |
import time | |
import random | |
import operator | |
from typing import List, Dict, Any, TypedDict, Annotated | |
from dotenv import load_dotenv | |
from langchain_core.tools import tool | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader | |
from langgraph.graph import StateGraph, END | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_groq import ChatGroq | |
load_dotenv() | |
# Enhanced system prompt for proper question-answering | |
ENHANCED_SYSTEM_PROMPT = ( | |
"You are a helpful assistant tasked with answering questions using available tools. " | |
"Follow these guidelines:\n" | |
"1. Read the question carefully and understand what is being asked\n" | |
"2. Use available tools when you need external information\n" | |
"3. Provide accurate, specific answers based on the information you find\n" | |
"4. For numbers: don't use commas or units unless specified\n" | |
"5. For strings: don't use articles or abbreviations, write digits in plain text\n" | |
"6. Always end with 'FINAL ANSWER: [YOUR ANSWER]' where [YOUR ANSWER] is concise\n" | |
"7. Never repeat the question as your answer\n" | |
"8. If you cannot find the answer, state 'Information not available'\n" | |
) | |
# ---- Tool Definitions ---- | |
def multiply(a: int, b: int) -> int: | |
"""Multiply two integers and return the product.""" | |
return a * b | |
def add(a: int, b: int) -> int: | |
"""Add two integers and return the sum.""" | |
return a + b | |
def subtract(a: int, b: int) -> int: | |
"""Subtract the second integer from the first and return the difference.""" | |
return a - b | |
def divide(a: int, b: int) -> float: | |
"""Divide the first integer by the second and return the quotient.""" | |
if b == 0: | |
raise ValueError("Cannot divide by zero.") | |
return a / b | |
def modulus(a: int, b: int) -> int: | |
"""Return the remainder when dividing the first integer by the second.""" | |
return a % b | |
def optimized_web_search(query: str) -> str: | |
"""Perform web search using TavilySearchResults.""" | |
try: | |
time.sleep(random.uniform(0.7, 1.5)) | |
search_tool = TavilySearchResults(max_results=3) | |
docs = search_tool.invoke({"query": query}) | |
return "\n\n---\n\n".join( | |
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:800]}</Doc>" | |
for d in docs | |
) | |
except Exception as e: | |
return f"Web search failed: {e}" | |
def optimized_wiki_search(query: str) -> str: | |
"""Perform Wikipedia search and return content.""" | |
try: | |
time.sleep(random.uniform(0.3, 1)) | |
docs = WikipediaLoader(query=query, load_max_docs=2).load() | |
return "\n\n---\n\n".join( | |
f"<Doc src='{d.metadata.get('source','Wikipedia')}'>{d.page_content[:1000]}</Doc>" | |
for d in docs | |
) | |
except Exception as e: | |
return f"Wikipedia search failed: {e}" | |
# ---- Enhanced Agent State ---- | |
class EnhancedAgentState(TypedDict): | |
"""State structure for the enhanced agent system.""" | |
messages: Annotated[List[HumanMessage | AIMessage], operator.add] | |
query: str | |
agent_type: str | |
final_answer: str | |
perf: Dict[str, Any] | |
agno_resp: str | |
# ---- Enhanced Multi-LLM System ---- | |
class HybridLangGraphMultiLLMSystem: | |
"""Enhanced question-answering system with proper response handling.""" | |
def __init__(self): | |
"""Initialize the enhanced multi-LLM system.""" | |
self.tools = [ | |
multiply, add, subtract, divide, modulus, | |
optimized_web_search, optimized_wiki_search | |
] | |
self.graph = self._build_graph() | |
def _llm(self, model_name: str) -> ChatGroq: | |
"""Create a Groq LLM instance.""" | |
return ChatGroq( | |
model=model_name, | |
temperature=0, | |
api_key=os.getenv("GROQ_API_KEY") | |
) | |
def _build_graph(self) -> StateGraph: | |
"""Build the LangGraph state machine with proper response handling.""" | |
# Initialize LLMs | |
llama8_llm = self._llm("llama3-8b-8192") | |
llama70_llm = self._llm("llama3-70b-8192") | |
deepseek_llm = self._llm("deepseek-chat") | |
def router(st: EnhancedAgentState) -> EnhancedAgentState: | |
"""Route queries to appropriate LLM based on content analysis.""" | |
q = st["query"].lower() | |
# Enhanced routing logic | |
if any(keyword in q for keyword in ["calculate", "compute", "math", "multiply", "add", "subtract", "divide"]): | |
t = "llama70" # Use more powerful model for calculations | |
elif any(keyword in q for keyword in ["search", "find", "lookup", "wikipedia", "information about"]): | |
t = "search_enhanced" # Use search-enhanced processing | |
elif "deepseek" in q or any(keyword in q for keyword in ["analyze", "reasoning", "complex"]): | |
t = "deepseek" | |
elif "llama-8" in q: | |
t = "llama8" | |
elif len(q.split()) > 20: # Complex queries | |
t = "llama70" | |
else: | |
t = "llama8" # Default for simple queries | |
return {**st, "agent_type": t} | |
def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
"""Process query with Llama-3 8B model.""" | |
t0 = time.time() | |
try: | |
# Create enhanced prompt with context | |
enhanced_query = f""" | |
Question: {st["query"]} | |
Please provide a direct, accurate answer to this question. Do not repeat the question. | |
""" | |
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) | |
res = llama8_llm.invoke([sys, HumanMessage(content=enhanced_query)]) | |
# Extract and clean the answer | |
answer = res.content.strip() | |
if "FINAL ANSWER:" in answer: | |
answer = answer.split("FINAL ANSWER:")[-1].strip() | |
return {**st, | |
"final_answer": answer, | |
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}} | |
except Exception as e: | |
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} | |
def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
"""Process query with Llama-3 70B model.""" | |
t0 = time.time() | |
try: | |
# Create enhanced prompt with context | |
enhanced_query = f""" | |
Question: {st["query"]} | |
Please provide a direct, accurate answer to this question. Do not repeat the question. | |
""" | |
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) | |
res = llama70_llm.invoke([sys, HumanMessage(content=enhanced_query)]) | |
# Extract and clean the answer | |
answer = res.content.strip() | |
if "FINAL ANSWER:" in answer: | |
answer = answer.split("FINAL ANSWER:")[-1].strip() | |
return {**st, | |
"final_answer": answer, | |
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}} | |
except Exception as e: | |
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} | |
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
"""Process query with DeepSeek model.""" | |
t0 = time.time() | |
try: | |
# Create enhanced prompt with context | |
enhanced_query = f""" | |
Question: {st["query"]} | |
Please provide a direct, accurate answer to this question. Do not repeat the question. | |
""" | |
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) | |
res = deepseek_llm.invoke([sys, HumanMessage(content=enhanced_query)]) | |
# Extract and clean the answer | |
answer = res.content.strip() | |
if "FINAL ANSWER:" in answer: | |
answer = answer.split("FINAL ANSWER:")[-1].strip() | |
return {**st, | |
"final_answer": answer, | |
"perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}} | |
except Exception as e: | |
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} | |
def search_enhanced_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
"""Process query with search enhancement.""" | |
t0 = time.time() | |
try: | |
# Determine search strategy | |
query = st["query"] | |
search_results = "" | |
if any(keyword in query.lower() for keyword in ["wikipedia", "wiki"]): | |
search_results = optimized_wiki_search.invoke({"query": query}) | |
else: | |
search_results = optimized_web_search.invoke({"query": query}) | |
# Create comprehensive prompt with search results | |
enhanced_query = f""" | |
Original Question: {query} | |
Search Results: | |
{search_results} | |
Based on the search results above, provide a direct answer to the original question. | |
Extract the specific information requested. Do not repeat the question. | |
""" | |
sys = SystemMessage(content=ENHANCED_SYSTEM_PROMPT) | |
res = llama70_llm.invoke([sys, HumanMessage(content=enhanced_query)]) | |
# Extract and clean the answer | |
answer = res.content.strip() | |
if "FINAL ANSWER:" in answer: | |
answer = answer.split("FINAL ANSWER:")[-1].strip() | |
return {**st, | |
"final_answer": answer, | |
"perf": {"time": time.time() - t0, "prov": "Search-Enhanced-Llama70"}} | |
except Exception as e: | |
return {**st, "final_answer": f"Error: {e}", "perf": {"error": str(e)}} | |
# Build graph | |
g = StateGraph(EnhancedAgentState) | |
g.add_node("router", router) | |
g.add_node("llama8", llama8_node) | |
g.add_node("llama70", llama70_node) | |
g.add_node("deepseek", deepseek_node) | |
g.add_node("search_enhanced", search_enhanced_node) | |
g.set_entry_point("router") | |
g.add_conditional_edges("router", lambda s: s["agent_type"], { | |
"llama8": "llama8", | |
"llama70": "llama70", | |
"deepseek": "deepseek", | |
"search_enhanced": "search_enhanced" | |
}) | |
for node in ["llama8", "llama70", "deepseek", "search_enhanced"]: | |
g.add_edge(node, END) | |
return g.compile(checkpointer=MemorySaver()) | |
def process_query(self, q: str) -> str: | |
"""Process a query and return the final answer.""" | |
state = { | |
"messages": [HumanMessage(content=q)], | |
"query": q, | |
"agent_type": "", | |
"final_answer": "", | |
"perf": {}, | |
"agno_resp": "" | |
} | |
cfg = {"configurable": {"thread_id": f"qa_{hash(q)}"}} | |
try: | |
out = self.graph.invoke(state, cfg) | |
answer = out.get("final_answer", "").strip() | |
# Ensure we don't return the question as the answer | |
if answer == q or answer.startswith(q): | |
return "Information not available" | |
return answer if answer else "No answer generated" | |
except Exception as e: | |
return f"Error processing query: {e}" | |
def build_graph(provider: str | None = None) -> StateGraph: | |
"""Build and return the graph for the enhanced agent system.""" | |
return HybridLangGraphMultiLLMSystem().graph | |
if __name__ == "__main__": | |
# Test the system | |
qa_system = HybridLangGraphMultiLLMSystem() | |
test_questions = [ | |
"What is 25 multiplied by 17?", | |
"Who was the first president of the United States?", | |
"Find information about artificial intelligence on Wikipedia" | |
] | |
for question in test_questions: | |
print(f"Question: {question}") | |
answer = qa_system.process_query(question) | |
print(f"Answer: {answer}") | |
print("-" * 50) | |