File size: 3,648 Bytes
79efd3c e54e4ba 79efd3c e54e4ba efec1ff e54e4ba efec1ff 51d1021 efec1ff 79efd3c efec1ff 79efd3c efec1ff e54e4ba 79efd3c efec1ff 79efd3c efec1ff 79efd3c efec1ff 79efd3c efec1ff 79efd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import gradio as gr
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import os
# Global variables to store loaded models
tokenizer = None
sess = None
def load_models():
"""Load tokenizer and model"""
global tokenizer, sess
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
if sess is None:
if os.path.exists("model_f16.onnx"):
print("Model already downloaded.")
model_path = "model_f16.onnx"
else:
print("Downloading model...")
model_path = hf_hub_download(
repo_id="bakhil-aissa/anti_prompt_injection",
filename="model_f16.onnx",
local_dir_use_symlinks=False,
)
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
return tokenizer, sess
def predict(text, confidence_threshold):
"""Predict function that uses the loaded models"""
if not text.strip():
return "Please enter some text to check.", 0.0, False
try:
# Load models if not already loaded
load_models()
# Make prediction
enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
logits = sess.run(["logits"], inputs)[0]
exp = np.exp(logits)
probs = exp / exp.sum(axis=1, keepdims=True)
jailbreak_prob = float(probs[0][1])
is_jailbreak = jailbreak_prob >= confidence_threshold
result_text = f"Is Jailbreak: {is_jailbreak}"
return result_text, jailbreak_prob, is_jailbreak
except Exception as e:
return f"Error: {str(e)}", 0.0, False
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Anti Prompt Injection Detection") as demo:
gr.Markdown("# 🚫 Anti Prompt Injection Detection")
gr.Markdown("Enter your text to check for prompt injection attempts.")
with gr.Row():
with gr.Column():
text_input = gr.Textbox(
label="Text Input",
placeholder="Enter text to analyze...",
lines=5,
max_lines=10
)
confidence_threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.01,
label="Confidence Threshold"
)
check_button = gr.Button("Check Text", variant="primary")
with gr.Column():
result_text = gr.Textbox(label="Result", interactive=False)
probability = gr.Number(label="Jailbreak Probability", precision=4)
is_jailbreak = gr.Checkbox(label="Is Jailbreak", interactive=False)
# Set up the prediction
check_button.click(
fn=predict,
inputs=[text_input, confidence_threshold],
outputs=[result_text, probability, is_jailbreak]
)
gr.Markdown("---")
gr.Markdown("**How it works:** This tool analyzes text to detect potential prompt injection attempts that could bypass AI safety measures.")
return demo
# Create and launch the interface
if __name__ == "__main__":
demo = create_interface()
demo.launch()
else:
# For Hugging Face Spaces
demo = create_interface()
|