Spaces:
Sleeping
Sleeping
File size: 4,437 Bytes
016be3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
from typing import Dict, TypedDict, List, Annotated, Literal, Union, Any
from .tools import get_tools
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langgraph.graph import StateGraph, START, END
import operator
from langchain_core.messages import (
AIMessage,
HumanMessage,
SystemMessage,
ToolMessage,
FunctionMessage,
)
from langchain_core.tools import BaseTool, tool
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
import json
# State definition
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
# Initialize tools
tools = get_tools()
# System prompt
system_prompt = """You are an AI research assistant specialized in {domain}.
Your goal is to help users find accurate information about {domain} topics.
You have access to the following tools:
1. Web Search - For general queries and recent information
2. Research Paper Search - For academic and scientific information
3. Wikipedia Search - For comprehensive background information and factual summaries
4. Data Analysis - For analyzing data provided by the user
Choose the most appropriate tool(s) based on the user's question:
- Use Web Search for current events, recent developments, or general information
- Use Research Paper Search for academic knowledge, scientific findings, or technical details
- Use Wikipedia Search for conceptual explanations, definitions, historical context, or general facts
- Use Data Analysis when the user provides data to be analyzed
Always try to provide the most accurate and helpful information.
When responding, cite your sources appropriately."""
# Function to create the system message
def create_system_message(domain):
return SystemMessage(content=system_prompt.format(domain=domain))
# Create the graph
def create_agent_graph(domain="general research"):
"""
Create a LangGraph for the research agent using prebuilt components
"""
# Initialize the graph with the state
workflow = StateGraph(AgentState)
# Add system message with domain context
system_prompt_message = create_system_message(domain)
# Agent node function
def agent_node(state: AgentState):
messages = state["messages"]
if len(messages) == 0 or not isinstance(messages[0], SystemMessage):
messages = [system_prompt_message] + messages
# Create model and bind tools
model = ChatOpenAI(model="gpt-4o", temperature=0)
model_with_tools = model.bind_tools(tools)
# Generate response with tools
return {"messages": [model_with_tools.invoke(messages)]}
# Add nodes
workflow.add_node("agent", agent_node)
# Use prebuilt ToolNode
tool_node = ToolNode(tools=tools)
workflow.add_node("tools", tool_node)
# Add conditional edges using prebuilt tools_condition
workflow.add_conditional_edges(
"agent",
tools_condition,
{
"tools": "tools",
END: END
}
)
# Add edge back to agent after tools execution
workflow.add_edge("tools", "agent")
# Set the entry point
workflow.add_edge(START, "agent")
# Compile the graph
return workflow.compile()
# Function to run the agent
def run_agent(user_input, domain="general research", messages=None):
"""
Run the agent with a user input
"""
# Create the graph
graph = create_agent_graph(domain)
# Initialize messages if not provided
if messages is None:
messages = [HumanMessage(content=user_input)]
else:
messages.append(HumanMessage(content=user_input))
# Run the graph
result = graph.invoke({"messages": messages})
return result["messages"]
if __name__ == "__main__":
# Test the agent
domain = "artificial intelligence"
query = "What are the latest developments in natural language processing?"
messages = run_agent(query, domain)
for message in messages:
if isinstance(message, AIMessage):
print("AI:", message.content)
elif isinstance(message, HumanMessage):
print("Human:", message.content)
elif isinstance(message, ToolMessage):
print(f"Tool ({message.name}):", message.content[:100] + "..." if len(message.content) > 100 else message.content) |