File size: 4,223 Bytes
81d00fe
73b637f
81d00fe
73b637f
81d00fe
 
bc3bc22
 
81d00fe
 
 
 
 
bc3bc22
 
 
81d00fe
 
 
 
 
 
 
 
bc3bc22
 
 
73b637f
 
bc3bc22
 
 
81d00fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc3bc22
81d00fe
bc3bc22
 
 
81d00fe
 
73b637f
 
 
 
 
 
 
 
a6295a0
73b637f
 
 
 
dfc37bb
81d00fe
 
73b637f
 
 
81d00fe
 
 
 
41df8d0
 
 
 
 
 
81d00fe
41df8d0
81d00fe
 
41df8d0
 
 
 
 
81d00fe
 
41df8d0
 
 
 
 
 
 
81d00fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
from typing import Callable, List, Optional, TypedDict
from langgraph.graph import StateGraph, END
from smolagents import CodeAgent, ToolCallingAgent, LiteLLMModel
from tools import tools
import yaml
import os
import litellm

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Configure LiteLLM to drop unsupported parameters
litellm.drop_params = True

# Define the state for our agent graph
class AgentState(TypedDict):
    messages: list
    question: str
    answer: str | None

class AgentNode:
    def __init__(self):
        # Load default prompt templates from local file
        current_dir = os.path.dirname(os.path.abspath(__file__))
        prompts_dir = os.path.join(current_dir, "prompts")
        # yaml_path = os.path.join(prompts_dir, "toolcalling_agent.yaml")
        yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
        
        with open(yaml_path, 'r') as f:
            prompt_templates = yaml.safe_load(f)
        
        # Log the default system prompt
        logger.info("Default system prompt:")
        logger.info("-" * 80)
        logger.info(prompt_templates["system_prompt"])
        logger.info("-" * 80)
        
#         # Define our custom system prompt
#         custom_system_prompt = "..."

#         # Update the system prompt in the loaded templates
#         prompt_templates["system_prompt"] = custom_system_prompt
        
        # Log our custom system prompt
        # logger.info("Custom system prompt:")
        # logger.info("-" * 80)
        # logger.info(custom_system_prompt)
        # logger.info("-" * 80)

        # Initialize the model and agent
        self.model = LiteLLMModel(
            api_base="http://localhost:11434",
            api_key=None,
            model_id="ollama/codellama",
        )
        
        # self.agent = ToolCallingAgent(
        #     max_steps=1,
        #     model=self.model,
        #     prompt_templates=prompt_templates,
        #     tools=tools
        # )

        step_callbacks: Optional[List[Callable]] = [
            lambda step: logger.info(f"Step {step.step_number} completed: {step.action}")
        ]

        self.agent = CodeAgent(
            add_base_tools=True,
            max_steps=1,
            model=self.model,
            prompt_templates=prompt_templates,
            step_callbacks=step_callbacks,
            tools=tools,
            verbosity_level=logging.DEBUG
        )

    def __call__(self, state: AgentState) -> AgentState:
        try:
            # Log the current state before processing
            logger.info("Current state before processing:")
            logger.info(f"Messages: {state['messages']}")
            logger.info(f"Question: {state['question']}")
            logger.info(f"Answer: {state['answer']}")
            
            # Process the question through the agent
            logger.info("Calling agent.run()...")
            result = self.agent.run(state["question"])
            
            # Log the result details
            logger.info("Agent run completed:")
            logger.info(f"Result type: {type(result)}")
            logger.info(f"Result value: {result}")
            
            # Update the state with the answer
            state["answer"] = result
            
            # Log the updated state
            logger.info("Updated state after processing:")
            logger.info(f"Messages: {state['messages']}")
            logger.info(f"Question: {state['question']}")
            logger.info(f"Answer: {state['answer']}")
            
            return state
            
        except Exception as e:
            logger.error(f"Error in agent node: {str(e)}", exc_info=True)
            state["answer"] = f"Error: {str(e)}"
            return state

def build_agent_graph():
    # Create the graph
    graph = StateGraph(AgentState)
    
    # Add the agent node
    graph.add_node("agent", AgentNode())
    
    # Add edges
    graph.add_edge("agent", END)
    
    # Set the entry point
    graph.set_entry_point("agent")
    
    # Compile the graph
    return graph.compile()

# Create an instance of the compiled graph
agent_graph = build_agent_graph()