deepugaur commited on
Commit
1c1f9d0
·
verified ·
1 Parent(s): 466d440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -58
app.py CHANGED
@@ -2,89 +2,99 @@
2
 
3
  import streamlit as st
4
  import tensorflow as tf
5
- from tensorflow.keras.preprocessing.text import Tokenizer
6
- from tensorflow.keras.preprocessing.sequence import pad_sequences
7
  import numpy as np
 
 
 
8
  from lime.lime_text import LimeTextExplainer
9
- import matplotlib.pyplot as plt
10
-
11
- # Streamlit Title
12
- st.title("Prompt Injection Detection and Prevention")
13
- st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
14
 
15
- # Cache Model Loading
16
  @st.cache_resource
17
  def load_model(filepath):
18
  return tf.keras.models.load_model(filepath)
19
 
20
- # Tokenizer Setup
21
  @st.cache_resource
22
- def setup_tokenizer():
23
- tokenizer = Tokenizer(num_words=5000)
24
- # Predefined vocabulary for demonstration purposes; replace with your actual tokenizer setup.
25
- tokenizer.fit_on_texts(["example prompt", "malicious attack", "valid input prompt"])
26
- return tokenizer
27
 
28
- # Preprocessing Function
29
  def preprocess_prompt(prompt, tokenizer, max_length=100):
30
  sequence = tokenizer.texts_to_sequences([prompt])
31
- return pad_sequences(sequence, maxlen=max_length)
 
32
 
33
- # Prediction Function
34
- def detect_prompt(prompt, tokenizer, model):
35
- processed_prompt = preprocess_prompt(prompt, tokenizer)
36
  prediction = model.predict(processed_prompt)[0][0]
37
- class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
38
  confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
39
  return class_label, confidence_score
40
 
41
- # LIME Explanation
42
  def lime_explain(prompt, model, tokenizer, max_length=100):
43
- explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
44
-
45
  def predict_fn(prompts):
46
  sequences = tokenizer.texts_to_sequences(prompts)
47
  padded_sequences = pad_sequences(sequences, maxlen=max_length)
48
  predictions = model.predict(padded_sequences)
49
  return np.hstack([1 - predictions, predictions])
50
-
51
- explanation = explainer.explain_instance(
52
- prompt,
53
- predict_fn,
54
- num_features=10
55
- )
56
  return explanation
57
 
58
- # Load Model Section
59
- st.subheader("Load Your Trained Model")
60
- model = None
61
- tokenizer = None
62
- model_path = "deep_learning_model (1).h5" # Ensure this file is in the same directory as app.py
63
-
64
- try:
65
- model = load_model(model_path)
66
- tokenizer = setup_tokenizer()
67
- st.success("Model Loaded Successfully!")
68
 
