File size: 7,354 Bytes
401799d
 
81d00fe
bc3bc22
9bd791c
f622879
401799d
 
 
 
 
 
f622879
401799d
 
 
81d00fe
 
 
 
 
401799d
 
 
 
 
 
 
 
 
 
bc3bc22
 
 
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d00fe
401799d
 
 
81d00fe
401799d
 
 
 
9bd791c
 
 
 
 
 
 
401799d
81d00fe
 
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bd791c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401799d
9bd791c
 
 
 
 
 
 
 
401799d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d00fe
9bd791c
 
401799d
 
9bd791c
401799d
 
 
 
 
 
 
 
9bd791c
401799d
 
 
 
9bd791c
401799d
81d00fe
401799d
81d00fe
 
401799d
 
 
 
 
 
 
 
 
 
81d00fe
401799d
9bd791c
 
 
 
 
 
 
 
 
 
 
 
401799d
9bd791c
401799d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
"""Define the agent graph and its components."""

import logging
import os
from datetime import datetime
from typing import Dict, List, Optional, TypedDict, Union

import yaml
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph
from langgraph.types import interrupt
from smolagents import CodeAgent, LiteLLMModel

from configuration import Configuration
from tools import tools

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Enable LiteLLM debug logging only if environment variable is set
import litellm

if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
    litellm.set_verbose = True
    logger.setLevel(logging.DEBUG)
else:
    litellm.set_verbose = False
    logger.setLevel(logging.INFO)

# Configure LiteLLM to drop unsupported parameters
litellm.drop_params = True

# Load default prompt templates from local file
current_dir = os.path.dirname(os.path.abspath(__file__))
prompts_dir = os.path.join(current_dir, "prompts")
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")

with open(yaml_path, "r") as f:
    prompt_templates = yaml.safe_load(f)

# Initialize the model and agent using configuration
config = Configuration()
model = LiteLLMModel(
    api_base=config.api_base,
    api_key=config.api_key,
    model_id=config.model_id,
)

agent = CodeAgent(
    add_base_tools=True,
    max_steps=1,  # Execute one step at a time
    model=model,
    prompt_templates=prompt_templates,
    tools=tools,
    verbosity_level=logging.DEBUG,
)


class AgentState(TypedDict):
    """State for the agent graph."""

    messages: List[Union[HumanMessage, AIMessage, SystemMessage]]
    question: str
    answer: Optional[str]
    step_logs: List[Dict]
    is_complete: bool
    step_count: int
    # Add memory-related fields
    context: Dict[str, any]  # For storing contextual information
    memory_buffer: List[Dict]  # For storing important information across steps
    last_action: Optional[str]  # Track the last action taken
    action_history: List[Dict]  # History of actions taken
    error_count: int  # Track error frequency
    success_count: int  # Track successful operations


class AgentNode:
    """Node that runs the agent."""

    def __init__(self, agent: CodeAgent):
        """Initialize the agent node with an agent."""
        self.agent = agent

    def __call__(
        self, state: AgentState, config: Optional[RunnableConfig] = None
    ) -> AgentState:
        """Run the agent on the current state."""
        # Log current state
        logger.info("Current state before processing:")
        logger.info(f"Messages: {state['messages']}")
        logger.info(f"Question: {state['question']}")
        logger.info(f"Answer: {state['answer']}")

        # Get configuration
        cfg = Configuration.from_runnable_config(config)
        logger.info(f"Using configuration: {cfg}")

        # Log execution start
        logger.info("Starting agent execution")

        try:
            # Run the agent
            result = self.agent.run(state["question"])

            # Update memory-related fields
            new_state = state.copy()
            new_state["messages"].append(AIMessage(content=result))
            new_state["answer"] = result
            new_state["step_count"] += 1
            new_state["last_action"] = "agent_response"
            new_state["action_history"].append(
                {
                    "step": state["step_count"],
                    "action": "agent_response",
                    "result": result,
                }
            )
            new_state["success_count"] += 1

            # Store important information in memory buffer
            if result:
                new_state["memory_buffer"].append(
                    {
                        "step": state["step_count"],
                        "content": result,
                        "timestamp": datetime.now().isoformat(),
                    }
                )

        except Exception as e:
            logger.error(f"Error during agent execution: {str(e)}")
            new_state = state.copy()
            new_state["error_count"] += 1
            new_state["action_history"].append(
                {"step": state["step_count"], "action": "error", "error": str(e)}
            )
            raise

        # Log updated state
        logger.info("Updated state after processing:")
        logger.info(f"Messages: {new_state['messages']}")
        logger.info(f"Question: {new_state['question']}")
        logger.info(f"Answer: {new_state['answer']}")

        return new_state


class StepCallbackNode:
    """Node that handles step callbacks and user interaction."""

    def __call__(
        self, state: AgentState, config: Optional[RunnableConfig] = None
    ) -> AgentState:
        """Handle step callback and user interaction."""
        # Get configuration
        cfg = Configuration.from_runnable_config(config)

        # Log the step
        step_log = {
            "step": state["step_count"],
            "messages": [msg.content for msg in state["messages"]],
            "question": state["question"],
            "answer": state["answer"],
        }
        state["step_logs"].append(step_log)

        try:
            # Use interrupt for user input and unpack the tuple
            interrupt_result = interrupt(
                "Press 'c' to continue, 'q' to quit, or 'i' for more info: "
            )
            user_input = interrupt_result[0]  # Get the actual user input

            if user_input.lower() == "q":
                state["is_complete"] = True
                return state
            elif user_input.lower() == "i":
                logger.info(f"Current step: {state['step_count']}")
                logger.info(f"Question: {state['question']}")
                logger.info(f"Current answer: {state['answer']}")
                return state
            elif user_input.lower() == "c":
                return state
            else:
                logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
                return state

        except Exception as e:
            logger.warning(f"Error during interrupt: {str(e)}")
            return state


def build_agent_graph(agent: AgentNode) -> StateGraph:
    """Build the agent graph."""
    # Initialize the graph
    workflow = StateGraph(AgentState)

    # Add nodes
    workflow.add_node("agent", agent)
    workflow.add_node("callback", StepCallbackNode())

    # Add edges
    workflow.add_edge("agent", "callback")

    # Add conditional edges for callback
    def should_continue(state: AgentState) -> str:
        """Determine the next node based on state."""
        if state["is_complete"]:
            return END
        # If we have an answer and no errors, we're done
        if state["answer"] and state["error_count"] == 0:
            return END
        # Otherwise continue to agent
        return "agent"

    workflow.add_conditional_edges(
        "callback", should_continue, {END: END, "agent": "agent"}
    )

    # Set entry point
    workflow.set_entry_point("agent")

    return workflow.compile()


# Initialize the agent graph
agent_graph = build_agent_graph(AgentNode(agent))