File size: 5,377 Bytes
81d00fe
401799d
 
218633c
9bd791c
401799d
81d00fe
 
 
 
 
 
 
 
401799d
 
81d00fe
 
 
 
 
 
401799d
81d00fe
401799d
 
81d00fe
401799d
 
 
 
218633c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d00fe
9bd791c
401799d
 
 
9bd791c
401799d
 
 
 
81d00fe
9bd791c
218633c
 
9bd791c
 
218633c
9bd791c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218633c
 
9bd791c
218633c
9bd791c
218633c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bd791c
 
218633c
9bd791c
218633c
 
 
 
 
 
 
 
 
 
 
 
9bd791c
81d00fe
9bd791c
81d00fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import logging
import os
import uuid

from langgraph.types import Command

from graph import agent_graph

# Configure logging
logging.basicConfig(level=logging.INFO)  # Default to INFO level
logger = logging.getLogger(__name__)

# Enable LiteLLM debug logging only if environment variable is set
import litellm

if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
    litellm.set_verbose = True
    logger.setLevel(logging.DEBUG)
else:
    litellm.set_verbose = False
    logger.setLevel(logging.INFO)


class AgentRunner:
    """Runner class for the code agent."""

    def __init__(self):
        """Initialize the agent runner with graph and tools."""
        logger.info("Initializing AgentRunner")
        self.graph = agent_graph
        self.last_state = None  # Store the last state for testing/debugging
        self.thread_id = str(
            uuid.uuid4()
        )  # Generate a unique thread_id for this runner
        logger.info(f"Created AgentRunner with thread_id: {self.thread_id}")

    def _extract_answer(self, state: dict) -> str:
        """Extract the answer from the state."""
        if not state:
            return None

        # First try to get answer from direct answer field
        if "answer" in state and state["answer"]:
            logger.info(f"Found answer in direct field: {state['answer']}")
            return state["answer"]

        # Then try to get answer from messages
        if "messages" in state and state["messages"]:
            for msg in reversed(state["messages"]):
                if hasattr(msg, "content") and msg.content:
                    logger.info(f"Found answer in message: {msg.content}")
                    return msg.content

        return None

    def __call__(self, input_data) -> str:
        """Process a question through the agent graph and return the answer.

        Args:
            input_data: Either a question string or a Command object for resuming

        Returns:
            str: The agent's response
        """
        try:
            config = {"configurable": {"thread_id": self.thread_id}}
            logger.info(f"Using config: {config}")

            if isinstance(input_data, str):
                # Initial question
                logger.info(f"Processing initial question: {input_data}")
                initial_state = {
                    "question": input_data,
                    "messages": [],
                    "answer": None,
                    "step_logs": [],
                    "is_complete": False,
                    "step_count": 0,
                    # Initialize new memory fields
                    "context": {},
                    "memory_buffer": [],
                    "last_action": None,
                    "action_history": [],
                    "error_count": 0,
                    "success_count": 0,
                }
                logger.info(f"Initial state: {initial_state}")

                # Use stream to get interrupt information
                logger.info("Starting graph stream for initial question")
                for chunk in self.graph.stream(initial_state, config):
                    logger.debug(f"Received chunk: {chunk}")

                    if isinstance(chunk, dict):
                        if "__interrupt__" in chunk:
                            logger.info("Detected interrupt in stream")
                            logger.info(f"Interrupt details: {chunk['__interrupt__']}")

                            # If we hit an interrupt, resume with 'c'
                            logger.info("Resuming with 'c' command")
                            for result in self.graph.stream(
                                Command(resume="c"), config
                            ):
                                logger.debug(f"Received resume result: {result}")
                                if isinstance(result, dict):
                                    answer = self._extract_answer(result)
                                    if answer:
                                        self.last_state = result
                                        return answer
                        else:
                            answer = self._extract_answer(chunk)
                            if answer:
                                self.last_state = chunk
                                return answer
                    else:
                        logger.debug(f"Skipping chunk without answer: {chunk}")
            else:
                # Resuming from interrupt
                logger.info(f"Resuming from interrupt with input: {input_data}")
                for result in self.graph.stream(input_data, config):
                    logger.debug(f"Received resume result: {result}")
                    if isinstance(result, dict):
                        answer = self._extract_answer(result)
                        if answer:
                            self.last_state = result
                            return answer
                    else:
                        logger.debug(f"Skipping result without answer: {result}")

            # If we get here, we didn't find an answer
            logger.warning("No answer generated from stream")
            return "No answer generated"

        except Exception as e:
            logger.error(f"Error processing input: {str(e)}")
            raise