Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing.text import Tokenizer | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
from sklearn.model_selection import train_test_split | |
import numpy as np | |
from lime.lime_text import LimeTextExplainer | |
import matplotlib.pyplot as plt | |
# Streamlit Title | |
st.title("Prompt Injection Detection and Prevention") | |
st.write("Detect malicious prompts and understand predictions using deep learning and LIME.") | |
# Cache Data Loading | |
def load_data(filepath): | |
return pd.read_csv(filepath) | |
# Cache Model Loading | |
def load_model(filepath): | |
return tf.keras.models.load_model(filepath) | |
# File Upload Section | |
uploaded_file = st.file_uploader("Upload your dataset (.csv)", type=["csv"]) | |
if uploaded_file is not None: | |
data = load_data(uploaded_file) | |
st.write("Dataset Preview:") | |
st.write(data.head()) | |
# Data Preprocessing | |
data['label'] = data['label'].replace({'valid': 0, 'malicious': 1}) | |
X = data['input'].values | |
y = data['label'].values | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
# Tokenization and Padding | |
tokenizer = Tokenizer(num_words=5000) | |
tokenizer.fit_on_texts(X_train) | |
max_length = 100 | |
X_train_pad = pad_sequences(tokenizer.texts_to_sequences(X_train), maxlen=max_length) | |
X_test_pad = pad_sequences(tokenizer.texts_to_sequences(X_test), maxlen=max_length) | |
# Load Deep Learning Model | |
model_path = st.text_input("Enter the path to your trained model (.h5):") | |
if model_path: | |
try: | |
model = load_model(model_path) | |
st.success("Model Loaded Successfully!") | |
# Test Prediction Functionality | |
def preprocess_prompt(prompt, tokenizer, max_length): | |
sequence = tokenizer.texts_to_sequences([prompt]) | |
return pad_sequences(sequence, maxlen=max_length) | |
def detect_prompt(prompt): | |
processed_prompt = preprocess_prompt(prompt, tokenizer, max_length) | |
prediction = model.predict(processed_prompt)[0][0] | |
class_label = 'Malicious' if prediction >= 0.5 else 'Valid' | |
confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100 | |
return class_label, confidence_score | |
# User Input for Prompt Detection | |
st.subheader("Test a Prompt") | |
user_prompt = st.text_input("Enter a prompt to test:") | |
if user_prompt: | |
class_label, confidence_score = detect_prompt(user_prompt) | |
st.write(f"Predicted Class: **{class_label}**") | |
st.write(f"Confidence Score: **{confidence_score:.2f}%**") | |
# LIME Explanation | |
explainer = LimeTextExplainer(class_names=["Valid", "Malicious"]) | |
def lime_explain(prompt): | |
def predict_fn(prompts): | |
sequences = tokenizer.texts_to_sequences(prompts) | |
padded_sequences = pad_sequences(sequences, maxlen=max_length) | |
predictions = model.predict(padded_sequences) | |
return np.hstack([1 - predictions, predictions]) | |
explanation = explainer.explain_instance( | |
prompt, | |
predict_fn, | |
num_features=10 | |
) | |
return explanation | |
st.subheader("LIME Explanation") | |
if user_prompt: | |
explanation = lime_explain(user_prompt) | |
explanation_as_html = explanation.as_html() | |
st.components.v1.html(explanation_as_html, height=500) | |
except Exception as e: | |
st.error(f"Error Loading Model: {e}") | |
# Footer | |
st.write("---") | |
st.write("Developed for detecting and preventing prompt injection attacks using Streamlit.") | |