File size: 2,552 Bytes
e54e4ba efec1ff e54e4ba efec1ff 51d1021 efec1ff e54e4ba efec1ff e54e4ba efec1ff 5f85e8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import streamlit as st
import pandas as pd
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import os
# Global variables to store loaded models
tokenizer = None
sess = None
def load_models():
"""Load tokenizer and model"""
global tokenizer, sess
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
if sess is None:
if os.path.exists("model_f16.onnx"):
st.write("Model already downloaded.")
model_path = "model_f16.onnx"
else:
st.write("Downloading model...")
model_path = hf_hub_download(
repo_id="bakhil-aissa/anti_prompt_injection",
filename="model_f16.onnx",
local_dir_use_symlinks=False,
)
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
return tokenizer, sess
def predict(text):
"""Predict function that uses the loaded models"""
enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
logits = sess.run(["logits"], inputs)[0]
exp = np.exp(logits)
probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
return probs
def main():
st.title("Anti Prompt Injection Detection")
# Load models when needed
global tokenizer, sess
tokenizer, sess = load_models()
st.subheader("Enter your text to check for prompt injection:")
text_input = st.text_area("Text Input", height=200)
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
if st.button("Check"):
if text_input:
try:
with st.spinner("Processing..."):
# Call the predict function
probs = predict(text_input)
jailbreak_prob = float(probs[0][1]) # index into batch
is_jailbreak = jailbreak_prob >= confidence_threshold
st.success(f"Is Jailbreak: {is_jailbreak}")
st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
except Exception as e:
st.error(f"Error: {str(e)}")
else:
st.warning("Please enter some text to check.")
# Only define functions, don't execute anything
# Streamlit will automatically run the script when it's ready
|