Boris commited on
Commit
1eb838c
·
1 Parent(s): 7879ba1

just test 3

Browse files
Files changed (2) hide show
  1. src/get_answer_gigachat.py +139 -143
  2. src/streamlit_app.py +97 -93
src/get_answer_gigachat.py CHANGED
@@ -1,151 +1,147 @@
1
- # from langchain_core.messages import HumanMessage, SystemMessage
2
- # from langchain_core.runnables.config import RunnableConfig
3
- # from langchain_gigachat.chat_models import GigaChat
4
- # import os
5
- # import re
6
- # from datetime import datetime
7
- # from pydantic import Field
8
- # from dotenv import load_dotenv
9
- # from typing import Any, Optional
10
- # import logging
11
- # from pathlib import Path
12
- # load_dotenv()
13
-
14
- # class AnswerGigaChat(GigaChat):
15
- # THRESHOLD_INPUT_SYMBOLS: int = 40000
16
- # THRESHOLD_COST: float = 5000.0
17
- # ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n"
18
- # COST_PER_INPUT_TOKEN: float = 0.0
19
- # COST_PER_OUTPUT_TOKEN: float = 1.95e-3
20
- # LOG_FILE_PATH: str = os.path.join(os.path.dirname(__file__), "gigachat.log")
21
 
22
- # logger: Any = Field(default=None)
23
-
24
- # def __init__(self):
25
- # super().__init__(credentials=os.getenv("GIGACHAT_CREDENTIALS"),
26
- # verify_ssl_certs=False,
27
- # model="GigaChat-Max",
28
- # scope="GIGACHAT_API_CORP")
29
- # self.logger = self._setup_logger(self.LOG_FILE_PATH)
30
-
31
- # def _setup_logger(self, log_file: str) -> logging.Logger:
32
- # logger = logging.getLogger(__name__)
33
- # logger.setLevel(logging.INFO)
34
-
35
- # if not logger.handlers:
36
- # log_file = Path(log_file)
37
- # log_file.parent.mkdir(parents=True, exist_ok=True) # Создать папку для логов
38
-
39
- # formatter = logging.Formatter(
40
- # '%(asctime)s - %(name)s - %(levelname)s - '
41
- # 'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - '
42
- # 'cost=%(cost).5f - execution_time=%(execution_time)s - '
43
- # 'status=%(status)s'
44
- # )
45
-
46
- # file_handler = logging.FileHandler(log_file)
47
- # file_handler.setFormatter(formatter)
48
- # logger.addHandler(file_handler)
49
-
50
- # return logger
51
-
52
- # def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any:
53
- # # Извлекаем сообщения из входных данных
54
- # if isinstance(input, list):
55
- # # Если input - это список сообщений (например, [SystemMessage, HumanMessage])
56
- # messages = input
57
- # elif isinstance(input, dict) and "messages" in input:
58
- # # Если input - это dict с ключом "messages" (формат LangGraph)
59
- # messages = [
60
- # SystemMessage(content=msg["content"]) if msg["role"] == "system"
61
- # else HumanMessage(content=msg["content"])
62
- # for msg in input["messages"]
63
- # if msg["role"] in ["system", "user"]
64
- # ]
65
- # elif hasattr(input, "messages"):
66
- # # Если input - это объект с атрибутом messages
67
- # messages = input.messages
68
- # else:
69
- # # Попробуем интерпретировать input как одно сообщение
70
- # messages = [HumanMessage(content=str(input))]
71
-
72
- # # Проверяем длину ввода
73
- # system_content = ""
74
- # user_content = ""
75
- # for msg in messages:
76
- # if isinstance(msg, SystemMessage):
77
- # system_content += msg.content
78
- # elif isinstance(msg, HumanMessage):
79
- # user_content += msg.content
80
-
81
- # if not self._check_input_length(system_content, user_content):
82
- # raise ValueError("Too long query")
83
-
84
- # # Вызываем родительский метод invoke
85
- # response = super().invoke(messages, config=config, **kwargs)
86
-
87
- # # Логируем информацию о запросе
88
- # num_input_tokens = response.usage_metadata["input_tokens"]
89
- # num_output_tokens = response.usage_metadata["output_tokens"]
90
- # cost = self._calculate_response_cost(num_input_tokens, num_output_tokens)
91
- # self.logger.info(
92
- # "got answer",
93
- # extra={
94
- # "input_tokens": num_input_tokens,
95
- # "output_tokens": num_output_tokens,
96
- # "cost": cost,
97
- # "execution_time": str(datetime.now()),
98
- # "status": "success"
99
- # }
100
- # )
101
-
102
- # # Проверяем общую стоимость
103
- # total_cost = self._calculate_total_cost()
104
- # if total_cost >= self.THRESHOLD_COST:
105
- # error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost)
106
- # print(error_message)
107
- # response.content = error_message + response.content
108
-
109
- # # Возвращаем ответ в соответствующем формате
110
- # if isinstance(input, dict):
111
- # return {**input, "messages": [{"role": "assistant", "content": response.content}]}
112
- # return response
113
 
