File size: 3,954 Bytes
c69f467
 
 
d63a5db
c69f467
d63a5db
 
c69f467
 
 
d63a5db
 
 
c69f467
d63a5db
 
 
 
c69f467
d63a5db
 
 
 
c69f467
d63a5db
 
 
 
 
 
 
 
 
 
 
 
c69f467
d63a5db
 
 
 
 
 
c69f467
d63a5db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c69f467
d63a5db
 
 
 
 
 
c69f467
d63a5db
 
 
 
 
 
c69f467
d63a5db
 
 
 
 
c69f467
d63a5db
 
c69f467
d63a5db
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import streamlit as st
import pandas as pd
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
import numpy as np
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt

# Streamlit Title
st.title("Prompt Injection Detection and Prevention")
st.write("Detect malicious prompts and understand predictions using deep learning and LIME.")

# Cache Data Loading
@st.cache_data
def load_data(filepath):
    return pd.read_csv(filepath)

# Cache Model Loading
@st.cache_resource
def load_model(filepath):
    return tf.keras.models.load_model(filepath)

# File Upload Section
uploaded_file = st.file_uploader("Upload your dataset (.csv)", type=["csv"])
if uploaded_file is not None:
    data = load_data(uploaded_file)
    st.write("Dataset Preview:")
    st.write(data.head())
    
    # Data Preprocessing
    data['label'] = data['label'].replace({'valid': 0, 'malicious': 1})
    X = data['input'].values
    y = data['label'].values
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Tokenization and Padding
    tokenizer = Tokenizer(num_words=5000)
    tokenizer.fit_on_texts(X_train)
    max_length = 100
    X_train_pad = pad_sequences(tokenizer.texts_to_sequences(X_train), maxlen=max_length)
    X_test_pad = pad_sequences(tokenizer.texts_to_sequences(X_test), maxlen=max_length)

    # Load Deep Learning Model
    model_path = st.text_input("Enter the path to your trained model (.h5):")
    if model_path:
        try:
            model = load_model(model_path)
            st.success("Model Loaded Successfully!")

            # Test Prediction Functionality
            def preprocess_prompt(prompt, tokenizer, max_length):
                sequence = tokenizer.texts_to_sequences([prompt])
                return pad_sequences(sequence, maxlen=max_length)

            def detect_prompt(prompt):
                processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
                prediction = model.predict(processed_prompt)[0][0]
                class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
                confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
                return class_label, confidence_score

            # User Input for Prompt Detection
            st.subheader("Test a Prompt")
            user_prompt = st.text_input("Enter a prompt to test:")
            if user_prompt:
                class_label, confidence_score = detect_prompt(user_prompt)
                st.write(f"Predicted Class: **{class_label}**")
                st.write(f"Confidence Score: **{confidence_score:.2f}%**")

            # LIME Explanation
            explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])

            def lime_explain(prompt):
                def predict_fn(prompts):
                    sequences = tokenizer.texts_to_sequences(prompts)
                    padded_sequences = pad_sequences(sequences, maxlen=max_length)
                    predictions = model.predict(padded_sequences)
                    return np.hstack([1 - predictions, predictions])

                explanation = explainer.explain_instance(
                    prompt,
                    predict_fn,
                    num_features=10
                )
                return explanation

            st.subheader("LIME Explanation")
            if user_prompt:
                explanation = lime_explain(user_prompt)
                explanation_as_html = explanation.as_html()
                st.components.v1.html(explanation_as_html, height=500)

        except Exception as e:
            st.error(f"Error Loading Model: {e}")

# Footer
st.write("---")
st.write("Developed for detecting and preventing prompt injection attacks using Streamlit.")