Boris commited on
Commit
7879ba1
·
1 Parent(s): 149db8b

just test 2

Browse files
Files changed (1) hide show
  1. src/get_answer_gigachat.py +143 -139
src/get_answer_gigachat.py CHANGED
@@ -1,147 +1,151 @@
1
- from langchain_core.messages import HumanMessage, SystemMessage
2
- from langchain_core.runnables.config import RunnableConfig
3
- from langchain_gigachat.chat_models import GigaChat
4
- import os
5
- import re
6
- from datetime import datetime
7
- from pydantic import Field
8
- from dotenv import load_dotenv
9
- from typing import Any, Optional
10
- import logging
11
- from pathlib import Path
12
- load_dotenv()
13
-
14
- class AnswerGigaChat(GigaChat):
15
- THRESHOLD_INPUT_SYMBOLS: int = 40000
16
- THRESHOLD_COST: float = 5000.0
17
- ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n"
18
- COST_PER_INPUT_TOKEN: float = 0.0
19
- COST_PER_OUTPUT_TOKEN: float = 1.95e-3
20
- LOG_FILE_PATH: str = os.path.join(os.path.dirname(__file__), "gigachat.log")
21
 
22
- logger: Any = Field(default=None)
23
-
24
- def __init__(self):
25
- super().__init__(credentials=os.getenv("GIGACHAT_CREDENTIALS"),
26
- verify_ssl_certs=False,
27
- model="GigaChat-Max",
28
- scope="GIGACHAT_API_CORP")
29
- self.logger = self._setup_logger(self.LOG_FILE_PATH)
30
-
31
- def _setup_logger(self, log_file: str) -> logging.Logger:
32
- logger = logging.getLogger(__name__)
33
- logger.setLevel(logging.INFO)
34
-
35
- if not logger.handlers:
36
- log_file = Path(log_file)
37
- log_file.parent.mkdir(parents=True, exist_ok=True) # Создать папку для логов
38
-
39
- formatter = logging.Formatter(
40
- '%(asctime)s - %(name)s - %(levelname)s - '
41
- 'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - '
42
- 'cost=%(cost).5f - execution_time=%(execution_time)s - '
43
- 'status=%(status)s'
44
- )
45
-
46
- file_handler = logging.FileHandler(log_file)
47
- file_handler.setFormatter(formatter)
48
- logger.addHandler(file_handler)
49
-
50
- return logger
51
-
52
- def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any:
53
- # Извлекаем сообщения из входных данных
54
- if isinstance(input, list):
55
- # Если input - это список сообщений (например, [SystemMessage, HumanMessage])
56
- messages = input
57
- elif isinstance(input, dict) and "messages" in input:
58
- # Если input - это dict с ключом "messages" (формат LangGraph)
59
- messages = [
60
- SystemMessage(content=msg["content"]) if msg["role"] == "system"
61
- else HumanMessage(content=msg["content"])
62
- for msg in input["messages"]
63
- if msg["role"] in ["system", "user"]
64
- ]
65
- elif hasattr(input, "messages"):
66
- # Если input - это объект с атрибутом messages
67
- messages = input.messages
68
- else:
69
- # Попробуем интерпретировать input как одно сообщение
70
- messages = [HumanMessage(content=str(input))]
71
-
72
- # Проверяем длину ввода
73
- system_content = ""
74
- user_content = ""
75
- for msg in messages:
76
- if isinstance(msg, SystemMessage):
77
- system_content += msg.content
78
- elif isinstance(msg, HumanMessage):
79
- user_content += msg.content
80
-
81
- if not self._check_input_length(system_content, user_content):
82
- raise ValueError("Too long query")
83
-
84
- # Вызываем родительский метод invoke
85
- response = super().invoke(messages, config=config, **kwargs)
86
-
87
- # Логируем информацию о запросе
88
- num_input_tokens = response.usage_metadata["input_tokens"]
89
- num_output_tokens = response.usage_metadata["output_tokens"]
90
- cost = self._calculate_response_cost(num_input_tokens, num_output_tokens)
91
- self.logger.info(
92
- "got answer",
93
- extra={
94
- "input_tokens": num_input_tokens,
95
- "output_tokens": num_output_tokens,
96
- "cost": cost,
97
- "execution_time": str(datetime.now()),
98
- "status": "success"
99
- }
100
- )
101
-
102
- # Проверяем общую стоимость
103
- total_cost = self._calculate_total_cost()
104
- if total_cost >= self.THRESHOLD_COST:
105
- error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost)
106
- print(error_message)
107
- response.content = error_message + response.content
108
-
109
- # Возвращаем ответ в соответствующем формате
110
- if isinstance(input, dict):
111
- return {**input, "messages": [{"role": "assistant", "content": response.content}]}
112
- return response
113
 
