shukdevdattaEX commited on
Commit
4364505
Β·
verified Β·
1 Parent(s): fa91193

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -61
app.py CHANGED
@@ -16,7 +16,7 @@ class ChutesClient:
16
  """Client for interacting with Chutes API"""
17
 
18
  def __init__(self, api_key: str):
19
- self.api_key = api_key
20
  self.base_url = "https://llm.chutes.ai/v1"
21
 
22
  async def chat_completions_create(self, **kwargs) -> Dict:
@@ -58,8 +58,8 @@ class ChutesClient:
58
  chunk_json = json.loads(data)
59
  if "choices" in chunk_json and len(chunk_json["choices"]) > 0:
60
  delta = chunk_json["choices"][0].get("delta", {})
61
- if "content" in delta:
62
- content += delta["content"]
63
  except json.JSONDecodeError:
64
  continue
65
 
@@ -97,14 +97,14 @@ class CreativeAgenticAI:
97
  chutes_api_key: Chutes API key
98
  model: Which model to use
99
  """
100
- self.groq_api_key = groq_api_key or ""
101
- self.chutes_api_key = chutes_api_key or ""
102
  if not self.groq_api_key and model != "openai/gpt-oss-20b":
103
  raise ValueError("No Groq API key provided")
104
  if not self.chutes_api_key and model == "openai/gpt-oss-20b":
105
  raise ValueError("No Chutes API key provided")
106
 
107
- self.model = model
108
  self.groq_client = Groq(api_key=self.groq_api_key) if self.groq_api_key else None
109
  self.chutes_client = ChutesClient(api_key=self.chutes_api_key) if self.chutes_api_key else None
110
  self.conversation_history = []
@@ -140,6 +140,11 @@ class CreativeAgenticAI:
140
  Returns:
141
  AI response with metadata
142
  """
 
 
 
 
 
143
  # Enhanced system prompt for better behavior
144
  if not system_prompt:
145
  if self.model == "openai/gpt-oss-20b":
@@ -158,9 +163,11 @@ IMPORTANT: When you search the web and find information, you MUST:
158
 
159
  domain_context = ""
160
  if include_domains and self._supports_web_search():
161
- domain_context = f"\nYou are restricted to searching ONLY these domains: {', '.join(include_domains)}. Make sure to find and cite sources specifically from these domains."
 
162
  elif exclude_domains and self._supports_web_search():
163
- domain_context = f"\nAvoid searching these domains: {', '.join(exclude_domains)}. Search everywhere else on the web."
 
164
 
165
  search_instruction = ""
166
  if search_type == "browser_search" and self._supports_browser_search():
@@ -192,10 +199,15 @@ IMPORTANT: When you search the web and find information, you MUST:
192
  if (include_domains or exclude_domains) and self._supports_web_search():
193
  filter_context = []
194
  if include_domains:
195
- filter_context.append(f"ONLY search these domains: {', '.join(include_domains)}")
 
 
196
  if exclude_domains:
197
- filter_context.append(f"EXCLUDE these domains: {', '.join(exclude_domains)}")
198
- enhanced_message += f"\n\n[Domain Filtering: {' | '.join(filter_context)}]"
 
 
 
199
 
200
  messages.append({"role": "user", "content": enhanced_message})
201
 
@@ -209,10 +221,14 @@ IMPORTANT: When you search the web and find information, you MUST:
209
 
210
  # Add domain filtering for compound models (Groq only)
211
  if self._supports_web_search():
212
- if include_domains and include_domains[0] and include_domains[0].strip():
213
- params["include_domains"] = [domain.strip() for domain in include_domains if domain and domain.strip()]
214
- if exclude_domains and exclude_domains[0] and exclude_domains[0].strip():
215
- params["exclude_domains"] = [domain.strip() for domain in exclude_domains if domain and domain.strip()]
 
 
 
 
216
 
217
  # Add tools only for Groq models that support browser search
