deepugaur commited on
Commit
5408638
·
verified ·
1 Parent(s): 3d67bb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -69
app.py CHANGED
@@ -2,103 +2,98 @@
2
 
3
  import streamlit as st
4
  import tensorflow as tf
5
- import numpy as np
6
  from tensorflow.keras.preprocessing.sequence import pad_sequences
7
- import pickle
8
- from lime import lime_text
9
  from lime.lime_text import LimeTextExplainer
 
 
 
 
 
 
10
 
11
- # Load the model
12
  @st.cache_resource
13
  def load_model(filepath):
14
  return tf.keras.models.load_model(filepath)
15
 
16
- # Load tokenizer
17
- import pickle
18
-
19
- # Function to load the tokenizer
20
  @st.cache_resource
21
- def load_tokenizer():
22
- with open('tokenizer.pickle', 'rb') as handle:
23
- tokenizer = pickle.load(handle)
24
- return tokenizer
 
 
 
 
25
 
26
- # Preprocess prompt
27
  def preprocess_prompt(prompt, tokenizer, max_length=100):
28
  sequence = tokenizer.texts_to_sequences([prompt])
29
- padded_sequence = pad_sequences(sequence, maxlen=max_length)
30
- return padded_sequence
31
 
32
- # Predict prompt class
33
- def detect_prompt(prompt, tokenizer, model, max_length=100):
34
- processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
35
  prediction = model.predict(processed_prompt)[0][0]
36
- class_label = "Malicious" if prediction >= 0.5 else "Valid"
37
  confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
38
  return class_label, confidence_score
39
 
40
- # LIME explanation
41
  def lime_explain(prompt, model, tokenizer, max_length=100):
 
 
42
  def predict_fn(prompts):
43
  sequences = tokenizer.texts_to_sequences(prompts)
44
  padded_sequences = pad_sequences(sequences, maxlen=max_length)
45
  predictions = model.predict(padded_sequences)
46
  return np.hstack([1 - predictions, predictions])
47
-
48
- class_names = ["Valid", "Malicious"]
49
- explainer = LimeTextExplainer(class_names=class_names)
50
- explanation = explainer.explain_instance(prompt, predict_fn, num_features=10)
51
- return explanation
52
 
53
- # Streamlit App
54
- st.title("Prompt Injection Detection and Prevention")
55
- st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
 
 
 
56
 
57
- # Model input
 
58
  model_path = st.text_input("Enter the path to your trained model (.h5):")
59
- if model_path:
60
- try:
61
- model = load_model(model_path)
62
- st.success("Model Loaded Successfully!")
63
- except Exception as e:
64
- st.error(f"Error Loading Model: {e}")
65
- model = None
66
- else:
67
- model = None
68
-
69
- # Tokenizer input
70
  tokenizer_path = st.text_input("Enter the path to your tokenizer file (.pickle):")
71
- if tokenizer_path:
 
 
 
72
  try:
 
73
  tokenizer = load_tokenizer(tokenizer_path)
74
- st.success("Tokenizer Loaded Successfully!")
75
- except Exception as e:
76
- st.error(f"Error Loading Tokenizer: {e}")
77
- tokenizer = None
78
- else:
79
- tokenizer = None
80
-
81
- # Prompt classification
82
- if model and tokenizer:
83
- user_prompt = st.text_input("Enter a prompt to classify:")
84
- if user_prompt:
85
- st.subheader("Model Prediction")
86
- try:
87
- # Classify the prompt
88
- class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
89
- st.write(f"Predicted Class: **{class_label}**")
90
- st.write(f"Confidence Score: **{confidence_score:.2f}%**")
91
 
92
- # Debugging information
93
- st.write("Debugging Information:")
94
- st.write(f"Tokenized Sequence: {tokenizer.texts_to_sequences([user_prompt])}")
95
- st.write(f"Padded Sequence: {preprocess_prompt(user_prompt, tokenizer)}")
96
- st.write(f"Raw Model Output: {model.predict(preprocess_prompt(user_prompt, tokenizer))[0][0]}")
97
 
