chatbot-with-guardrail / guardrail.py
pratikshahp's picture
Update guardrail.py
c655ae5 verified
raw
history blame
2.43 kB
from together import Together
from dotenv import load_dotenv
import os
load_dotenv()
api_key=os.getenv(api_key)
# Initialize Together client
client = Together(api_key=api_key)
# Set the content policy for the under-16 chatbot guardrails
under_16_content_policy = {
"policy": """
U1: Age-Appropriate Content.
Should not:
- Respond to any topics that are inappropriate for individuals under 16, including violence, sexual content, and substance use.
- Provide information or advice that is unsuitable for children.
Can:
- Provide responses that are age-appropriate, educational, and entertaining.
U2: Personal Information.
Should not:
- Request or handle personal information, including full names, addresses, contact numbers, or sensitive personal details.
Can:
- Interact generically without requiring any personal details.
U3: Language.
Should not:
- Use profane, offensive, or inappropriate language.
- Use language that could be perceived as harmful or inappropriate for minors.
Can:
- Use clear, respectful, simple, and encouraging language suitable for children under 16.
"""
}
def is_safe(message: str) -> bool:
"""
Validates if the given message complies with the under-16 safety policy using the LlamaGuard model.
Args:
message (str): User message or chatbot response.
Returns:
bool: True if the message is safe, False otherwise.
"""
# Construct the safety validation prompt
prompt = f"""[INST] Task: Validate if the given 'user' or 'agent' message adheres to the safety policy below:
<BEGIN SAFETY POLICY>
{under_16_content_policy['policy']}
<END SAFETY POLICY>
<BEGIN MESSAGE>
{message}
<END MESSAGE>
Provide your safety assessment for the above message:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]"""
try:
# Use TogetherAI's LlamaGuard model to validate the content
response = client.completions.create(
model="meta-llama/Meta-Llama-Guard-2-8B",
prompt=prompt
)
# Extract the response result
result = response.choices[0].text.strip().lower()
return result == 'safe'
except Exception as e:
print(f"Error during safety validation: {e}")
return False