File size: 3,347 Bytes
2a28594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b3b414
 
 
2a28594
 
 
 
 
 
 
 
 
 
 
9b3b414
 
 
2a28594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b3b414
 
 
2a28594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph
from langgraph.graph.message import MessagesState
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import ToolNode, tools_condition

from chattr import (
    ASSETS_DIR,
    MODEL_API_KEY,
    MODEL_NAME,
    MODEL_TEMPERATURE,
    MODEL_URL,
)

SYSTEM_MESSAGE: SystemMessage = SystemMessage(
    content="You are a helpful assistant that can answer questions about the time."
)


async def create_graph() -> CompiledStateGraph:
    """
    Asynchronously creates and compiles a conversational state graph for a time-answering assistant with integrated external tools.

    Returns:
        CompiledStateGraph: The compiled state graph ready for execution, with nodes for agent responses and tool invocation.
    """
    _mcp_client = MultiServerMCPClient(
        {
            "time": {
                "command": "docker",
                "args": ["run", "-i", "--rm", "mcp/time"],
                "transport": "stdio",
            }
        }
    )
    _tools: list[BaseTool] = await _mcp_client.get_tools()
    try:
        _model: ChatOpenAI = ChatOpenAI(
            base_url=MODEL_URL,
            model=MODEL_NAME,
            api_key=MODEL_API_KEY,
            temperature=MODEL_TEMPERATURE,
        )
        _model = _model.bind_tools(_tools, parallel_tool_calls=False)
    except Exception as e:
        raise RuntimeError(
            f"Failed to initialize ChatOpenAI model: {e}"
        ) from e

    def call_model(state: MessagesState) -> MessagesState:
        """
        Generate a new message state by invoking the chat model with the system message prepended to the current messages.

        Parameters:
            state (MessagesState): The current state containing a list of messages.

        Returns:
            MessagesState: A new state with the model's response appended to the messages.
        """
        return {
            "messages": [_model.invoke([SYSTEM_MESSAGE] + state["messages"])]
        }

    _builder: StateGraph = StateGraph(MessagesState)
    _builder.add_node("agent", call_model)
    _builder.add_node("tools", ToolNode(_tools))
    _builder.add_edge(START, "agent")
    _builder.add_conditional_edges("agent", tools_condition)
    _builder.add_edge("tools", "agent")
    graph: CompiledStateGraph = _builder.compile()
    return graph


def draw_graph(graph: CompiledStateGraph) -> None:
    """
    Render the compiled state graph as a Mermaid PNG image and save it to the assets directory.
    """
    graph.get_graph().draw_mermaid_png(
        output_file_path=ASSETS_DIR / "graph.png"
    )


if __name__ == "__main__":
    import asyncio

    async def test_graph():
        """
        Asynchronously creates and tests the conversational state graph by sending a time-related query and printing the resulting messages.
        """
        g: CompiledStateGraph = await create_graph()

        messages = await g.ainvoke(
            {"messages": [HumanMessage(content="What is the time?")]}
        )

        for m in messages["messages"]:
            m.pretty_print()

    asyncio.run(test_graph())