File size: 2,491 Bytes
2e2b62c
 
c69f467
d63a5db
92203ec
 
 
 
 
 
 
c69f467
58990ee
d63a5db
92203ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76


import streamlit as st
import numpy as np
import pandas as pd
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import load_model
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt
import seaborn as sns

# Load the trained model
@st.cache_resource
def load_trained_model():
    model = load_model("deep_learning_model.h5")
    return model

model = load_trained_model()

# Tokenizer setup
tokenizer = Tokenizer(num_words=5000)
max_length = 100

# Load Data
@st.cache_data
def load_data():
    data = pd.read_csv("train prompt.csv", sep=',', quoting=3, encoding='ISO-8859-1', on_bad_lines='skip', engine='python')
    data['label'] = data['label'].replace({'valid': 0, 'malicious': 1})
    return data

data = load_data()
tokenizer.fit_on_texts(data['input'].values)

# Preprocessing functions
def preprocess_prompt(prompt, tokenizer, max_length):
    sequence = tokenizer.texts_to_sequences([prompt])
    padded_sequence = pad_sequences(sequence, maxlen=max_length)
    return padded_sequence

def detect_prompt(prompt, model, tokenizer, max_length):
    processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
    prediction = model.predict(processed_prompt)[0][0]
    class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
    confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
    return class_label, confidence_score

# Load model
model = load_trained_model()

# Streamlit app
st.title("Prompt Injection Attack Detection")
st.write("This application detects malicious prompts to prevent injection attacks.")

prompt = st.text_input("Enter a prompt to analyze:")

if prompt:
    class_label, confidence = detect_prompt(prompt, model, tokenizer, max_length)
    st.write(f"### Prediction: {class_label}")
    st.write(f"Confidence: {confidence:.2f}%")

    # LIME explanation
    st.write("Generating LIME Explanation...")
    explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])

    def predict_fn(prompts):
        sequences = tokenizer.texts_to_sequences(prompts)
        padded_sequences = pad_sequences(sequences, maxlen=max_length)
        predictions = model.predict(padded_sequences)
        return np.hstack([1 - predictions, predictions])

    explanation = explainer.explain_instance(prompt, predict_fn, num_features=10)
    fig = explanation.as_pyplot_figure()
    st.pyplot(fig)