114
- # def _check_input_length(self, system_message: str, user_message: str) -> bool:
115
- # return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS
116
 
117
- # def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float:
118
- # return num_input_tokens * self.COST_PER_INPUT_TOKEN + \
119
- # num_output_tokens * self.COST_PER_OUTPUT_TOKEN
120
 
121
- # def _calculate_total_cost(self,
122
- # start_date: str = '2025-06-01',
123
- # end_date: str = str(datetime.now().date())):
124
- # total_cost = 0.0
125
- # start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
126
- # end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
127
 
128
- # # Регулярное выражение для извлечения даты и cost из строки лога
129
- # log_pattern = re.compile(
130
- # r'^(?P<date>\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - '
131
- # r'input_tokens=\d+ - output_tokens=\d+ - '
132
- # r'cost=(?P<cost>\d+\.\d{2}) - '
133
- # r'execution_time=.*? - status=.*$'
134
- # )
135
 
136
- # with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file:
137
- # for line in file:
138
- # match = log_pattern.match(line)
139
- # if match:
140
- # log_date_str = match.group('date')
141
- # log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date()
142
- # cost = float(match.group('cost'))
143
 
144
- # if start_date <= log_date <= end_date:
145
- # total_cost += cost
146
 
147
- # return total_cost
148
-
149
-
150
- import streamlit as st
151
- st.write("Hello, World!")
 
1
+ from langchain_core.messages import HumanMessage, SystemMessage
2
+ from langchain_core.runnables.config import RunnableConfig
3
+ from langchain_gigachat.chat_models import GigaChat
4
+ import os
5
+ import re
6
+ from datetime import datetime
7
+ from pydantic import Field
8
+ from dotenv import load_dotenv
9
+ from typing import Any, Optional
10
+ import logging
11
+ from pathlib import Path
12
+ load_dotenv()
13
+
14
+ class AnswerGigaChat(GigaChat):
15
+ THRESHOLD_INPUT_SYMBOLS: int = 40000
16
+ THRESHOLD_COST: float = 5000.0
17
+ ERROR_MESSAGE_EXCEEDED_COST: str = "\n\n\n\n\nALERT!!!!!!\nCOST NORM WAS EXCEEDED!!!!!!!!\n{} >= " + str(THRESHOLD_COST) + "\n\n\n\n"
18
+ COST_PER_INPUT_TOKEN: float = 0.0
19
+ COST_PER_OUTPUT_TOKEN: float = 1.95e-3
20
+ LOG_FILE_PATH: str = os.path.join(os.path.dirname(__file__), "gigachat.log")
21
 
