File size: 3,075 Bytes
23ddfce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466d440
fae5744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23ddfce
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91


import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt

# Streamlit Title
st.title("Prompt Injection Detection and Prevention")
st.write("Classify prompts as malicious or valid and understand predictions using LIME.")

# Cache Model Loading
@st.cache_resource
def load_model(filepath):
    return tf.keras.models.load_model(filepath)

# Tokenizer Setup
@st.cache_resource
def setup_tokenizer():
    tokenizer = Tokenizer(num_words=5000)
    # Predefined vocabulary for demonstration purposes; replace with your actual tokenizer setup.
    tokenizer.fit_on_texts(["example prompt", "malicious attack", "valid input prompt"])
    return tokenizer

# Preprocessing Function
def preprocess_prompt(prompt, tokenizer, max_length=100):
    sequence = tokenizer.texts_to_sequences([prompt])
    return pad_sequences(sequence, maxlen=max_length)

# Prediction Function
def detect_prompt(prompt, tokenizer, model):
    processed_prompt = preprocess_prompt(prompt, tokenizer)
    prediction = model.predict(processed_prompt)[0][0]
    class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
    confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
    return class_label, confidence_score

# LIME Explanation
def lime_explain(prompt, model, tokenizer, max_length=100):
    explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
    
    def predict_fn(prompts):
        sequences = tokenizer.texts_to_sequences(prompts)
        padded_sequences = pad_sequences(sequences, maxlen=max_length)
        predictions = model.predict(padded_sequences)
        return np.hstack([1 - predictions, predictions])

    explanation = explainer.explain_instance(
        prompt,
        predict_fn,
        num_features=10
    )
    return explanation

# Load Model Section
st.subheader("Load Your Trained Model")
model = None
tokenizer = None
model_path = "deep_learning_model (1).h5"  # Ensure this file is in the same directory as app.py

try:
    model = load_model(model_path)
    tokenizer = setup_tokenizer()
    st.success("Model Loaded Successfully!")

    # User Prompt Input
    st.subheader("Classify Your Prompt")
    user_prompt = st.text_input("Enter a prompt to classify:")

    if user_prompt:
        class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
        st.write(f"Predicted Class: **{class_label}**")
        st.write(f"Confidence Score: **{confidence_score:.2f}%**")

        # LIME Explanation
        st.subheader("LIME Explanation")
        explanation = lime_explain(user_prompt, model, tokenizer)
        explanation_as_html = explanation.as_html()
        st.components.v1.html(explanation_as_html, height=500)

except Exception as e:
    st.error(f"Error Loading Model: {e}")


# Footer
st.write("---")
st.write("Developed for detecting and preventing prompt injection attacks.")