rwillats commited on
Commit
7d63613
·
verified ·
1 Parent(s): 2a7b37a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. ai_responses_demo.py +13 -37
ai_responses_demo.py CHANGED
@@ -440,8 +440,9 @@ def get_llama_guard_rating(together_client, user_input, user_response):
440
  # OpenAI moderation
441
  def get_openai_moderation(openai_client, user_input, user_response):
442
  try:
443
- # Make the API call
444
  response = openai_client.moderations.create(input=user_response, model="omni-moderation-latest")
 
445
 
446
  moderation_result = response.results[0]
447
  flagged = moderation_result.flagged
@@ -449,63 +450,38 @@ def get_openai_moderation(openai_client, user_input, user_response):
449
  safety_level = "unsafe" if flagged else "safe"
450
 
451
  categories = moderation_result.categories
452
- high_risk_categories = {}
453
-
454
- # Check each category explicitly
455
- for cat_name in ["hate", "hate_threatening", "harassment", "harassment_threatening",
456
- "self_harm", "self_harm_intent", "self_harm_instructions",
457
- "violence", "violence_graphic", "sexual", "sexual_minors",
458
- "illicit", "illicit_violent"]:
459
- if hasattr(categories, cat_name):
460
- score = getattr(categories, cat_name)
461
- if score > 0.5:
462
- high_risk_categories[cat_name] = score
463
 
464
- # Check for warning level
465
- if not flagged:
466
- for cat_name in ["hate", "hate_threatening", "harassment", "harassment_threatening",
467
- "self_harm", "self_harm_intent", "self_harm_instructions",
468
- "violence", "violence_graphic", "sexual", "sexual_minors",
469
- "illicit", "illicit_violent"]:
470
- if hasattr(categories, cat_name):
471
- score = getattr(categories, cat_name)
472
- if score > 0.3:
473
- safety_level = "warning"
474
- break
475
 
476
  formatted_result = f"Safety Status: {safety_status}\n"
477
  if high_risk_categories:
478
  formatted_result += "Flagged Categories (Confidence > 0.5):\n"
479
  for category, score in high_risk_categories.items():
480
- formatted_result += f" - {category.replace('_', ' ').title()}: {score:.2f}\n"
481
  else:
482
  formatted_result += "Flagged Categories: None\n"
483
 
484
  return formatted_result, safety_level
485
-
486
  except Exception as e:
487
- error_msg = str(e)
488
-
489
- # Handle rate limit errors with a more user-friendly message
490
- if "429" in error_msg or "Too Many Requests" in error_msg:
491
- return (
492
- "OpenAI Moderation temporarily unavailable due to rate limiting.\n"
493
- "Please try again in a few minutes.",
494
- "warning"
495
- )
496
- else:
497
- return f"OpenAI Moderation unavailable.\nError: {error_msg[:100]}...", "warning"
498
 
499
  # NEW APPROACH: Instead of modal, show/hide the knowledge directly in the page
500
  def rate_user_interaction(user_input, user_response):
501
  # Initialize APIs with hardcoded keys
502
  contextual_api = ContextualAPIUtils(api_key=ORACLE_API_KEY)
503
  together_client = Together(api_key=TOGETHER_API_KEY)
 
504
 
505
  # Get ratings
506
  llama_rating, llama_safety = get_llama_guard_rating(together_client, user_input, user_response)
507
  contextual_rating, contextual_retrieval, contextual_safety = get_contextual_rating(contextual_api, user_input, user_response)
508
- openai_rating, openai_safety = get_openai_moderation(OPENAI_CLIENT, user_input, user_response)
509
 
510
  # Format responses carefully to avoid random line breaks
511
  llama_rating = re.sub(r'\.(?=\s+[A-Z])', '.\n', llama_rating)
 
440
  # OpenAI moderation
441
  def get_openai_moderation(openai_client, user_input, user_response):
442
  try:
443
+ start_time = time.time()
444
  response = openai_client.moderations.create(input=user_response, model="omni-moderation-latest")
445
+ end_time = time.time()
446
 
447
  moderation_result = response.results[0]
448
  flagged = moderation_result.flagged
 
450
  safety_level = "unsafe" if flagged else "safe"
451
 
452
  categories = moderation_result.categories
453
+ high_risk_categories = {
454
+ category: score
455
+ for category, score in vars(categories).items()
456
+ if isinstance(score, (int, float)) and score > 0.5
457
+ }
 
 
 
 
 
 
458
 
459
+ if not flagged and any(score > 0.3 for score in vars(categories).values() if isinstance(score, (int, float))):
460
+ safety_level = "warning"
 
 
 
 
 
 
 
 
 
461
 
462
  formatted_result = f"Safety Status: {safety_status}\n"
463
  if high_risk_categories:
464
  formatted_result += "Flagged Categories (Confidence > 0.5):\n"
465
  for category, score in high_risk_categories.items():
466
+ formatted_result += f" - {category}: {score:.2f}\n"
467
  else:
468
  formatted_result += "Flagged Categories: None\n"
469
 
470
  return formatted_result, safety_level
 
471
  except Exception as e:
472
+ return f"Safety Status: Error\nError: {str(e)}", "unsafe"
 
 
 
 
 
 
 
 
 
 
473
 
474
  # NEW APPROACH: Instead of modal, show/hide the knowledge directly in the page
475
  def rate_user_interaction(user_input, user_response):
476
  # Initialize APIs with hardcoded keys
477
  contextual_api = ContextualAPIUtils(api_key=ORACLE_API_KEY)
478
  together_client = Together(api_key=TOGETHER_API_KEY)
479
+ openai_client = openai.OpenAI(api_key=OPENAI_API_KEY)
480
 
481
  # Get ratings
482
  llama_rating, llama_safety = get_llama_guard_rating(together_client, user_input, user_response)
483
  contextual_rating, contextual_retrieval, contextual_safety = get_contextual_rating(contextual_api, user_input, user_response)
484
+ openai_rating, openai_safety = get_openai_moderation(openai_client, user_input, user_response)
485
 
486
  # Format responses carefully to avoid random line breaks
487
  llama_rating = re.sub(r'\.(?=\s+[A-Z])', '.\n', llama_rating)