from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables.config import RunnableConfig from langchain_gigachat.chat_models import GigaChat import os import re from datetime import datetime from pydantic import Field from dotenv import load_dotenv from typing import Any, Optional import logging from pathlib import Path load_dotenv() class AnswerGigaChat(GigaChat): THRESHOLD_INPUT_SYMBOLS: int = 40000 THRESHOLD_COST: float = 5000.0 ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n" COST_PER_INPUT_TOKEN: float = 0.0 COST_PER_OUTPUT_TOKEN: float = 1.95e-3 LOG_FILE_PATH: str = os.path.join("/tmp", "gigachat.log") logger: Any = Field(default=None) def __init__(self): super().__init__(credentials="YWZhMjk1OTktNGY3My00ZTNkLTliZDMtMDE2MzU5MzcxNjAyOjc1ZjU4OTJjLTg3MDctNGVjMi04YTVhLTY2NGFlZTQ3MzQ4NQ==", # os.getenv("GIGACHAT_CREDENTIALS"), verify_ssl_certs=False, model="GigaChat-Max", scope="GIGACHAT_API_PERS") self.logger = self._setup_logger(self.LOG_FILE_PATH) def _setup_logger(self, log_file: str) -> logging.Logger: logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) if not logger.handlers: log_file = Path(log_file) log_file.parent.mkdir(parents=True, exist_ok=True) # Создать папку для логов formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - ' 'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - ' 'cost=%(cost).5f - execution_time=%(execution_time)s - ' 'status=%(status)s' ) file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any: # Извлекаем сообщения из входных данных if isinstance(input, list): # Если input - это список сообщений (например, [SystemMessage, HumanMessage]) messages = input elif isinstance(input, dict) and "messages" in input: # Если input - это dict с ключом "messages" (формат LangGraph) messages = [ SystemMessage(content=msg["content"]) if msg["role"] == "system" else HumanMessage(content=msg["content"]) for msg in input["messages"] if msg["role"] in ["system", "user"] ] elif hasattr(input, "messages"): # Если input - это объект с атрибутом messages messages = input.messages else: # Попробуем интерпретировать input как одно сообщение messages = [HumanMessage(content=str(input))] # Проверяем длину ввода system_content = "" user_content = "" for msg in messages: if isinstance(msg, SystemMessage): system_content += msg.content elif isinstance(msg, HumanMessage): user_content += msg.content if not self._check_input_length(system_content, user_content): raise ValueError("Too long query") # Вызываем родительский метод invoke response = super().invoke(messages, config=config, **kwargs) # Логируем информацию о запросе num_input_tokens = response.usage_metadata["input_tokens"] num_output_tokens = response.usage_metadata["output_tokens"] cost = self._calculate_response_cost(num_input_tokens, num_output_tokens) self.logger.info( "got answer", extra={ "input_tokens": num_input_tokens, "output_tokens": num_output_tokens, "cost": cost, "execution_time": str(datetime.now()), "status": "success" } ) # Проверяем общую стоимость total_cost = self._calculate_total_cost() if total_cost >= self.THRESHOLD_COST: error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost) print(error_message) response.content = error_message + response.content # Возвращаем ответ в соответствующем формате if isinstance(input, dict): return {**input, "messages": [{"role": "assistant", "content": response.content}]} return response def _check_input_length(self, system_message: str, user_message: str) -> bool: return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float: return num_input_tokens * self.COST_PER_INPUT_TOKEN + \ num_output_tokens * self.COST_PER_OUTPUT_TOKEN def _calculate_total_cost(self, start_date: str = '2025-06-01', end_date: str = str(datetime.now().date())): total_cost = 0.0 start_date = datetime.strptime(start_date, '%Y-%m-%d').date() end_date = datetime.strptime(end_date, '%Y-%m-%d').date() # Регулярное выражение для извлечения даты и cost из строки лога log_pattern = re.compile( r'^(?P\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - ' r'input_tokens=\d+ - output_tokens=\d+ - ' r'cost=(?P\d+\.\d{2}) - ' r'execution_time=.*? - status=.*$' ) with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file: for line in file: match = log_pattern.match(line) if match: log_date_str = match.group('date') log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date() cost = float(match.group('cost')) if start_date <= log_date <= end_date: total_cost += cost return total_cost