|
import gradio as gr |
|
import numpy as np |
|
import onnxruntime as ort |
|
from transformers import AutoTokenizer |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
|
|
|
|
tokenizer = None |
|
sess = None |
|
|
|
def load_models(): |
|
"""Load tokenizer and model""" |
|
global tokenizer, sess |
|
|
|
if tokenizer is None: |
|
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large') |
|
|
|
if sess is None: |
|
if os.path.exists("model_f16.onnx"): |
|
print("Model already downloaded.") |
|
model_path = "model_f16.onnx" |
|
else: |
|
print("Downloading model...") |
|
model_path = hf_hub_download( |
|
repo_id="bakhil-aissa/anti_prompt_injection", |
|
filename="model_f16.onnx", |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
|
|
|
return tokenizer, sess |
|
|
|
def predict(text, confidence_threshold): |
|
"""Predict function that uses the loaded models""" |
|
if not text.strip(): |
|
return "Please enter some text to check.", 0.0, False |
|
|
|
try: |
|
|
|
load_models() |
|
|
|
|
|
enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048) |
|
inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]} |
|
logits = sess.run(["logits"], inputs)[0] |
|
exp = np.exp(logits) |
|
probs = exp / exp.sum(axis=1, keepdims=True) |
|
|
|
jailbreak_prob = float(probs[0][1]) |
|
is_jailbreak = jailbreak_prob >= confidence_threshold |
|
|
|
result_text = f"Is Jailbreak: {is_jailbreak}" |
|
return result_text, jailbreak_prob, is_jailbreak |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}", 0.0, False |
|
|
|
|
|
def create_interface(): |
|
with gr.Blocks(title="Anti Prompt Injection Detection") as demo: |
|
gr.Markdown("# π« Anti Prompt Injection Detection") |
|
gr.Markdown("Enter your text to check for prompt injection attempts.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox( |
|
label="Text Input", |
|
placeholder="Enter text to analyze...", |
|
lines=5, |
|
max_lines=10 |
|
) |
|
confidence_threshold = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.5, |
|
step=0.01, |
|
label="Confidence Threshold" |
|
) |
|
check_button = gr.Button("Check Text", variant="primary") |
|
|
|
with gr.Column(): |
|
result_text = gr.Textbox(label="Result", interactive=False) |
|
probability = gr.Number(label="Jailbreak Probability", precision=4) |
|
is_jailbreak = gr.Checkbox(label="Is Jailbreak", interactive=False) |
|
|
|
|
|
check_button.click( |
|
fn=predict, |
|
inputs=[text_input, confidence_threshold], |
|
outputs=[result_text, probability, is_jailbreak] |
|
) |
|
|
|
gr.Markdown("---") |
|
gr.Markdown("**How it works:** This tool analyzes text to detect potential prompt injection attempts that could bypass AI safety measures.") |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = create_interface() |
|
demo.launch() |
|
else: |
|
|
|
demo = create_interface() |
|
|