File size: 2,552 Bytes
e54e4ba
 
 
 
 
 
 
 
efec1ff
 
 
e54e4ba
efec1ff
51d1021
efec1ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e54e4ba
efec1ff
 
e54e4ba
 
 
 
 
 
 
efec1ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f85e8f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import streamlit as st 
import pandas as pd
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download 
import os

# Global variables to store loaded models
tokenizer = None
sess = None

def load_models():
    """Load tokenizer and model"""
    global tokenizer, sess
    
    if tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
    
    if sess is None:
        if os.path.exists("model_f16.onnx"):
            st.write("Model already downloaded.")
            model_path = "model_f16.onnx"
        else:
            st.write("Downloading model...")
            model_path = hf_hub_download(
                repo_id="bakhil-aissa/anti_prompt_injection",
                filename="model_f16.onnx",
                local_dir_use_symlinks=False,
            )
        
        sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
    
    return tokenizer, sess

def predict(text):
    """Predict function that uses the loaded models"""
    enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
    inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
    logits = sess.run(["logits"], inputs)[0]
    exp = np.exp(logits)
    probs = exp / exp.sum(axis=1, keepdims=True)        # shape (1, num_classes)
    return probs

def main():
    st.title("Anti Prompt Injection Detection")
    
    # Load models when needed
    global tokenizer, sess
    tokenizer, sess = load_models()
    
    st.subheader("Enter your text to check for prompt injection:")
    text_input = st.text_area("Text Input", height=200)
    confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
    
    if st.button("Check"):
        if text_input:
            try:
                with st.spinner("Processing..."):
                    # Call the predict function
                    probs = predict(text_input)
                jailbreak_prob = float(probs[0][1])  # index into batch
                is_jailbreak = jailbreak_prob >= confidence_threshold
                
                st.success(f"Is Jailbreak: {is_jailbreak}")
                st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
            except Exception as e:
                st.error(f"Error: {str(e)}")
        else:
            st.warning("Please enter some text to check.")

# Only define functions, don't execute anything
# Streamlit will automatically run the script when it's ready