|
import streamlit as st |
|
import pandas as pd |
|
import numpy as np |
|
import onnxruntime as ort |
|
from transformers import AutoTokenizer |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
|
|
|
|
tokenizer = None |
|
sess = None |
|
|
|
@st.cache_resource |
|
def load_models(): |
|
"""Load tokenizer and model with caching""" |
|
global tokenizer, sess |
|
|
|
if tokenizer is None: |
|
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large') |
|
|
|
if sess is None: |
|
if os.path.exists("model_f16.onnx"): |
|
st.write("Model already downloaded.") |
|
model_path = "model_f16.onnx" |
|
else: |
|
st.write("Downloading model...") |
|
model_path = hf_hub_download( |
|
repo_id="bakhil-aissa/anti_prompt_injection", |
|
filename="model_f16.onnx", |
|
local_dir_use_symlinks=False, |
|
) |
|
|
|
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
|
|
|
return tokenizer, sess |
|
|
|
def predict(text): |
|
"""Predict function that uses the loaded models""" |
|
enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048) |
|
inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]} |
|
logits = sess.run(["logits"], inputs)[0] |
|
exp = np.exp(logits) |
|
probs = exp / exp.sum(axis=1, keepdims=True) |
|
return probs |
|
|
|
def main(): |
|
st.title("Anti Prompt Injection Detection") |
|
|
|
|
|
global tokenizer, sess |
|
tokenizer, sess = load_models() |
|
|
|
st.subheader("Enter your text to check for prompt injection:") |
|
text_input = st.text_area("Text Input", height=200) |
|
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5) |
|
|
|
if st.button("Check"): |
|
if text_input: |
|
try: |
|
with st.spinner("Processing..."): |
|
|
|
probs = predict(text_input) |
|
jailbreak_prob = float(probs[0][1]) |
|
is_jailbreak = jailbreak_prob >= confidence_threshold |
|
|
|
st.success(f"Is Jailbreak: {is_jailbreak}") |
|
st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}") |
|
except Exception as e: |
|
st.error(f"Error: {str(e)}") |
|
else: |
|
st.warning("Please enter some text to check.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|