Spaces:
Sleeping
Sleeping
File size: 3,144 Bytes
c69f467 d63a5db c69f467 d63a5db c69f467 d63a5db c95a6a9 c69f467 d63a5db c69f467 c95a6a9 d63a5db c95a6a9 c69f467 c95a6a9 d63a5db c95a6a9 d63a5db c95a6a9 d63a5db c95a6a9 d63a5db c95a6a9 c69f467 c95a6a9 c69f467 c95a6a9 c69f467 c95a6a9 d63a5db c95a6a9 c69f467 c95a6a9 c69f467 d63a5db c95a6a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import streamlit as st
import tensorflow as tf
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from lime.lime_text import LimeTextExplainer
import matplotlib.pyplot as plt
# Streamlit Title
st.title("Prompt Injection Detection and Prevention")
st.write("Classify prompts as malicious or valid and understand predictions using LIME.")
# Cache Model Loading
@st.cache_resource
def load_model(filepath):
return tf.keras.models.load_model(filepath)
# Tokenizer Setup
@st.cache_resource
def setup_tokenizer():
tokenizer = Tokenizer(num_words=5000)
# Predefined vocabulary for demonstration purposes; replace with your actual tokenizer setup.
tokenizer.fit_on_texts(["example prompt", "malicious attack", "valid input prompt"])
return tokenizer
# Preprocessing Function
def preprocess_prompt(prompt, tokenizer, max_length=100):
sequence = tokenizer.texts_to_sequences([prompt])
return pad_sequences(sequence, maxlen=max_length)
# Prediction Function
def detect_prompt(prompt, tokenizer, model):
processed_prompt = preprocess_prompt(prompt, tokenizer)
prediction = model.predict(processed_prompt)[0][0]
class_label = 'Malicious' if prediction >= 0.5 else 'Valid'
confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
return class_label, confidence_score
# LIME Explanation
def lime_explain(prompt, model, tokenizer, max_length=100):
explainer = LimeTextExplainer(class_names=["Valid", "Malicious"])
def predict_fn(prompts):
sequences = tokenizer.texts_to_sequences(prompts)
padded_sequences = pad_sequences(sequences, maxlen=max_length)
predictions = model.predict(padded_sequences)
return np.hstack([1 - predictions, predictions])
explanation = explainer.explain_instance(
prompt,
predict_fn,
num_features=10
)
return explanation
# Load Model Section
st.subheader("Load Your Trained Model")
model_path = st.text_input("Enter the path to your trained model (.h5):")
model = None
tokenizer = None
if model_path:
try:
model = load_model(model_path)
tokenizer = setup_tokenizer()
st.success("Model Loaded Successfully!")
# User Prompt Input
st.subheader("Classify Your Prompt")
user_prompt = st.text_input("Enter a prompt to classify:")
if user_prompt:
class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model)
st.write(f"Predicted Class: **{class_label}**")
st.write(f"Confidence Score: **{confidence_score:.2f}%**")
# LIME Explanation
st.subheader("LIME Explanation")
explanation = lime_explain(user_prompt, model, tokenizer)
explanation_as_html = explanation.as_html()
st.components.v1.html(explanation_as_html, height=500)
except Exception as e:
st.error(f"Error Loading Model: {e}")
# Footer
st.write("---")
st.write("Developed for detecting and preventing prompt injection attacks.")
|