22
+ logger: Any = Field(default=None)
23
+
24
+ def __init__(self):
25
+ super().__init__(credentials=os.getenv("GIGACHAT_CREDENTIALS"),
26
+ verify_ssl_certs=False,
27
+ model="GigaChat-Max",
28
+ scope="GIGACHAT_API_CORP")
29
+ self.logger = self._setup_logger(self.LOG_FILE_PATH)
30
+
31
+ def _setup_logger(self, log_file: str) -> logging.Logger:
32
+ logger = logging.getLogger(__name__)
33
+ logger.setLevel(logging.INFO)
34
+
35
+ if not logger.handlers:
36
+ log_file = Path(log_file)
37
+ log_file.parent.mkdir(parents=True, exist_ok=True) # Создать папку для логов
38
+
39
+ formatter = logging.Formatter(
40
+ '%(asctime)s - %(name)s - %(levelname)s - '
41
+ 'input_tokens=%(input_tokens)d - output_tokens=%(output_tokens)d - '
42
+ 'cost=%(cost).5f - execution_time=%(execution_time)s - '
43
+ 'status=%(status)s'
44
+ )
45
+
46
+ file_handler = logging.FileHandler(log_file)
47
+ file_handler.setFormatter(formatter)
48
+ logger.addHandler(file_handler)
49
+
50
+ return logger
51
+
52
+ def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any:
53
+ # Извлекаем сообщения из входных данных
54
+ if isinstance(input, list):
55
+ # Если input - это список сообщений (например, [SystemMessage, HumanMessage])
56
+ messages = input
57
+ elif isinstance(input, dict) and "messages" in input:
58
+ # Если input - это dict с ключом "messages" (формат LangGraph)
59
+ messages = [
60
+ SystemMessage(content=msg["content"]) if msg["role"] == "system"
61
+ else HumanMessage(content=msg["content"])
62
+ for msg in input["messages"]
63
+ if msg["role"] in ["system", "user"]
64
+ ]
65
+ elif hasattr(input, "messages"):
66
+ # Если input - это объект с атрибутом messages
67
+ messages = input.messages
68
+ else:
69
+ # Попробуем интерпретировать input как одно сообщение
70
+ messages = [HumanMessage(content=str(input))]
71
+
72
+ # Проверяем длину ввода
73
+ system_content = ""
74
+ user_content = ""
75
+ for msg in messages:
76
+ if isinstance(msg, SystemMessage):
77
+ system_content += msg.content
78
+ elif isinstance(msg, HumanMessage):
79
+ user_content += msg.content
80
+
81
+ if not self._check_input_length(system_content, user_content):
82
+ raise ValueError("Too long query")
83
+
84
+ # Вызываем родительский метод invoke
85
+ response = super().invoke(messages, config=config, **kwargs)
86
+
87
+ # Логируем информацию о запросе
88
+ num_input_tokens = response.usage_metadata["input_tokens"]
89
+ num_output_tokens = response.usage_metadata["output_tokens"]
90
+ cost = self._calculate_response_cost(num_input_tokens, num_output_tokens)
91
+ self.logger.info(
92
+ "got answer",
93
+ extra={
94
+ "input_tokens": num_input_tokens,
95
+ "output_tokens": num_output_tokens,
96
+ "cost": cost,
97
+ "execution_time": str(datetime.now()),
98
+ "status": "success"
99
+ }
100
+ )
101
+
102
+ # Проверяем общую стоимость
103
+ total_cost = self._calculate_total_cost()
104
+ if total_cost >= self.THRESHOLD_COST:
105
+ error_message = self.ERROR_MESSAGE_EXCEEDED_COST.format(total_cost)
106
+ print(error_message)
107
+ response.content = error_message + response.content
108
+
109
+ # Возвращаем ответ в соответствующем формате
110
+ if isinstance(input, dict):
111
+ return {**input, "messages": [{"role": "assistant", "content": response.content}]}
112
+ return response
113
 
114
+ def _check_input_length(self, system_message: str, user_message: str) -> bool:
115
+ return len(system_message) + len(user_message) < self.THRESHOLD_INPUT_SYMBOLS
116
 
117
+ def _calculate_response_cost(self, num_input_tokens: int, num_output_tokens: int) -> float:
118
+ return num_input_tokens * self.COST_PER_INPUT_TOKEN + \
119
+ num_output_tokens * self.COST_PER_OUTPUT_TOKEN
120
 
121
+ def _calculate_total_cost(self,
122
+ start_date: str = '2025-06-01',
123
+ end_date: str = str(datetime.now().date())):
124
+ total_cost = 0.0
125
+ start_date = datetime.strptime(start_date, '%Y-%m-%d').date()
126
+ end_date = datetime.strptime(end_date, '%Y-%m-%d').date()
127
 
128
+ # Регулярное выражение для извлечения даты и cost из строки лога
129
+ log_pattern = re.compile(
130
+ r'^(?P<date>\d{4}-\d{2}-\d{2}) \d{2}:\d{2}:\d{2},\d{3} - .*? - .*? - '
131
+ r'input_tokens=\d+ - output_tokens=\d+ - '
132
+ r'cost=(?P<cost>\d+\.\d{2}) - '
133
+ r'execution_time=.*? - status=.*$'
134
+ )
135
 
