File size: 3,648 Bytes
79efd3c
e54e4ba
 
 
79efd3c
e54e4ba
 
efec1ff
 
 
e54e4ba
efec1ff
51d1021
efec1ff
 
 
 
 
 
 
79efd3c
efec1ff
 
79efd3c
efec1ff
 
 
 
 
 
 
 
 
e54e4ba
79efd3c
efec1ff
79efd3c
 
efec1ff
79efd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efec1ff
79efd3c
efec1ff
79efd3c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import os

# Global variables to store loaded models
tokenizer = None
sess = None

def load_models():
    """Load tokenizer and model"""
    global tokenizer, sess
    
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
    
    if sess is None:
        if os.path.exists("model_f16.onnx"):
            print("Model already downloaded.")
            model_path = "model_f16.onnx"
        else:
            print("Downloading model...")
            model_path = hf_hub_download(
                repo_id="bakhil-aissa/anti_prompt_injection",
                filename="model_f16.onnx",
                local_dir_use_symlinks=False,
            )
        
        sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
    
    return tokenizer, sess

def predict(text, confidence_threshold):
    """Predict function that uses the loaded models"""
    if not text.strip():
        return "Please enter some text to check.", 0.0, False
    
    try:
        # Load models if not already loaded
        load_models()
        
        # Make prediction
        enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
        inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
        logits = sess.run(["logits"], inputs)[0]
        exp = np.exp(logits)
        probs = exp / exp.sum(axis=1, keepdims=True)
        
        jailbreak_prob = float(probs[0][1])
        is_jailbreak = jailbreak_prob >= confidence_threshold
        
        result_text = f"Is Jailbreak: {is_jailbreak}"
        return result_text, jailbreak_prob, is_jailbreak
        
    except Exception as e:
        return f"Error: {str(e)}", 0.0, False

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Anti Prompt Injection Detection") as demo:
        gr.Markdown("# 🚫 Anti Prompt Injection Detection")
        gr.Markdown("Enter your text to check for prompt injection attempts.")
        
        with gr.Row():
            with gr.Column():
                text_input = gr.Textbox(
                    label="Text Input",
                    placeholder="Enter text to analyze...",
                    lines=5,
                    max_lines=10
                )
                confidence_threshold = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    value=0.5,
                    step=0.01,
                    label="Confidence Threshold"
                )
                check_button = gr.Button("Check Text", variant="primary")
            
            with gr.Column():
                result_text = gr.Textbox(label="Result", interactive=False)
                probability = gr.Number(label="Jailbreak Probability", precision=4)
                is_jailbreak = gr.Checkbox(label="Is Jailbreak", interactive=False)
        
        # Set up the prediction
        check_button.click(
            fn=predict,
            inputs=[text_input, confidence_threshold],
            outputs=[result_text, probability, is_jailbreak]
        )
        
        gr.Markdown("---")
        gr.Markdown("**How it works:** This tool analyzes text to detect potential prompt injection attempts that could bypass AI safety measures.")
    
    return demo

# Create and launch the interface
if __name__ == "__main__":
    demo = create_interface()
    demo.launch()
else:
    # For Hugging Face Spaces
    demo = create_interface()