import gradio as gr import numpy as np import onnxruntime as ort from transformers import AutoTokenizer from huggingface_hub import hf_hub_download import os # Global variables to store loaded models tokenizer = None sess = None def load_models(): """Load tokenizer and model""" global tokenizer, sess if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large') if sess is None: if os.path.exists("model_f16.onnx"): print("Model already downloaded.") model_path = "model_f16.onnx" else: print("Downloading model...") model_path = hf_hub_download( repo_id="bakhil-aissa/anti_prompt_injection", filename="model_f16.onnx", local_dir_use_symlinks=False, ) sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) return tokenizer, sess def predict(text, confidence_threshold): """Predict function that uses the loaded models""" if not text.strip(): return "Please enter some text to check.", 0.0, False try: # Load models if not already loaded load_models() # Make prediction enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048) inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]} logits = sess.run(["logits"], inputs)[0] exp = np.exp(logits) probs = exp / exp.sum(axis=1, keepdims=True) jailbreak_prob = float(probs[0][1]) is_jailbreak = jailbreak_prob >= confidence_threshold result_text = f"Is Jailbreak: {is_jailbreak}" return result_text, jailbreak_prob, is_jailbreak except Exception as e: return f"Error: {str(e)}", 0.0, False # Create Gradio interface def create_interface(): with gr.Blocks(title="Anti Prompt Injection Detection") as demo: gr.Markdown("# 🚫 Anti Prompt Injection Detection") gr.Markdown("Enter your text to check for prompt injection attempts.") with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Text Input", placeholder="Enter text to analyze...", lines=5, max_lines=10 ) confidence_threshold = gr.Slider( minimum=0.0, maximum=1.0, value=0.5, step=0.01, label="Confidence Threshold" ) check_button = gr.Button("Check Text", variant="primary") with gr.Column(): result_text = gr.Textbox(label="Result", interactive=False) probability = gr.Number(label="Jailbreak Probability", precision=4) is_jailbreak = gr.Checkbox(label="Is Jailbreak", interactive=False) # Set up the prediction check_button.click( fn=predict, inputs=[text_input, confidence_threshold], outputs=[result_text, probability, is_jailbreak] ) gr.Markdown("---") gr.Markdown("**How it works:** This tool analyzes text to detect potential prompt injection attempts that could bypass AI safety measures.") return demo # Create and launch the interface if __name__ == "__main__": demo = create_interface() demo.launch() else: # For Hugging Face Spaces demo = create_interface()