Spaces:
Sleeping
Sleeping
from langchain_core.messages import HumanMessage, SystemMessage | |
from langchain_core.runnables.config import RunnableConfig | |
from langchain_gigachat.chat_models import GigaChat | |
import os | |
import re | |
from datetime import datetime | |
from pydantic import Field | |
from dotenv import load_dotenv | |
from typing import Any, Optional | |
import logging | |
from pathlib import Path | |
load_dotenv() | |
class AnswerGigaChat(GigaChat): | |
THRESHOLD_INPUT_SYMBOLS: int = 40000 | |
THRESHOLD_COST: float = 5000.0 | |
ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n" | |
COST_PER_INPUT_TOKEN: float = 0.0 | |
COST_PER_OUTPUT_TOKEN: float = 1.95e-3 | |
LOG_FILE_PATH: str = os.path.join("/tmp", "gigachat.log") | |
logger: Any = Field(default=None) | |
def __init__(self): | |
super().__init__(credentials="YWZhMjk1OTktNGY3My00ZTNkLTliZDMtMDE2MzU5MzcxNjAyOjc1ZjU4OTJjLTg3MDctNGVjMi04YTVhLTY2NGFlZTQ3MzQ4NQ==", # os.getenv("GIGACHAT_CREDENTIALS"), | |
verify_ssl_certs=False, | |
model="GigaChat-Max", | |
scope="GIGACHAT_API_PERS") | |
self.logger = self._setup_logger(self.LOG_FILE_PATH) | |
def _setup_logger(self, log_file: str) -> logging.Logger: | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
if not logger.handlers: | |
log_file = Path(log_file) | |
log_file.parent.mkdir(parents=True, exist_ok=True) # Создать папку для логов | |
formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - ' | |
'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - ' | |
'cost=%(cost).5f - execution_time=%(execution_time)s - ' | |
'status=%(status)s' | |
) | |
file_handler = logging.FileHandler(log_file) | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
return logger | |
def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any: | |
# Извлекаем сообщения из входных данных | |
if isinstance(input, list): | |
# Если input - это список сообщений (например, [SystemMessage, HumanMessage]) | |
messages = input | |
elif isinstance(input, dict) and "messages" in input: | |
# Если input - это dict с ключом "messages" (формат LangGraph) | |
messages = [ | |
SystemMessage(content=msg["content"]) if msg["role"] == "system" | |
else HumanMessage(content=msg["content"]) | |
for msg in input["messages"] | |
if msg["role"] in ["system", "user"] | |
] | |
elif hasattr(input, "messages"): | |
# Если input - это объект с атрибутом messages | |
messages = input.messages | |
else: | |
# Попробуем интерпретировать input как одно сообщение | |
messages = [HumanMessage(content=str(input))] | |
# Проверяем длину ввода | |
system_content = "" | |
user_content = "" | |
for msg in messages: | |
if isinstance(msg, SystemMessage): | |
system_content += msg.content | |
elif isinstance(msg, HumanMessage): | |
user_content += msg.content | |
if not self._check_input_length(system_content, user_content): | |
raise ValueError("Too long query") | |
# Вызываем родительский метод invoke | |
response = super().invoke(messages, config=config, **kwargs) | |
# Логируем информацию о запросе | |
num_input_tokens = response.usage_metadata["input_tokens"] | |
num_output_tokens = response.usage_metadata["output_tokens"] | |
cost = self._calculate_response_cost(num_input_tokens, num_output_tokens) | |
self.logger.info( | |
"got answer", | |
extra={ | |
"input_tokens": num_input_tokens, | |
"output_tokens": num_output_tokens, | |
"cost": cost, | |
"execution_time": str(datetime.now()), | |
"status": "success" | |
} | |
) | |
# Проверяем общую стоимость | |
total_cost = self._calculate_total_cost() | |
if total_cost >= self.THRESHOLD_COST: | |
error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost) | |
print(error_message) | |
response.content = error_message + response.content | |
# Возвращаем ответ в соответствующем формате | |
if isinstance(input, dict): | |
return {**input, "messages": [{"role": "assistant", "content": response.content}]} | |
return response | |
def _check_input_length(self, system_message: str, user_message: str) -> bool: | |
return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS | |
def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float: | |
return num_input_tokens * self.COST_PER_INPUT_TOKEN + \ | |
num_output_tokens * self.COST_PER_OUTPUT_TOKEN | |
def _calculate_total_cost(self, | |
start_date: str = '2025-06-01', | |
end_date: str = str(datetime.now().date())): | |
total_cost = 0.0 | |
start_date = datetime.strptime(start_date, '%Y-%m-%d').date() | |
end_date = datetime.strptime(end_date, '%Y-%m-%d').date() | |
# Регулярное выражение для извлечения даты и cost из строки лога | |
log_pattern = re.compile( | |
r'^(?P<date>\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - ' | |
r'input_tokens=\d+ - output_tokens=\d+ - ' | |
r'cost=(?P<cost>\d+\.\d{2}) - ' | |
r'execution_time=.*? - status=.*$' | |
) | |
with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file: | |
for line in file: | |
match = log_pattern.match(line) | |
if match: | |
log_date_str = match.group('date') | |
log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date() | |
cost = float(match.group('cost')) | |
if start_date <= log_date <= end_date: | |
total_cost += cost | |
return total_cost | |