deepugaur commited on
Commit
58990ee
·
verified ·
1 Parent(s): 8b4479c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -74
app.py CHANGED
@@ -1,88 +1,59 @@
1
  import streamlit as st
2
  import tensorflow as tf
3
- from tensorflow.keras.preprocessing.text import Tokenizer
4
- from tensorflow.keras.preprocessing.sequence import pad_sequences
5
  import numpy as np
6
- from lime.lime_text import LimeTextExplainer
7
- import matplotlib.pyplot as plt
8
 
9
- # Streamlit Title
 
10
  st.title("Prompt Injection Detection and Prevention")
11
- st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
12
 
13
- # Cache Model Loading
14
  @st.cache_resource
15
- def load_model(filepath):
16
- return tf.keras.models.load_model(filepath)
 
 
 
 
17
 
18
- # Tokenizer Setup
19
  @st.cache_resource
20
- def setup_tokenizer():
21
- tokenizer = Tokenizer(num_words=5000)
22
- # Predefined vocabulary for demonstration purposes; replace with your actual tokenizer setup.
23
- tokenizer.fit_on_texts(["example prompt", "malicious attack", "valid input prompt"])
24
- return tokenizer
25
-
26
- # Preprocessing Function
27
- def preprocess_prompt(prompt, tokenizer, max_length=100):
28
- sequence = tokenizer.texts_to_sequences([prompt])
29
- return pad_sequences(sequence, maxlen=max_length)
30
-
31
- # Prediction Function
32
- def detect_prompt(prompt, tokenizer, model):
33
- processed_prompt = preprocess_prompt(prompt, tokenizer)
34
- prediction = model.predict(processed_prompt)[0][0]
35
- class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
36
- confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
37
- return class_label, confidence_score
38
-
39
- # LIME Explanation
40
- def lime_explain(prompt, model, tokenizer, max_length=100):
41
- explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
42
-
43
- def predict_fn(prompts):
44
- sequences = tokenizer.texts_to_sequences(prompts)
45
- padded_sequences = pad_sequences(sequences, maxlen=max_length)
46
- predictions = model.predict(padded_sequences)
47
- return np.hstack([1 - predictions, predictions])
48
-
49
- explanation = explainer.explain_instance(
50
- prompt,
51
- predict_fn,
52
- num_features=10
53
- )
54
- return explanation
55
-
56
- # Load Model Section
57
- st.subheader("Load Your Trained Model")
58
- model_path = st.text_input("Enter the path to your trained model (.h5):")
59
- model = None
60
- tokenizer = None
61
-
62
- if model_path:
63
  try:
64
- model = load_model(model_path)
65
- tokenizer = setup_tokenizer()
66
- st.success("Model Loaded Successfully!")
67
-
68
- # User Prompt Input
69
- st.subheader("Classify Your Prompt")
70
- user_prompt = st.text_input("Enter a prompt to classify:")
71
-
72
- if user_prompt:
73
- class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
74
- st.write(f"Predicted Class: **{class_label}**")
75
- st.write(f"Confidence Score: **{confidence_score:.2f}%**")
76
-
77
- # LIME Explanation
78
- st.subheader("LIME Explanation")
79
- explanation = lime_explain(user_prompt, model, tokenizer)
80
- explanation_as_html = explanation.as_html()
81
- st.components.v1.html(explanation_as_html, height=500)
82
-
83
  except Exception as e:
84
- st.error(f"Error Loading Model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  # Footer
87
  st.write("---")
88
- st.write("Developed for detecting and preventing prompt injection attacks.")
 
1
  import streamlit as st
2
  import tensorflow as tf
 
 
3
  import numpy as np
4
+ import pickle
 
5
 
6
+ # Set page title and header
7
+ st.set_page_config(page_title="Prompt Injection Detection and Prevention")
8
  st.title("Prompt Injection Detection and Prevention")
9
+ st.subheader("Classify prompts as malicious or valid and understand predictions using LIME.")
10
 
11
+ # Load the trained model
12
  @st.cache_resource
13
+ def load_model(model_path):
14
+ try:
15
+ return tf.keras.models.load_model(model_path)
16
+ except Exception as e:
17
+ st.error(f"Error loading model: {e}")
18
+ return None
19
 
20
+ # Load the tokenizer
21
  @st.cache_resource
22
+ def load_tokenizer(tokenizer_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  try:
24
+ with open(tokenizer_path, "rb") as f:
25
+ return pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  except Exception as e:
27
+ st.error(f"Error loading tokenizer: {e}")
28
+ return None
29
+
30
+ # Paths to your files (these should be present in your Hugging Face repository)
31
+ MODEL_PATH = "model.h5"
32
+ TOKENIZER_PATH = "tokenizer.pkl"
33
+
34
+ # Load model and tokenizer
35
+ model = load_model(MODEL_PATH)
36
+ tokenizer = load_tokenizer(TOKENIZER_PATH)
37
+
38
+ if model and tokenizer:
39
+ st.success("Model and tokenizer loaded successfully!")
40
+
41
+ # User input for prompt classification
42
+ st.write("## Classify a Prompt")
43
+ user_input = st.text_area("Enter a prompt for classification:")
44
+ if st.button("Classify"):
45
+ if user_input:
46
+ # Preprocess the user input
47
+ sequence = tokenizer.texts_to_sequences([user_input])
48
+ padded_sequence = tf.keras.preprocessing.sequence.pad_sequences(sequence, maxlen=50)
49
+
50
+ # Make prediction
51
+ prediction = model.predict(padded_sequence)
52
+ label = "Malicious" if prediction[0] > 0.5 else "Valid"
53
+ st.write(f"Prediction: **{label}** (Confidence: {prediction[0][0]:.2f})")
54
+ else:
55
+ st.error("Please enter a prompt for classification.")
56
 
57
  # Footer
58
  st.write("---")
59
+ st.caption("Developed for detecting and preventing prompt injection attacks.")