Spaces:
Sleeping
Sleeping
import os | |
import json | |
import numpy as np | |
from tokenizers import Tokenizer | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
class ONNXInferencePipeline: | |
def __init__(self, repo_id): | |
# Retrieve the Hugging Face token from the environment variable | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token is None: | |
raise ValueError("HF_TOKEN environment variable is not set.") | |
# Load banned keywords list | |
self.banned_keywords = self.load_banned_keywords() | |
print(f"Loaded {len(self.banned_keywords)} banned keywords") | |
# Download files from Hugging Face Hub using the token | |
self.onnx_path = hf_hub_download(repo_id=repo_id, filename="model.onnx", use_auth_token=hf_token) | |
self.tokenizer_path = hf_hub_download(repo_id=repo_id, filename="train_bpe_tokenizer.json", use_auth_token=hf_token) | |
self.config_path = hf_hub_download(repo_id=repo_id, filename="hyperparameters.json", use_auth_token=hf_token) | |
# Load configuration | |
with open(self.config_path) as f: | |
self.config = json.load(f) | |
# Initialize tokenizer | |
self.tokenizer = Tokenizer.from_file(self.tokenizer_path) | |
self.max_len = 256 | |
# Initialize ONNX runtime session | |
self.session = ort.InferenceSession(self.onnx_path) | |
self.providers = ['CPUExecutionProvider'] # Use CUDA if available | |
if 'CUDAExecutionProvider' in ort.get_available_providers(): | |
self.providers = ['CUDAExecutionProvider'] | |
self.session.set_providers(self.providers) | |
def load_banned_keywords(self): | |
# For testing purposes, using a small list | |
# In production, load your full list | |
return [ | |
"fuck", "shit", "bitch", "cunt", "asshole", "faggot", "nigger", | |
"bastard", "damn", "crap", "ass", "dick", "piss", "wanker", | |
"whore", "slut", "motherfucker", "son of a bitch", "kill yourself", | |
"twat", "idiot", "stupid", "retard", "dumb" | |
# Add more keywords or load from your full list | |
] | |
def contains_banned_keyword(self, text): | |
"""Check if the input text contains any banned keywords.""" | |
text_lower = text.lower() | |
for keyword in self.banned_keywords: | |
if keyword in text_lower: | |
# Print which keyword was found (for debugging) | |
print(f"Keyword detected: '{keyword}'") | |
return True | |
print("Keywords Passed - No inappropriate keywords found") | |
return False | |
def preprocess(self, text): | |
encoding = self.tokenizer.encode(text) | |
ids = encoding.ids[:self.max_len] | |
padding = [0] * (self.max_len - len(ids)) | |
return np.array(ids + padding, dtype=np.int64).reshape(1, -1) | |
def predict(self, text): | |
print(f"\nProcessing input: '{text[:50]}...' ({len(text)} characters)") | |
# First check if the text contains any banned keywords | |
if self.contains_banned_keyword(text): | |
print("Input rejected by keyword filter") | |
return { | |
'label': 'Inappropriate Content', | |
'probabilities': [1.0, 0.0] # Assuming [inappropriate, appropriate] | |
} | |
# If no banned keywords found, proceed with model prediction | |
print("Running ML model for classification...") | |
# Preprocess | |
input_array = self.preprocess(text) | |
# Run inference | |
results = self.session.run( | |
None, | |
{'input': input_array} | |
) | |
# Post-process | |
logits = results[0] | |
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) | |
predicted_class = int(np.argmax(probabilities)) | |
# Map to labels | |
class_labels = ['Inappropriate Content', 'Appropriate'] | |
# Log model result | |
print(f"Model Passed - Result: {class_labels[predicted_class]} (Confidence: {probabilities[0][predicted_class]:.2%})") | |
return { | |
'label': class_labels[predicted_class], | |
'probabilities': probabilities[0].tolist() | |
} | |
# Example usage | |
if __name__ == "__main__": | |
# Initialize the pipeline with the Hugging Face repository ID | |
print("Initializing content filter pipeline...") | |
pipeline = ONNXInferencePipeline(repo_id="iimran/abuse-detector") | |
print("Pipeline initialized successfully") | |
# Example texts for testing | |
example_texts = [ | |
"You're a worthless piece of garbage who should die", | |
"Hello HR, I hope this message finds you well. I'm writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team." | |
] | |
for text in example_texts: | |
result = pipeline.predict(text) | |
print(f"Input: {text[:50]}...") | |
print(f"Prediction: {result['label']} ") | |
print("-" * 80) | |
# Define a function for Gradio to use | |
def gradio_predict(text): | |
result = pipeline.predict(text) | |
return ( | |
f"Prediction: {result['label']} \n" | |
) | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=gradio_predict, | |
inputs=gr.Textbox(lines=7, placeholder="Enter text here..."), | |
outputs="text", | |
title="Abuse Detector - Offensive Language Detector", | |
description=( | |
"Abuse detector is designed to identify inappropriate content in text. " | |
"It analyzes input for Australian Slang language and abuses. " | |
"While it's trained on a compact dataset and may not catch highly nuanced or sophisticated language, " | |
"it effectively detects day-to-day offensive language commonly used in conversations." | |
), | |
examples=[ | |
# Explicitly offensive examples | |
"Congrats, you fuckbrain arsehole, you've outdone yourself in stupidity. A real cock-up of a human—should we clap for your bollocks-faced greatness or just pity you?", | |
"You're a mad bastard, but I'd still grab a beer with you! Fuck around all you like, you cockheaded legend—your arsehole antics are bloody brilliant.", | |
"Your mother should have done better raising such a useless idiot.", | |
# Neutral or appropriate examples | |
"Hello HR, I hope this message finds you well. I'm writing to express my gratitude for the opportunity to interview for the Financial Analyst position last week. It was a pleasure to meet you and learn more about the role and your team.", | |
"Thank you for your time and consideration. Please don't hesitate to reach out if you need additional information—I'd be happy to discuss further. Looking forward to hearing from you soon!", | |
"The weather today is lovely, and I'm looking forward to a productive day at work.", | |
# Mixed examples (some offensive, some neutral) | |
"I appreciate your help, but honestly, you're such a clueless idiot sometimes. Still, thanks for trying." | |
] | |
) | |
# Launch the Gradio app | |
print("Launching Gradio interface...") | |
iface.launch() |