bakhil-aissa's picture
Update app.py
5f85e8f verified
raw
history blame
2.55 kB
import streamlit as st
import pandas as pd
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer
from huggingface_hub import hf_hub_download
import os
# Global variables to store loaded models
tokenizer = None
sess = None
def load_models():
"""Load tokenizer and model"""
global tokenizer, sess
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained('answerdotai/ModernBERT-large')
if sess is None:
if os.path.exists("model_f16.onnx"):
st.write("Model already downloaded.")
model_path = "model_f16.onnx"
else:
st.write("Downloading model...")
model_path = hf_hub_download(
repo_id="bakhil-aissa/anti_prompt_injection",
filename="model_f16.onnx",
local_dir_use_symlinks=False,
)
sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
return tokenizer, sess
def predict(text):
"""Predict function that uses the loaded models"""
enc = tokenizer([text], return_tensors="np", truncation=True, max_length=2048)
inputs = {"input_ids": enc["input_ids"], "attention_mask": enc["attention_mask"]}
logits = sess.run(["logits"], inputs)[0]
exp = np.exp(logits)
probs = exp / exp.sum(axis=1, keepdims=True) # shape (1, num_classes)
return probs
def main():
st.title("Anti Prompt Injection Detection")
# Load models when needed
global tokenizer, sess
tokenizer, sess = load_models()
st.subheader("Enter your text to check for prompt injection:")
text_input = st.text_area("Text Input", height=200)
confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.5)
if st.button("Check"):
if text_input:
try:
with st.spinner("Processing..."):
# Call the predict function
probs = predict(text_input)
jailbreak_prob = float(probs[0][1]) # index into batch
is_jailbreak = jailbreak_prob >= confidence_threshold
st.success(f"Is Jailbreak: {is_jailbreak}")
st.info(f"Jailbreak Probability: {jailbreak_prob:.4f}")
except Exception as e:
st.error(f"Error: {str(e)}")
else:
st.warning("Please enter some text to check.")
# Only define functions, don't execute anything
# Streamlit will automatically run the script when it's ready