218
  tools = []
@@ -233,13 +249,23 @@ IMPORTANT: When you search the web and find information, you MUST:
233
  params["stream"] = True
234
  response = await self.chutes_client.chat_completions_create(**params)
235
  # Handle Chutes response
236
- content = response.get("choices", [{}])[0].get("message", {}).get("content", "No response content")
 
 
 
 
 
237
  tool_calls = None
238
  else:
239
  # Groq API call
240
  params["max_completion_tokens"] = params.pop("max_tokens", None)
241
  response = self.groq_client.chat.completions.create(**params)
242
- content = response.choices[0].message.content or ""
 
 
 
 
 
243
  tool_calls = response.choices[0].message.tool_calls if hasattr(response.choices[0].message, "tool_calls") else None
244
 
245
  # Extract tool usage information
@@ -302,18 +328,20 @@ IMPORTANT: When you search the web and find information, you MUST:
302
  if tools:
303
  for tool in tools:
304
  tool_dict = {
305
- "tool_type": getattr(tool, "type", "unknown"),
306
- "tool_name": getattr(tool, "name", "unknown"),
307
  }
308
  if hasattr(tool, "input"):
309
- tool_input = str(tool.input) if tool.input is not None else ""
310
- tool_dict["input"] = tool_input
 
311
  if "search" in tool_dict["tool_name"].lower():
312
- tool_info["search_queries"].append(tool_input)
313
  if hasattr(tool, "output"):
314
- tool_output = str(tool.output) if tool.output is not None else ""
315
- tool_dict["output"] = tool_output
316
- urls = self._extract_urls(tool_output)
 
317
  tool_info["sources_found"].extend(urls)
318
  tool_info["tools_used"].append(tool_dict)
319
 
@@ -321,18 +349,25 @@ IMPORTANT: When you search the web and find information, you MUST:
321
  if tool_calls:
322
  for tool_call in tool_calls:
323
  tool_dict = {
324
- "tool_type": getattr(tool_call, "type", "browser_search"),
325
- "tool_name": getattr(tool_call, "function", {}).get("name", "browser_search") if hasattr(tool_call, "function") else "browser_search",
326
- "tool_id": getattr(tool_call, "id", None)
327
  }
328
- if hasattr(tool_call, "function") and hasattr(tool_call.function, "arguments"):
329
- try:
330
- args = json.loads(tool_call.function.arguments) if isinstance(tool_call.function.arguments, str) else tool_call.function.arguments
331
- tool_dict["arguments"] = args
332
- if "query" in args:
333
- tool_info["search_queries"].append(str(args["query"]))
334
- except:
335
- tool_dict["arguments"] = str(tool_call.function.arguments) if tool_call.function.arguments is not None else ""
 
 
 
 
 
 
 
336
  tool_info["tools_used"].append(tool_dict)
337
 
338
  return tool_info
@@ -341,39 +376,42 @@ IMPORTANT: When you search the web and find information, you MUST:
341
  """Extract URLs from text"""
342
  if not text:
343
  return []
 
344
  url_pattern = r'https?://[^\s<>"]{2,}'
345
- urls = re.findall(url_pattern, text)
346
  return list(set(urls))
347
 
348
  def _enhance_citations(self, content: str, tool_info: Dict) -> str:
349
  """Enhance content with better citation formatting"""
350
  if not content:
351
  return ""
 
352
  if not tool_info or not tool_info.get("sources_found"):
353
- return content
354
 
355
- if "Sources Used:" not in content and "sources:" not in content.lower():
356
  sources_section = "\n\n---\n\n### Sources Used:\n"
357
  for i, url in enumerate(tool_info["sources_found"][:10], 1):
358
- domain = self._extract_domain(url)
359
  sources_section += f"{i}. [{domain}]({url})\n"
360
- content += sources_section
361
 
362
- return content
363
 
364
  def _extract_domain(self, url: str) -> str:
