Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ class ChutesClient:
|
|
16 |
"""Client for interacting with Chutes API"""
|
17 |
|
18 |
def __init__(self, api_key: str):
|
19 |
-
self.api_key = api_key
|
20 |
self.base_url = "https://llm.chutes.ai/v1"
|
21 |
|
22 |
async def chat_completions_create(self, **kwargs) -> Dict:
|
@@ -58,8 +58,8 @@ class ChutesClient:
|
|
58 |
chunk_json = json.loads(data)
|
59 |
if "choices" in chunk_json and len(chunk_json["choices"]) > 0:
|
60 |
delta = chunk_json["choices"][0].get("delta", {})
|
61 |
-
if "content" in delta:
|
62 |
-
content += delta["content"]
|
63 |
except json.JSONDecodeError:
|
64 |
continue
|
65 |
|
@@ -97,14 +97,14 @@ class CreativeAgenticAI:
|
|
97 |
chutes_api_key: Chutes API key
|
98 |
model: Which model to use
|
99 |
"""
|
100 |
-
self.groq_api_key = groq_api_key
|
101 |
-
self.chutes_api_key = chutes_api_key
|
102 |
if not self.groq_api_key and model != "openai/gpt-oss-20b":
|
103 |
raise ValueError("No Groq API key provided")
|
104 |
if not self.chutes_api_key and model == "openai/gpt-oss-20b":
|
105 |
raise ValueError("No Chutes API key provided")
|
106 |
|
107 |
-
self.model = model
|
108 |
self.groq_client = Groq(api_key=self.groq_api_key) if self.groq_api_key else None
|
109 |
self.chutes_client = ChutesClient(api_key=self.chutes_api_key) if self.chutes_api_key else None
|
110 |
self.conversation_history = []
|
@@ -140,6 +140,11 @@ class CreativeAgenticAI:
|
|
140 |
Returns:
|
141 |
AI response with metadata
|
142 |
"""
|
|
|
|
|
|
|
|
|
|
|
143 |
# Enhanced system prompt for better behavior
|
144 |
if not system_prompt:
|
145 |
if self.model == "openai/gpt-oss-20b":
|
@@ -158,9 +163,11 @@ IMPORTANT: When you search the web and find information, you MUST:
|
|
158 |
|
159 |
domain_context = ""
|
160 |
if include_domains and self._supports_web_search():
|
161 |
-
|
|
|
162 |
elif exclude_domains and self._supports_web_search():
|
163 |
-
|
|
|
164 |
|
165 |
search_instruction = ""
|
166 |
if search_type == "browser_search" and self._supports_browser_search():
|
@@ -192,10 +199,15 @@ IMPORTANT: When you search the web and find information, you MUST:
|
|
192 |
if (include_domains or exclude_domains) and self._supports_web_search():
|
193 |
filter_context = []
|
194 |
if include_domains:
|
195 |
-
|
|
|
|
|
196 |
if exclude_domains:
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
199 |
|
200 |
messages.append({"role": "user", "content": enhanced_message})
|
201 |
|
@@ -209,10 +221,14 @@ IMPORTANT: When you search the web and find information, you MUST:
|
|
209 |
|
210 |
# Add domain filtering for compound models (Groq only)
|
211 |
if self._supports_web_search():
|
212 |
-
if include_domains
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
216 |
|
217 |
# Add tools only for Groq models that support browser search
|
218 |
tools = []
|
@@ -233,13 +249,23 @@ IMPORTANT: When you search the web and find information, you MUST:
|
|
233 |
params["stream"] = True
|
234 |
response = await self.chutes_client.chat_completions_create(**params)
|
235 |
# Handle Chutes response
|
236 |
-
content =
|
|
|
|
|
|
|
|
|
|
|
237 |
tool_calls = None
|
238 |
else:
|
239 |
# Groq API call
|
240 |
params["max_completion_tokens"] = params.pop("max_tokens", None)
|
241 |
response = self.groq_client.chat.completions.create(**params)
|
242 |
-
content =
|
|
|
|
|
|
|
|
|
|
|
243 |
tool_calls = response.choices[0].message.tool_calls if hasattr(response.choices[0].message, "tool_calls") else None
|
244 |
|
245 |
# Extract tool usage information
|
@@ -302,18 +328,20 @@ IMPORTANT: When you search the web and find information, you MUST:
|
|
302 |
if tools:
|
303 |
for tool in tools:
|
304 |
tool_dict = {
|
305 |
-
"tool_type": getattr(tool, "type", "unknown"),
|
306 |
-
"tool_name": getattr(tool, "name", "unknown"),
|
307 |
}
|
308 |
if hasattr(tool, "input"):
|
309 |
-
tool_input =
|
310 |
-
|
|
|
311 |
if "search" in tool_dict["tool_name"].lower():
|
312 |
-
tool_info["search_queries"].append(
|
313 |
if hasattr(tool, "output"):
|
314 |
-
tool_output =
|
315 |
-
|
316 |
-
|
|
|
317 |
tool_info["sources_found"].extend(urls)
|
318 |
tool_info["tools_used"].append(tool_dict)
|
319 |
|
@@ -321,18 +349,25 @@ IMPORTANT: When you search the web and find information, you MUST:
|
|
321 |
if tool_calls:
|
322 |
for tool_call in tool_calls:
|
323 |
tool_dict = {
|
324 |
-
"tool_type": getattr(tool_call, "type", "browser_search"),
|
325 |
-
"tool_name":
|
326 |
-
"tool_id": getattr(tool_call, "id", None)
|
327 |
}
|
328 |
-
if hasattr(tool_call, "function") and
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
tool_info["tools_used"].append(tool_dict)
|
337 |
|
338 |
return tool_info
|
@@ -341,39 +376,42 @@ IMPORTANT: When you search the web and find information, you MUST:
|
|
341 |
"""Extract URLs from text"""
|
342 |
if not text:
|
343 |
return []
|
|
|
344 |
url_pattern = r'https?://[^\s<>"]{2,}'
|
345 |
-
urls = re.findall(url_pattern,
|
346 |
return list(set(urls))
|
347 |
|
348 |
def _enhance_citations(self, content: str, tool_info: Dict) -> str:
|
349 |
"""Enhance content with better citation formatting"""
|
350 |
if not content:
|
351 |
return ""
|
|
|
352 |
if not tool_info or not tool_info.get("sources_found"):
|
353 |
-
return
|
354 |
|
355 |
-
if "Sources Used:" not in
|
356 |
sources_section = "\n\n---\n\n### Sources Used:\n"
|
357 |
for i, url in enumerate(tool_info["sources_found"][:10], 1):
|
358 |
-
domain = self._extract_domain(url)
|
359 |
sources_section += f"{i}. [{domain}]({url})\n"
|
360 |
-
|
361 |
|
362 |
-
return
|
363 |
|
364 |
def _extract_domain(self, url: str) -> str:
|
365 |
"""Extract domain name from URL for display"""
|
366 |
if not url:
|
367 |
return ""
|
|
|
368 |
try:
|
369 |
-
if
|
370 |
-
domain =
|
371 |
if domain.startswith('www.'):
|
372 |
domain = domain[4:]
|
373 |
return domain
|
374 |
-
return
|
375 |
except:
|
376 |
-
return
|
377 |
|
378 |
def get_model_info(self) -> Dict:
|
379 |
"""Get information about current model capabilities"""
|
@@ -401,15 +439,16 @@ async def validate_api_keys(groq_api_key: str, chutes_api_key: str, model: str)
|
|
401 |
"""Validate both Groq and Chutes API keys and initialize AI instance"""
|
402 |
global ai_instance, api_key_status
|
403 |
|
404 |
-
# Handle None values
|
405 |
-
groq_api_key = groq_api_key
|
406 |
-
chutes_api_key = chutes_api_key
|
|
|
407 |
|
408 |
-
if model == "openai/gpt-oss-20b" and not chutes_api_key:
|
409 |
api_key_status = "Invalid β"
|
410 |
return "β Please enter a valid Chutes API key for the selected model"
|
411 |
|
412 |
-
if model in ["compound-beta", "compound-beta-mini"] and not groq_api_key:
|
413 |
api_key_status = "Invalid β"
|
414 |
return "β Please enter a valid Groq API key for the selected model"
|
415 |
|
@@ -452,6 +491,8 @@ def update_model(model: str) -> str:
|
|
452 |
"""Update the model selection"""
|
453 |
global ai_instance
|
454 |
|
|
|
|
|
455 |
if ai_instance:
|
456 |
ai_instance.model = model
|
457 |
model_info = ai_instance.get_model_info()
|
@@ -471,6 +512,7 @@ def get_search_options(model: str) -> gr.update:
|
|
471 |
if not ai_instance:
|
472 |
return gr.update(choices=["none"], value="none")
|
473 |
|
|
|
474 |
model_info = ai_instance.available_models.get(model, {})
|
475 |
options = ["none"]
|
476 |
|
@@ -497,16 +539,18 @@ async def chat_with_ai(message: str,
|
|
497 |
|
498 |
if not ai_instance:
|
499 |
error_msg = "β οΈ Please set your API keys first!"
|
500 |
-
history.append([message, error_msg])
|
501 |
return history, ""
|
502 |
|
503 |
-
|
504 |
-
|
|
|
|
|
|
|
|
|
505 |
|
506 |
-
|
507 |
-
|
508 |
-
exclude_domains = exclude_domains or ""
|
509 |
-
system_prompt = system_prompt or ""
|
510 |
|
511 |
include_list = [d.strip() for d in include_domains.split(",") if d.strip()] if include_domains.strip() else []
|
512 |
exclude_list = [d.strip() for d in exclude_domains.split(",") if d.strip()] if exclude_domains.strip() else []
|
@@ -523,7 +567,7 @@ async def chat_with_ai(message: str,
|
|
523 |
force_search=force_search
|
524 |
)
|
525 |
|
526 |
-
ai_response = response
|
527 |
|
528 |
# Add tool usage info for Groq models
|
529 |
if response.get("tool_usage") and ai_instance.model != "openai/gpt-oss-20b":
|
@@ -537,7 +581,7 @@ async def chat_with_ai(message: str,
|
|
537 |
tool_summary.append(f"π Sources found: {len(tool_info['sources_found'])}")
|
538 |
|
539 |
if tool_info.get("tools_used"):
|
540 |
-
tool_types = [tool.get("tool_type", "unknown") for tool in tool_info["tools_used"]]
|
541 |
unique_types = list(set(tool_types))
|
542 |
tool_summary.append(f"π§ Tools used: {', '.join(unique_types)}")
|
543 |
|
@@ -546,7 +590,7 @@ async def chat_with_ai(message: str,
|
|
546 |
|
547 |
# Add search settings info
|
548 |
search_info = []
|
549 |
-
if response.get("search_type_used") and response["search_type_used"] != "none":
|
550 |
search_info.append(f"π Search type: {response['search_type_used']}")
|
551 |
|
552 |
if force_search:
|
|
|
16 |
"""Client for interacting with Chutes API"""
|
17 |
|
18 |
def __init__(self, api_key: str):
|
19 |
+
self.api_key = api_key or ""
|
20 |
self.base_url = "https://llm.chutes.ai/v1"
|
21 |
|
22 |
async def chat_completions_create(self, **kwargs) -> Dict:
|
|
|
58 |
chunk_json = json.loads(data)
|
59 |
if "choices" in chunk_json and len(chunk_json["choices"]) > 0:
|
60 |
delta = chunk_json["choices"][0].get("delta", {})
|
61 |
+
if "content" in delta and delta["content"]:
|
62 |
+
content += str(delta["content"])
|
63 |
except json.JSONDecodeError:
|
64 |
continue
|
65 |
|
|
|
97 |
chutes_api_key: Chutes API key
|
98 |
model: Which model to use
|
99 |
"""
|
100 |
+
self.groq_api_key = str(groq_api_key) if groq_api_key else ""
|
101 |
+
self.chutes_api_key = str(chutes_api_key) if chutes_api_key else ""
|
102 |
if not self.groq_api_key and model != "openai/gpt-oss-20b":
|
103 |
raise ValueError("No Groq API key provided")
|
104 |
if not self.chutes_api_key and model == "openai/gpt-oss-20b":
|
105 |
raise ValueError("No Chutes API key provided")
|
106 |
|
107 |
+
self.model = str(model) if model else "compound-beta"
|
108 |
self.groq_client = Groq(api_key=self.groq_api_key) if self.groq_api_key else None
|
109 |
self.chutes_client = ChutesClient(api_key=self.chutes_api_key) if self.chutes_api_key else None
|
110 |
self.conversation_history = []
|
|
|
140 |
Returns:
|
141 |
AI response with metadata
|
142 |
"""
|
143 |
+
# Safe string conversion
|
144 |
+
message = str(message) if message else ""
|
145 |
+
system_prompt = str(system_prompt) if system_prompt else ""
|
146 |
+
search_type = str(search_type) if search_type else "auto"
|
147 |
+
|
148 |
# Enhanced system prompt for better behavior
|
149 |
if not system_prompt:
|
150 |
if self.model == "openai/gpt-oss-20b":
|
|
|
163 |
|
164 |
domain_context = ""
|
165 |
if include_domains and self._supports_web_search():
|
166 |
+
safe_domains = [str(d) for d in include_domains if d]
|
167 |
+
domain_context = f"\nYou are restricted to searching ONLY these domains: {', '.join(safe_domains)}. Make sure to find and cite sources specifically from these domains."
|
168 |
elif exclude_domains and self._supports_web_search():
|
169 |
+
safe_domains = [str(d) for d in exclude_domains if d]
|
170 |
+
domain_context = f"\nAvoid searching these domains: {', '.join(safe_domains)}. Search everywhere else on the web."
|
171 |
|
172 |
search_instruction = ""
|
173 |
if search_type == "browser_search" and self._supports_browser_search():
|
|
|
199 |
if (include_domains or exclude_domains) and self._supports_web_search():
|
200 |
filter_context = []
|
201 |
if include_domains:
|
202 |
+
safe_domains = [str(d) for d in include_domains if d]
|
203 |
+
if safe_domains:
|
204 |
+
filter_context.append(f"ONLY search these domains: {', '.join(safe_domains)}")
|
205 |
if exclude_domains:
|
206 |
+
safe_domains = [str(d) for d in exclude_domains if d]
|
207 |
+
if safe_domains:
|
208 |
+
filter_context.append(f"EXCLUDE these domains: {', '.join(safe_domains)}")
|
209 |
+
if filter_context:
|
210 |
+
enhanced_message += f"\n\n[Domain Filtering: {' | '.join(filter_context)}]"
|
211 |
|
212 |
messages.append({"role": "user", "content": enhanced_message})
|
213 |
|
|
|
221 |
|
222 |
# Add domain filtering for compound models (Groq only)
|
223 |
if self._supports_web_search():
|
224 |
+
if include_domains:
|
225 |
+
safe_domains = [str(d).strip() for d in include_domains if d and str(d).strip()]
|
226 |
+
if safe_domains:
|
227 |
+
params["include_domains"] = safe_domains
|
228 |
+
if exclude_domains:
|
229 |
+
safe_domains = [str(d).strip() for d in exclude_domains if d and str(d).strip()]
|
230 |
+
if safe_domains:
|
231 |
+
params["exclude_domains"] = safe_domains
|
232 |
|
233 |
# Add tools only for Groq models that support browser search
|
234 |
tools = []
|
|
|
249 |
params["stream"] = True
|
250 |
response = await self.chutes_client.chat_completions_create(**params)
|
251 |
# Handle Chutes response
|
252 |
+
content = ""
|
253 |
+
if response and "choices" in response and response["choices"]:
|
254 |
+
message_content = response["choices"][0].get("message", {}).get("content")
|
255 |
+
content = str(message_content) if message_content else "No response content"
|
256 |
+
else:
|
257 |
+
content = "No response received"
|
258 |
tool_calls = None
|
259 |
else:
|
260 |
# Groq API call
|
261 |
params["max_completion_tokens"] = params.pop("max_tokens", None)
|
262 |
response = self.groq_client.chat.completions.create(**params)
|
263 |
+
content = ""
|
264 |
+
if response and response.choices and response.choices[0].message:
|
265 |
+
message_content = response.choices[0].message.content
|
266 |
+
content = str(message_content) if message_content else "No response content"
|
267 |
+
else:
|
268 |
+
content = "No response received"
|
269 |
tool_calls = response.choices[0].message.tool_calls if hasattr(response.choices[0].message, "tool_calls") else None
|
270 |
|
271 |
# Extract tool usage information
|
|
|
328 |
if tools:
|
329 |
for tool in tools:
|
330 |
tool_dict = {
|
331 |
+
"tool_type": str(getattr(tool, "type", "unknown")),
|
332 |
+
"tool_name": str(getattr(tool, "name", "unknown")),
|
333 |
}
|
334 |
if hasattr(tool, "input"):
|
335 |
+
tool_input = getattr(tool, "input")
|
336 |
+
tool_input_str = str(tool_input) if tool_input is not None else ""
|
337 |
+
tool_dict["input"] = tool_input_str
|
338 |
if "search" in tool_dict["tool_name"].lower():
|
339 |
+
tool_info["search_queries"].append(tool_input_str)
|
340 |
if hasattr(tool, "output"):
|
341 |
+
tool_output = getattr(tool, "output")
|
342 |
+
tool_output_str = str(tool_output) if tool_output is not None else ""
|
343 |
+
tool_dict["output"] = tool_output_str
|
344 |
+
urls = self._extract_urls(tool_output_str)
|
345 |
tool_info["sources_found"].extend(urls)
|
346 |
tool_info["tools_used"].append(tool_dict)
|
347 |
|
|
|
349 |
if tool_calls:
|
350 |
for tool_call in tool_calls:
|
351 |
tool_dict = {
|
352 |
+
"tool_type": str(getattr(tool_call, "type", "browser_search")),
|
353 |
+
"tool_name": "browser_search",
|
354 |
+
"tool_id": str(getattr(tool_call, "id", "")) if getattr(tool_call, "id", None) else ""
|
355 |
}
|
356 |
+
if hasattr(tool_call, "function") and tool_call.function:
|
357 |
+
tool_dict["tool_name"] = str(getattr(tool_call.function, "name", "browser_search"))
|
358 |
+
if hasattr(tool_call.function, "arguments"):
|
359 |
+
try:
|
360 |
+
args_raw = tool_call.function.arguments
|
361 |
+
if isinstance(args_raw, str):
|
362 |
+
args = json.loads(args_raw)
|
363 |
+
else:
|
364 |
+
args = args_raw or {}
|
365 |
+
tool_dict["arguments"] = args
|
366 |
+
if "query" in args:
|
367 |
+
tool_info["search_queries"].append(str(args["query"]))
|
368 |
+
except:
|
369 |
+
args_str = str(args_raw) if args_raw is not None else ""
|
370 |
+
tool_dict["arguments"] = args_str
|
371 |
tool_info["tools_used"].append(tool_dict)
|
372 |
|
373 |
return tool_info
|
|
|
376 |
"""Extract URLs from text"""
|
377 |
if not text:
|
378 |
return []
|
379 |
+
text_str = str(text)
|
380 |
url_pattern = r'https?://[^\s<>"]{2,}'
|
381 |
+
urls = re.findall(url_pattern, text_str)
|
382 |
return list(set(urls))
|
383 |
|
384 |
def _enhance_citations(self, content: str, tool_info: Dict) -> str:
|
385 |
"""Enhance content with better citation formatting"""
|
386 |
if not content:
|
387 |
return ""
|
388 |
+
content_str = str(content)
|
389 |
if not tool_info or not tool_info.get("sources_found"):
|
390 |
+
return content_str
|
391 |
|
392 |
+
if "Sources Used:" not in content_str and "sources:" not in content_str.lower():
|
393 |
sources_section = "\n\n---\n\n### Sources Used:\n"
|
394 |
for i, url in enumerate(tool_info["sources_found"][:10], 1):
|
395 |
+
domain = self._extract_domain(str(url))
|
396 |
sources_section += f"{i}. [{domain}]({url})\n"
|
397 |
+
content_str += sources_section
|
398 |
|
399 |
+
return content_str
|
400 |
|
401 |
def _extract_domain(self, url: str) -> str:
|
402 |
"""Extract domain name from URL for display"""
|
403 |
if not url:
|
404 |
return ""
|
405 |
+
url_str = str(url)
|
406 |
try:
|
407 |
+
if url_str.startswith(('http://', 'https://')):
|
408 |
+
domain = url_str.split('/')[2]
|
409 |
if domain.startswith('www.'):
|
410 |
domain = domain[4:]
|
411 |
return domain
|
412 |
+
return url_str
|
413 |
except:
|
414 |
+
return url_str
|
415 |
|
416 |
def get_model_info(self) -> Dict:
|
417 |
"""Get information about current model capabilities"""
|
|
|
439 |
"""Validate both Groq and Chutes API keys and initialize AI instance"""
|
440 |
global ai_instance, api_key_status
|
441 |
|
442 |
+
# Handle None values and convert to strings
|
443 |
+
groq_api_key = str(groq_api_key) if groq_api_key else ""
|
444 |
+
chutes_api_key = str(chutes_api_key) if chutes_api_key else ""
|
445 |
+
model = str(model) if model else "compound-beta"
|
446 |
|
447 |
+
if model == "openai/gpt-oss-20b" and not chutes_api_key.strip():
|
448 |
api_key_status = "Invalid β"
|
449 |
return "β Please enter a valid Chutes API key for the selected model"
|
450 |
|
451 |
+
if model in ["compound-beta", "compound-beta-mini"] and not groq_api_key.strip():
|
452 |
api_key_status = "Invalid β"
|
453 |
return "β Please enter a valid Groq API key for the selected model"
|
454 |
|
|
|
491 |
"""Update the model selection"""
|
492 |
global ai_instance
|
493 |
|
494 |
+
model = str(model) if model else "compound-beta"
|
495 |
+
|
496 |
if ai_instance:
|
497 |
ai_instance.model = model
|
498 |
model_info = ai_instance.get_model_info()
|
|
|
512 |
if not ai_instance:
|
513 |
return gr.update(choices=["none"], value="none")
|
514 |
|
515 |
+
model = str(model) if model else "compound-beta"
|
516 |
model_info = ai_instance.available_models.get(model, {})
|
517 |
options = ["none"]
|
518 |
|
|
|
539 |
|
540 |
if not ai_instance:
|
541 |
error_msg = "β οΈ Please set your API keys first!"
|
542 |
+
history.append([str(message) if message else "", error_msg])
|
543 |
return history, ""
|
544 |
|
545 |
+
# Convert all inputs to strings and handle None values
|
546 |
+
message = str(message) if message else ""
|
547 |
+
include_domains = str(include_domains) if include_domains else ""
|
548 |
+
exclude_domains = str(exclude_domains) if exclude_domains else ""
|
549 |
+
system_prompt = str(system_prompt) if system_prompt else ""
|
550 |
+
search_type = str(search_type) if search_type else "auto"
|
551 |
|
552 |
+
if not message.strip():
|
553 |
+
return history, ""
|
|
|
|
|
554 |
|
555 |
include_list = [d.strip() for d in include_domains.split(",") if d.strip()] if include_domains.strip() else []
|
556 |
exclude_list = [d.strip() for d in exclude_domains.split(",") if d.strip()] if exclude_domains.strip() else []
|
|
|
567 |
force_search=force_search
|
568 |
)
|
569 |
|
570 |
+
ai_response = str(response.get("content", "No response received"))
|
571 |
|
572 |
# Add tool usage info for Groq models
|
573 |
if response.get("tool_usage") and ai_instance.model != "openai/gpt-oss-20b":
|
|
|
581 |
tool_summary.append(f"π Sources found: {len(tool_info['sources_found'])}")
|
582 |
|
583 |
if tool_info.get("tools_used"):
|
584 |
+
tool_types = [str(tool.get("tool_type", "unknown")) for tool in tool_info["tools_used"]]
|
585 |
unique_types = list(set(tool_types))
|
586 |
tool_summary.append(f"π§ Tools used: {', '.join(unique_types)}")
|
587 |
|
|
|
590 |
|
591 |
# Add search settings info
|
592 |
search_info = []
|
593 |
+
if response.get("search_type_used") and str(response["search_type_used"]) != "none":
|
594 |
search_info.append(f"π Search type: {response['search_type_used']}")
|
595 |
|
596 |
if force_search:
|