Spaces:
Running
Running
import os | |
import json | |
import numpy as np | |
from tokenizers import Tokenizer | |
import onnxruntime as ort | |
from huggingface_hub import hf_hub_download | |
import gradio as gr | |
class ONNXInferencePipeline: | |
def __init__(self, repo_id): | |
# Retrieve the Hugging Face token from the environment variable | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token is None: | |
raise ValueError("HF_TOKEN environment variable is not set.") | |
# Download files from Hugging Face Hub using the token | |
self.onnx_path = hf_hub_download(repo_id=repo_id, filename="RudeRater.onnx", use_auth_token=hf_token) | |
self.tokenizer_path = hf_hub_download(repo_id=repo_id, filename="train_bpe_tokenizer.json", use_auth_token=hf_token) | |
self.config_path = hf_hub_download(repo_id=repo_id, filename="hyperparameters.json", use_auth_token=hf_token) | |
# Load configuration | |
with open(self.config_path) as f: | |
self.config = json.load(f) | |
# Initialize tokenizer | |
self.tokenizer = Tokenizer.from_file(self.tokenizer_path) | |
self.max_len = self.config["tokenizer"]["max_len"] | |
# Initialize ONNX runtime session | |
self.session = ort.InferenceSession(self.onnx_path) | |
self.providers = ['CPUExecutionProvider'] # Use CUDA if available | |
if 'CUDAExecutionProvider' in ort.get_available_providers(): | |
self.providers = ['CUDAExecutionProvider'] | |
self.session.set_providers(self.providers) | |
def preprocess(self, text): | |
encoding = self.tokenizer.encode(text) | |
ids = encoding.ids[:self.max_len] | |
padding = [0] * (self.max_len - len(ids)) | |
return np.array(ids + padding, dtype=np.int64).reshape(1, -1) | |
def predict(self, text): | |
# Preprocess | |
input_array = self.preprocess(text) | |
# Run inference | |
results = self.session.run( | |
None, | |
{'input': input_array} | |
) | |
# Post-process | |
logits = results[0] | |
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) | |
predicted_class = int(np.argmax(probabilities)) | |
# Map to labels | |
class_labels = ['Inappropriate Content', 'Not Inappropriate'] | |
return { | |
'label': class_labels[predicted_class], | |
'confidence': float(probabilities[0][predicted_class]), | |
'probabilities': probabilities[0].tolist() | |
} | |
# Example usage | |
if __name__ == "__main__": | |
# Initialize the pipeline with the Hugging Face repository ID | |
pipeline = ONNXInferencePipeline(repo_id="iimran/RudeRater") | |
# Example texts for testing | |
example_texts = [ | |
"This content contains explicit language and violent threats", | |
"The weather today is pleasant and suitable for all ages", | |
"You're a worthless piece of garbage who should die", | |
"Please remember to submit your reports by Friday" | |
] | |
for text in example_texts: | |
result = pipeline.predict(text) | |
print(f"Input: {text}") | |
print(f"Prediction: {result['label']} ({result['confidence']:.2%})") | |
print(f"Probabilities: Inappropriate={result['probabilities'][0]:.2%}, Not Inappropriate={result['probabilities'][1]:.2%}") | |
print("-" * 80) | |
# Define a function for Gradio to use | |
def gradio_predict(text): | |
result = pipeline.predict(text) | |
return ( | |
f"Prediction: {result['label']} ({result['confidence']:.2%})\n" | |
f"Probabilities: Inappropriate={result['probabilities'][0]:.2%}, Not Inappropriate={result['probabilities'][1]:.2%}" | |
) | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=gradio_predict, | |
inputs=gr.Textbox(lines=7, placeholder="Enter text here..."), | |
outputs="text", | |
title="RudeRater - Content Appropriateness Classifier", | |
description="RudeRater is designed to identify inappropriate content in text. It analyzes input for offensive language, explicit content, or harmful material. Enter text to check its appropriateness.", | |
examples=[ | |
"This is completely unacceptable behavior and I'll make sure you regret it", | |
"The community guidelines clearly prohibit any form of discrimination", | |
"Your mother should have done better raising such a useless idiot", | |
"We appreciate your feedback and will improve our services", | |
"I'm going to find you and make you pay for what you've done", | |
"The park maintenance schedule has been updated for summer" | |
] | |
) | |
# Launch the Gradio app | |
iface.launch() |