pratikshahp commited on
Commit
00f554d
·
verified ·
1 Parent(s): 2a9a68c

Update guardrail.py

Browse files
Files changed (1) hide show
  1. guardrail.py +24 -34
guardrail.py CHANGED
@@ -1,52 +1,43 @@
1
  from together import Together
2
- from dotenv import load_dotenv
3
- import os
4
 
5
- load_dotenv()
6
-
7
- api_key=os.getenv("API_KEY")
8
  # Initialize Together client
9
- client = Together(api_key=api_key)
10
 
11
- # Set the content policy for the under-16 chatbot guardrails
12
- under_16_content_policy = {
13
  "policy": """
14
- U1: Age-Appropriate Content.
15
  Should not:
16
- - Respond to any topics that are inappropriate for individuals under 16, including violence, sexual content, and substance use.
17
- - Provide information or advice that is unsuitable for children.
18
- Can:
19
- - Provide responses that are age-appropriate, educational, and entertaining.
20
-
21
- U2: Personal Information.
22
  Should not:
23
- - Request or handle personal information, including full names, addresses, contact numbers, or sensitive personal details.
24
  Can:
25
- - Interact generically without requiring any personal details.
26
-
27
- U3: Language.
28
  Should not:
29
  - Use profane, offensive, or inappropriate language.
30
- - Use language that could be perceived as harmful or inappropriate for minors.
31
  Can:
32
- - Use clear, respectful, simple, and encouraging language suitable for children under 16.
33
  """
34
  }
35
 
36
  def is_safe(message: str) -> bool:
37
  """
38
- Validates if the given message complies with the under-16 safety policy using the LlamaGuard model.
39
-
40
  Args:
41
  message (str): User message or chatbot response.
42
-
43
  Returns:
44
  bool: True if the message is safe, False otherwise.
45
  """
46
- # Construct the safety validation prompt
47
- prompt = f"""[INST] Task: Validate if the given 'user' or 'agent' message adheres to the safety policy below:
48
  <BEGIN SAFETY POLICY>
49
- {under_16_content_policy['policy']}
50
  <END SAFETY POLICY>
51
  <BEGIN MESSAGE>
52
  {message}
@@ -56,16 +47,15 @@ def is_safe(message: str) -> bool:
56
  - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
57
 
58
  try:
59
- # Use TogetherAI's LlamaGuard model to validate the content
60
  response = client.completions.create(
61
- model="meta-llama/Meta-Llama-Guard-2-8B",
62
  prompt=prompt
63
  )
64
 
65
- # Extract the response result
66
  result = response.choices[0].text.strip().lower()
67
- return result == 'safe'
68
-
69
  except Exception as e:
70
- print(f"Error during safety validation: {e}")
71
- return False
 
1
  from together import Together
2
+ from helper import get_together_api_key
 
3
 
 
 
 
4
  # Initialize Together client
5
+ client = Together(api_key=get_together_api_key())
6
 
7
+ # Set the safety content policy
8
+ safe_content_policy = {
9
  "policy": """
10
+ O1: Topic Restriction.
11
  Should not:
12
+ - Respond to questions outside the domain of safe child-friendly content.
13
+ - Provide answers to sensitive or inappropriate topics such as violence, hate speech, profanity, or sexual content.
14
+ O2: Sensitive Information.
 
 
 
15
  Should not:
16
+ - Request or respond with any personal or sensitive user data.
17
  Can:
18
+ - Discuss general knowledge, child-friendly educational content, or safe topics.
19
+ O3: Language.
 
20
  Should not:
21
  - Use profane, offensive, or inappropriate language.
 
22
  Can:
23
+ - Use simple, clear, and friendly language suitable for children.
24
  """
25
  }
26
 
27
  def is_safe(message: str) -> bool:
28
  """
29
+ Validates if the given message complies with the safety policy using LlamaGuard.
30
+
31
  Args:
32
  message (str): User message or chatbot response.
33
+
34
  Returns:
35
  bool: True if the message is safe, False otherwise.
36
  """
37
+ # Prepare the LlamaGuard prompt
38
+ prompt = f"""[INST] Task: Validate if the given 'message' adheres to the safety policy below:
39
  <BEGIN SAFETY POLICY>
40
+ {safe_content_policy['policy']}
41
  <END SAFETY POLICY>
42
  <BEGIN MESSAGE>
43
  {message}
 
47
  - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
48
 
49
  try:
50
+ # Send the prompt to the LlamaGuard model
51
  response = client.completions.create(
52
+ model="meta-llama/Meta-Llama-Guard-3-8B",
53
  prompt=prompt
54
  )
55
 
56
+ # Parse the result
57
  result = response.choices[0].text.strip().lower()
58
+ return result.startswith('safe') # Ensure 'safe' is at the beginning
 
59
  except Exception as e:
60
+ print(f"Error in guardrail check: {e}")
61
+ return False # Default to unsafe if an error occurs