File size: 3,144 Bytes
c69f467
 
d63a5db
c69f467
d63a5db
c69f467
 
 
d63a5db
 
c95a6a9
c69f467
d63a5db
 
 
 
c69f467
c95a6a9
 
 
d63a5db
c95a6a9
 
 
c69f467
c95a6a9
 
 
 
d63a5db
c95a6a9
 
 
 
 
 
 
d63a5db
c95a6a9
 
 
 
 
 
 
 
 
d63a5db
c95a6a9
 
 
 
 
 
d63a5db
c95a6a9
 
 
 
 
c69f467
c95a6a9
 
 
 
 
 
 
 
 
c69f467
c95a6a9
 
 
 
c69f467
c95a6a9
d63a5db
c95a6a9
 
 
c69f467
c95a6a9
 
c69f467
d63a5db
 
c95a6a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt

# Streamlit Title
st.title("Prompt Injection Detection and Prevention")
st.write("Classify prompts as malicious or valid and understand predictions using LIME.")

# Cache Model Loading
@st.cache_resource
def load_model(filepath):
    return tf.keras.models.load_model(filepath)

# Tokenizer Setup
@st.cache_resource
def setup_tokenizer():
    tokenizer = Tokenizer(num_words=5000)
    # Predefined vocabulary for demonstration purposes; replace with your actual tokenizer setup.
    tokenizer.fit_on_texts(["example prompt", "malicious attack", "valid input prompt"])
    return tokenizer

# Preprocessing Function
def preprocess_prompt(prompt, tokenizer, max_length=100):
    sequence = tokenizer.texts_to_sequences([prompt])
    return pad_sequences(sequence, maxlen=max_length)

# Prediction Function
def detect_prompt(prompt, tokenizer, model):
    processed_prompt = preprocess_prompt(prompt, tokenizer)
    prediction = model.predict(processed_prompt)[0][0]
    class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
    confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
    return class_label, confidence_score

# LIME Explanation
def lime_explain(prompt, model, tokenizer, max_length=100):
    explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
    
    def predict_fn(prompts):
        sequences = tokenizer.texts_to_sequences(prompts)
        padded_sequences = pad_sequences(sequences, maxlen=max_length)
        predictions = model.predict(padded_sequences)
        return np.hstack([1 - predictions, predictions])

    explanation = explainer.explain_instance(
        prompt,
        predict_fn,
        num_features=10
    )
    return explanation

# Load Model Section
st.subheader("Load Your Trained Model")
model_path = st.text_input("Enter the path to your trained model (.h5):")
model = None
tokenizer = None

if model_path:
    try:
        model = load_model(model_path)
        tokenizer = setup_tokenizer()
        st.success("Model Loaded Successfully!")
        
        # User Prompt Input
        st.subheader("Classify Your Prompt")
        user_prompt = st.text_input("Enter a prompt to classify:")

        if user_prompt:
            class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
            st.write(f"Predicted Class: **{class_label}**")
            st.write(f"Confidence Score: **{confidence_score:.2f}%**")

            # LIME Explanation
            st.subheader("LIME Explanation")
            explanation = lime_explain(user_prompt, model, tokenizer)
            explanation_as_html = explanation.as_html()
            st.components.v1.html(explanation_as_html, height=500)

    except Exception as e:
        st.error(f"Error Loading Model: {e}")

# Footer
st.write("---")
st.write("Developed for detecting and preventing prompt injection attacks.")