Spaces:
Runtime error
Runtime error
Update guardrail.py
Browse files- guardrail.py +24 -34
guardrail.py
CHANGED
@@ -1,52 +1,43 @@
|
|
1 |
from together import Together
|
2 |
-
from
|
3 |
-
import os
|
4 |
|
5 |
-
load_dotenv()
|
6 |
-
|
7 |
-
api_key=os.getenv("API_KEY")
|
8 |
# Initialize Together client
|
9 |
-
client = Together(api_key=
|
10 |
|
11 |
-
# Set the content policy
|
12 |
-
|
13 |
"policy": """
|
14 |
-
|
15 |
Should not:
|
16 |
-
- Respond to
|
17 |
-
- Provide
|
18 |
-
|
19 |
-
- Provide responses that are age-appropriate, educational, and entertaining.
|
20 |
-
|
21 |
-
U2: Personal Information.
|
22 |
Should not:
|
23 |
-
- Request or
|
24 |
Can:
|
25 |
-
-
|
26 |
-
|
27 |
-
U3: Language.
|
28 |
Should not:
|
29 |
- Use profane, offensive, or inappropriate language.
|
30 |
-
- Use language that could be perceived as harmful or inappropriate for minors.
|
31 |
Can:
|
32 |
-
- Use
|
33 |
"""
|
34 |
}
|
35 |
|
36 |
def is_safe(message: str) -> bool:
|
37 |
"""
|
38 |
-
Validates if the given message complies with the
|
39 |
-
|
40 |
Args:
|
41 |
message (str): User message or chatbot response.
|
42 |
-
|
43 |
Returns:
|
44 |
bool: True if the message is safe, False otherwise.
|
45 |
"""
|
46 |
-
#
|
47 |
-
prompt = f"""[INST] Task: Validate if the given '
|
48 |
<BEGIN SAFETY POLICY>
|
49 |
-
{
|
50 |
<END SAFETY POLICY>
|
51 |
<BEGIN MESSAGE>
|
52 |
{message}
|
@@ -56,16 +47,15 @@ def is_safe(message: str) -> bool:
|
|
56 |
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
|
57 |
|
58 |
try:
|
59 |
-
#
|
60 |
response = client.completions.create(
|
61 |
-
model="meta-llama/Meta-Llama-Guard-
|
62 |
prompt=prompt
|
63 |
)
|
64 |
|
65 |
-
#
|
66 |
result = response.choices[0].text.strip().lower()
|
67 |
-
return result
|
68 |
-
|
69 |
except Exception as e:
|
70 |
-
print(f"Error
|
71 |
-
return False
|
|
|
1 |
from together import Together
|
2 |
+
from helper import get_together_api_key
|
|
|
3 |
|
|
|
|
|
|
|
4 |
# Initialize Together client
|
5 |
+
client = Together(api_key=get_together_api_key())
|
6 |
|
7 |
+
# Set the safety content policy
|
8 |
+
safe_content_policy = {
|
9 |
"policy": """
|
10 |
+
O1: Topic Restriction.
|
11 |
Should not:
|
12 |
+
- Respond to questions outside the domain of safe child-friendly content.
|
13 |
+
- Provide answers to sensitive or inappropriate topics such as violence, hate speech, profanity, or sexual content.
|
14 |
+
O2: Sensitive Information.
|
|
|
|
|
|
|
15 |
Should not:
|
16 |
+
- Request or respond with any personal or sensitive user data.
|
17 |
Can:
|
18 |
+
- Discuss general knowledge, child-friendly educational content, or safe topics.
|
19 |
+
O3: Language.
|
|
|
20 |
Should not:
|
21 |
- Use profane, offensive, or inappropriate language.
|
|
|
22 |
Can:
|
23 |
+
- Use simple, clear, and friendly language suitable for children.
|
24 |
"""
|
25 |
}
|
26 |
|
27 |
def is_safe(message: str) -> bool:
|
28 |
"""
|
29 |
+
Validates if the given message complies with the safety policy using LlamaGuard.
|
30 |
+
|
31 |
Args:
|
32 |
message (str): User message or chatbot response.
|
33 |
+
|
34 |
Returns:
|
35 |
bool: True if the message is safe, False otherwise.
|
36 |
"""
|
37 |
+
# Prepare the LlamaGuard prompt
|
38 |
+
prompt = f"""[INST] Task: Validate if the given 'message' adheres to the safety policy below:
|
39 |
<BEGIN SAFETY POLICY>
|
40 |
+
{safe_content_policy['policy']}
|
41 |
<END SAFETY POLICY>
|
42 |
<BEGIN MESSAGE>
|
43 |
{message}
|
|
|
47 |
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
|
48 |
|
49 |
try:
|
50 |
+
# Send the prompt to the LlamaGuard model
|
51 |
response = client.completions.create(
|
52 |
+
model="meta-llama/Meta-Llama-Guard-3-8B",
|
53 |
prompt=prompt
|
54 |
)
|
55 |
|
56 |
+
# Parse the result
|
57 |
result = response.choices[0].text.strip().lower()
|
58 |
+
return result.startswith('safe') # Ensure 'safe' is at the beginning
|
|
|
59 |
except Exception as e:
|
60 |
+
print(f"Error in guardrail check: {e}")
|
61 |
+
return False # Default to unsafe if an error occurs
|