69
- # User Prompt Input
70
- st.subheader("Classify Your Prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  user_prompt = st.text_input("Enter a prompt to classify:")
72
-
73
  if user_prompt:
74
- class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
75
- st.write(f"Predicted Class: **{class_label}**")
76
- st.write(f"Confidence Score: **{confidence_score:.2f}%**")
77
-
78
- # LIME Explanation
79
- st.subheader("LIME Explanation")
80
- explanation = lime_explain(user_prompt, model, tokenizer)
81
- explanation_as_html = explanation.as_html()
82
- st.components.v1.html(explanation_as_html, height=500)
83
-
84
- except Exception as e:
85
- st.error(f"Error Loading Model: {e}")
86
-
 
 
 
 
 
 
87
 
88
- # Footer
89
- st.write("---")
90
- st.write("Developed for detecting and preventing prompt injection attacks.")
 
2
 
3
  import streamlit as st
4
  import tensorflow as tf
 
 
5
  import numpy as np
6
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
7
+ import pickle
8
+ from lime import lime_text
9
  from lime.lime_text import LimeTextExplainer
 
 
 
 
 
10
 
11
+ # Load the model
12
  @st.cache_resource
13
  def load_model(filepath):
14
  return tf.keras.models.load_model(filepath)
15
 
16
+ # Load tokenizer
17
  @st.cache_resource
18
+ def load_tokenizer(filepath):
19
+ with open(filepath, 'rb') as handle:
20
+ return pickle.load(handle)
 
 
21
 
22
+ # Preprocess prompt
23
  def preprocess_prompt(prompt, tokenizer, max_length=100):
24
  sequence = tokenizer.texts_to_sequences([prompt])
25
+ padded_sequence = pad_sequences(sequence, maxlen=max_length)
26
+ return padded_sequence
27
 
28
+ # Predict prompt class
29
+ def detect_prompt(prompt, tokenizer, model, max_length=100):
30
+ processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
31
  prediction = model.predict(processed_prompt)[0][0]
32
+ class_label = "Malicious" if prediction >= 0.5 else "Valid"
33
  confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
34
  return class_label, confidence_score
35
 
36
+ # LIME explanation
37
  def lime_explain(prompt, model, tokenizer, max_length=100):
 
 
38
  def predict_fn(prompts):
39
  sequences = tokenizer.texts_to_sequences(prompts)
40
  padded_sequences = pad_sequences(sequences, maxlen=max_length)
41
  predictions = model.predict(padded_sequences)
42
  return np.hstack([1 - predictions, predictions])
43
+
44
+ class_names = ["Valid", "Malicious"]
45
+ explainer = LimeTextExplainer(class_names=class_names)
46
+ explanation = explainer.explain_instance(prompt, predict_fn, num_features=10)
 
 
47
  return explanation
48
 
49
+ # Streamlit App
50
+ st.title("Prompt Injection Detection and Prevention")
51
+ st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
 
 
 
 
 
 
 
52
 
53
+ # Model input
54
+ model_path = st.text_input("Enter the path to your trained model (.h5):")
55
+ if model_path:
56
+ try:
57
+ model = load_model(model_path)
58
+ st.success("Model Loaded Successfully!")
59
+ except Exception as e:
60
+ st.error(f"Error Loading Model: {e}")
61
+ model = None
62
+ else:
63
+ model = None
64
+
65
+ # Tokenizer input
66
+ tokenizer_path = st.text_input("Enter the path to your tokenizer file (.pickle):")
67
+ if tokenizer_path:
68
+ try:
69
+ tokenizer = load_tokenizer(tokenizer_path)
70
+ st.success("Tokenizer Loaded Successfully!")
71
+ except Exception as e:
72
+ st.error(f"Error Loading Tokenizer: {e}")
73
+ tokenizer = None
74
+ else:
75
+ tokenizer = None
76
+
77
+ # Prompt classification
78
+ if model and tokenizer:
79
  user_prompt = st.text_input("Enter a prompt to classify:")
 
80
  if user_prompt:
81
+ st.subheader("Model Prediction")
82
+ try:
83
+ # Classify the prompt
84
+ class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
85
+ st.write(f"Predicted Class: **{class_label}**")
86
+ st.write(f"Confidence Score: **{confidence_score:.2f}%**")
87
+
88
+ # Debugging information
89
+ st.write("Debugging Information:")
90
+ st.write(f"Tokenized Sequence: {tokenizer.texts_to_sequences([user_prompt])}")
91
+ st.write(f"Padded Sequence: {preprocess_prompt(user_prompt, tokenizer)}")
92
+ st.write(f"Raw Model Output: {model.predict(preprocess_prompt(user_prompt, tokenizer))[0][0]}")
93
+
94
+ # Generate LIME explanation
95
+ explanation = lime_explain(user_prompt, model, tokenizer)
96
+ explanation_as_html = explanation.as_html()
97
+ st.components.v1.html(explanation_as_html, height=500)
98
+ except Exception as e:
99
+ st.error(f"Error during prediction: {e}")
100