from typing import Annotated, Type from langgraph.graph import StateGraph from langchain_core.messages import HumanMessage, ToolMessage from langgraph.graph.message import add_messages from typing_extensions import TypedDict from base_service import BaseService from get_answer_gigachat import AnswerGigaChat import logging import os logger = logging.getLogger(__name__) class State(TypedDict): messages: Annotated[list, add_messages] class BaseGraph: def __init__(self, service: Type[BaseService]): self.service = service self.tools_dict = {tool.name: tool for tool in service.tools} self.llm_with_tools = AnswerGigaChat().bind_tools(service.tools) self.messages = service.get_initial_messages() self.graph = self._build_graph() logger.info(f"BaseGraph with service {service} was built") def rebuild_with_new_service(self, service: Type[BaseService]): self.service = service self.tools_dict = {tool.name: tool for tool in service.tools} self.llm_with_tools = AnswerGigaChat().bind_tools(service.tools) self.graph = self._build_graph() self.messages = service.get_messages_from_redirect(self.messages) logger.info(f"BaseGraph was rebuilt with service {service}") def _agent_node(self, state: State): try: logger.info("Starting agent_node") messages = state["messages"] response = self.llm_with_tools.invoke(messages) response.content = self._clean_response(response.content) return {"messages": [response]} except Exception as e: logger.error(f"Error in agent_node: {str(e)}", exc_info=True) raise def _tool_node(self, state: State): try: logger.info("Starting tool_node") last_message = state["messages"][-1] tool_calls = last_message.tool_calls results = [] for call in tool_calls: tool_name = call["name"] logger.info(f"Running tool {tool_name}") args = call["args"] tool = self.tools_dict.get(tool_name) if not tool: raise ValueError(f"Tool {tool_name} not found") tool_response = tool.invoke(args) if tool_name == "make_redirect": self.rebuild_with_new_service(tool_response) results.append(str(tool_response)) return {"messages": [ToolMessage(content=", ".join(results), tool_call_id=call["id"])]} except Exception as e: logger.error(f"Error in tool_node: {str(e)}", exc_info=True) raise def _should_continue(self, state: State): try: logger.info("Checking should continue") last_message = state["messages"][-1] return "tool" if "function_call" in last_message.additional_kwargs else "end" except Exception as e: logger.error(f"Error in should_continue: {str(e)}", exc_info=True) raise def _build_graph(self): try: logger.info("Building graph") graph_builder = StateGraph(State) graph_builder.add_node("agent", self._agent_node) graph_builder.add_node("tool", self._tool_node) graph_builder.add_conditional_edges( "agent", self._should_continue, {"tool": "tool", "end": "__end__"} ) graph_builder.add_edge("tool", "agent") graph_builder.set_entry_point("agent") return graph_builder.compile() except Exception as e: logger.error(f"Error building graph: {str(e)}", exc_info=True) raise def _clean_response(self, content: str) -> str: content = content.replace("", "").replace("", "") if "" in content: content = content.split("")[1] if "" in content: content = content.split("")[-1] return content def invoke(self, user_input): try: self.messages.append(HumanMessage(content=user_input)) result = self.graph.invoke({"messages": self.messages}) self.messages = result["messages"] return result except Exception as e: logger.error(f"Error invoking graph: {str(e)}", exc_info=True) raise