ginipick commited on
Commit
5c081fe
ยท
verified ยท
1 Parent(s): e73cb46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -27
app.py CHANGED
@@ -17,10 +17,6 @@ from huggingface_hub import hf_hub_download
17
  SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
18
 
19
  def do_web_search(query: str) -> str:
20
- """
21
- Brave Web Search API๋ฅผ ์ด์šฉํ•˜์—ฌ query ๊ฒ€์ƒ‰ ํ›„,
22
- ์ตœ๋Œ€ 10๊ฐœ ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ๋ฅผ title/url/description ํ˜•ํƒœ๋กœ ๋งˆํฌ๋‹ค์šด์œผ๋กœ ์š”์•ฝ ๋ฐ˜ํ™˜.
23
- """
24
  try:
25
  url = "https://api.search.brave.com/res/v1/web/search"
26
  params = {
@@ -31,7 +27,6 @@ def do_web_search(query: str) -> str:
31
  headers = {
32
  "Accept": "application/json",
33
  "Accept-Encoding": "gzip",
34
- # Brave API ํ‚ค ( SERPHOUSE_API_KEY ๋ฅผ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ )
35
  "X-Subscription-Token": SERPHOUSE_API_KEY,
36
  }
37
  response = requests.get(url, headers=headers, params=params, timeout=30)
@@ -125,14 +120,14 @@ css = """
125
 
126
  def get_messages_formatter_type(model_name):
127
  if "Mistral" in model_name or "BitSix" in model_name:
128
- return MessagesFormatterType.CHATML # Mistral ๊ณ„์—ด ๋ชจ๋ธ์€ ChatML ํ˜•์‹ ์‚ฌ์šฉ
129
  else:
130
  raise ValueError(f"Unsupported model: {model_name}")
131
 
132
  @spaces.GPU(duration=120)
133
  def respond(
134
  message,
135
- history: list[dict], # history ํ•ญ๋ชฉ์ด tuple์ด ์•„๋‹Œ dict ํ˜•์‹์œผ๋กœ ์ „๋‹ฌ๋จ
136
  system_message,
137
  max_tokens,
138
  temperature,
@@ -145,9 +140,7 @@ def respond(
145
 
146
  chat_template = get_messages_formatter_type(MISTRAL_MODEL_NAME)
147
 
148
- # ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ ํ™•์ธ
149
  model_path_local = os.path.join("./models", MISTRAL_MODEL_NAME)
150
-
151
  print(f"Model path: {model_path_local}")
152
 
153
  if not os.path.exists(model_path_local):
@@ -182,31 +175,47 @@ def respond(
182
  settings.stream = True
183
 
184
  # --------------------------------------------------------------------------------------
185
- # ์—ฌ๊ธฐ์„œ Brave Web Search๋ฅผ ์ˆ˜ํ–‰ํ•˜์—ฌ ๊ทธ ๊ฒฐ๊ณผ๋ฅผ system_message์— ์ถ”๊ฐ€ (๊ธฐ๋ณธ ์ ์šฉ)
186
  # --------------------------------------------------------------------------------------
187
- # 1) ์œ ์ € ๋ฉ”์‹œ์ง€์— ๋Œ€ํ•ด Brave ๊ฒ€์ƒ‰
188
  search_results = do_web_search(message)
189
-
190
- # 2) system_message ๋์— ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์ถ”๊ฐ€
191
- # ํ•„์š”์— ๋”ฐ๋ผ ์•ˆ๋‚ด ๋ฌธ๊ตฌ๋„ ๋ง๋ถ™์ผ ์ˆ˜ ์žˆ์Œ
192
  agent.system_prompt += f"\n\n[Brave Search Results for '{message}']\n{search_results}\n"
193
  # --------------------------------------------------------------------------------------
194
 
195
  messages = BasicChatHistory()
196
 
197
- # history์˜ ๊ฐ ํ•ญ๋ชฉ์ด dict ํ˜•์‹์œผ๋กœ {'user': <user_message>, 'assistant': <assistant_message>} ํ˜•ํƒœ๋ผ๊ณ  ๊ฐ€์ •
198
- for msn in history:
199
- user_message = {
200
- 'role': Roles.user,
201
- 'content': msn.get('user', '')
202
- }
203
- assistant_message = {
204
- 'role': Roles.assistant,
205
- 'content': msn.get('assistant', '')
206
- }
207
- messages.add_message(user_message)
208
- messages.add_message(assistant_message)
209
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  stream = agent.get_chat_response(
211
  message,
212
  llm_sampling_settings=settings,
 
17
  SERPHOUSE_API_KEY = os.getenv("SERPHOUSE_API_KEY", "")
18
 
19
  def do_web_search(query: str) -> str:
 
 
 
 
20
  try:
21
  url = "https://api.search.brave.com/res/v1/web/search"
22
  params = {
 
27
  headers = {
28
  "Accept": "application/json",
29
  "Accept-Encoding": "gzip",
 
30
  "X-Subscription-Token": SERPHOUSE_API_KEY,
31
  }
32
  response = requests.get(url, headers=headers, params=params, timeout=30)
 
120
 
121
  def get_messages_formatter_type(model_name):
122
  if "Mistral" in model_name or "BitSix" in model_name:
123
+ return MessagesFormatterType.CHATML
124
  else:
125
  raise ValueError(f"Unsupported model: {model_name}")
126
 
127
  @spaces.GPU(duration=120)
128
  def respond(
129
  message,
130
+ history: list[dict],
131
  system_message,
132
  max_tokens,
133
  temperature,
 
140
 
141
  chat_template = get_messages_formatter_type(MISTRAL_MODEL_NAME)
142
 
 
143
  model_path_local = os.path.join("./models", MISTRAL_MODEL_NAME)
 
144
  print(f"Model path: {model_path_local}")
145
 
146
  if not os.path.exists(model_path_local):
 
175
  settings.stream = True
176
 
177
  # --------------------------------------------------------------------------------------
178
+ # Brave Web Search๋ฅผ ์ˆ˜ํ–‰ํ•˜์—ฌ ๊ทธ ๊ฒฐ๊ณผ๋ฅผ system_message ๋์— ์ถ”๊ฐ€
179
  # --------------------------------------------------------------------------------------
 
180
  search_results = do_web_search(message)
 
 
 
181
  agent.system_prompt += f"\n\n[Brave Search Results for '{message}']\n{search_results}\n"
182
  # --------------------------------------------------------------------------------------
183
 
184
  messages = BasicChatHistory()
185
 
186
+ # ----------------------------------------------------------------------------
187
+ # 2๋ฒˆ ํ•ด๊ฒฐ์ฑ…: history ๋””๋ฒ„๊น… ๋ฐ ๋นˆ ๋ฉ”์‹œ์ง€ ๋ฐฉ์ง€
188
+ # ----------------------------------------------------------------------------
189
+ for i, msn in enumerate(history):
190
+ print(f"[DEBUG] History item #{i}: {msn}") # ์‹ค์ œ ๊ตฌ์กฐ๋ฅผ ํ™•์ธํ•˜๊ธฐ ์œ„ํ•œ ๋””๋ฒ„๊ทธ ๋กœ๊ทธ
191
+
192
+ user_text = msn.get("user", "")
193
+ assistant_text = msn.get("assistant", "")
194
+
195
+ # user (role=user)
196
+ if user_text.strip():
197
+ user_message = {
198
+ "role": Roles.user,
199
+ "content": user_text
200
+ }
201
+ messages.add_message(user_message)
202
+ else:
203
+ if "user" not in msn or not msn["user"]:
204
+ print(f"[WARN] History item #{i}: 'user'๊ฐ€ ์—†๊ฑฐ๋‚˜ ๋นˆ ๋ฌธ์ž์—ด์ž…๋‹ˆ๋‹ค.")
205
+
206
+ # assistant (role=assistant)
207
+ if assistant_text.strip():
208
+ assistant_message = {
209
+ "role": Roles.assistant,
210
+ "content": assistant_text
211
+ }
212
+ messages.add_message(assistant_message)
213
+ else:
214
+ if "assistant" not in msn or not msn["assistant"]:
215
+ print(f"[WARN] History item #{i}: 'assistant'๊ฐ€ ์—†๊ฑฐ๋‚˜ ๋นˆ ๋ฌธ์ž์—ด์ž…๋‹ˆ๋‹ค.")
216
+ # ----------------------------------------------------------------------------
217
+
218
+ # ๋ชจ๋ธ ์ƒ์„ฑ
219
  stream = agent.get_chat_response(
220
  message,
221
  llm_sampling_settings=settings,