114
- def _check_input_length(self, system_message: str, user_message: str) -> bool:
115
- return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS
116
 
117
- def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float:
118
- return num_input_tokens * self.COST_PER_INPUT_TOKEN + \
119
- num_output_tokens * self.COST_PER_OUTPUT_TOKEN
120
 
121
- def _calculate_total_cost(self,
122
- start_date: str = '2025-06-01',
123
- end_date: str = str(datetime.now().date())):
124
- total_cost = 0.0
125
- start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
126
- end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
127
 
128
- # Регулярное выражение для извлечения даты и cost из строки лога
129
- log_pattern = re.compile(
130
- r'^(?P<date>\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - '
131
- r'input_tokens=\d+ - output_tokens=\d+ - '
132
- r'cost=(?P<cost>\d+\.\d{2}) - '
133
- r'execution_time=.*? - status=.*$'
134
- )
135
 
136
- with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file:
137
- for line in file:
138
- match = log_pattern.match(line)
139
- if match:
140
- log_date_str = match.group('date')
141
- log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date()
142
- cost = float(match.group('cost'))
143
 
144
- if start_date <= log_date <= end_date:
145
- total_cost += cost
146
 
147
- return total_cost
 
 
 
 
 
1
+ # from langchain_core.messages import HumanMessage, SystemMessage
2
+ # from langchain_core.runnables.config import RunnableConfig
3
+ # from langchain_gigachat.chat_models import GigaChat
4
+ # import os
5
+ # import re
6
+ # from datetime import datetime
7
+ # from pydantic import Field
8
+ # from dotenv import load_dotenv
9
+ # from typing import Any, Optional
10
+ # import logging
11
+ # from pathlib import Path
12
+ # load_dotenv()
13
+
14
+ # class AnswerGigaChat(GigaChat):
15
+ # THRESHOLD_INPUT_SYMBOLS: int = 40000
16
+ # THRESHOLD_COST: float = 5000.0
17
+ # ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n"
18
+ # COST_PER_INPUT_TOKEN: float = 0.0
19
+ # COST_PER_OUTPUT_TOKEN: float = 1.95e-3
20
+ # LOG_FILE_PATH: str = os.path.join(os.path.dirname(__file__), "gigachat.log")
21
 
