deepugaur commited on
Commit
d63a5db
·
verified ·
1 Parent(s): c69f467

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -61
app.py CHANGED
@@ -1,78 +1,100 @@
1
  import streamlit as st
2
- import numpy as np
3
  import pandas as pd
4
  import tensorflow as tf
 
5
  from tensorflow.keras.preprocessing.sequence import pad_sequences
 
 
6
  from lime.lime_text import LimeTextExplainer
7
  import matplotlib.pyplot as plt
8
 
9
- # Load model and tokenizer
10
- @st.cache(allow_output_mutation=True)
11
- def load_model_and_tokenizer(model_path, tokenizer_path):
12
- model = tf.keras.models.load_model(model_path)
13
- tokenizer = pd.read_pickle(tokenizer_path)
14
- return model, tokenizer
15
 
16
- # Preprocess input for prediction
17
- def preprocess_prompt(prompt, tokenizer, max_length):
18
- sequence = tokenizer.texts_to_sequences([prompt])
19
- padded_sequence = pad_sequences(sequence, maxlen=max_length)
20
- return padded_sequence
21
 
22
- # Make predictions
23
- def detect_prompt(prompt, model, tokenizer, max_length):
24
- processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
25
- prediction = model.predict(processed_prompt)[0][0]
26
- class_label = "Malicious" if prediction >= 0.5 else "Valid"
27
- confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
28
- return class_label, confidence_score
29
 
30
- # Explain predictions using LIME
31
- def lime_explain(prompt, model, tokenizer, max_length):
32
- def predict_fn(prompts):
33
- sequences = tokenizer.texts_to_sequences(prompts)
34
- padded_sequences = pad_sequences(sequences, maxlen=max_length)
35
- predictions = model.predict(padded_sequences)
36
- return np.hstack([1 - predictions, predictions]) # [P(valid), P(malicious)]
 
 
 
 
 
37
 
38
- explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
39
- explanation = explainer.explain_instance(prompt, predict_fn, num_features=10)
40
- return explanation
 
 
 
41
 
