Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
from lime.lime_text import LimeTextExplainer | |
import matplotlib.pyplot as plt | |
# Load model and tokenizer | |
def load_model_and_tokenizer(model_path, tokenizer_path): | |
model = tf.keras.models.load_model(model_path) | |
tokenizer = pd.read_pickle(tokenizer_path) | |
return model, tokenizer | |
# Preprocess input for prediction | |
def preprocess_prompt(prompt, tokenizer, max_length): | |
sequence = tokenizer.texts_to_sequences([prompt]) | |
padded_sequence = pad_sequences(sequence, maxlen=max_length) | |
return padded_sequence | |
# Make predictions | |
def detect_prompt(prompt, model, tokenizer, max_length): | |
processed_prompt = preprocess_prompt(prompt, tokenizer, max_length) | |
prediction = model.predict(processed_prompt)[0][0] | |
class_label = "Malicious" if prediction >= 0.5 else "Valid" | |
confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100 | |
return class_label, confidence_score | |
# Explain predictions using LIME | |
def lime_explain(prompt, model, tokenizer, max_length): | |
def predict_fn(prompts): | |
sequences = tokenizer.texts_to_sequences(prompts) | |
padded_sequences = pad_sequences(sequences, maxlen=max_length) | |
predictions = model.predict(padded_sequences) | |
return np.hstack([1 - predictions, predictions]) # [P(valid), P(malicious)] | |
explainer = LimeTextExplainer(class_names=["Valid", "Malicious"]) | |
explanation = explainer.explain_instance(prompt, predict_fn, num_features=10) | |
return explanation | |
# Set up Streamlit app | |
st.title("Prompt Injection Detection and Prevention") | |
st.write("Detect malicious prompts and understand predictions using deep learning and LIME.") | |
# Load model and tokenizer | |
model_path = "path/to/your/saved_model" | |
tokenizer_path = "path/to/your/tokenizer.pkl" | |
max_length = 100 # Update based on your model | |
model, tokenizer = load_model_and_tokenizer(model_path, tokenizer_path) | |
# Input prompt | |
user_input = st.text_area("Enter your prompt:", height=150) | |
if st.button("Detect"): | |
if user_input.strip() == "": | |
st.error("Please enter a prompt.") | |
else: | |
# Prediction | |
class_label, confidence_score = detect_prompt(user_input, model, tokenizer, max_length) | |
st.subheader("Detection Result:") | |
st.write(f"**Class:** {class_label}") | |
st.write(f"**Confidence Score:** {confidence_score:.2f}%") | |
# Generate LIME explanation | |
st.subheader("Explanation:") | |
explanation = lime_explain(user_input, model, tokenizer, max_length) | |
fig = explanation.as_pyplot_figure() | |
st.pyplot(fig) | |
# Sidebar information | |
st.sidebar.title("About") | |
st.sidebar.info( | |
""" | |
This app uses a deep learning model to classify prompts as "Malicious" or "Valid." | |
LIME explanations are provided to interpret the predictions. | |
""" | |
) | |