365
  """Extract domain name from URL for display"""
366
  if not url:
367
  return ""
 
368
  try:
369
- if url.startswith(('http://', 'https://')):
370
- domain = url.split('/')[2]
371
  if domain.startswith('www.'):
372
  domain = domain[4:]
373
  return domain
374
- return url
375
  except:
376
- return url
377
 
378
  def get_model_info(self) -> Dict:
379
  """Get information about current model capabilities"""
@@ -401,15 +439,16 @@ async def validate_api_keys(groq_api_key: str, chutes_api_key: str, model: str)
401
  """Validate both Groq and Chutes API keys and initialize AI instance"""
402
  global ai_instance, api_key_status
403
 
404
- # Handle None values
405
- groq_api_key = groq_api_key or ""
406
- chutes_api_key = chutes_api_key or ""
 
407
 
408
- if model == "openai/gpt-oss-20b" and not chutes_api_key:
409
  api_key_status = "Invalid ❌"
410
  return "❌ Please enter a valid Chutes API key for the selected model"
411
 
412
- if model in ["compound-beta", "compound-beta-mini"] and not groq_api_key:
413
  api_key_status = "Invalid ❌"
414
  return "❌ Please enter a valid Groq API key for the selected model"
415
 
@@ -452,6 +491,8 @@ def update_model(model: str) -> str:
452
  """Update the model selection"""
453
  global ai_instance
454
 
 
 
455
  if ai_instance:
456
  ai_instance.model = model
457
  model_info = ai_instance.get_model_info()
@@ -471,6 +512,7 @@ def get_search_options(model: str) -> gr.update:
471
  if not ai_instance:
472
  return gr.update(choices=["none"], value="none")
473
 
 
474
  model_info = ai_instance.available_models.get(model, {})
475
  options = ["none"]
476
 