42
- # Set up Streamlit app
43
- st.title("Prompt Injection Detection and Prevention")
44
- st.write("Detect malicious prompts and understand predictions using deep learning and LIME.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # Load model and tokenizer
47
- model_path = "path/to/your/saved_model"
48
- tokenizer_path = "path/to/your/tokenizer.pkl"
49
- max_length = 100 # Update based on your model
50
- model, tokenizer = load_model_and_tokenizer(model_path, tokenizer_path)
 
51
 
52
- # Input prompt
53
- user_input = st.text_area("Enter your prompt:", height=150)
 
 
 
 
54
 
55
- if st.button("Detect"):
56
- if user_input.strip() == "":
57
- st.error("Please enter a prompt.")
58
- else:
59
- # Prediction
60
- class_label, confidence_score = detect_prompt(user_input, model, tokenizer, max_length)
61
- st.subheader("Detection Result:")
62
- st.write(f"**Class:** {class_label}")
63
- st.write(f"**Confidence Score:** {confidence_score:.2f}%")
64
 
65
- # Generate LIME explanation
66
- st.subheader("Explanation:")
67
- explanation = lime_explain(user_input, model, tokenizer, max_length)
68
- fig = explanation.as_pyplot_figure()
69
- st.pyplot(fig)
70
 
71
- # Sidebar information
72
- st.sidebar.title("About")
73
- st.sidebar.info(
74
- """
75
- This app uses a deep learning model to classify prompts as "Malicious" or "Valid."
76
- LIME explanations are provided to interpret the predictions.
77
- """
78
- )
 
1
  import streamlit as st
 
2
  import pandas as pd
3
  import tensorflow as tf
4
+ from tensorflow.keras.preprocessing.text import Tokenizer
5
  from tensorflow.keras.preprocessing.sequence import pad_sequences
6
+ from sklearn.model_selection import train_test_split
7
+ import numpy as np
8
  from lime.lime_text import LimeTextExplainer
9
  import matplotlib.pyplot as plt
10
 
11
+ # Streamlit Title
12
+ st.title("Prompt Injection Detection and Prevention")
13
+ st.write("Detect malicious prompts and understand predictions using deep learning and LIME.")
 
 
 
14
 
15
+ # Cache Data Loading
16
+ @st.cache_data
17
+ def load_data(filepath):
18
+ return pd.read_csv(filepath)
 
19
 
20
+ # Cache Model Loading
21
+ @st.cache_resource
22
+ def load_model(filepath):
23
+ return tf.keras.models.load_model(filepath)
 
 
 
24
 
25
+ # File Upload Section
26
+ uploaded_file = st.file_uploader("Upload your dataset (.csv)", type=["csv"])
27
+ if uploaded_file is not None:
28
+ data = load_data(uploaded_file)
29
+ st.write("Dataset Preview:")
30
+ st.write(data.head())
31
+
32
+ # Data Preprocessing
33
+ data['label'] = data['label'].replace({'valid': 0, 'malicious': 1})
34
+ X = data['input'].values
35
+ y = data['label'].values
36
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
37
 
38
+ # Tokenization and Padding
39
+ tokenizer = Tokenizer(num_words=5000)
40
+ tokenizer.fit_on_texts(X_train)
41
+ max_length = 100
42
+ X_train_pad = pad_sequences(tokenizer.texts_to_sequences(X_train), maxlen=max_length)
43
+ X_test_pad = pad_sequences(tokenizer.texts_to_sequences(X_test), maxlen=max_length)
44
 
45
+ # Load Deep Learning Model
46
+ model_path = st.text_input("Enter the path to your trained model (.h5):")
47
+ if model_path:
48
+ try:
49
+ model = load_model(model_path)
50
+ st.success("Model Loaded Successfully!")
51
+
52
+ # Test Prediction Functionality
53
+ def preprocess_prompt(prompt, tokenizer, max_length):
54
+ sequence = tokenizer.texts_to_sequences([prompt])
55
+ return pad_sequences(sequence, maxlen=max_length)
56
+
57
+ def detect_prompt(prompt):
58
+ processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
59
+ prediction = model.predict(processed_prompt)[0][0]
60
+ class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
61
+ confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
62
+ return class_label, confidence_score
63
+
64
+ # User Input for Prompt Detection
65
+ st.subheader("Test a Prompt")
66
+ user_prompt = st.text_input("Enter a prompt to test:")
67
+ if user_prompt:
68
+ class_label, confidence_score = detect_prompt(user_prompt)
69
+ st.write(f"Predicted Class: **{class_label}**")
70
+ st.write(f"Confidence Score: **{confidence_score:.2f}%**")
71
+
72
+ # LIME Explanation
73
+ explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
74
 
75
+ def lime_explain(prompt):
76
+ def predict_fn(prompts):
77
+ sequences = tokenizer.texts_to_sequences(prompts)
78
+ padded_sequences = pad_sequences(sequences, maxlen=max_length)
79
+ predictions = model.predict(padded_sequences)
80
+ return np.hstack([1 - predictions, predictions])
81
 
82
+ explanation = explainer.explain_instance(
83
+ prompt,
84
+ predict_fn,
85
+ num_features=10
86
+ )
87
+ return explanation
88
 
89
+ st.subheader("LIME Explanation")
90
+ if user_prompt:
91
+ explanation = lime_explain(user_prompt)
92
+ explanation_as_html = explanation.as_html()
93
+ st.components.v1.html(explanation_as_html, height=500)
 
 
 
 
94
 
95
+ except Exception as e:
96
+ st.error(f"Error Loading Model: {e}")
 
 
 
97
 
98
+ # Footer
99
+ st.write("---")
100
+ st.write("Developed for detecting and preventing prompt injection attacks using Streamlit.")