Spaces:
Sleeping
Sleeping
File size: 6,600 Bytes
1eb838c b719109 1094f14 1eb838c b688f8a 1eb838c 5c095fc 1eb838c 1094f14 1eb838c 1094f14 1eb838c 1094f14 1eb838c 1094f14 1eb838c 1094f14 1eb838c 1094f14 1eb838c 1094f14 1eb838c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_gigachat.chat_models import GigaChat
import os
import re
from datetime import datetime
from pydantic import Field
from dotenv import load_dotenv
from typing import Any, Optional
import logging
from pathlib import Path
load_dotenv()
class AnswerGigaChat(GigaChat):
THRESHOLD_INPUT_SYMBOLS: int = 40000
THRESHOLD_COST: float = 5000.0
ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n"
COST_PER_INPUT_TOKEN: float = 0.0
COST_PER_OUTPUT_TOKEN: float = 1.95e-3
LOG_FILE_PATH: str = os.path.join("/tmp", "gigachat.log")
logger: Any = Field(default=None)
def __init__(self):
super().__init__(credentials="YWZhMjk1OTktNGY3My00ZTNkLTliZDMtMDE2MzU5MzcxNjAyOjc1ZjU4OTJjLTg3MDctNGVjMi04YTVhLTY2NGFlZTQ3MzQ4NQ==", # os.getenv("GIGACHAT_CREDENTIALS"),
verify_ssl_certs=False,
model="GigaChat-Max",
scope="GIGACHAT_API_PERS")
self.logger = self._setup_logger(self.LOG_FILE_PATH)
def _setup_logger(self, log_file: str) -> logging.Logger:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
if not logger.handlers:
log_file = Path(log_file)
log_file.parent.mkdir(parents=True, exist_ok=True) # Создать папку для логов
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - '
'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - '
'cost=%(cost).5f - execution_time=%(execution_time)s - '
'status=%(status)s'
)
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any:
# Извлекаем сообщения из входных данных
if isinstance(input, list):
# Если input - это список сообщений (например, [SystemMessage, HumanMessage])
messages = input
elif isinstance(input, dict) and "messages" in input:
# Если input - это dict с ключом "messages" (формат LangGraph)
messages = [
SystemMessage(content=msg["content"]) if msg["role"] == "system"
else HumanMessage(content=msg["content"])
for msg in input["messages"]
if msg["role"] in ["system", "user"]
]
elif hasattr(input, "messages"):
# Если input - это объект с атрибутом messages
messages = input.messages
else:
# Попробуем интерпретировать input как одно сообщение
messages = [HumanMessage(content=str(input))]
# Проверяем длину ввода
system_content = ""
user_content = ""
for msg in messages:
if isinstance(msg, SystemMessage):
system_content += msg.content
elif isinstance(msg, HumanMessage):
user_content += msg.content
if not self._check_input_length(system_content, user_content):
raise ValueError("Too long query")
# Вызываем родительский метод invoke
response = super().invoke(messages, config=config, **kwargs)
# Логируем информацию о запросе
num_input_tokens = response.usage_metadata["input_tokens"]
num_output_tokens = response.usage_metadata["output_tokens"]
cost = self._calculate_response_cost(num_input_tokens, num_output_tokens)
self.logger.info(
"got answer",
extra={
"input_tokens": num_input_tokens,
"output_tokens": num_output_tokens,
"cost": cost,
"execution_time": str(datetime.now()),
"status": "success"
}
)
# Проверяем общую стоимость
total_cost = self._calculate_total_cost()
if total_cost >= self.THRESHOLD_COST:
error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost)
print(error_message)
response.content = error_message + response.content
# Возвращаем ответ в соответствующем формате
if isinstance(input, dict):
return {**input, "messages": [{"role": "assistant", "content": response.content}]}
return response
def _check_input_length(self, system_message: str, user_message: str) -> bool:
return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS
def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float:
return num_input_tokens * self.COST_PER_INPUT_TOKEN + \
num_output_tokens * self.COST_PER_OUTPUT_TOKEN
def _calculate_total_cost(self,
start_date: str = '2025-06-01',
end_date: str = str(datetime.now().date())):
total_cost = 0.0
start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
# Регулярное выражение для извлечения даты и cost из строки лога
log_pattern = re.compile(
r'^(?P<date>\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - '
r'input_tokens=\d+ - output_tokens=\d+ - '
r'cost=(?P<cost>\d+\.\d{2}) - '
r'execution_time=.*? - status=.*$'
)
with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file:
for line in file:
match = log_pattern.match(line)
if match:
log_date_str = match.group('date')
log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date()
cost = float(match.group('cost'))
if start_date <= log_date <= end_date:
total_cost += cost
return total_cost
|