Spaces:
Sleeping
Sleeping
import os | |
import time | |
import random | |
import operator | |
from typing import List, Dict, Any, TypedDict, Annotated | |
from dotenv import load_dotenv | |
from langchain_core.tools import tool | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader | |
from langgraph.graph import StateGraph, END | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_groq import ChatGroq | |
load_dotenv() # expects GROQ_API_KEY in your .env | |
def multiply(a: int, b: int) -> int: return a * b | |
def add(a: int, b: int) -> int: return a + b | |
def subtract(a: int, b: int) -> int: return a - b | |
def divide(a: int, b: int) -> float: | |
if b == 0: | |
raise ValueError("Cannot divide by zero.") | |
return a / b | |
def modulus(a: int, b: int) -> int: return a % b | |
def optimized_web_search(query: str) -> str: | |
try: | |
time.sleep(random.uniform(0.7, 1.5)) | |
docs = TavilySearchResults(max_results=2).invoke(query=query) | |
return "\n\n---\n\n".join( | |
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>" | |
for d in docs | |
) | |
except Exception as e: | |
return f"Web search failed: {e}" | |
def optimized_wiki_search(query: str) -> str: | |
try: | |
time.sleep(random.uniform(0.3, 1)) | |
docs = WikipediaLoader(query=query, load_max_docs=1).load() | |
return "\n\n---\n\n".join( | |
f"<Doc src='{d.metadata.get('source','Wikipedia')}'>{d.page_content[:800]}</Doc>" | |
for d in docs | |
) | |
except Exception as e: | |
return f"Wikipedia search failed: {e}" | |
class EnhancedAgentState(TypedDict): | |
messages: Annotated[List[HumanMessage | AIMessage], operator.add] | |
query: str | |
agent_type: str | |
final_answer: str | |
perf: Dict[str, Any] | |
agno_resp: str | |
class HybridLangGraphMultiLLMSystem: | |
""" | |
Router that picks between Groq-hosted Llama-3 8B, Llama-3 70B (default), | |
and Groq-hosted DeepSeek-Chat according to the query content. | |
""" | |
def __init__(self): | |
self.tools = [ | |
multiply, add, subtract, divide, modulus, | |
optimized_web_search, optimized_wiki_search | |
] | |
self.graph = self._build_graph() | |
def _llm(self, model_name: str): | |
return ChatGroq( | |
model=model_name, | |
temperature=0, | |
api_key=os.getenv("GROQ_API_KEY") | |
) | |
def _build_graph(self): | |
llama8_llm = self._llm("llama3-8b-8192") | |
llama70_llm = self._llm("llama3-70b-8192") | |
deepseek_llm = self._llm("deepseek-chat") | |
def router(st: EnhancedAgentState) -> EnhancedAgentState: | |
q = st["query"].lower() | |
if "llama-8" in q: | |
t = "llama8" | |
elif "deepseek" in q: | |
t = "deepseek" | |
else: | |
t = "llama70" | |
return {**st, "agent_type": t} | |
def llama8_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
t0 = time.time() | |
sys = SystemMessage(content="You are a helpful AI assistant.") | |
res = llama8_llm.invoke([sys, HumanMessage(content=st["query"])]) | |
return {**st, | |
"final_answer": res.content, | |
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-8B"}} | |
def llama70_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
t0 = time.time() | |
sys = SystemMessage(content="You are a helpful AI assistant.") | |
res = llama70_llm.invoke([sys, HumanMessage(content=st["query"])]) | |
return {**st, | |
"final_answer": res.content, | |
"perf": {"time": time.time() - t0, "prov": "Groq-Llama3-70B"}} | |
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
t0 = time.time() | |
sys = SystemMessage(content="You are a helpful AI assistant.") | |
res = deepseek_llm.invoke([sys, HumanMessage(content=st["query"])]) | |
return {**st, | |
"final_answer": res.content, | |
"perf": {"time": time.time() - t0, "prov": "Groq-DeepSeek"}} | |
g = StateGraph(EnhancedAgentState) | |
g.add_node("router", router) | |
g.add_node("llama8", llama8_node) | |
g.add_node("llama70", llama70_node) | |
g.add_node("deepseek", deepseek_node) | |
g.set_entry_point("router") | |
g.add_conditional_edges("router", lambda s: s["agent_type"], | |
{"llama8": "llama8", "llama70": "llama70", "deepseek": "deepseek"}) | |
g.add_edge("llama8", END) | |
g.add_edge("llama70", END) | |
g.add_edge("deepseek", END) | |
return g.compile(checkpointer=MemorySaver()) | |
def process_query(self, q: str) -> str: | |
state = { | |
"messages": [HumanMessage(content=q)], | |
"query": q, | |
"agent_type": "", | |
"final_answer": "", | |
"perf": {}, | |
"agno_resp": "" | |
} | |
cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}} | |
out = self.graph.invoke(state, cfg) | |
return out.get("final_answer", "").strip() | |
def build_graph(provider: str | None = None): | |
return HybridLangGraphMultiLLMSystem().graph | |
if __name__ == "__main__": | |
qa_system = HybridLangGraphMultiLLMSystem() | |
# Test each model | |
print(qa_system.process_query("llama-8: What is the capital of France?")) | |
print(qa_system.process_query("llama-70: Tell me about quantum mechanics.")) | |
print(qa_system.process_query("deepseek: What is the Riemann Hypothesis?")) | |