File size: 5,260 Bytes
9e93462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import time
from gradio import ChatMessage
from langchain_core.runnables import RunnableConfig
from langchain_teddynote.messages import random_uuid
from langchain_core.messages import BaseMessage, HumanMessage
from pprint import pprint

def format_namespace(namespace):
    return namespace[-1].split(":")[0] if len(namespace) > 0 else "root graph"

from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph_supervisor import create_supervisor
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver, InMemorySaver
from langgraph.store.memory import InMemoryStore

checkpointer = InMemorySaver()
store = InMemoryStore()

model = ChatOpenAI(model="gpt-4o")

# Create specialized agents

def add(a: float, b: float) -> float:
    """Add two numbers."""
    return a + b

def multiply(a: float, b: float) -> float:
    """Multiply two numbers."""
    return a * b

def web_search(query: str) -> str:
    """Search the web for information."""
    return (
        "Here are the headcounts for each of the FAANG companies in 2024:\n"
        "1. **Facebook (Meta)**: 67,317 employees.\n"
        "2. **Apple**: 164,000 employees.\n"
        "3. **Amazon**: 1,551,000 employees.\n"
        "4. **Netflix**: 14,000 employees.\n"
        "5. **Google (Alphabet)**: 181,269 employees."
    )

math_agent = create_react_agent(
    model=model,
    tools=[add, multiply],
    name="math_expert",
    prompt="You are a math expert. Always use one tool at a time."
)

research_agent = create_react_agent(
    model=model,
    tools=[web_search],
    name="research_expert",
    prompt="You are a world class researcher with access to web search. Do not do any math."
)

# Create supervisor workflow
workflow = create_supervisor(
    [research_agent, math_agent],
    model=model,
    prompt=(
        "You are a team supervisor managing a research expert and a math expert. "
        "For current events, use research_agent. "
        "For math problems, use math_agent."
    )
)

# Compile and run
app = workflow.compile()

def generate_response(message, history):
    inputs = {
            "messages": [HumanMessage(content=message)],
        }
    node_names = []
    response = []
    for namespace, chunk in app.stream(
            inputs,
            stream_mode="updates", subgraphs=True
        ):
        for node_name, node_chunk in chunk.items():
            # node_namesκ°€ λΉ„μ–΄μžˆμ§€ μ•Šμ€ κ²½μš°μ—λ§Œ 필터링
            if len(node_names) > 0 and node_name not in node_names:
                continue

        if len(response) > 0:
            response[-1].metadata["status"] = "done"
        # print("\n" + "=" * 50)
        msg = []
        formatted_namespace = format_namespace(namespace)
        if formatted_namespace == "root graph":
            print(f"πŸ”„ Node: \033[1;36m{node_name}\033[0m πŸ”„")
            meta_title = f"πŸ€” `{node_name}`"
        else:
            print(
                f"πŸ”„ Node: \033[1;36m{node_name}\033[0m in [\033[1;33m{formatted_namespace}\033[0m] πŸ”„"
            )
            meta_title = f"πŸ€” `{node_name}` in `{formatted_namespace}`"

        response.append(ChatMessage(content="", metadata={"title": meta_title, "status": "pending"}))
        yield response
        print("- " * 25)

        # λ…Έλ“œμ˜ 청크 데이터 좜λ ₯
        out_str = []
        if isinstance(node_chunk, dict):
            for k, v in node_chunk.items():
                if isinstance(v, BaseMessage):
                    v.pretty_print()
                    out_str.append(v.pretty_repr())
                elif isinstance(v, list):
                    for list_item in v:
                        if isinstance(list_item, BaseMessage):
                            list_item.pretty_print()
                            out_str.append(list_item.pretty_repr())
                        else:
                            out_str.append(list_item)
                            print(list_item)
                elif isinstance(v, dict):
                    for node_chunk_key, node_chunk_value in node_chunk.items():
                        out_str.append(f"{node_chunk_key}:\n{node_chunk_value}")
                        print(f"{node_chunk_key}:\n{node_chunk_value}")
                else:
                    out_str.append(f"{k}:\n{v}")
                    print(f"\033[1;32m{k}\033[0m:\n{v}")
                response[-1].content = "\n".join(out_str)
                yield response
        else:
            if node_chunk is not None:
                for item in node_chunk:
                    out_str.append(item)
                    print(item)
                response[-1].content = "\n".join(out_str)
                yield response
        yield response
        print("=" * 50)
    response[-1].metadata["status"] = "done"
    response.append(ChatMessage(content=node_chunk['messages'][-1].content))
    yield response
demo = gr.ChatInterface(
    generate_response,
    type="messages",
    title="Nested Thoughts Chat Interface",
    examples=["2024λ…„μ˜ the combined headcount of the FAANG companiesμˆ˜μΉ˜μ— λŒ€ν•œ 뢄석을 ν•œκ΅­μ–΄λ‘œ 뢀탁해!"]
)

if __name__ == "__main__":
    demo.launch()