Upload app.py
Browse files
app.py
CHANGED
@@ -1,35 +1,41 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
3 |
import pandas as pd
|
4 |
import numpy as np
|
5 |
import onnxruntime as ort
|
6 |
from transformers import AutoTokenizer
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
-
|
9 |
-
|
10 |
import os
|
11 |
|
|
|
|
|
|
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
if os.path.exists("model_f16.onnx"):
|
17 |
-
st.write("Model already downloaded.")
|
18 |
-
else:
|
19 |
-
st.write("Downloading model...")
|
20 |
-
model_path = hf_hub_download(
|
21 |
-
repo_id="bakhil-aissa/anti_prompt_injection",
|
22 |
-
filename="model_f16.onnx",
|
23 |
-
local_dir_use_symlinks=False,
|
24 |
-
)
|
25 |
-
|
26 |
-
st.title("Anti Prompt Injection Detection")
|
27 |
-
|
28 |
-
|
29 |
-
# Load the ONNX model
|
30 |
-
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
31 |
-
# Define the input form
|
32 |
-
def predict ( text ):
|
33 |
enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
|
34 |
inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
|
35 |
logits = sess.run(["logits"], inputs)[0]
|
@@ -37,21 +43,32 @@ def predict ( text ):
|
|
37 |
probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
|
38 |
return probs
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
|
|
2 |
import pandas as pd
|
3 |
import numpy as np
|
4 |
import onnxruntime as ort
|
5 |
from transformers import AutoTokenizer
|
6 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
7 |
import os
|
8 |
|
9 |
+
# Global variables to store loaded models
|
10 |
+
tokenizer = None
|
11 |
+
sess = None
|
12 |
|
13 |
+
@st.cache_resource
|
14 |
+
def load_models():
|
15 |
+
"""Load tokenizer and model with caching"""
|
16 |
+
global tokenizer, sess
|
17 |
+
|
18 |
+
if tokenizer is None:
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
|
20 |
+
|
21 |
+
if sess is None:
|
22 |
+
if os.path.exists("model_f16.onnx"):
|
23 |
+
st.write("Model already downloaded.")
|
24 |
+
model_path = "model_f16.onnx"
|
25 |
+
else:
|
26 |
+
st.write("Downloading model...")
|
27 |
+
model_path = hf_hub_download(
|
28 |
+
repo_id="bakhil-aissa/anti_prompt_injection",
|
29 |
+
filename="model_f16.onnx",
|
30 |
+
local_dir_use_symlinks=False,
|
31 |
+
)
|
32 |
+
|
33 |
+
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
|
34 |
+
|
35 |
+
return tokenizer, sess
|
36 |
|
37 |
+
def predict(text):
|
38 |
+
"""Predict function that uses the loaded models"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
|
40 |
inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
|
41 |
logits = sess.run(["logits"], inputs)[0]
|
|
|
43 |
probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
|
44 |
return probs
|
45 |
|
46 |
+
def main():
|
47 |
+
st.title("Anti Prompt Injection Detection")
|
48 |
+
|
49 |
+
# Load models when needed
|
50 |
+
global tokenizer, sess
|
51 |
+
tokenizer, sess = load_models()
|
52 |
+
|
53 |
+
st.subheader("Enter your text to check for prompt injection:")
|
54 |
+
text_input = st.text_area("Text Input", height=200)
|
55 |
+
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
|
56 |
+
|
57 |
+
if st.button("Check"):
|
58 |
+
if text_input:
|
59 |
+
try:
|
60 |
+
with st.spinner("Processing..."):
|
61 |
+
# Call the predict function
|
62 |
+
probs = predict(text_input)
|
63 |
+
jailbreak_prob = float(probs[0][1]) # index into batch
|
64 |
+
is_jailbreak = jailbreak_prob >= confidence_threshold
|
65 |
+
|
66 |
+
st.success(f"Is Jailbreak: {is_jailbreak}")
|
67 |
+
st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
|
68 |
+
except Exception as e:
|
69 |
+
st.error(f"Error: {str(e)}")
|
70 |
+
else:
|
71 |
+
st.warning("Please enter some text to check.")
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
main()
|