File size: 2,994 Bytes
c69f467
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import streamlit as st
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt

# Load model and tokenizer
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer(model_path, tokenizer_path):
    model = tf.keras.models.load_model(model_path)
    tokenizer = pd.read_pickle(tokenizer_path)
    return model, tokenizer

# Preprocess input for prediction
def preprocess_prompt(prompt, tokenizer, max_length):
    sequence = tokenizer.texts_to_sequences([prompt])
    padded_sequence = pad_sequences(sequence, maxlen=max_length)
    return padded_sequence

# Make predictions
def detect_prompt(prompt, model, tokenizer, max_length):
    processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
    prediction = model.predict(processed_prompt)[0][0]
    class_label = "Malicious" if prediction >= 0.5 else "Valid"
    confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
    return class_label, confidence_score

# Explain predictions using LIME
def lime_explain(prompt, model, tokenizer, max_length):
    def predict_fn(prompts):
        sequences = tokenizer.texts_to_sequences(prompts)
        padded_sequences = pad_sequences(sequences, maxlen=max_length)
        predictions = model.predict(padded_sequences)
        return np.hstack([1 - predictions, predictions])  # [P(valid), P(malicious)]

    explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
    explanation = explainer.explain_instance(prompt, predict_fn, num_features=10)
    return explanation

# Set up Streamlit app
st.title("Prompt Injection Detection and Prevention")
st.write("Detect malicious prompts and understand predictions using deep learning and LIME.")

# Load model and tokenizer
model_path = "path/to/your/saved_model"
tokenizer_path = "path/to/your/tokenizer.pkl"
max_length = 100  # Update based on your model
model, tokenizer = load_model_and_tokenizer(model_path, tokenizer_path)

# Input prompt
user_input = st.text_area("Enter your prompt:", height=150)

if st.button("Detect"):
    if user_input.strip() == "":
        st.error("Please enter a prompt.")
    else:
        # Prediction
        class_label, confidence_score = detect_prompt(user_input, model, tokenizer, max_length)
        st.subheader("Detection Result:")
        st.write(f"**Class:** {class_label}")
        st.write(f"**Confidence Score:** {confidence_score:.2f}%")

        # Generate LIME explanation
        st.subheader("Explanation:")
        explanation = lime_explain(user_input, model, tokenizer, max_length)
        fig = explanation.as_pyplot_figure()
        st.pyplot(fig)

# Sidebar information
st.sidebar.title("About")
st.sidebar.info(
    """
    This app uses a deep learning model to classify prompts as "Malicious" or "Valid."
    LIME explanations are provided to interpret the predictions.
    """
)