deepugaur commited on
Commit
c95a6a9
·
verified ·
1 Parent(s): d63a5db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -70
app.py CHANGED
@@ -1,100 +1,89 @@
 
1
  import streamlit as st
2
- import pandas as pd
3
  import tensorflow as tf
4
  from tensorflow.keras.preprocessing.text import Tokenizer
5
  from tensorflow.keras.preprocessing.sequence import pad_sequences
6
- from sklearn.model_selection import train_test_split
7
  import numpy as np
8
  from lime.lime_text import LimeTextExplainer
9
  import matplotlib.pyplot as plt
10
 
11
  # Streamlit Title
12
  st.title("Prompt Injection Detection and Prevention")
13
- st.write("Detect malicious prompts and understand predictions using deep learning and LIME.")
14
-
15
- # Cache Data Loading
16
- @st.cache_data
17
- def load_data(filepath):
18
- return pd.read_csv(filepath)
19
 
20
  # Cache Model Loading
21
  @st.cache_resource
22
  def load_model(filepath):
23
  return tf.keras.models.load_model(filepath)
24
 
25
- # File Upload Section
26
- uploaded_file = st.file_uploader("Upload your dataset (.csv)", type=["csv"])
27
- if uploaded_file is not None:
28
- data = load_data(uploaded_file)
29
- st.write("Dataset Preview:")
30
- st.write(data.head())
31
-
32
- # Data Preprocessing
33
- data['label'] = data['label'].replace({'valid': 0, 'malicious': 1})
34
- X = data['input'].values
35
- y = data['label'].values
36
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
37
-
38
- # Tokenization and Padding
39
  tokenizer = Tokenizer(num_words=5000)
40
- tokenizer.fit_on_texts(X_train)
41
- max_length = 100
42
- X_train_pad = pad_sequences(tokenizer.texts_to_sequences(X_train), maxlen=max_length)
43
- X_test_pad = pad_sequences(tokenizer.texts_to_sequences(X_test), maxlen=max_length)
44
 
45
- # Load Deep Learning Model
46
- model_path = st.text_input("Enter the path to your trained model (.h5):")
47
- if model_path:
48
- try:
49
- model = load_model(model_path)
50
- st.success("Model Loaded Successfully!")
51
 
52
- # Test Prediction Functionality
53
- def preprocess_prompt(prompt, tokenizer, max_length):
54
- sequence = tokenizer.texts_to_sequences([prompt])
55
- return pad_sequences(sequence, maxlen=max_length)
 
 
 
56
 
57
- def detect_prompt(prompt):
58
- processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
59
- prediction = model.predict(processed_prompt)[0][0]
60
- class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
61
- confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
62
- return class_label, confidence_score
 
 
 
63
 
64
- # User Input for Prompt Detection
65
- st.subheader("Test a Prompt")
66
- user_prompt = st.text_input("Enter a prompt to test:")
67
- if user_prompt:
68
- class_label, confidence_score = detect_prompt(user_prompt)
69
- st.write(f"Predicted Class: **{class_label}**")
70
- st.write(f"Confidence Score: **{confidence_score:.2f}%**")
71
 
72
- # LIME Explanation
73
- explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
 
 
 
74
 
75
- def lime_explain(prompt):
76
- def predict_fn(prompts):
77
- sequences = tokenizer.texts_to_sequences(prompts)
78
- padded_sequences = pad_sequences(sequences, maxlen=max_length)
79
- predictions = model.predict(padded_sequences)
80
- return np.hstack([1 - predictions, predictions])
 
 
 
81
 
82
- explanation = explainer.explain_instance(
83
- prompt,
84
- predict_fn,
85
- num_features=10
86
- )
87
- return explanation
88
 
 
89
  st.subheader("LIME Explanation")
90
- if user_prompt:
91
- explanation = lime_explain(user_prompt)
92
- explanation_as_html = explanation.as_html()
93
- st.components.v1.html(explanation_as_html, height=500)
94
 