136
+ with open(self.LOG_FILE_PATH, 'r', encoding='utf-8') as file:
137
+ for line in file:
138
+ match = log_pattern.match(line)
139
+ if match:
140
+ log_date_str = match.group('date')
141
+ log_date = datetime.strptime(log_date_str, '%Y-%m-%d').date()
142
+ cost = float(match.group('cost'))
143
 
144
+ if start_date <= log_date <= end_date:
145
+ total_cost += cost
146
 
147
+ return total_cost
 
 
 
 
src/streamlit_app.py CHANGED
@@ -1,105 +1,109 @@
1
- import streamlit as st
2
- from langchain_core.messages import HumanMessage, AIMessage
3
- from get_classification import get_graph_class
4
- from datetime import datetime
5
-
6
-
7
- def message_to_dict(messages):
8
- result = []
9
- for message in messages:
10
- if isinstance(message, HumanMessage) or isinstance(message, AIMessage):
11
- print("message:", message.content)
12
- if message.content == "" or message.content is None:
13
- continue
14
- if isinstance(message, HumanMessage):
15
- result.append({"role": "user", "content": message.content})
16
- elif isinstance(message, AIMessage):
17
- result.append({"role": "assistant", "content": message.content})
18
- print("-" * 100)
19
- return result
20
-
21
-
22
- def find_last_bot_message(messages):
23
- """Находит последнее сообщение бота"""
24
- for message in messages[::-1]:
25
- if isinstance(message, AIMessage) and len(message.content) > 0:
26
- return message.content
27
- return None
28
-
29
-
30
- def display_chat_messages():
31
- """Отображает историю сообщений в чате"""
32
- for message in st.session_state.messages:
33
- with st.chat_message(message["role"]):
34
- st.markdown(message["content"])
35
-
36
-
37
- def save_broken_case():
38
- messages_dict = st.session_state.messages
39
- result_str = ""
40
- for elem in messages_dict:
41
- role = elem["role"]
42
- content = elem["content"]
43
- result_str += f"{role}: {content}\n"
44
 
45
- current_datetime = datetime.now()
46
- formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
47
-
48
- # with open("/Users/admin/my_documents/retrieval_part/services/broken_cases.txt", "a") as file:
49
- # file.write(f"{formatted_datetime}\n" + result_str + "\n" + "-" * 50 + "\n\n")
50
-
51
-
52
- def handle_user_input():
53
- """Обрабатывает ввод пользователя и генерирует ответ бота"""
54
- if prompt := st.chat_input("Введите ваш вопрос"):
55
- st.session_state.messages.append({"role": "user", "content": prompt})
56
- if prompt.lower().startswith("log"):
57
- save_broken_case()
58
- st.session_state.messages = []
59
- display_chat_messages()
60
- else:
61
- with st.chat_message("user"):
62
- st.markdown(prompt)
63
 
64
- try:
65
- if "bot" not in st.session_state:
66
- st.session_state.bot = get_graph_class(prompt)
67
- st.session_state.bot.invoke(prompt)
68
- except Exception as e:
69
- st.error(f"Ошибка: {str(e)}")
70
 
71
- # Извлекаем последнее сообщение бота
72
- last_bot_message = find_last_bot_message(st.session_state.bot.messages)
73
- st.session_state.messages.append(
74
- {"role": "assistant", "content": last_bot_message}
75
- )
76
- with st.chat_message("assistant"):
77
- st.markdown(last_bot_message)
78
- raise ValueError
 
79
 
 
 
 
 
80
 
81
- def clear_chat():
82
- """Очищает чат и пересоздает бота"""
83
- st.session_state.messages = []
84
- del st.session_state.bot
85
 
 
 
 
 
86
 
87
- def main():
88
- """Основная функция приложения"""
89
- # Заголовок приложения
90
- st.title("Чат-бот технической поддержки OpenVPN")
91
 
92
- # Кнопка очистки чата
93
- if st.button("Clear"):
94
- clear_chat()
95
 
96
- if "messages" not in st.session_state:
97
- st.session_state.messages = []
 
98
 
99
- # Отображение чата и обработка ввода
100
- display_chat_messages()
101
- handle_user_input()
102
 
 
 
