|
""" |
|
Финальный агент для Agent Challenge (LangGraph) |
|
""" |
|
|
|
import os |
|
import json |
|
import re |
|
import math |
|
import requests |
|
from typing import List, Dict, Any, Optional, TypedDict, Annotated, Literal, Union |
|
from datetime import datetime |
|
|
|
|
|
from langgraph.graph import StateGraph, END |
|
from langgraph.prebuilt import ToolNode, tools_condition |
|
|
|
|
|
from langchain_core.tools import tool |
|
|
|
|
|
|
|
|
|
HUGGINGFACE_TOKEN = os.environ.get("HUGGINGFACE_TOKEN") |
|
|
|
|
|
client = None |
|
try: |
|
from huggingface_hub import InferenceClient |
|
client = InferenceClient( |
|
model="mistralai/Mixtral-8x7B-Instruct-v0.1", |
|
token=HUGGINGFACE_TOKEN, |
|
timeout=120 |
|
) |
|
except ImportError: |
|
print("Ошибка: библиотека huggingface_hub не установлена. Установите: pip install huggingface_hub") |
|
except Exception as e: |
|
print(f"Ошибка инициализации InferenceClient: {e}. Проверьте токен и доступность модели.") |
|
|
|
|
|
|
|
@tool |
|
def calculator(expression: str) -> str: |
|
"""Выполняет математические вычисления. |
|
Пример входа: "(2 + 3) * 4 / 2" |
|
Возвращает результат вычисления или сообщение об ошибке. |
|
""" |
|
try: |
|
|
|
allowed_names = {k: v for k, v in math.__dict__.items() if not k.startswith("__")} |
|
allowed_names["abs"] = abs |
|
allowed_names["round"] = round |
|
allowed_names["max"] = max |
|
allowed_names["min"] = min |
|
|
|
|
|
safe_expression = re.sub(r"[^0-9\.\+\-\*\/\(\)\s]|\b(import|exec|eval|open|lambda|\_\_)\b", "", expression) |
|
|
|
if safe_expression != expression: |
|
return "Ошибка: Обнаружены недопустимые символы в выражении." |
|
|
|
result = eval(safe_expression, {"__builtins__": {}}, allowed_names) |
|
return f"Результат: {result}" |
|
except Exception as e: |
|
return f"Ошибка в вычислении: {str(e)}" |
|
|
|
@tool |
|
def web_search(query: str) -> str: |
|
"""Выполняет поиск в интернете по заданному запросу. |
|
Пример входа: "прогноз погоды в Париже" |
|
Возвращает результаты поиска (симуляция). |
|
Для реального использования замените на API поисковой системы (например, Tavily, Serper). |
|
""" |
|
print(f"--- Выполняется поиск: {query} ---") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "погода" in query.lower(): |
|
return json.dumps([{"title": "Прогноз погоды", "content": "В городе, который вы ищете, сегодня солнечно, +25C."}]) |
|
elif "hugging face" in query.lower(): |
|
return json.dumps([{"title": "Hugging Face", "content": "Hugging Face - это платформа и сообщество для работы с моделями машинного обучения."}]) |
|
elif "langgraph" in query.lower(): |
|
return json.dumps([{"title": "LangGraph", "content": "LangGraph - это библиотека для создания агентов с состоянием на основе LangChain."}]) |
|
else: |
|
return json.dumps([{"title": "Результат поиска", "content": f"По вашему запросу '{query}' найдена общая информация."}]) |
|
|
|
except requests.exceptions.RequestException as e: |
|
return f"Ошибка сети при поиске: {e}" |
|
except Exception as e: |
|
return f"Ошибка при выполнении поиска: {str(e)}" |
|
|
|
@tool |
|
def get_current_datetime() -> str: |
|
"""Возвращает текущую дату и время. |
|
Не требует входных данных. |
|
""" |
|
now = datetime.now() |
|
return f"Текущая дата и время: {now.strftime('%Y-%m-%d %H:%M:%S')}" |
|
|
|
|
|
tools_list = [calculator, web_search, get_current_datetime] |
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
"""Состояние агента LangGraph.""" |
|
messages: List[Union[Dict[str, str], Any]] |
|
|
|
|
|
def agent_node(state: AgentState) -> Dict[str, Any]: |
|
"""Вызывает LLM для определения следующего шага (вызов инструмента или финальный ответ).""" |
|
if client is None: |
|
raise ValueError("Клиент Hugging Face не инициализирован.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt_messages = [] |
|
for msg in state["messages"]: |
|
if isinstance(msg, dict) and "role" in msg and "content" in msg: |
|
prompt_messages.append(msg) |
|
elif hasattr(msg, "type") and msg.type == "human": |
|
prompt_messages.append({"role": "user", "content": msg.content}) |
|
elif hasattr(msg, "type") and msg.type == "ai": |
|
|
|
content = msg.content |
|
if hasattr(msg, "tool_calls") and msg.tool_calls: |
|
tool_calls_str = json.dumps([tc["name"] for tc in msg.tool_calls]) |
|
content += f"\n(Вызов инструментов: {tool_calls_str})" |
|
prompt_messages.append({"role": "assistant", "content": content}) |
|
elif hasattr(msg, "type") and msg.type == "tool": |
|
prompt_messages.append({ |
|
"role": "tool", |
|
"content": f"Результат инструмента {msg.name}: {msg.content}", |
|
"name": msg.name |
|
}) |
|
else: |
|
|
|
print(f"Пропущено сообщение неизвестного типа: {type(msg)}") |
|
|
|
print("--- Промпт для LLM ---") |
|
|
|
print("...") |
|
|
|
|
|
response = client.chat_completion( |
|
messages=prompt_messages, |
|
tool_choice="auto", |
|
tools=[tool.get_input_schema().schema() for tool in tools_list], |
|
temperature=0.1, |
|
max_tokens=1500 |
|
) |
|
|
|
ai_message = response["choices"][0]["message"] |
|
|
|
print("--- Ответ LLM ---") |
|
print(ai_message) |
|
|
|
|
|
return {"messages": [ai_message]} |
|
|
|
|
|
workflow = StateGraph(AgentState) |
|
|
|
|
|
workflow.add_node("agent", agent_node) |
|
workflow.add_node("tools", ToolNode(tools_list)) |
|
|
|
|
|
workflow.set_entry_point("agent") |
|
|
|
|
|
workflow.add_conditional_edges( |
|
"agent", |
|
|
|
tools_condition, |
|
|
|
{ |
|
"tools": "tools", |
|
END: END, |
|
}, |
|
) |
|
|
|
|
|
workflow.add_edge("tools", "agent") |
|
|
|
|
|
agent_graph = workflow.compile() |
|
|
|
|
|
def run_final_agent(query: str) -> str: |
|
"""Запускает финального агента LangGraph для ответа на вопрос.""" |
|
if agent_graph is None: |
|
return "Ошибка: Граф агента не скомпилирован." |
|
|
|
|
|
initial_state = {"messages": [{"role": "user", "content": query}]} |
|
|
|
final_state = None |
|
try: |
|
|
|
final_state = agent_graph.invoke(initial_state, {"recursion_limit": 10}) |
|
|
|
except Exception as e: |
|
print(f"Ошибка выполнения графа: {e}") |
|
return f"Произошла ошибка во время обработки запроса: {e}" |
|
|
|
|
|
if final_state and "messages" in final_state and final_state["messages"]: |
|
|
|
for msg in reversed(final_state["messages"]): |
|
|
|
is_ai = (isinstance(msg, dict) and msg.get("role") == "assistant") or (hasattr(msg, "type") and msg.type == "ai") |
|
has_tool_calls = (isinstance(msg, dict) and msg.get("tool_calls")) or (hasattr(msg, "tool_calls") and msg.tool_calls) |
|
|
|
if is_ai and not has_tool_calls: |
|
return msg.get("content") if isinstance(msg, dict) else msg.content |
|
|
|
|
|
last_ai_msg = next((m for m in reversed(final_state["messages"]) if (isinstance(m, dict) and m.get("role") == "assistant") or (hasattr(m, "type") and m.type == "ai")), None) |
|
if last_ai_msg: |
|
return last_ai_msg.get("content") if isinstance(last_ai_msg, dict) else last_ai_msg.content |
|
|
|
return "Не удалось получить финальный ответ от агента." |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse |
|
import uvicorn |
|
|
|
app = FastAPI(title="Agent Challenge - Финальный агент") |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Агент готов к работе! Отправьте POST запрос на /agent с JSON {'query': 'ваш вопрос'}"} |
|
|
|
@app.post("/agent") |
|
async def agent_endpoint(request: Request): |
|
try: |
|
data = await request.json() |
|
query = data.get("query", "") |
|
|
|
if not query: |
|
return JSONResponse( |
|
status_code=400, |
|
content={"error": "Запрос должен содержать поле 'query'"} |
|
) |
|
|
|
response = run_final_agent(query) |
|
return {"answer": response} |
|
|
|
except Exception as e: |
|
return JSONResponse( |
|
status_code=500, |
|
content={"error": f"Ошибка обработки запроса: {str(e)}"} |
|
) |
|
|
|
|
|
if os.environ.get("RUN_LOCAL") == "true": |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|