95
- except Exception as e:
96
- st.error(f"Error Loading Model: {e}")
97
 
98
  # Footer
99
  st.write("---")
100
- st.write("Developed for detecting and preventing prompt injection attacks using Streamlit.")
 
1
+
2
  import streamlit as st
 
3
  import tensorflow as tf
4
  from tensorflow.keras.preprocessing.text import Tokenizer
5
  from tensorflow.keras.preprocessing.sequence import pad_sequences
 
6
  import numpy as np
7
  from lime.lime_text import LimeTextExplainer
8
  import matplotlib.pyplot as plt
9
 
10
  # Streamlit Title
11
  st.title("Prompt Injection Detection and Prevention")
12
+ st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
 
 
 
 
 
13
 
14
  # Cache Model Loading
15
  @st.cache_resource
16
  def load_model(filepath):
17
  return tf.keras.models.load_model(filepath)
18
 
19
+ # Tokenizer Setup
20
+ @st.cache_resource
21
+ def setup_tokenizer():
 
 
 
 
 
 
 
 
 
 
 
22
  tokenizer = Tokenizer(num_words=5000)
23
+ # Predefined vocabulary for demonstration purposes; replace with your actual tokenizer setup.
24
+ tokenizer.fit_on_texts(["example prompt", "malicious attack", "valid input prompt"])
25
+ return tokenizer
 
26
 
27
+ # Preprocessing Function
28
+ def preprocess_prompt(prompt, tokenizer, max_length=100):
29
+ sequence = tokenizer.texts_to_sequences([prompt])
30
+ return pad_sequences(sequence, maxlen=max_length)
 
 
31
 
32
+ # Prediction Function
33
+ def detect_prompt(prompt, tokenizer, model):
34
+ processed_prompt = preprocess_prompt(prompt, tokenizer)
35
+ prediction = model.predict(processed_prompt)[0][0]
36
+ class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
37
+ confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
38
+ return class_label, confidence_score
39
 
40
+ # LIME Explanation
41
+ def lime_explain(prompt, model, tokenizer, max_length=100):
42
+ explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
43
+
44
+ def predict_fn(prompts):
45
+ sequences = tokenizer.texts_to_sequences(prompts)
46
+ padded_sequences = pad_sequences(sequences, maxlen=max_length)
47
+ predictions = model.predict(padded_sequences)
48
+ return np.hstack([1 - predictions, predictions])
49
 
50
+ explanation = explainer.explain_instance(
51
+ prompt,
52
+ predict_fn,
53
+ num_features=10
54
+ )
55
+ return explanation
 
56
 
57
+ # Load Model Section
58
+ st.subheader("Load Your Trained Model")
59
+ model_path = st.text_input("Enter the path to your trained model (.h5):")
60
+ model = None
61
+ tokenizer = None
62
 
63
+ if model_path:
64
+ try:
65
+ model = load_model(model_path)
66
+ tokenizer = setup_tokenizer()
67
+ st.success("Model Loaded Successfully!")
68
+
69
+ # User Prompt Input
70
+ st.subheader("Classify Your Prompt")
71
+ user_prompt = st.text_input("Enter a prompt to classify:")
72
 
73
+ if user_prompt:
74
+ class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
75
+ st.write(f"Predicted Class: **{class_label}**")
76
+ st.write(f"Confidence Score: **{confidence_score:.2f}%**")
 
 
77
 
78
+ # LIME Explanation
79
  st.subheader("LIME Explanation")
80
+ explanation = lime_explain(user_prompt, model, tokenizer)
81
+ explanation_as_html = explanation.as_html()
82
+ st.components.v1.html(explanation_as_html, height=500)
 
83
 
84
+ except Exception as e:
85
+ st.error(f"Error Loading Model: {e}")
86
 
87
  # Footer
88
  st.write("---")
89
+ st.write("Developed for detecting and preventing prompt injection attacks.")