103
 
104
- if __name__ == "__main__":
105
- main()
 
 
1
+ # import streamlit as st
2
+ # from langchain_core.messages import HumanMessage, AIMessage
3
+ # from get_classification import get_graph_class
4
+ # from datetime import datetime
5
+
6
+
7
+ # def message_to_dict(messages):
8
+ # result = []
9
+ # for message in messages:
10
+ # if isinstance(message, HumanMessage) or isinstance(message, AIMessage):
11
+ # print("message:", message.content)
12
+ # if message.content == "" or message.content is None:
13
+ # continue
14
+ # if isinstance(message, HumanMessage):
15
+ # result.append({"role": "user", "content": message.content})
16
+ # elif isinstance(message, AIMessage):
17
+ # result.append({"role": "assistant", "content": message.content})
18
+ # print("-" * 100)
19
+ # return result
20
+
21
+
22
+ # def find_last_bot_message(messages):
23
+ # """Находит последнее сообщение бота"""
24
+ # for message in messages[::-1]:
25
+ # if isinstance(message, AIMessage) and len(message.content) > 0:
26
+ # return message.content
27
+ # return None
28
+
29
+
30
+ # def display_chat_messages():
31
+ # """Отображает историю сообщений в чате"""
32
+ # for message in st.session_state.messages:
33
+ # with st.chat_message(message["role"]):
34
+ # st.markdown(message["content"])
35
+
36
+
37
+ # def save_broken_case():
38
+ # messages_dict = st.session_state.messages
39
+ # result_str = ""
40
+ # for elem in messages_dict:
41
+ # role = elem["role"]
42
+ # content = elem["content"]
43
+ # result_str += f"{role}: {content}\n"
44
 
45
+ # current_datetime = datetime.now()
46
+ # formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
47
+
48
+ # # with open("/Users/admin/my_documents/retrieval_part/services/broken_cases.txt", "a") as file:
49
+ # # file.write(f"{formatted_datetime}\n" + result_str + "\n" + "-" * 50 + "\n\n")
50
+
51
+
52
+ # def handle_user_input():
53
+ # """Обрабатывает ввод пользователя и генерирует ответ бота"""
54
+ # if prompt := st.chat_input("Введите ваш вопрос"):
55
+ # st.session_state.messages.append({"role": "user", "content": prompt})
56
+ # if prompt.lower().startswith("log"):
57
+ # save_broken_case()
58
+ # st.session_state.messages = []
59
+ # display_chat_messages()
60
+ # else:
61
+ # with st.chat_message("user"):
62
+ # st.markdown(prompt)
63
 
64
+ # try:
65
+ # if "bot" not in st.session_state:
66
+ # st.session_state.bot = get_graph_class(prompt)
67
+ # st.session_state.bot.invoke(prompt)
68
+ # except Exception as e:
69
+ # st.error(f"Ошибка: {str(e)}")
70
 
71
+ # # Извлекаем последнее сообщение бота
72
+ # last_bot_message = find_last_bot_message(st.session_state.bot.messages)
73
+ # st.session_state.messages.append(
74
+ # {"role": "assistant", "content": last_bot_message}
75
+ # )
76
+ # with st.chat_message("assistant"):
77
+ # st.markdown(last_bot_message)
78
+ # raise ValueError
79
+
80
 
81
+ # def clear_chat():
82
+ # """Очищает чат и пересоздает бота"""
83
+ # st.session_state.messages = []
84
+ # del st.session_state.bot
85
 
 
 
 
 
86
 
87
+ # def main():
88
+ # """Основная функция приложения"""
89
+ # # Заголовок приложения
90
+ # st.title("Чат-бот технической поддержки OpenVPN")
91
 
92
+ # # Кнопка очистки чата
93
+ # if st.button("Clear"):
94
+ # clear_chat()
 
95
 
96
+ # if "messages" not in st.session_state:
97
+ # st.session_state.messages = []
 
98
 
99
+ # # Отображение чата и обработка ввода
100
+ # display_chat_messages()
101
+ # handle_user_input()
102
 
 
 
 
103
 
104
+ # if __name__ == "__main__":
105
+ # main()
106
 
107
+
108
+ import streamlit as st
109
+ st.write("Hello, World!")