98
- # Generate LIME explanation
99
- explanation = lime_explain(user_prompt, model, tokenizer)
100
- explanation_as_html = explanation.as_html()
101
- st.components.v1.html(explanation_as_html, height=500)
102
- except Exception as e:
103
- st.error(f"Error during prediction: {e}")
 
 
 
 
 
 
 
 
 
104
 
 
 
 
 
2
 
3
  import streamlit as st
4
  import tensorflow as tf
5
+ from tensorflow.keras.preprocessing.text import Tokenizer
6
  from tensorflow.keras.preprocessing.sequence import pad_sequences
7
+ import numpy as np
 
8
  from lime.lime_text import LimeTextExplainer
9
+ import pickle
10
+ import matplotlib.pyplot as plt
11
+
12
+ # Streamlit Title
13
+ st.title("Prompt Injection Detection and Prevention")
14
+ st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
15
 
16
+ # Cache Model Loading
17
  @st.cache_resource
18
  def load_model(filepath):
19
  return tf.keras.models.load_model(filepath)
20
 
21
+ # Load Tokenizer Function
 
 
 
22
  @st.cache_resource
23
+ def load_tokenizer(filepath):
24
+ try:
25
+ with open(filepath, 'rb') as handle:
26
+ tokenizer = pickle.load(handle)
27
+ return tokenizer
28
+ except Exception as e:
29
+ st.error(f"Error loading tokenizer: {e}")
30
+ return None
31
 
32
+ # Preprocessing Function
33
  def preprocess_prompt(prompt, tokenizer, max_length=100):
34
  sequence = tokenizer.texts_to_sequences([prompt])
35
+ return pad_sequences(sequence, maxlen=max_length)
 
36
 
37
+ # Prediction Function
38
+ def detect_prompt(prompt, tokenizer, model):
39
+ processed_prompt = preprocess_prompt(prompt, tokenizer)
40
  prediction = model.predict(processed_prompt)[0][0]
41
+ class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
42
  confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
43
  return class_label, confidence_score
44
 
45
+ # LIME Explanation
46
  def lime_explain(prompt, model, tokenizer, max_length=100):
47
+ explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
48
+
49
  def predict_fn(prompts):
50
  sequences = tokenizer.texts_to_sequences(prompts)
51
  padded_sequences = pad_sequences(sequences, maxlen=max_length)
52
  predictions = model.predict(padded_sequences)
53
  return np.hstack([1 - predictions, predictions])
 
 
 
 
 
54
 
55
+ explanation = explainer.explain_instance(
56
+ prompt,
57
+ predict_fn,
58
+ num_features=10
59
+ )
60
+ return explanation
61
 
62
+ # Load Model Section
63
+ st.subheader("Load Your Trained Model")
64
  model_path = st.text_input("Enter the path to your trained model (.h5):")
 
 
 
 
 
 
 
 
 
 
 
65
  tokenizer_path = st.text_input("Enter the path to your tokenizer file (.pickle):")
66
+ model = None
67
+ tokenizer = None
68
+
69
+ if model_path and tokenizer_path:
70
  try:
71
+ model = load_model(model_path)
72
  tokenizer = load_tokenizer(tokenizer_path)
73
+
74
+ if model and tokenizer:
75
+ st.success("Model and Tokenizer Loaded Successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
+ # User Prompt Input
78
+ st.subheader("Classify Your Prompt")
79
+ user_prompt = st.text_input("Enter a prompt to classify:")
 
 
80
 
81
+ if user_prompt:
82
+ class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
83
+ st.write(f"Predicted Class: **{class_label}**")
84
+ st.write(f"Confidence Score: **{confidence_score:.2f}%**")
85
+
86
+ # LIME Explanation
87
+ st.subheader("LIME Explanation")
88
+ explanation = lime_explain(user_prompt, model, tokenizer)
89
+ explanation_as_html = explanation.as_html()
90
+ st.components.v1.html(explanation_as_html, height=500)
91
+ else:
92
+ st.error("Failed to load model or tokenizer.")
93
+
94
+ except Exception as e:
95
+ st.error(f"Error Loading Model or Tokenizer: {e}")
96
 
97
+ # Footer
98
+ st.write("---")
99
+ st.write("Developed for detecting and preventing prompt injection attacks.")