@@ -497,16 +539,18 @@ async def chat_with_ai(message: str,
497
 
498
  if not ai_instance:
499
  error_msg = "⚠️ Please set your API keys first!"
500
- history.append([message, error_msg])
501
  return history, ""
502
 
503
- if not message or not message.strip():
504
- return history, ""
 
 
 
 
505
 
506
- # Handle None values and empty strings
507
- include_domains = include_domains or ""
508
- exclude_domains = exclude_domains or ""
509
- system_prompt = system_prompt or ""
510
 
511
  include_list = [d.strip() for d in include_domains.split(",") if d.strip()] if include_domains.strip() else []
512
  exclude_list = [d.strip() for d in exclude_domains.split(",") if d.strip()] if exclude_domains.strip() else []
@@ -523,7 +567,7 @@ async def chat_with_ai(message: str,
523
  force_search=force_search
524
  )
525
 
526
- ai_response = response["content"] or "No response received"
527
 
528
  # Add tool usage info for Groq models
529
  if response.get("tool_usage") and ai_instance.model != "openai/gpt-oss-20b":
@@ -537,7 +581,7 @@ async def chat_with_ai(message: str,
537
  tool_summary.append(f"πŸ“„ Sources found: {len(tool_info['sources_found'])}")
538
 
539
  if tool_info.get("tools_used"):
540
- tool_types = [tool.get("tool_type", "unknown") for tool in tool_info["tools_used"]]
541
  unique_types = list(set(tool_types))
542
  tool_summary.append(f"πŸ”§ Tools used: {', '.join(unique_types)}")
543
 
@@ -546,7 +590,7 @@ async def chat_with_ai(message: str,
546
 
547
  # Add search settings info
548
  search_info = []
549
- if response.get("search_type_used") and response["search_type_used"] != "none":
550
  search_info.append(f"πŸ” Search type: {response['search_type_used']}")
551
 
552
  if force_search:
 
16
  """Client for interacting with Chutes API"""
17
 
18
  def __init__(self, api_key: str):
19
+ self.api_key = api_key or ""
20
  self.base_url = "https://llm.chutes.ai/v1"
21
 
22
  async def chat_completions_create(self, **kwargs) -> Dict:
 
58
  chunk_json = json.loads(data)
59
  if "choices" in chunk_json and len(chunk_json["choices"]) > 0:
60
  delta = chunk_json["choices"][0].get("delta", {})
61
+ if "content" in delta and delta["content"]:
62
+ content += str(delta["content"])
63
  except json.JSONDecodeError:
64
  continue
65
 
 
97
  chutes_api_key: Chutes API key
98
  model: Which model to use
99
  """
100
+ self.groq_api_key = str(groq_api_key) if groq_api_key else ""
101
+ self.chutes_api_key = str(chutes_api_key) if chutes_api_key else ""
102
  if not self.groq_api_key and model != "openai/gpt-oss-20b":
103
  raise ValueError("No Groq API key provided")
104
  if not self.chutes_api_key and model == "openai/gpt-oss-20b":
105
  raise ValueError("No Chutes API key provided")
106
 
107
+ self.model = str(model) if model else "compound-beta"
108
  self.groq_client = Groq(api_key=self.groq_api_key) if self.groq_api_key else None
109
  self.chutes_client = ChutesClient(api_key=self.chutes_api_key) if self.chutes_api_key else None
110
  self.conversation_history = []
 
140
  Returns:
141
  AI response with metadata
142
  """
143
+ # Safe string conversion
144
+ message = str(message) if message else ""
145
+ system_prompt = str(system_prompt) if system_prompt else ""
146
+ search_type = str(search_type) if search_type else "auto"
147
+
148
  # Enhanced system prompt for better behavior
149
  if not system_prompt:
150
  if self.model == "openai/gpt-oss-20b":
 
163
 
164
  domain_context = ""
165
  if include_domains and self._supports_web_search():
166
+ safe_domains = [str(d) for d in include_domains if d]
167
+ domain_context = f"\nYou are restricted to searching ONLY these domains: {', '.join(safe_domains)}. Make sure to find and cite sources specifically from these domains."
168
  elif exclude_domains and self._supports_web_search():
169
+ safe_domains = [str(d) for d in exclude_domains if d]
170
+ domain_context = f"\nAvoid searching these domains: {', '.join(safe_domains)}. Search everywhere else on the web."
171
 
172
  search_instruction = ""
173
  if search_type == "browser_search" and self._supports_browser_search():
 
199
  if (include_domains or exclude_domains) and self._supports_web_search():
200
  filter_context = []
201
  if include_domains:
202
+ safe_domains = [str(d) for d in include_domains if d]
203
+ if safe_domains:
204
+ filter_context.append(f"ONLY search these domains: {', '.join(safe_domains)}")
205
  if exclude_domains:
206
+ safe_domains = [str(d) for d in exclude_domains if d]
207
+ if safe_domains:
208
+ filter_context.append(f"EXCLUDE these domains: {', '.join(safe_domains)}")
209
+ if filter_context:
210
+ enhanced_message += f"\n\n[Domain Filtering: {' | '.join(filter_context)}]"
211
 
212
  messages.append({"role": "user", "content": enhanced_message})
213
 
 
221
 
222
  # Add domain filtering for compound models (Groq only)
223
  if self._supports_web_search():
224
+ if include_domains:
225
+ safe_domains = [str(d).strip() for d in include_domains if d and str(d).strip()]
226
+ if safe_domains:
227
+ params["include_domains"] = safe_domains
228
+ if exclude_domains:
229
+ safe_domains = [str(d).strip() for d in exclude_domains if d and str(d).strip()]
230
+ if safe_domains:
231
+ params["exclude_domains"] = safe_domains
232
 
233
  # Add tools only for Groq models that support browser search
234
  tools = []
 
249
  params["stream"] = True
250
  response = await self.chutes_client.chat_completions_create(**params)
251
  # Handle Chutes response
252
+ content = ""
253
+ if response and "choices" in response and response["choices"]:
254
+ message_content = response["choices"][0].get("message", {}).get("content")
255
+ content = str(message_content) if message_content else "No response content"
256
+ else:
257
+ content = "No response received"
258
  tool_calls = None
259
  else:
260
  # Groq API call
261
  params["max_completion_tokens"] = params.pop("max_tokens", None)
262
  response = self.groq_client.chat.completions.create(**params)
263
+ content = ""
264
+ if response and response.choices and response.choices[0].message:
265
+ message_content = response.choices[0].message.content
266
+ content = str(message_content) if message_content else "No response content"
267
+ else:
268
+ content = "No response received"
269
  tool_calls = response.choices[0].message.tool_calls if hasattr(response.choices[0].message, "tool_calls") else None
270
 
271
  # Extract tool usage information
 
328
  if tools:
329
  for tool in tools:
330
  tool_dict = {
331
+ "tool_type": str(getattr(tool, "type", "unknown")),
332
+ "tool_name": str(getattr(tool, "name", "unknown")),
333
  }
334
  if hasattr(tool, "input"):
335
+ tool_input = getattr(tool, "input")
336
+ tool_input_str = str(tool_input) if tool_input is not None else ""
337
+ tool_dict["input"] = tool_input_str
338
  if "search" in tool_dict["tool_name"].lower():
339
+ tool_info["search_queries"].append(tool_input_str)
340
  if hasattr(tool, "output"):
341
+ tool_output = getattr(tool, "output")
342
+ tool_output_str = str(tool_output) if tool_output is not None else ""
343
+ tool_dict["output"] = tool_output_str
344
+ urls = self._extract_urls(tool_output_str)
345
  tool_info["sources_found"].extend(urls)
346
  tool_info["tools_used"].append(tool_dict)
347
 
 
349
  if tool_calls:
350
  for tool_call in tool_calls:
351
  tool_dict = {
352
+ "tool_type": str(getattr(tool_call, "type", "browser_search")),
353
+ "tool_name": "browser_search",
354
+ "tool_id": str(getattr(tool_call, "id", "")) if getattr(tool_call, "id", None) else ""
355
  }
356
+ if hasattr(tool_call, "function") and tool_call.function:
357
+ tool_dict["tool_name"] = str(getattr(tool_call.function, "name", "browser_search"))
358
+ if hasattr(tool_call.function, "arguments"):
359
+ try:
360
+ args_raw = tool_call.function.arguments
361
+ if isinstance(args_raw, str):
362
+ args = json.loads(args_raw)
363
+ else:
364
+ args = args_raw or {}
365
+ tool_dict["arguments"] = args
366
+ if "query" in args:
367
+ tool_info["search_queries"].append(str(args["query"]))
368
+ except:
369
+ args_str = str(args_raw) if args_raw is not None else ""
370
+ tool_dict["arguments"] = args_str
371
  tool_info["tools_used"].append(tool_dict)
372
 
373
  return tool_info
 
376
  """Extract URLs from text"""
377
  if not text:
378
  return []
379
+ text_str = str(text)
380
  url_pattern = r'https?://[^\s<>"]{2,}'
381
+ urls = re.findall(url_pattern, text_str)
382
  return list(set(urls))
383
 
384
  def _enhance_citations(self, content: str, tool_info: Dict) -> str:
385
  """Enhance content with better citation formatting"""
386
  if not content:
387
  return ""
388
+ content_str = str(content)
389
  if not tool_info or not tool_info.get("sources_found"):
390
+ return content_str
391
 
392
+ if "Sources Used:" not in content_str and "sources:" not in content_str.lower():
393
  sources_section = "\n\n---\n\n### Sources Used:\n"
394
  for i, url in enumerate(tool_info["sources_found"][:10], 1):
395
+ domain = self._extract_domain(str(url))
396
  sources_section += f"{i}. [{domain}]({url})\n"
397
+ content_str += sources_section
398
 
399
+ return content_str
400
 
401
  def _extract_domain(self, url: str) -> str:
402
  """Extract domain name from URL for display"""
403
  if not url:
404
  return ""
405
+ url_str = str(url)
406
  try:
407
+ if url_str.startswith(('http://', 'https://')):
408
+ domain = url_str.split('/')[2]
409
  if domain.startswith('www.'):
410
  domain = domain[4:]
411
  return domain
412
+ return url_str
413
  except:
414
+ return url_str
415
 
416
  def get_model_info(self) -> Dict:
417
  """Get information about current model capabilities"""
 
439
  """Validate both Groq and Chutes API keys and initialize AI instance"""
440
  global ai_instance, api_key_status
441
 
442
+ # Handle None values and convert to strings
443
+ groq_api_key = str(groq_api_key) if groq_api_key else ""
444
+ chutes_api_key = str(chutes_api_key) if chutes_api_key else ""
445
+ model = str(model) if model else "compound-beta"
446
 
447
+ if model == "openai/gpt-oss-20b" and not chutes_api_key.strip():
448
  api_key_status = "Invalid ❌"
449
  return "❌ Please enter a valid Chutes API key for the selected model"
450
 
451
+ if model in ["compound-beta", "compound-beta-mini"] and not groq_api_key.strip():
452
  api_key_status = "Invalid ❌"
453
  return "❌ Please enter a valid Groq API key for the selected model"
454
 
 
491
  """Update the model selection"""
492
  global ai_instance
493
 
494
+ model = str(model) if model else "compound-beta"
495
+
496
  if ai_instance:
497
  ai_instance.model = model
498
  model_info = ai_instance.get_model_info()
 
512
  if not ai_instance:
513
  return gr.update(choices=["none"], value="none")
514
 
515
+ model = str(model) if model else "compound-beta"
516
  model_info = ai_instance.available_models.get(model, {})
517
  options = ["none"]
518
 
 
539
 
540
  if not ai_instance:
541
  error_msg = "⚠️ Please set your API keys first!"
542
+ history.append([str(message) if message else "", error_msg])
543
  return history, ""
544
 
545
+ # Convert all inputs to strings and handle None values
546
+ message = str(message) if message else ""
547
+ include_domains = str(include_domains) if include_domains else ""
548
+ exclude_domains = str(exclude_domains) if exclude_domains else ""
549
+ system_prompt = str(system_prompt) if system_prompt else ""
550
+ search_type = str(search_type) if search_type else "auto"
551
 
552
+ if not message.strip():
553
+ return history, ""
 
 
554
 
555
  include_list = [d.strip() for d in include_domains.split(",") if d.strip()] if include_domains.strip() else []
556
  exclude_list = [d.strip() for d in exclude_domains.split(",") if d.strip()] if exclude_domains.strip() else []
 
567
  force_search=force_search
568
  )
569
 
570
+ ai_response = str(response.get("content", "No response received"))
571
 
572
  # Add tool usage info for Groq models
573
  if response.get("tool_usage") and ai_instance.model != "openai/gpt-oss-20b":
 
581
  tool_summary.append(f"πŸ“„ Sources found: {len(tool_info['sources_found'])}")
582
 
583
  if tool_info.get("tools_used"):
584
+ tool_types = [str(tool.get("tool_type", "unknown")) for tool in tool_info["tools_used"]]
585
  unique_types = list(set(tool_types))
586
  tool_summary.append(f"πŸ”§ Tools used: {', '.join(unique_types)}")
587
 
 
590
 
591
  # Add search settings info
592
  search_info = []
593
+ if response.get("search_type_used") and str(response["search_type_used"]) != "none":
594
  search_info.append(f"πŸ” Search type: {response['search_type_used']}")
595
 
596
  if force_search: