deepugaur commited on
Commit
c69f467
·
verified ·
1 Parent(s): 829b00f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -0
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import tensorflow as tf
5
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
6
+ from lime.lime_text import LimeTextExplainer
7
+ import matplotlib.pyplot as plt
8
+
9
+ # Load model and tokenizer
10
+ @st.cache(allow_output_mutation=True)
11
+ def load_model_and_tokenizer(model_path, tokenizer_path):
12
+ model = tf.keras.models.load_model(model_path)
13
+ tokenizer = pd.read_pickle(tokenizer_path)
14
+ return model, tokenizer
15
+
16
+ # Preprocess input for prediction
17
+ def preprocess_prompt(prompt, tokenizer, max_length):
18
+ sequence = tokenizer.texts_to_sequences([prompt])
19
+ padded_sequence = pad_sequences(sequence, maxlen=max_length)
20
+ return padded_sequence
21
+
22
+ # Make predictions
23
+ def detect_prompt(prompt, model, tokenizer, max_length):
24
+ processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
25
+ prediction = model.predict(processed_prompt)[0][0]
26
+ class_label = "Malicious" if prediction >= 0.5 else "Valid"
27
+ confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
28
+ return class_label, confidence_score
29
+
30
+ # Explain predictions using LIME
31
+ def lime_explain(prompt, model, tokenizer, max_length):
32
+ def predict_fn(prompts):
33
+ sequences = tokenizer.texts_to_sequences(prompts)
34
+ padded_sequences = pad_sequences(sequences, maxlen=max_length)
35
+ predictions = model.predict(padded_sequences)
36
+ return np.hstack([1 - predictions, predictions]) # [P(valid), P(malicious)]
37
+
38
+ explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
39
+ explanation = explainer.explain_instance(prompt, predict_fn, num_features=10)
40
+ return explanation
41
+
42
+ # Set up Streamlit app
43
+ st.title("Prompt Injection Detection and Prevention")
44
+ st.write("Detect malicious prompts and understand predictions using deep learning and LIME.")
45
+
46
+ # Load model and tokenizer
47
+ model_path = "path/to/your/saved_model"
48
+ tokenizer_path = "path/to/your/tokenizer.pkl"
49
+ max_length = 100 # Update based on your model
50
+ model, tokenizer = load_model_and_tokenizer(model_path, tokenizer_path)
51
+
52
+ # Input prompt
53
+ user_input = st.text_area("Enter your prompt:", height=150)
54
+
55
+ if st.button("Detect"):
56
+ if user_input.strip() == "":
57
+ st.error("Please enter a prompt.")
58
+ else:
59
+ # Prediction
60
+ class_label, confidence_score = detect_prompt(user_input, model, tokenizer, max_length)
61
+ st.subheader("Detection Result:")
62
+ st.write(f"**Class:** {class_label}")
63
+ st.write(f"**Confidence Score:** {confidence_score:.2f}%")
64
+
65
+ # Generate LIME explanation
66
+ st.subheader("Explanation:")
67
+ explanation = lime_explain(user_input, model, tokenizer, max_length)
68
+ fig = explanation.as_pyplot_figure()
69
+ st.pyplot(fig)
70
+
71
+ # Sidebar information
72
+ st.sidebar.title("About")
73
+ st.sidebar.info(
74
+ """
75
+ This app uses a deep learning model to classify prompts as "Malicious" or "Valid."
76
+ LIME explanations are provided to interpret the predictions.
77
+ """
78
+ )