File size: 3,647 Bytes
23ddfce
 
 
 
 
1c1f9d0
 
 
23ddfce
 
1c1f9d0
23ddfce
 
 
 
1c1f9d0
23ddfce
1c1f9d0
 
 
23ddfce
1c1f9d0
23ddfce
 
1c1f9d0
 
23ddfce
1c1f9d0
 
 
23ddfce
1c1f9d0
23ddfce
 
 
1c1f9d0
23ddfce
 
 
 
 
 
1c1f9d0
 
 
 
23ddfce
 
1c1f9d0
 
 
fae5744
1c1f9d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fae5744
 
1c1f9d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23ddfce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101


import streamlit as st
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pickle
from lime import lime_text
from lime.lime_text import LimeTextExplainer

# Load the model
@st.cache_resource
def load_model(filepath):
    return tf.keras.models.load_model(filepath)

# Load tokenizer
@st.cache_resource
def load_tokenizer(filepath):
    with open(filepath, 'rb') as handle:
        return pickle.load(handle)

# Preprocess prompt
def preprocess_prompt(prompt, tokenizer, max_length=100):
    sequence = tokenizer.texts_to_sequences([prompt])
    padded_sequence = pad_sequences(sequence, maxlen=max_length)
    return padded_sequence

# Predict prompt class
def detect_prompt(prompt, tokenizer, model, max_length=100):
    processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
    prediction = model.predict(processed_prompt)[0][0]
    class_label = "Malicious" if prediction >= 0.5 else "Valid"
    confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
    return class_label, confidence_score

# LIME explanation
def lime_explain(prompt, model, tokenizer, max_length=100):
    def predict_fn(prompts):
        sequences = tokenizer.texts_to_sequences(prompts)
        padded_sequences = pad_sequences(sequences, maxlen=max_length)
        predictions = model.predict(padded_sequences)
        return np.hstack([1 - predictions, predictions])
    
    class_names = ["Valid", "Malicious"]
    explainer = LimeTextExplainer(class_names=class_names)
    explanation = explainer.explain_instance(prompt, predict_fn, num_features=10)
    return explanation

# Streamlit App
st.title("Prompt Injection Detection and Prevention")
st.write("Classify prompts as malicious or valid and understand predictions using LIME.")

# Model input
model_path = st.text_input("Enter the path to your trained model (.h5):")
if model_path:
    try:
        model = load_model(model_path)
        st.success("Model Loaded Successfully!")
    except Exception as e:
        st.error(f"Error Loading Model: {e}")
        model = None
else:
    model = None

# Tokenizer input
tokenizer_path = st.text_input("Enter the path to your tokenizer file (.pickle):")
if tokenizer_path:
    try:
        tokenizer = load_tokenizer(tokenizer_path)
        st.success("Tokenizer Loaded Successfully!")
    except Exception as e:
        st.error(f"Error Loading Tokenizer: {e}")
        tokenizer = None
else:
    tokenizer = None

# Prompt classification
if model and tokenizer:
    user_prompt = st.text_input("Enter a prompt to classify:")
    if user_prompt:
        st.subheader("Model Prediction")
        try:
            # Classify the prompt
            class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
            st.write(f"Predicted Class: **{class_label}**")
            st.write(f"Confidence Score: **{confidence_score:.2f}%**")
            
            # Debugging information
            st.write("Debugging Information:")
            st.write(f"Tokenized Sequence: {tokenizer.texts_to_sequences([user_prompt])}")
            st.write(f"Padded Sequence: {preprocess_prompt(user_prompt, tokenizer)}")
            st.write(f"Raw Model Output: {model.predict(preprocess_prompt(user_prompt, tokenizer))[0][0]}")

            # Generate LIME explanation
            explanation = lime_explain(user_prompt, model, tokenizer)
            explanation_as_html = explanation.as_html()
            st.components.v1.html(explanation_as_html, height=500)
        except Exception as e:
            st.error(f"Error during prediction: {e}")