bakhil-aissa commited on
Commit
efec1ff
·
verified ·
1 Parent(s): 3d9c665

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -40
app.py CHANGED
@@ -1,35 +1,41 @@
1
  import streamlit as st
2
-
3
  import pandas as pd
4
  import numpy as np
5
  import onnxruntime as ort
6
  from transformers import AutoTokenizer
7
  from huggingface_hub import hf_hub_download
8
-
9
-
10
  import os
11
 
 
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # download the model from Hugging Face
15
- tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
16
- if os.path.exists("model_f16.onnx"):
17
- st.write("Model already downloaded.")
18
- else:
19
- st.write("Downloading model...")
20
- model_path = hf_hub_download(
21
- repo_id="bakhil-aissa/anti_prompt_injection",
22
- filename="model_f16.onnx",
23
- local_dir_use_symlinks=False,
24
- )
25
-
26
- st.title("Anti Prompt Injection Detection")
27
-
28
-
29
- # Load the ONNX model
30
- sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
31
- # Define the input form
32
- def predict ( text ):
33
  enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
34
  inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
35
  logits = sess.run(["logits"], inputs)[0]
@@ -37,21 +43,32 @@ def predict ( text ):
37
  probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
38
  return probs
39
 
40
- st.subheader("Enter your text to check for prompt injection:")
41
- text_input = st.text_area("Text Input", height=200)
42
- confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
43
- if st.button("Check"):
44
- if text_input:
45
- try:
46
- with st.spinner("Processing..."):
47
- # Call the predict function
48
- probs = predict(text_input)
49
- jailbreak_prob = float(probs[0][1]) # index into batch
50
- is_jailbreak = jailbreak_prob >= confidence_threshold
51
-
52
- st.success(f"Is Jailbreak: {is_jailbreak}")
53
- st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
54
- except Exception as e:
55
- st.error(f"Error: {str(e)}")
56
- else:
57
- st.warning("Please enter some text to check.")
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
 
2
  import pandas as pd
3
  import numpy as np
4
  import onnxruntime as ort
5
  from transformers import AutoTokenizer
6
  from huggingface_hub import hf_hub_download
 
 
7
  import os
8
 
9
+ # Global variables to store loaded models
10
+ tokenizer = None
11
+ sess = None
12
 
13
+ @st.cache_resource
14
+ def load_models():
15
+ """Load tokenizer and model with caching"""
16
+ global tokenizer, sess
17
+
18
+ if tokenizer is None:
19
+ tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
20
+
21
+ if sess is None:
22
+ if os.path.exists("model_f16.onnx"):
23
+ st.write("Model already downloaded.")
24
+ model_path = "model_f16.onnx"
25
+ else:
26
+ st.write("Downloading model...")
27
+ model_path = hf_hub_download(
28
+ repo_id="bakhil-aissa/anti_prompt_injection",
29
+ filename="model_f16.onnx",
30
+ local_dir_use_symlinks=False,
31
+ )
32
+
33
+ sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
34
+
35
+ return tokenizer, sess
36
 
37
+ def predict(text):
38
+ """Predict function that uses the loaded models"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
40
  inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
41
  logits = sess.run(["logits"], inputs)[0]
 
43
  probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
44
  return probs
45
 
46
+ def main():
47
+ st.title("Anti Prompt Injection Detection")
48
+
49
+ # Load models when needed
50
+ global tokenizer, sess
51
+ tokenizer, sess = load_models()
52
+
53
+ st.subheader("Enter your text to check for prompt injection:")
54
+ text_input = st.text_area("Text Input", height=200)
55
+ confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
56
+
57
+ if st.button("Check"):
58
+ if text_input:
59
+ try:
60
+ with st.spinner("Processing..."):
61
+ # Call the predict function
62
+ probs = predict(text_input)
63
+ jailbreak_prob = float(probs[0][1]) # index into batch
64
+ is_jailbreak = jailbreak_prob >= confidence_threshold
65
+
66
+ st.success(f"Is Jailbreak: {is_jailbreak}")
67
+ st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
68
+ except Exception as e:
69
+ st.error(f"Error: {str(e)}")
70
+ else:
71
+ st.warning("Please enter some text to check.")
72
+
73
+ if __name__ == "__main__":
74
+ main()