rwillats commited on
Commit
22f3054
·
verified ·
1 Parent(s): 9c81739

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. ai_responses_demo.py +45 -32
ai_responses_demo.py CHANGED
@@ -67,7 +67,7 @@ def process_retrieval_text(retrieval_text, user_input):
67
  # Replace these with your actual API keys
68
  ORACLE_API_KEY = "key-HgVH3QX0GkyPKZhS3l3QrnLAqvjR2shrPPb_WK3lmrWHPzeKU"
69
  TOGETHER_API_KEY = "25e1acc0998143afee6b7cb3cb4a9447d39166be767a13a36a22da64234343de"
70
- OPENAI_API_KEY = "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"
71
 
72
  # Try to load from environment variables if available
73
  ORACLE_API_KEY = os.getenv("ORACLE_KEY", ORACLE_API_KEY)
@@ -439,37 +439,50 @@ def get_llama_guard_rating(together_client, user_input, user_response):
439
 
440
  # OpenAI moderation
441
  def get_openai_moderation(openai_client, user_input, user_response):
442
- try:
443
- start_time = time.time()
444
- response = openai_client.moderations.create(input=user_response, model="omni-moderation-latest")
445
- end_time = time.time()
446
-
447
- moderation_result = response.results[0]
448
- flagged = moderation_result.flagged
449
- safety_status = "Unsafe" if flagged else "Safe"
450
- safety_level = "unsafe" if flagged else "safe"
451
-
452
- categories = moderation_result.categories
453
- high_risk_categories = {
454
- category: score
455
- for category, score in vars(categories).items()
456
- if isinstance(score, (int, float)) and score > 0.5
457
- }
458
-
459
- if not flagged and any(score > 0.3 for score in vars(categories).values() if isinstance(score, (int, float))):
460
- safety_level = "warning"
461
-
462
- formatted_result = f"Safety Status: {safety_status}\n"
463
- if high_risk_categories:
464
- formatted_result += "Flagged Categories (Confidence > 0.5):\n"
465
- for category, score in high_risk_categories.items():
466
- formatted_result += f" - {category}: {score:.2f}\n"
467
- else:
468
- formatted_result += "Flagged Categories: None\n"
469
-
470
- return formatted_result, safety_level
471
- except Exception as e:
472
- return f"Safety Status: Error\nError: {str(e)}", "unsafe"
 
 
 
 
 
 
 
 
 
 
 
 
 
473
 
474
  # NEW APPROACH: Instead of modal, show/hide the knowledge directly in the page
475
  def rate_user_interaction(user_input, user_response):
 
67
  # Replace these with your actual API keys
68
  ORACLE_API_KEY = "key-HgVH3QX0GkyPKZhS3l3QrnLAqvjR2shrPPb_WK3lmrWHPzeKU"
69
  TOGETHER_API_KEY = "25e1acc0998143afee6b7cb3cb4a9447d39166be767a13a36a22da64234343de"
70
+ OPENAI_API_KEY = "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"
71
 
72
  # Try to load from environment variables if available
73
  ORACLE_API_KEY = os.getenv("ORACLE_KEY", ORACLE_API_KEY)
 
439
 
440
  # OpenAI moderation
441
  def get_openai_moderation(openai_client, user_input, user_response):
442
+ max_retries = 3
443
+ base_delay = 1 # Start with 1 second delay
444
+
445
+ for attempt in range(max_retries):
446
+ try:
447
+ response = openai_client.moderations.create(input=user_response, model="omni-moderation-latest")
448
+
449
+ moderation_result = response.results[0]
450
+ flagged = moderation_result.flagged
451
+ safety_status = "Unsafe" if flagged else "Safe"
452
+ safety_level = "unsafe" if flagged else "safe"
453
+
454
+ categories = moderation_result.categories
455
+ high_risk_categories = {
456
+ category: score
457
+ for category, score in vars(categories).items()
458
+ if isinstance(score, (int, float)) and score > 0.5
459
+ }
460
+
461
+ if not flagged and any(score > 0.3 for score in vars(categories).values() if isinstance(score, (int, float))):
462
+ safety_level = "warning"
463
+
464
+ formatted_result = f"Safety Status: {safety_status}\n"
465
+ if high_risk_categories:
466
+ formatted_result += "Flagged Categories (Confidence > 0.5):\n"
467
+ for category, score in high_risk_categories.items():
468
+ formatted_result += f" - {category}: {score:.2f}\n"
469
+ else:
470
+ formatted_result += "Flagged Categories: None\n"
471
+
472
+ return formatted_result, safety_level
473
+
474
+ except Exception as e:
475
+ error_message = str(e)
476
+ # If this is a rate limit error (429)
477
+ if "429" in error_message or "Too Many Requests" in error_message:
478
+ if attempt < max_retries - 1: # Don't sleep on the last attempt
479
+ # Calculate delay with exponential backoff and jitter
480
+ delay = base_delay * (2 ** attempt) + random.uniform(0, 0.5)
481
+ time.sleep(delay)
482
+ continue
483
+
484
+ # For non-429 errors or if we've exhausted retries
485
+ return f"Safety Status: Error\nError: {error_message}", "unsafe"
486
 
487
  # NEW APPROACH: Instead of modal, show/hide the knowledge directly in the page
488
  def rate_user_interaction(user_input, user_response):