File size: 6,600 Bytes
1eb838c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b719109
1094f14
1eb838c
 
 
b688f8a
 
1eb838c
5c095fc
1eb838c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1094f14
1eb838c
 
1094f14
1eb838c
 
 
1094f14
1eb838c
 
 
 
 
 
1094f14
1eb838c
 
 
 
 
 
 
1094f14
1eb838c
 
 
 
 
 
 
1094f14
1eb838c
 
1094f14
1eb838c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_gigachat.chat_models import GigaChat
import os
import re
from datetime import datetime
from pydantic import Field
from dotenv import load_dotenv
from typing import Any, Optional
import logging
from pathlib import Path
load_dotenv()

class AnswerGigaChat(GigaChat):
    THRESHOLD_INPUT_SYMBOLS: int = 40000
    THRESHOLD_COST: float = 5000.0
    ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n"
    COST_PER_INPUT_TOKEN: float = 0.0
    COST_PER_OUTPUT_TOKEN: float = 1.95e-3
    LOG_FILE_PATH: str = os.path.join("/tmp", "gigachat.log")
    
    logger: Any = Field(default=None)

    def __init__(self):
        super().__init__(credentials="YWZhMjk1OTktNGY3My00ZTNkLTliZDMtMDE2MzU5MzcxNjAyOjc1ZjU4OTJjLTg3MDctNGVjMi04YTVhLTY2NGFlZTQ3MzQ4NQ==", # os.getenv("GIGACHAT_CREDENTIALS"),
                        verify_ssl_certs=False,
                        model="GigaChat-Max",
                        scope="GIGACHAT_API_PERS")
        self.logger = self._setup_logger(self.LOG_FILE_PATH)

    def _setup_logger(self, log_file: str) -> logging.Logger:
        logger = logging.getLogger(__name__)
        logger.setLevel(logging.INFO)

        if not logger.handlers:
            log_file = Path(log_file)
            log_file.parent.mkdir(parents=True, exist_ok=True)  # Создать папку для логов

            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - '
                'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - '
                'cost=%(cost).5f - execution_time=%(execution_time)s - '
                'status=%(status)s'
            )

            file_handler = logging.FileHandler(log_file)
            file_handler.setFormatter(formatter)
            logger.addHandler(file_handler)

        return logger

    def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any:
        # Извлекаем сообщения из входных данных
        if isinstance(input, list):
            # Если input - это список сообщений (например, [SystemMessage, HumanMessage])
            messages = input
        elif isinstance(input, dict) and "messages" in input:
            # Если input - это dict с ключом "messages" (формат LangGraph)
            messages = [
                SystemMessage(content=msg["content"]) if msg["role"] == "system"
                else HumanMessage(content=msg["content"])
                for msg in input["messages"]
                if msg["role"] in ["system", "user"]
            ]
        elif hasattr(input, "messages"):
            # Если input - это объект с атрибутом messages
            messages = input.messages
        else:
            # Попробуем интерпретировать input как одно сообщение
            messages = [HumanMessage(content=str(input))]

        # Проверяем длину ввода
        system_content = ""
        user_content = ""
        for msg in messages:
            if isinstance(msg, SystemMessage):
                system_content += msg.content
            elif isinstance(msg, HumanMessage):
                user_content += msg.content

        if not self._check_input_length(system_content, user_content):
            raise ValueError("Too long query")

        # Вызываем родительский метод invoke
        response = super().invoke(messages, config=config, **kwargs)

        # Логируем информацию о запросе
        num_input_tokens = response.usage_metadata["input_tokens"]
        num_output_tokens = response.usage_metadata["output_tokens"]
        cost = self._calculate_response_cost(num_input_tokens, num_output_tokens)
        self.logger.info(
            "got answer",
            extra={
                "input_tokens": num_input_tokens,
                "output_tokens": num_output_tokens,
                "cost": cost,
                "execution_time": str(datetime.now()),
                "status": "success"
            }
        )

        # Проверяем общую стоимость
        total_cost = self._calculate_total_cost()
        if total_cost >= self.THRESHOLD_COST:
            error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost)
            print(error_message)
            response.content = error_message + response.content

        # Возвращаем ответ в соответствующем формате
        if isinstance(input, dict):
            return {**input, "messages": [{"role": "assistant", "content": response.content}]}
        return response
    
    def _check_input_length(self, system_message: str, user_message: str) -> bool:
        return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS
    
    def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float:
        return num_input_tokens * self.COST_PER_INPUT_TOKEN + \
        num_output_tokens * self.COST_PER_OUTPUT_TOKEN
        
    def _calculate_total_cost(self, 
                              start_date: str = '2025-06-01', 
                              end_date: str = str(datetime.now().date())):
        total_cost = 0.0
        start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
        end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
        
        # Регулярное выражение для извлечения даты и cost из строки лога
        log_pattern = re.compile(
            r'^(?P<date>\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - '
            r'input_tokens=\d+ - output_tokens=\d+ - '
            r'cost=(?P<cost>\d+\.\d{2}) - '
            r'execution_time=.*? - status=.*$'
        )
        
        with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file:
            for line in file:
                match = log_pattern.match(line)
                if match:
                    log_date_str = match.group('date')
                    log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date()
                    cost = float(match.group('cost'))
                    
                    if start_date <= log_date <= end_date:
                        total_cost += cost
        
        return total_cost