testing-mvp / src /get_answer_gigachat.py
Boris
dump test
b688f8a
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_gigachat.chat_models import GigaChat
import os
import re
from datetime import datetime
from pydantic import Field
from dotenv import load_dotenv
from typing import Any, Optional
import logging
from pathlib import Path
load_dotenv()
class AnswerGigaChat(GigaChat):
THRESHOLD_INPUT_SYMBOLS: int = 40000
THRESHOLD_COST: float = 5000.0
ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n"
COST_PER_INPUT_TOKEN: float = 0.0
COST_PER_OUTPUT_TOKEN: float = 1.95e-3
LOG_FILE_PATH: str = os.path.join("/tmp", "gigachat.log")
logger: Any = Field(default=None)
def __init__(self):
super().__init__(credentials="YWZhMjk1OTktNGY3My00ZTNkLTliZDMtMDE2MzU5MzcxNjAyOjc1ZjU4OTJjLTg3MDctNGVjMi04YTVhLTY2NGFlZTQ3MzQ4NQ==", # os.getenv("GIGACHAT_CREDENTIALS"),
verify_ssl_certs=False,
model="GigaChat-Max",
scope="GIGACHAT_API_PERS")
self.logger = self._setup_logger(self.LOG_FILE_PATH)
def _setup_logger(self, log_file: str) -> logging.Logger:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if not logger.handlers:
log_file = Path(log_file)
log_file.parent.mkdir(parents=True, exist_ok=True) # Создать папку для логов
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - '
'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - '
'cost=%(cost).5f - execution_time=%(execution_time)s - '
'status=%(status)s'
)
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any:
# Извлекаем сообщения из входных данных
if isinstance(input, list):
# Если input - это список сообщений (например, [SystemMessage, HumanMessage])
messages = input
elif isinstance(input, dict) and "messages" in input:
# Если input - это dict с ключом "messages" (формат LangGraph)
messages = [
SystemMessage(content=msg["content"]) if msg["role"] == "system"
else HumanMessage(content=msg["content"])
for msg in input["messages"]
if msg["role"] in ["system", "user"]
]
elif hasattr(input, "messages"):
# Если input - это объект с атрибутом messages
messages = input.messages
else:
# Попробуем интерпретировать input как одно сообщение
messages = [HumanMessage(content=str(input))]
# Проверяем длину ввода
system_content = ""
user_content = ""
for msg in messages:
if isinstance(msg, SystemMessage):
system_content += msg.content
elif isinstance(msg, HumanMessage):
user_content += msg.content
if not self._check_input_length(system_content, user_content):
raise ValueError("Too long query")
# Вызываем родительский метод invoke
response = super().invoke(messages, config=config, **kwargs)
# Логируем информацию о запросе
num_input_tokens = response.usage_metadata["input_tokens"]
num_output_tokens = response.usage_metadata["output_tokens"]
cost = self._calculate_response_cost(num_input_tokens, num_output_tokens)
self.logger.info(
"got answer",
extra={
"input_tokens": num_input_tokens,
"output_tokens": num_output_tokens,
"cost": cost,
"execution_time": str(datetime.now()),
"status": "success"
}
)
# Проверяем общую стоимость
total_cost = self._calculate_total_cost()
if total_cost >= self.THRESHOLD_COST:
error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost)
print(error_message)
response.content = error_message + response.content
# Возвращаем ответ в соответствующем формате
if isinstance(input, dict):
return {**input, "messages": [{"role": "assistant", "content": response.content}]}
return response
def _check_input_length(self, system_message: str, user_message: str) -> bool:
return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS
def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float:
return num_input_tokens * self.COST_PER_INPUT_TOKEN + \
num_output_tokens * self.COST_PER_OUTPUT_TOKEN
def _calculate_total_cost(self,
start_date: str = '2025-06-01',
end_date: str = str(datetime.now().date())):
total_cost = 0.0
start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
# Регулярное выражение для извлечения даты и cost из строки лога
log_pattern = re.compile(
r'^(?P<date>\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - '
r'input_tokens=\d+ - output_tokens=\d+ - '
r'cost=(?P<cost>\d+\.\d{2}) - '
r'execution_time=.*? - status=.*$'
)
with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file:
for line in file:
match = log_pattern.match(line)
if match:
log_date_str = match.group('date')
log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date()
cost = float(match.group('cost'))
if start_date <= log_date <= end_date:
total_cost += cost
return total_cost