22
+ # logger: Any = Field(default=None)
23
+
24
+ # def __init__(self):
25
+ # super().__init__(credentials=os.getenv("GIGACHAT_CREDENTIALS"),
26
+ # verify_ssl_certs=False,
27
+ # model="GigaChat-Max",
28
+ # scope="GIGACHAT_API_CORP")
29
+ # self.logger = self._setup_logger(self.LOG_FILE_PATH)
30
+
31
+ # def _setup_logger(self, log_file: str) -> logging.Logger:
32
+ # logger = logging.getLogger(__name__)
33
+ # logger.setLevel(logging.INFO)
34
+
35
+ # if not logger.handlers:
36
+ # log_file = Path(log_file)
37
+ # log_file.parent.mkdir(parents=True, exist_ok=True) # Создать папку для логов
38
+
39
+ # formatter = logging.Formatter(
40
+ # '%(asctime)s - %(name)s - %(levelname)s - '
41
+ # 'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - '
42
+ # 'cost=%(cost).5f - execution_time=%(execution_time)s - '
43
+ # 'status=%(status)s'
44
+ # )
45
+
46
+ # file_handler = logging.FileHandler(log_file)
47
+ # file_handler.setFormatter(formatter)
48
+ # logger.addHandler(file_handler)
49
+
50
+ # return logger
51
+
52
+ # def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any:
53
+ # # Извлекаем сообщения из входных данных
54
+ # if isinstance(input, list):
55
+ # # Если input - это список сообщений (например, [SystemMessage, HumanMessage])
56
+ # messages = input
57
+ # elif isinstance(input, dict) and "messages" in input:
58
+ # # Если input - это dict с ключом "messages" (формат LangGraph)
59
+ # messages = [
60
+ # SystemMessage(content=msg["content"]) if msg["role"] == "system"
61
+ # else HumanMessage(content=msg["content"])
62
+ # for msg in input["messages"]
63
+ # if msg["role"] in ["system", "user"]
64
+ # ]
65
+ # elif hasattr(input, "messages"):
66
+ # # Если input - это объект с атрибутом messages
67
+ # messages = input.messages
68
+ # else:
69
+ # # Попробуем интерпретировать input как одно сообщение
70
+ # messages = [HumanMessage(content=str(input))]
71
+
72
+ # # Проверяем длину ввода
73
+ # system_content = ""
74
+ # user_content = ""
75
+ # for msg in messages:
76
+ # if isinstance(msg, SystemMessage):
77
+ # system_content += msg.content
78
+ # elif isinstance(msg, HumanMessage):
79
+ # user_content += msg.content
80
+
81
+ # if not self._check_input_length(system_content, user_content):
82
+ # raise ValueError("Too long query")
83
+
84
+ # # Вызываем родительский метод invoke
85
+ # response = super().invoke(messages, config=config, **kwargs)
86
+
87
+ # # Логируем информацию о запросе
88
+ # num_input_tokens = response.usage_metadata["input_tokens"]
89
+ # num_output_tokens = response.usage_metadata["output_tokens"]
90
+ # cost = self._calculate_response_cost(num_input_tokens, num_output_tokens)
91
+ # self.logger.info(
92
+ # "got answer",
93
+ # extra={
94
+ # "input_tokens": num_input_tokens,
95
+ # "output_tokens": num_output_tokens,
96
+ # "cost": cost,
97
+ # "execution_time": str(datetime.now()),
98
+ # "status": "success"
99
+ # }
100
+ # )
101
+
102
+ # # Проверяем общую стоимость
103
+ # total_cost = self._calculate_total_cost()
104
+ # if total_cost >= self.THRESHOLD_COST:
105
+ # error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost)
106
+ # print(error_message)
107
+ # response.content = error_message + response.content
108
+
109
+ # # Возвращаем ответ в соответствующем формате
110
+ # if isinstance(input, dict):
111
+ # return {**input, "messages": [{"role": "assistant", "content": response.content}]}
112
+ # return response
113
 
114
+ # def _check_input_length(self, system_message: str, user_message: str) -> bool:
115
+ # return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS
116
 
117
+ # def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float:
118
+ # return num_input_tokens * self.COST_PER_INPUT_TOKEN + \
119
+ # num_output_tokens * self.COST_PER_OUTPUT_TOKEN
120
 
121
+ # def _calculate_total_cost(self,
122
+ # start_date: str = '2025-06-01',
123
+ # end_date: str = str(datetime.now().date())):
124
+ # total_cost = 0.0
125
+ # start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
126
+ # end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
127
 
128
+ # # Регулярное выражение для извлечения даты и cost из строки лога
129
+ # log_pattern = re.compile(
130
+ # r'^(?P<date>\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - '
131
+ # r'input_tokens=\d+ - output_tokens=\d+ - '
132
+ # r'cost=(?P<cost>\d+\.\d{2}) - '
133
+ # r'execution_time=.*? - status=.*$'
134
+ # )
135
 
136
+ # with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file:
137
+ # for line in file:
138
+ # match = log_pattern.match(line)
139
+ # if match:
140
+ # log_date_str = match.group('date')
141
+ # log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date()
142
+ # cost = float(match.group('cost'))
143
 
144
+ # if start_date <= log_date <= end_date:
145
+ # total_cost += cost
146
 
147
+ # return total_cost
148
+
149
+
150
+ import streamlit as st
151
+ st.write("Hello, World!")