Spaces:
Sleeping
Sleeping
"""Enhanced LangGraph + Agno Hybrid Agent System with TavilyTools""" | |
import os | |
import time | |
import random | |
from dotenv import load_dotenv | |
from typing import List, Dict, Any, TypedDict, Annotated | |
import operator | |
# LangGraph imports | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import tools_condition, ToolNode | |
from langgraph.checkpoint.memory import MemorySaver | |
# LangChain imports | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_core.tools import tool | |
from langchain_groq import ChatGroq | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader, JSONLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
# Agno imports | |
from agno.agent import Agent | |
from agno.models.groq import GroqChat | |
from agno.models.google import GeminiChat | |
from agno.tools.tavily import TavilyTools | |
from agno.memory.agent import AgentMemory | |
from agno.storage.sqlite import SqliteStorage | |
load_dotenv() | |
# Rate limiter with exponential backoff | |
class PerformanceRateLimiter: | |
def __init__(self, rpm: int, name: str): | |
self.rpm = rpm | |
self.name = name | |
self.times: List[float] = [] | |
self.failures = 0 | |
def wait_if_needed(self): | |
now = time.time() | |
self.times = [t for t in self.times if now - t < 60] | |
if len(self.times) >= self.rpm: | |
wait = 60 - (now - self.times[0]) + random.uniform(1, 3) | |
time.sleep(wait) | |
if self.failures: | |
backoff = min(2 ** self.failures, 30) + random.uniform(0.5, 1.5) | |
time.sleep(backoff) | |
self.times.append(now) | |
def record_success(self): | |
self.failures = 0 | |
def record_failure(self): | |
self.failures += 1 | |
# Initialize rate limiters | |
gemini_limiter = PerformanceRateLimiter(28, "Gemini") | |
groq_limiter = PerformanceRateLimiter(28, "Groq") | |
nvidia_limiter = PerformanceRateLimiter(4, "NVIDIA") | |
# Create Agno agents with SQLite storage | |
def create_agno_agents(): | |
storage = SqliteStorage( | |
table_name="agent_sessions", | |
db_file="tmp/agent_sessions.db", | |
auto_upgrade_schema=True | |
) | |
math_agent = Agent( | |
name="MathSpecialist", | |
model=GroqChat( | |
model="llama-3.3-70b-versatile", | |
api_key=os.getenv("GROQ_API_KEY"), | |
temperature=0 | |
), | |
description="Expert mathematical problem solver", | |
instructions=[ | |
"Solve math problems with precision", | |
"Show step-by-step calculations", | |
"Use calculation tools as needed", | |
"Finish with: FINAL ANSWER: [result]" | |
], | |
memory=AgentMemory( | |
db=storage, | |
create_user_memories=True, | |
create_session_summary=True | |
), | |
show_tool_calls=False, | |
markdown=False | |
) | |
research_agent = Agent( | |
name="ResearchSpecialist", | |
model=GeminiChat( | |
model="gemini-2.0-flash-lite", | |
api_key=os.getenv("GOOGLE_API_KEY"), | |
temperature=0 | |
), | |
description="Expert research and information gathering specialist", | |
instructions=[ | |
"Conduct thorough research using available tools", | |
"Synthesize information from multiple sources", | |
"Provide comprehensive, well-cited answers", | |
"Finish with: FINAL ANSWER: [answer]" | |
], | |
tools=[ | |
TavilyTools( | |
api_key=os.getenv("TAVILY_API_KEY"), | |
search=True, | |
max_tokens=6000, | |
search_depth="advanced", | |
format="markdown" | |
) | |
], | |
memory=AgentMemory( | |
db=storage, | |
create_user_memories=True, | |
create_session_summary=True | |
), | |
show_tool_calls=False, | |
markdown=False | |
) | |
return {"math": math_agent, "research": research_agent} | |
# LangGraph tools | |
def multiply(a: int, b: int) -> int: | |
"""Multiply two numbers.""" | |
return a * b | |
def add(a: int, b: int) -> int: | |
"""Add two numbers.""" | |
return a + b | |
def subtract(a: int, b: int) -> int: | |
"""Subtract two numbers.""" | |
return a - b | |
def divide(a: int, b: int) -> float: | |
"""Divide two numbers.""" | |
if b == 0: | |
raise ValueError("Cannot divide by zero.") | |
return a / b | |
def modulus(a: int, b: int) -> int: | |
"""Get the remainder of division.""" | |
return a % b | |
def optimized_web_search(query: str) -> str: | |
"""Optimized Tavily web search.""" | |
try: | |
time.sleep(random.uniform(1, 2)) | |
docs = TavilySearchResults(max_results=2).invoke(query=query) | |
return "\n\n---\n\n".join( | |
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>" | |
for d in docs | |
) | |
except Exception as e: | |
return f"Web search failed: {e}" | |
def optimized_wiki_search(query: str) -> str: | |
"""Optimized Wikipedia search.""" | |
try: | |
time.sleep(random.uniform(0.5, 1)) | |
docs = WikipediaLoader(query=query, load_max_docs=1).load() | |
return "\n\n---\n\n".join( | |
f"<Doc src='{d.metadata['source']}'>{d.page_content[:800]}</Doc>" | |
for d in docs | |
) | |
except Exception as e: | |
return f"Wikipedia search failed: {e}" | |
# FAISS setup | |
def setup_faiss(): | |
try: | |
schema = """ | |
{ | |
page_content: .Question, | |
metadata: { task_id: .task_id, Final_answer: ."Final answer" } | |
} | |
""" | |
loader = JSONLoader(file_path="metadata.jsonl", jq_schema=schema, json_lines=True, text_content=False) | |
docs = loader.load() | |
splitter = RecursiveCharacterTextSplitter(chunk_size=256, chunk_overlap=50) | |
chunks = splitter.split_documents(docs) | |
embeds = NVIDIAEmbeddings( | |
model="nvidia/nv-embedqa-e5-v5", | |
api_key=os.getenv("NVIDIA_API_KEY") | |
) | |
return FAISS.from_documents(chunks, embeds) | |
except Exception as e: | |
print(f"FAISS setup failed: {e}") | |
return None | |
class EnhancedAgentState(TypedDict): | |
messages: Annotated[List[HumanMessage|AIMessage], operator.add] | |
query: str | |
agent_type: str | |
final_answer: str | |
perf: Dict[str,Any] | |
agno_resp: str | |
class HybridLangGraphAgnoSystem: | |
def __init__(self): | |
self.agno = create_agno_agents() | |
self.store = setup_faiss() | |
self.tools = [ | |
multiply, add, subtract, divide, modulus, | |
optimized_web_search, optimized_wiki_search | |
] | |
if self.store: | |
retr = self.store.as_retriever(search_type="similarity", search_kwargs={"k":2}) | |
self.tools.append(create_retriever_tool( | |
retriever=retr, | |
name="Question_Search", | |
description="Retrieve similar questions" | |
)) | |
self.graph = self._build_graph() | |
def _build_graph(self): | |
groq_llm = ChatGroq(model="llama-3.3-70b-versatile", temperature=0) | |
gemini_llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-lite", temperature=0) | |
nvidia_llm = ChatNVIDIA(model="meta/llama-3.1-70b-instruct", temperature=0) | |
def router(st: EnhancedAgentState) -> EnhancedAgentState: | |
q = st["query"].lower() | |
if any(k in q for k in ["calculate","math"]): | |
t = "lg_math" | |
elif any(k in q for k in ["research","analyze"]): | |
t = "agno_research" | |
elif any(k in q for k in ["what is","who is"]): | |
t = "lg_retrieval" | |
else: | |
t = "agno_general" | |
return {**st, "agent_type": t} | |
def lg_math(st: EnhancedAgentState) -> EnhancedAgentState: | |
groq_limiter.wait_if_needed() | |
t0 = time.time() | |
llm = groq_llm.bind_tools([multiply, add, subtract, divide, modulus]) | |
sys = SystemMessage(content="Fast calculator. FINAL ANSWER: [result]") | |
res = llm.invoke([sys, HumanMessage(content=st["query"])]) | |
return {**st, "final_answer": res.content, "perf": {"time": time.time()-t0, "prov":"LG-Groq"}} | |
def agno_research(st: EnhancedAgentState) -> EnhancedAgentState: | |
gemini_limiter.wait_if_needed() | |
t0 = time.time() | |
resp = self.agno["research"].run(st["query"], stream=False) | |
return {**st, "final_answer": resp, "perf": {"time": time.time()-t0, "prov":"Agno-Gemini"}} | |
def lg_retrieval(st: EnhancedAgentState) -> EnhancedAgentState: | |
groq_limiter.wait_if_needed() | |
t0 = time.time() | |
llm = groq_llm.bind_tools(self.tools) | |
sys = SystemMessage(content="Retrieve. FINAL ANSWER: [answer]") | |
res = llm.invoke([sys, HumanMessage(content=st["query"])]) | |
return {**st, "final_answer": res.content, "perf": {"time": time.time()-t0, "prov":"LG-Retrieval"}} | |
def agno_general(st: EnhancedAgentState) -> EnhancedAgentState: | |
nvidia_limiter.wait_if_needed() | |
t0 = time.time() | |
if any(k in st["query"].lower() for k in ["calculate","compute"]): | |
resp = self.agno["math"].run(st["query"], stream=False) | |
else: | |
resp = self.agno["research"].run(st["query"], stream=False) | |
return {**st, "final_answer": resp, "perf": {"time": time.time()-t0, "prov":"Agno-General"}} | |
def pick(st: EnhancedAgentState) -> str: | |
return st["agent_type"] | |
g = StateGraph(EnhancedAgentState) | |
g.add_node("router", router) | |
g.add_node("lg_math", lg_math) | |
g.add_node("agno_research", agno_research) | |
g.add_node("lg_retrieval", lg_retrieval) | |
g.add_node("agno_general", agno_general) | |
g.set_entry_point("router") | |
g.add_conditional_edges("router", pick, { | |
"lg_math":"lg_math", | |
"agno_research":"agno_research", | |
"lg_retrieval":"lg_retrieval", | |
"agno_general":"agno_general" | |
}) | |
for n in ["lg_math","agno_research","lg_retrieval","agno_general"]: | |
g.add_edge(n, "END") | |
return g.compile(checkpointer=MemorySaver()) | |
def process_query(self, q: str) -> Dict[str,Any]: | |
state = { | |
"messages":[HumanMessage(content=q)], | |
"query":q, "agent_type":"", "final_answer":"", "perf":{}, "agno_resp":"" | |
} | |
cfg = {"configurable":{"thread_id":f"hyb_{hash(q)}"}} | |
try: | |
out = self.graph.invoke(state, cfg) | |
return { | |
"answer": out["final_answer"], | |
"performance_metrics": out["perf"], | |
"provider_used": out["perf"].get("prov") | |
} | |
except Exception as e: | |
return {"answer":f"Error: {e}", "performance_metrics":{}, "provider_used":"Error"} | |
def build_graph(provider: str="hybrid"): | |
if provider=="hybrid": | |
return HybridLangGraphAgnoSystem().graph | |
raise ValueError("Only 'hybrid' supported") | |
if __name__ == "__main__": | |
graph = build_graph() | |
msgs = [HumanMessage(content="What are the names of the US presidents who were assassinated?")] | |
res = graph.invoke({"messages":msgs},{"configurable":{"thread_id":"test"}}) | |
for m in res["messages"]: | |
m.pretty_print() | |