Spaces:
Sleeping
Sleeping
import os | |
import time | |
import random | |
from dotenv import load_dotenv | |
from typing import List, Dict, Any, TypedDict, Annotated | |
import operator | |
from langchain_core.tools import tool | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain_text_splitters import RecursiveCharacterTextSplitter | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_community.embeddings import SentenceTransformerEmbeddings | |
from langgraph.graph import StateGraph, START, END | |
from langgraph.checkpoint.memory import MemorySaver | |
# Load environment variables | |
load_dotenv() | |
# ---- Tool Definitions ---- | |
def multiply(a: int, b: int) -> int: | |
"""Multiply two integers and return the product.""" | |
return a * b | |
def add(a: int, b: int) -> int: | |
"""Add two integers and return the sum.""" | |
return a + b | |
def subtract(a: int, b: int) -> int: | |
"""Subtract the second integer from the first and return the difference.""" | |
return a - b | |
def divide(a: int, b: int) -> float: | |
"""Divide the first integer by the second and return the quotient.""" | |
if b == 0: | |
raise ValueError("Cannot divide by zero.") | |
return a / b | |
def modulus(a: int, b: int) -> int: | |
"""Return the remainder of the division of the first integer by the second.""" | |
return a % b | |
def optimized_web_search(query: str) -> str: | |
"""Perform an optimized web search using TavilySearchResults and return concatenated document snippets.""" | |
try: | |
time.sleep(random.uniform(1, 2)) | |
search_tool = TavilySearchResults(max_results=2) | |
docs = search_tool.invoke({"query": query}) | |
return "\n\n---\n\n".join( | |
f"<Doc url='{d.get('url','')}'>{d.get('content','')[:500]}</Doc>" | |
for d in docs | |
) | |
except Exception as e: | |
return f"Web search failed: {e}" | |
def optimized_wiki_search(query: str) -> str: | |
"""Perform an optimized Wikipedia search and return concatenated document snippets.""" | |
try: | |
time.sleep(random.uniform(0.5, 1)) | |
docs = WikipediaLoader(query=query, load_max_docs=1).load() | |
return "\n\n---\n\n".join( | |
f"<Doc src='{d.metadata.get('source', 'Wikipedia')}'>{d.page_content[:800]}</Doc>" | |
for d in docs | |
) | |
except Exception as e: | |
return f"Wikipedia search failed: {e}" | |
# ---- LLM Integrations with Error Handling ---- | |
try: | |
from langchain_groq import ChatGroq | |
GROQ_AVAILABLE = True | |
except ImportError: | |
GROQ_AVAILABLE = False | |
import requests | |
def deepseek_generate(prompt, api_key=None): | |
"""Call DeepSeek API directly.""" | |
if not api_key: | |
return "DeepSeek API key not provided" | |
url = "https://api.deepseek.com/v1/chat/completions" | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": "deepseek-chat", | |
"messages": [{"role": "user", "content": prompt}], | |
"stream": False | |
} | |
try: | |
resp = requests.post(url, headers=headers, json=data, timeout=30) | |
resp.raise_for_status() | |
choices = resp.json().get("choices", []) | |
if choices and "message" in choices[0]: | |
return choices[0]["message"].get("content", "") | |
return "No response from DeepSeek" | |
except Exception as e: | |
return f"DeepSeek API error: {e}" | |
def baidu_ernie_generate(prompt, api_key=None): | |
"""Call Baidu ERNIE API.""" | |
if not api_key: | |
return "Baidu ERNIE API key not provided" | |
# Baidu ERNIE API endpoint (replace with actual endpoint) | |
url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {api_key}" | |
} | |
data = { | |
"messages": [{"role": "user", "content": prompt}], | |
"temperature": 0.1, | |
"top_p": 0.8 | |
} | |
try: | |
resp = requests.post(url, headers=headers, json=data, timeout=30) | |
resp.raise_for_status() | |
result = resp.json().get("result", "") | |
return result if result else "No response from Baidu ERNIE" | |
except Exception as e: | |
return f"Baidu ERNIE API error: {e}" | |
# ---- Graph State ---- | |
class EnhancedAgentState(TypedDict): | |
messages: Annotated[List[HumanMessage|AIMessage], operator.add] | |
query: str | |
agent_type: str | |
final_answer: str | |
perf: Dict[str,Any] | |
agno_resp: str | |
class HybridLangGraphMultiLLMSystem: | |
def __init__(self, provider="groq"): | |
self.provider = provider | |
self.tools = [ | |
multiply, add, subtract, divide, modulus, | |
optimized_web_search, optimized_wiki_search | |
] | |
self.graph = self._build_graph() | |
def _build_graph(self): | |
# Initialize Groq LLM with error handling | |
groq_llm = None | |
if GROQ_AVAILABLE and os.getenv("GROQ_API_KEY"): | |
try: | |
# Use Groq for multiple model access | |
groq_llm = ChatGroq( | |
model="llama-3.1-70b-versatile", # Updated to a current model | |
temperature=0, | |
api_key=os.getenv("GROQ_API_KEY") | |
) | |
except Exception as e: | |
print(f"Failed to initialize Groq: {e}") | |
def router(st: EnhancedAgentState) -> EnhancedAgentState: | |
q = st["query"].lower() | |
if "groq" in q and groq_llm: | |
t = "groq" | |
elif "deepseek" in q: | |
t = "deepseek" | |
elif "ernie" in q or "baidu" in q: | |
t = "baidu" | |
else: | |
# Default to first available provider | |
if groq_llm: | |
t = "groq" | |
elif os.getenv("DEEPSEEK_API_KEY"): | |
t = "deepseek" | |
else: | |
t = "baidu" | |
return {**st, "agent_type": t} | |
def groq_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
if not groq_llm: | |
return {**st, "final_answer": "Groq not available", "perf": {"error": "No Groq LLM"}} | |
t0 = time.time() | |
try: | |
sys = SystemMessage(content="You are a helpful AI assistant. Provide accurate and detailed answers. Be concise but thorough.") | |
res = groq_llm.invoke([sys, HumanMessage(content=st["query"])]) | |
return {**st, "final_answer": res.content, "perf": {"time": time.time() - t0, "prov": "Groq"}} | |
except Exception as e: | |
return {**st, "final_answer": f"Groq error: {e}", "perf": {"error": str(e)}} | |
def deepseek_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
t0 = time.time() | |
try: | |
prompt = f"You are a helpful AI assistant. Provide accurate and detailed answers. Be concise but thorough.\n\nUser question: {st['query']}" | |
resp = deepseek_generate(prompt, api_key=os.getenv("DEEPSEEK_API_KEY")) | |
return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "DeepSeek"}} | |
except Exception as e: | |
return {**st, "final_answer": f"DeepSeek error: {e}", "perf": {"error": str(e)}} | |
def baidu_node(st: EnhancedAgentState) -> EnhancedAgentState: | |
t0 = time.time() | |
try: | |
prompt = f"You are a helpful AI assistant. Provide accurate and detailed answers. Be concise but thorough.\n\nUser question: {st['query']}" | |
resp = baidu_ernie_generate(prompt, api_key=os.getenv("BAIDU_API_KEY")) | |
return {**st, "final_answer": resp, "perf": {"time": time.time() - t0, "prov": "Baidu ERNIE"}} | |
except Exception as e: | |
return {**st, "final_answer": f"Baidu ERNIE error: {e}", "perf": {"error": str(e)}} | |
def pick(st: EnhancedAgentState) -> str: | |
return st["agent_type"] | |
g = StateGraph(EnhancedAgentState) | |
g.add_node("router", router) | |
g.add_node("groq", groq_node) | |
g.add_node("deepseek", deepseek_node) | |
g.add_node("baidu", baidu_node) | |
g.set_entry_point("router") | |
g.add_conditional_edges("router", pick, { | |
"groq": "groq", | |
"deepseek": "deepseek", | |
"baidu": "baidu" | |
}) | |
for n in ["groq", "deepseek", "baidu"]: | |
g.add_edge(n, END) | |
return g.compile(checkpointer=MemorySaver()) | |
def process_query(self, q: str) -> str: | |
state = { | |
"messages": [HumanMessage(content=q)], | |
"query": q, | |
"agent_type": "", | |
"final_answer": "", | |
"perf": {}, | |
"agno_resp": "" | |
} | |
cfg = {"configurable": {"thread_id": f"hyb_{hash(q)}"}} | |
try: | |
out = self.graph.invoke(state, cfg) | |
raw_answer = out.get("final_answer", "No answer generated") | |
# Clean up the answer | |
if isinstance(raw_answer, str): | |
return raw_answer.strip() | |
return str(raw_answer) | |
except Exception as e: | |
return f"Error processing query: {e}" | |
# Function expected by app.py | |
def build_graph(provider="groq"): | |
"""Build and return the graph for the agent system.""" | |
system = HybridLangGraphMultiLLMSystem(provider=provider) | |
return system.graph | |
if __name__ == "__main__": | |
query = "What are the main benefits of using multiple LLM providers?" | |
system = HybridLangGraphMultiLLMSystem() | |
result = system.process_query(query) | |
print("LangGraph Multi-LLM Result:", result) | |