bakhil-aissa's picture
Upload app.py
efec1ff verified
raw
history blame
2.51 kB
import streamlit as st
import pandas as pd
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import os
# Global variables to store loaded models
tokenizer = None
sess = None
@st.cache_resource
def load_models():
"""Load tokenizer and model with caching"""
global tokenizer, sess
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
if sess is None:
if os.path.exists("model_f16.onnx"):
st.write("Model already downloaded.")
model_path = "model_f16.onnx"
else:
st.write("Downloading model...")
model_path = hf_hub_download(
repo_id="bakhil-aissa/anti_prompt_injection",
filename="model_f16.onnx",
local_dir_use_symlinks=False,
)
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
return tokenizer, sess
def predict(text):
"""Predict function that uses the loaded models"""
enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
logits = sess.run(["logits"], inputs)[0]
exp = np.exp(logits)
probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
return probs
def main():
st.title("Anti Prompt Injection Detection")
# Load models when needed
global tokenizer, sess
tokenizer, sess = load_models()
st.subheader("Enter your text to check for prompt injection:")
text_input = st.text_area("Text Input", height=200)
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
if st.button("Check"):
if text_input:
try:
with st.spinner("Processing..."):
# Call the predict function
probs = predict(text_input)
jailbreak_prob = float(probs[0][1]) # index into batch
is_jailbreak = jailbreak_prob >= confidence_threshold
st.success(f"Is Jailbreak: {is_jailbreak}")
st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
except Exception as e:
st.error(f"Error: {str(e)}")
else:
st.warning("Please enter some text to check.")
if __name__ == "__main__":
main()