File size: 6,796 Bytes
20fac12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1ede08
 
 
892cddc
f1ede08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20fac12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1ede08
20fac12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1ede08
20fac12
f1ede08
 
 
 
 
 
 
 
 
20fac12
 
f1ede08
20fac12
 
 
74c5c11
f1ede08
20fac12
 
 
 
 
 
 
 
 
 
f1ede08
20fac12
 
 
 
 
f1ede08
20fac12
f1ede08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import json

from rich import print as rich_print
from rich.panel import Panel
from rich.console import Console
from rich.pretty import Pretty
from rich.markdown import Markdown
from rich.json import JSON

from typing import TypedDict, Sequence, Annotated
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from tqdm import tqdm


def print_conversation(messages):
    console = Console(width=200, soft_wrap=True)
    
    for msg in messages:
        role = msg.get("role", "unknown").capitalize()
        content = msg.get("content", "")

        try:
            parsed_json = json.loads(content)
            rendered_content = JSON.from_data(parsed_json)
        except (json.JSONDecodeError, TypeError):
            rendered_content = Markdown(content.strip())

        panel = Panel(
            rendered_content,
            title=f"[bold blue]{role}[/]",
            border_style="green" if role == "User" else "magenta",
            expand=True
        )

        console.print(panel)


def generate_final_answer(qa: dict[str, str]) -> str:
    """Invokes gpt-4o-mini to extract generate a final answer based on the content query, response, and metadata"""

    final_answer_llm = ChatOpenAI(model="gpt-4o", temperature=0)

    system_prompt = (
        "You will receive a JSON string containing a user's query, a response, and metadata. "
        "Extract and return only the final answer to the query as a plain string. "
        "Do not return anything else. "
        "Avoid any labels, prefixes, or explanation. "
        "Return only the exact value that satisfies the query, suitable for string comparison."
        "If the query is not answerable due to a missing file in the input and is reflected in the response, answer with 'File not found'. "
    )

    system_message = SystemMessage(content=system_prompt)
    messages = [
        system_message,
        HumanMessage(content=f'Generate the final answer for the following query:\n\n{json.dumps(qa)}')
    ]

    response = final_answer_llm.invoke(messages)

    return response.content
    

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]


class BasicOpenAIAgentWorkflow:
    """Basic custom class from an agent prompted for tool-use pattern"""
    
    def __init__(self, tools: list, model='gpt-4o', backstory:str="", streaming=False):
        self.name = "Basic OpenAI Agent Workflow"
        self.tools = tools
        self.llm = ChatOpenAI(model=model, temperature=0, streaming=streaming)
        self.graph = None
        self.history = []
        self.history_messages = []  # Store messages in LangChain format
        self.backstory = backstory if backstory else "You are a helpful assistant that can use tools to answer questions. Your name is Gaia."

        role_message = {'role': 'system', 'content': self.backstory}
        self.history.append(role_message)

    
    def _call_llm(self, state: AgentState):
        """invokes the assigned llm"""
        return {'messages': [self.llm.invoke(state['messages'])]}


    def _convert_history_to_messages(self):
        """Convert self.history to LangChain-compatible messages"""
        converted = []
        for msg in self.history:
            content = msg['content']

            if not isinstance(content, str):
                raise ValueError(f"Expected string content, got: {type(content)}{content}")

            if msg['role'] == 'user':
                converted.append(HumanMessage(content=content))
            elif msg['role'] == 'assistant':
                converted.append(AIMessage(content=content))
            elif msg['role'] == 'system':
                converted.append(SystemMessage(content=content))
            else:
                raise ValueError(f"Unknown role in message: {msg}")
        self.history_messages = converted

    
    def create_basic_tool_use_agent_state_graph(self, custom_tools_nm="tools"):
        """Binds tools, creates and compiles graph"""
        self.llm = self.llm.bind_tools(self.tools)

        # Graph Init
        graph = StateGraph(AgentState)
        
        # Nodes
        graph.add_node('agent', self._call_llm)
        tools_node = ToolNode(self.tools)
        graph.add_node(custom_tools_nm, tools_node)
        
        # Edges
        graph.add_edge(START, "agent")
        graph.add_conditional_edges('agent', tools_condition, {'tools': custom_tools_nm, END: END})
        
        self.graph = graph.compile()


    def chat(self, query, verbose=2, only_final_answer=False):
        """Simple agent call"""
        if isinstance(query, dict):
            query = query["messages"]

        user_message = {'role': 'user', 'content': query}
        self.history.append(user_message)

        # Ensure history has at least 1 message
        if not self.history:
            raise ValueError("History is empty. Cannot proceed.")

        self._convert_history_to_messages()

        if not self.history_messages:
            raise ValueError("Converted message history is empty. Something went wrong.")

        response = self.graph.invoke({'messages': self.history_messages})  # invoke with all the history to keep context (dummy mem)
        response = response['messages'][-1].content
        
        if only_final_answer:
            final_answer_content = {
                'query': query,
                'response': response,
                'metadata': {}
            }
            response = generate_final_answer(final_answer_content)

        assistant_message = {'role': 'assistant', 'content': response}
        self.history.append(assistant_message)

        if verbose==2:
            print_conversation(self.history)
        elif verbose==1:
            print_conversation([assistant_message])

        return response


    def invoke(self, input_str: str):
        """Invoke the compiled graph with the input data"""
        _ = self.chat(input_str)  # prints response in terminal
        self._convert_history_to_messages()
        return {'messages': self.history_messages}


    def chat_batch(self, queries=None, only_final_answer=False):
        """Send several simple agent calls to the llm using the compiled graph"""
        if queries is None:
            queries = []
        for i, query in tqdm(enumerate(queries, start=1)):
            if i == len(queries):
                self.chat(query, verbose=2, only_final_answer=only_final_answer)
            else:
                self.chat(query, verbose=0, only_final_answer=only_final_answer)