gabrielchua's picture
update UI
8cad311 unverified
raw
history blame
25.9 kB
"""
simple_demo.py
"""
import os
import gradio as gr
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import gspread
from google.oauth2 import service_account
import json
from datetime import datetime
import uuid
# Local imports
from lionguard2 import LionGuard2, CATEGORIES
from utils import get_embeddings
# Google Sheets configuration
GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL")
GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT")
RESULTS_SHEET_NAME = "results"
VOTES_SHEET_NAME = "votes"
# Helper to save results data
def save_results_data(row):
try:
# Create credentials object
credentials = service_account.Credentials.from_service_account_info(
json.loads(GOOGLE_CREDENTIALS),
scopes=[
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
],
)
# Create authorized client
gc = gspread.authorize(credentials)
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
ws = sheet.worksheet(RESULTS_SHEET_NAME)
ws.append_row(list(row.values()))
print(f"Saved results data for text_id: {row['text_id']}")
except Exception as e:
print(f"Error saving results data: {e}")
# Helper to save vote data
def save_vote_data(text_id, agree):
try:
# Create credentials object
credentials = service_account.Credentials.from_service_account_info(
json.loads(GOOGLE_CREDENTIALS),
scopes=[
"https://www.googleapis.com/auth/spreadsheets",
"https://www.googleapis.com/auth/drive",
],
)
# Create authorized client
gc = gspread.authorize(credentials)
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
ws = sheet.worksheet(VOTES_SHEET_NAME)
vote_row = {
"datetime": datetime.now().isoformat(),
"text_id": text_id,
"agree": agree
}
ws.append_row(list(vote_row.values()))
print(f"Saved vote data for text_id: {text_id}, agree: {agree}")
except Exception as e:
print(f"Error saving vote data: {e}")
def download_model(repo_id, filename="LionGuard2.safetensors", token=None):
"""
Download the LionGuard2 model from a Hugging Face private repository.
Args:
repo_id: The Hugging Face repository ID (e.g., "username/repo-name")
filename: The filename to download (default: "LionGuard2.safetensors")
token: Hugging Face access token for private repositories
Returns:
Path to the downloaded file
"""
if token is None:
token = os.environ.get("HF_API_KEY")
# Download the model file
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
token=token,
cache_dir="./cache"
)
return model_path
def load_model(repo_id=None, use_local=True):
"""
Load the LionGuard2 model from either local file or Hugging Face repository.
Args:
repo_id: The Hugging Face repository ID (optional)
use_local: Whether to use local file first (default: True)
"""
model = LionGuard2()
model.eval()
model_path = "LionGuard2.safetensors"
# Try to download from HF repo if specified and local file doesn't exist or use_local is False
if repo_id and (not use_local or not os.path.exists(model_path)):
try:
print(f"Downloading LionGuard2.safetensors from {repo_id}...")
model_path = download_model(repo_id)
print(f"Model downloaded to: {model_path}")
except Exception as e:
print(f"Failed to download from HF repo: {e}")
if not os.path.exists("LionGuard2.safetensors"):
raise Exception("No local model file found and failed to download from HF repo")
print("Falling back to local file...")
state_dict = load_file(model_path)
model.load_state_dict(state_dict)
return model
# Load model once at startup
HF_REPO_ID = "gabrielchua/refactored-guacamole" # Update this with the actual repo ID
model = load_model(repo_id=HF_REPO_ID)
def format_score_with_style(score_str, compact=False):
"""
Format score with color and emoji based on value.
Args:
score_str: Score as string (could be "-" for missing values)
compact: Whether to use compact styling
Returns:
HTML formatted string with styling
"""
if score_str == "-":
return '<span style="color: #9ca3af;">-</span>'
try:
score = float(score_str)
percentage = int(score * 100)
padding = "6px 12px" if compact else "8px 16px"
if score < 0.4:
# Safe - refined green
return f'<span style="background: linear-gradient(135deg, #065f46 0%, #047857 100%); color: #34d399; padding: {padding}; border-radius: 8px; font-size: 13px; font-weight: 600; border: 1px solid #10b981; box-shadow: 0 2px 4px rgba(0,0,0,0.2);">πŸ‘Œ {percentage}%</span>'
elif 0.4 <= score < 0.7:
# Warning - refined amber
return f'<span style="background: linear-gradient(135deg, #92400e 0%, #b45309 100%); color: #fbbf24; padding: {padding}; border-radius: 8px; font-size: 13px; font-weight: 600; border: 1px solid #f59e0b; box-shadow: 0 2px 4px rgba(0,0,0,0.2);">⚠️ {percentage}%</span>'
else:
# High risk - refined red
return f'<span style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: {padding}; border-radius: 8px; font-size: 13px; font-weight: 600; border: 1px solid #ef4444; box-shadow: 0 2px 4px rgba(0,0,0,0.2);">🚨 {percentage}%</span>'
except:
return score_str
def format_binary_score(score):
"""Format the binary score with appropriate styling for dark mode."""
percentage = int(score * 100)
if score < 0.4:
return f'<div style="background: linear-gradient(135deg, #065f46 0%, #047857 100%); color: #34d399; padding: 24px; border-radius: 12px; text-align: center; font-weight: 700; border: 2px solid #10b981; font-size: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.3); margin: 16px 0;">βœ… Pass ({percentage}/100)</div>'
elif 0.4 <= score < 0.7:
return f'<div style="background: linear-gradient(135deg, #92400e 0%, #b45309 100%); color: #fbbf24; padding: 24px; border-radius: 12px; text-align: center; font-weight: 700; border: 2px solid #f59e0b; font-size: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.3); margin: 16px 0;">⚠️ Warning ({percentage}/100)</div>'
else:
return f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 24px; border-radius: 12px; text-align: center; font-weight: 700; border: 2px solid #ef4444; font-size: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.3); margin: 16px 0;">🚨 Fail ({percentage}/100)</div>'
def analyze_text(text):
"""
Analyze text for content moderation violations.
Args:
text: Input text to analyze
Returns:
binary_score: Overall safety score with styling
category_table: HTML table with category-specific scores and styling
text_id: Unique identifier for this analysis
voting_section: HTML for voting buttons
"""
if not text.strip():
empty_html = '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
return '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>', empty_html, "", ""
try:
# Generate unique text ID
text_id = str(uuid.uuid4())
# Get embeddings for the text
embeddings = get_embeddings([text])
# Run inference
results = model.predict(embeddings)
# Extract binary score (overall safety)
binary_score = results.get('binary', [0.0])[0]
# Extract specific scores for Google Sheets
hateful_scores = CATEGORIES['hateful']
hateful_l1_score = results.get(hateful_scores[0], [0.0])[0] if len(hateful_scores) > 0 else 0.0
hateful_l2_score = results.get(hateful_scores[1], [0.0])[0] if len(hateful_scores) > 1 else 0.0
insults_scores = CATEGORIES['insults']
insults_score = results.get(insults_scores[0], [0.0])[0] if len(insults_scores) > 0 else 0.0
sexual_scores = CATEGORIES['sexual']
sexual_l1_score = results.get(sexual_scores[0], [0.0])[0] if len(sexual_scores) > 0 else 0.0
sexual_l2_score = results.get(sexual_scores[1], [0.0])[0] if len(sexual_scores) > 1 else 0.0
physical_violence_scores = CATEGORIES['physical_violence']
physical_violence_score = results.get(physical_violence_scores[0], [0.0])[0] if len(physical_violence_scores) > 0 else 0.0
self_harm_scores = CATEGORIES['self_harm']
self_harm_l1_score = results.get(self_harm_scores[0], [0.0])[0] if len(self_harm_scores) > 0 else 0.0
self_harm_l2_score = results.get(self_harm_scores[1], [0.0])[0] if len(self_harm_scores) > 1 else 0.0
aom_scores = CATEGORIES['all_other_misconduct']
aom_l1_score = results.get(aom_scores[0], [0.0])[0] if len(aom_scores) > 0 else 0.0
aom_l2_score = results.get(aom_scores[1], [0.0])[0] if len(aom_scores) > 1 else 0.0
# Save results to Google Sheets
if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
results_row = {
"datetime": datetime.now().isoformat(),
"text_id": text_id,
"text": text,
"binary_score": binary_score,
"hateful_l1_score": hateful_l1_score,
"hateful_l2_score": hateful_l2_score,
"insults_score": insults_score,
"sexual_l1_score": sexual_l1_score,
"sexual_l2_score": sexual_l2_score,
"physical_violence_score": physical_violence_score,
"self_harm_l1_score": self_harm_l1_score,
"self_harm_l2_score": self_harm_l2_score,
"aom_l1_score": aom_l1_score,
"aom_l2_score": aom_l2_score
}
save_results_data(results_row)
# Prepare category data with max scores and dropdowns
categories_html = []
# Define the main categories (excluding binary)
main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct']
for category in main_categories:
subcategories = CATEGORIES[category]
category_name = category.replace('_', ' ').title()
# Add emoji to category name based on type
category_emojis = {
'Hateful': '🀬',
'Insults': 'πŸ’’',
'Sexual': 'πŸ”ž',
'Physical Violence': 'βš”οΈ',
'Self Harm': '☹️',
'All Other Misconduct': 'πŸ™…β€β™€οΈ'
}
category_display = f"{category_emojis.get(category_name, 'πŸ“')} {category_name}"
# Get scores for all levels
level_scores = []
for i, subcategory_key in enumerate(subcategories):
score = results.get(subcategory_key, [0.0])[0]
level_scores.append((f"Level {i+1}", score))
# Find max score
max_score = max([score for _, score in level_scores]) if level_scores else 0.0
# Create the row HTML - just show max score
categories_html.append(f'''
<tr style="border-bottom: 1px solid #374151; transition: background-color 0.2s ease;">
<td style="padding: 16px; font-weight: 500; color: #f9fafb; font-size: 15px;">{category_display}</td>
<td style="padding: 16px; text-align: center;">{format_score_with_style(f"{max_score:.4f}")}</td>
</tr>
''')
# Create refined HTML table for dark mode
html_table = f'''
<div style="margin: 24px 0;">
<div style="margin-bottom: 20px; text-align: center;">
<h2 style="color: #f9fafb; font-size: 20px; font-weight: 600; margin-bottom: 6px;">πŸ“Š Category-Specific Scores</h2>
</div>
<div style="background: #1f2937; border-radius: 12px; overflow: hidden; box-shadow: 0 4px 12px rgba(0,0,0,0.3); border: 1px solid #374151;">
<table style="width: 100%; border-collapse: collapse;">
<thead>
<tr style="background: linear-gradient(135deg, #374151 0%, #4b5563 100%);">
<th style="padding: 16px; text-align: left; font-weight: 600; font-size: 15px; color: #f9fafb;">Category</th>
<th style="padding: 16px; text-align: center; font-weight: 600; font-size: 15px; color: #f9fafb;">Score</th>
</tr>
</thead>
<tbody>
{"".join(categories_html)}
</tbody>
</table>
</div>
</div>
'''
# Create voting section
voting_html = f'''
<div style="margin: 24px 0; text-align: center;">
<div style="background: #1f2937; border-radius: 12px; padding: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.3); border: 1px solid #374151;">
<h3 style="color: #f9fafb; font-size: 18px; font-weight: 600; margin-bottom: 12px;">πŸ“Š How accurate are these results?</h3>
<p style="color: #d1d5db; font-size: 14px; margin-bottom: 16px;">Your feedback helps improve the model</p>
</div>
</div>
'''
return format_binary_score(binary_score), html_table, text_id, voting_html
except Exception as e:
error_msg = f"Error analyzing text: {str(e)}"
error_html = f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 20px; border-radius: 12px; text-align: center; border: 2px solid #ef4444; box-shadow: 0 4px 12px rgba(0,0,0,0.3);">❌ {error_msg}</div>'
return f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 16px; border-radius: 8px; text-align: center; border: 1px solid #ef4444;">❌ {error_msg}</div>', error_html, "", ""
# Voting functions
def vote_thumbs_up(text_id):
"""Handle thumbs up vote"""
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
save_vote_data(text_id, True)
return '''
<div style="background: linear-gradient(135deg, #065f46 0%, #047857 100%); color: #34d399; padding: 24px; border-radius: 16px; text-align: center; font-weight: 700; border: 2px solid #10b981; margin: 16px 0; box-shadow: 0 8px 25px rgba(16, 185, 129, 0.4); animation: bounceIn 0.6s ease-out;">
<div style="font-size: 48px; margin-bottom: 12px;">πŸŽ‰</div>
<h3 style="color: #ffffff; font-size: 20px; margin-bottom: 8px; text-shadow: 0 2px 4px rgba(0,0,0,0.3);">
Awesome! Thank you!
</h3>
<p style="color: #a7f3d0; font-size: 14px; margin: 0;">
Your positive feedback helps improve our AI safety models
</p>
</div>
'''
return '<div style="color: #9ca3af; text-align: center; padding: 16px;">Voting not available</div>'
def vote_thumbs_down(text_id):
"""Handle thumbs down vote"""
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
save_vote_data(text_id, False)
return '''
<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 24px; border-radius: 16px; text-align: center; font-weight: 700; border: 2px solid #ef4444; margin: 16px 0; box-shadow: 0 8px 25px rgba(239, 68, 68, 0.4); animation: bounceIn 0.6s ease-out;">
<div style="font-size: 48px; margin-bottom: 12px;">πŸ“</div>
<h3 style="color: #ffffff; font-size: 20px; margin-bottom: 8px; text-shadow: 0 2px 4px rgba(0,0,0,0.3);">
Thanks for the feedback!
</h3>
<p style="color: #fecaca; font-size: 14px; margin: 0;">
Critical feedback helps us identify and fix model weaknesses
</p>
</div>
'''
return '<div style="color: #9ca3af; text-align: center; padding: 16px;">Voting not available</div>'
# Create Gradio interface with dark theme
with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
body_background_fill="*neutral_950",
background_fill_primary="*neutral_900",
background_fill_secondary="*neutral_800",
border_color_primary="*neutral_700",
color_accent_soft="*blue_500"
)) as demo:
gr.HTML("""
<div style="text-align: center; margin-bottom: 40px; padding: 20px;">
</div>
""")
with gr.Row():
with gr.Column(scale=1, min_width=400):
text_input = gr.Textbox(
label="Enter text to analyze:",
placeholder="Type your text here...",
lines=12,
max_lines=20,
container=True
)
analyze_btn = gr.Button("πŸ” Analyze Text", variant="primary")
with gr.Column(scale=1, min_width=400):
gr.HTML("""
<div style="margin-bottom: 24px; text-align: center;">
<h2 style="color: #f9fafb; font-size: 22px; font-weight: 600; margin-bottom: 8px;">Overall Safety Score</h2>
<p style="color: #d1d5db; font-size: 14px; margin: 0; opacity: 0.8;">Higher percentages indicate higher likelihood of harmful content</p>
</div>
""")
binary_output = gr.HTML(
value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
)
category_table = gr.HTML(
value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
)
# Enhanced Voting section - always visible and prominent
with gr.Row():
with gr.Column():
gr.HTML("""
<div style="background: linear-gradient(135deg, #1e40af 0%, #3b82f6 100%); border-radius: 16px; padding: 24px; margin: 20px 0; text-align: center; border: 2px solid #60a5fa; box-shadow: 0 8px 25px rgba(59, 130, 246, 0.3); position: relative; overflow: hidden;">
<div style="position: absolute; top: -50%; left: -50%; width: 200%; height: 200%; background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, transparent 70%); animation: pulse 3s ease-in-out infinite;"></div>
<div style="position: relative; z-index: 1;">
<h3 style="color: #ffffff; font-size: 22px; font-weight: 700; margin-bottom: 12px; text-shadow: 0 2px 4px rgba(0,0,0,0.3);">
🎯 Help Us Improve! Rate the Analysis
</h3>
<p style="color: #dbeafe; font-size: 16px; margin-bottom: 0; font-weight: 500;">
Your feedback trains better AI safety models
</p>
</div>
</div>
""")
voting_instruction = gr.HTML(value="""
<div id="voting-instructions" style="background: #374151; border-radius: 12px; padding: 20px; margin: 16px 0; text-align: center; border: 2px dashed #6b7280; box-shadow: 0 4px 12px rgba(0,0,0,0.2);">
<p style="color: #d1d5db; font-size: 16px; margin: 0; font-weight: 500;">
⬆️ Analyze some text above to unlock voting ⬆️
</p>
</div>
""")
with gr.Row(visible=False) as voting_buttons_row:
with gr.Column(scale=1):
thumbs_up_btn = gr.Button(
"πŸ‘ Results Look Accurate",
variant="primary",
size="lg",
elem_classes=["voting-btn-positive"]
)
with gr.Column(scale=1):
thumbs_down_btn = gr.Button(
"πŸ‘Ž Results Look Wrong",
variant="secondary",
size="lg",
elem_classes=["voting-btn-negative"]
)
vote_feedback = gr.HTML(value="")
# Hidden text_id to track current analysis
current_text_id = gr.Textbox(value="", visible=False)
# Add information about the categories
with gr.Row():
with gr.Accordion("ℹ️ About the Scoring System", open=False):
gr.HTML("""
<div style="font-size: 14px; line-height: 1.6; color: #f3f4f6; padding: 10px;">
<h3 style="color: #f9fafb; margin-bottom: 16px;">How Scoring Works:</h3>
<ul style="color: #d1d5db; margin-bottom: 24px;">
<li><b>Percentages represent likelihood of harmful content</b> - Higher % = More likely to be harmful</li>
<li><b>0-40%:</b> Content appears safe</li>
<li><b>40-70%:</b> Potentially concerning content that warrants review</li>
<li><b>70-100%:</b> High likelihood of policy violation</li>
</ul>
<h3 style="color: #f9fafb; margin-bottom: 16px;">Content Categories (Singapore Context):</h3>
<ul style="color: #d1d5db;">
<li><b>🀬 Hateful:</b> Content targeting Singapore's protected traits (e.g., race, religion), including discriminatory remarks and explicit calls for harm/violence.</li>
<li><b>πŸ’’ Insults:</b> Personal attacks on non-protected attributes (e.g., appearance). Note: Sexuality attacks are classified as insults, not hateful, in Singapore.</li>
<li><b>πŸ”ž Sexual:</b> Sexual content or adult themes, ranging from mild content inappropriate for minors to explicit content inappropriate for general audiences.</li>
<li><b>βš”οΈ Physical Violence:</b> Threats, descriptions, or glorification of physical harm against individuals or groups (not property damage).</li>
<li><b>☹️ Self Harm:</b> Content about self-harm or suicide, including ideation, encouragement, or descriptions of ongoing actions.</li>
<li><b>πŸ™…β€β™€οΈ All Other Misconduct:</b> Unethical/criminal conduct not covered above, from socially condemned behavior to clearly illegal activities under Singapore law.</li>
</ul>
</div>
""")
# Function to handle analysis and show voting buttons
def analyze_and_show_voting(text):
binary_score, category_table, text_id, voting_html = analyze_text(text)
# Show voting buttons and hide instructions if we have results
if text_id:
voting_buttons_update = gr.update(visible=True)
voting_instruction_update = gr.update(value="""
<div style="background: linear-gradient(135deg, #065f46 0%, #047857 100%); border-radius: 12px; padding: 20px; margin: 16px 0; text-align: center; border: 2px solid #10b981; box-shadow: 0 4px 12px rgba(16, 185, 129, 0.3);">
<p style="color: #34d399; font-size: 16px; margin: 0; font-weight: 600;">
βœ… Analysis Complete! Please rate the accuracy below πŸ‘‡
</p>
</div>
""")
else:
voting_buttons_update = gr.update(visible=False)
voting_instruction_update = gr.update(value="""
<div style="background: #374151; border-radius: 12px; padding: 20px; margin: 16px 0; text-align: center; border: 2px dashed #6b7280; box-shadow: 0 4px 12px rgba(0,0,0,0.2);">
<p style="color: #d1d5db; font-size: 16px; margin: 0; font-weight: 500;">
⬆️ Analyze some text above to unlock voting ⬆️
</p>
</div>
""")
return binary_score, category_table, text_id, voting_buttons_update, voting_instruction_update, ""
# Connect the analyze button to the function
analyze_btn.click(
fn=analyze_and_show_voting,
inputs=[text_input],
outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_instruction, vote_feedback]
)
# Allow Enter key to trigger analysis
text_input.submit(
fn=analyze_and_show_voting,
inputs=[text_input],
outputs=[binary_output, category_table, current_text_id, voting_buttons_row, voting_instruction, vote_feedback]
)
# Connect voting buttons
thumbs_up_btn.click(
fn=vote_thumbs_up,
inputs=[current_text_id],
outputs=[vote_feedback]
)
thumbs_down_btn.click(
fn=vote_thumbs_down,
inputs=[current_text_id],
outputs=[vote_feedback]
)
if __name__ == "__main__":
demo.launch(share=True, server_name="0.0.0.0", server_port=7860)