Spaces:
Sleeping
Sleeping
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_core.runnables.config import RunnableConfig | |
| from langchain_gigachat.chat_models import GigaChat | |
| import os | |
| import re | |
| from datetime import datetime | |
| from pydantic import Field | |
| from dotenv import load_dotenv | |
| from typing import Any, Optional | |
| import logging | |
| from pathlib import Path | |
| load_dotenv() | |
| class AnswerGigaChat(GigaChat): | |
| THRESHOLD_INPUT_SYMBOLS: int = 40000 | |
| THRESHOLD_COST: float = 5000.0 | |
| ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n" | |
| COST_PER_INPUT_TOKEN: float = 0.0 | |
| COST_PER_OUTPUT_TOKEN: float = 1.95e-3 | |
| LOG_FILE_PATH: str = os.path.join("/tmp", "gigachat.log") | |
| logger: Any = Field(default=None) | |
| def __init__(self): | |
| super().__init__(credentials=os.getenv("GIGACHAT_CREDENTIALS"), | |
| verify_ssl_certs=False, | |
| model="GigaChat-Max", | |
| scope="GIGACHAT_API_CORP") | |
| self.logger = self._setup_logger(self.LOG_FILE_PATH) | |
| def _setup_logger(self, log_file: str) -> logging.Logger: | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| if not logger.handlers: | |
| log_file = Path(log_file) | |
| log_file.parent.mkdir(parents=True, exist_ok=True) # Создать папку для логов | |
| formatter = logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - ' | |
| 'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - ' | |
| 'cost=%(cost).5f - execution_time=%(execution_time)s - ' | |
| 'status=%(status)s' | |
| ) | |
| file_handler = logging.FileHandler(log_file) | |
| file_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| return logger | |
| def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any: | |
| # Извлекаем сообщения из входных данных | |
| if isinstance(input, list): | |
| # Если input - это список сообщений (например, [SystemMessage, HumanMessage]) | |
| messages = input | |
| elif isinstance(input, dict) and "messages" in input: | |
| # Если input - это dict с ключом "messages" (формат LangGraph) | |
| messages = [ | |
| SystemMessage(content=msg["content"]) if msg["role"] == "system" | |
| else HumanMessage(content=msg["content"]) | |
| for msg in input["messages"] | |
| if msg["role"] in ["system", "user"] | |
| ] | |
| elif hasattr(input, "messages"): | |
| # Если input - это объект с атрибутом messages | |
| messages = input.messages | |
| else: | |
| # Попробуем интерпретировать input как одно сообщение | |
| messages = [HumanMessage(content=str(input))] | |
| # Проверяем длину ввода | |
| system_content = "" | |
| user_content = "" | |
| for msg in messages: | |
| if isinstance(msg, SystemMessage): | |
| system_content += msg.content | |
| elif isinstance(msg, HumanMessage): | |
| user_content += msg.content | |
| if not self._check_input_length(system_content, user_content): | |
| raise ValueError("Too long query") | |
| # Вызываем родительский метод invoke | |
| response = super().invoke(messages, config=config, **kwargs) | |
| # Логируем информацию о запросе | |
| num_input_tokens = response.usage_metadata["input_tokens"] | |
| num_output_tokens = response.usage_metadata["output_tokens"] | |
| cost = self._calculate_response_cost(num_input_tokens, num_output_tokens) | |
| self.logger.info( | |
| "got answer", | |
| extra={ | |
| "input_tokens": num_input_tokens, | |
| "output_tokens": num_output_tokens, | |
| "cost": cost, | |
| "execution_time": str(datetime.now()), | |
| "status": "success" | |
| } | |
| ) | |
| # Проверяем общую стоимость | |
| total_cost = self._calculate_total_cost() | |
| if total_cost >= self.THRESHOLD_COST: | |
| error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost) | |
| print(error_message) | |
| response.content = error_message + response.content | |
| # Возвращаем ответ в соответствующем формате | |
| if isinstance(input, dict): | |
| return {**input, "messages": [{"role": "assistant", "content": response.content}]} | |
| return response | |
| def _check_input_length(self, system_message: str, user_message: str) -> bool: | |
| return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS | |
| def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float: | |
| return num_input_tokens * self.COST_PER_INPUT_TOKEN + \ | |
| num_output_tokens * self.COST_PER_OUTPUT_TOKEN | |
| def _calculate_total_cost(self, | |
| start_date: str = '2025-06-01', | |
| end_date: str = str(datetime.now().date())): | |
| total_cost = 0.0 | |
| start_date = datetime.strptime(start_date, '%Y-%m-%d').date() | |
| end_date = datetime.strptime(end_date, '%Y-%m-%d').date() | |
| # Регулярное выражение для извлечения даты и cost из строки лога | |
| log_pattern = re.compile( | |
| r'^(?P<date>\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - ' | |
| r'input_tokens=\d+ - output_tokens=\d+ - ' | |
| r'cost=(?P<cost>\d+\.\d{2}) - ' | |
| r'execution_time=.*? - status=.*$' | |
| ) | |
| with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file: | |
| for line in file: | |
| match = log_pattern.match(line) | |
| if match: | |
| log_date_str = match.group('date') | |
| log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date() | |
| cost = float(match.group('cost')) | |
| if start_date <= log_date <= end_date: | |
| total_cost += cost | |
| return total_cost | |