Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
from tensorflow.keras.preprocessing.text import Tokenizer | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
from tensorflow.keras.models import load_model | |
from lime.lime_text import LimeTextExplainer | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# Load the trained model | |
def load_trained_model(): | |
model = load_model("deep_learning_model.h5") | |
return model | |
model = load_trained_model() | |
# Tokenizer setup | |
tokenizer = Tokenizer(num_words=5000) | |
max_length = 100 | |
# Load Data | |
def load_data(): | |
data = pd.read_csv("train prompt.csv", sep=',', quoting=3, encoding='ISO-8859-1', on_bad_lines='skip', engine='python') | |
data['label'] = data['label'].replace({'valid': 0, 'malicious': 1}) | |
return data | |
data = load_data() | |
tokenizer.fit_on_texts(data['input'].values) | |
# Preprocessing functions | |
def preprocess_prompt(prompt, tokenizer, max_length): | |
sequence = tokenizer.texts_to_sequences([prompt]) | |
padded_sequence = pad_sequences(sequence, maxlen=max_length) | |
return padded_sequence | |
def detect_prompt(prompt, model, tokenizer, max_length): | |
processed_prompt = preprocess_prompt(prompt, tokenizer, max_length) | |
prediction = model.predict(processed_prompt)[0][0] | |
class_label = 'Malicious' if prediction >= 0.5 else 'Valid' | |
confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100 | |
return class_label, confidence_score | |
# Load model | |
model = load_trained_model() | |
# Streamlit app | |
st.title("Prompt Injection Attack Detection") | |
st.write("This application detects malicious prompts to prevent injection attacks.") | |
prompt = st.text_input("Enter a prompt to analyze:") | |
if prompt: | |
class_label, confidence = detect_prompt(prompt, model, tokenizer, max_length) | |
st.write(f"### Prediction: {class_label}") | |
st.write(f"Confidence: {confidence:.2f}%") | |
# LIME explanation | |
st.write("Generating LIME Explanation...") | |
explainer = LimeTextExplainer(class_names=["Valid", "Malicious"]) | |
def predict_fn(prompts): | |
sequences = tokenizer.texts_to_sequences(prompts) | |
padded_sequences = pad_sequences(sequences, maxlen=max_length) | |
predictions = model.predict(padded_sequences) | |
return np.hstack([1 - predictions, predictions]) | |
explanation = explainer.explain_instance(prompt, predict_fn, num_features=10) | |
fig = explanation.as_pyplot_figure() | |
st.pyplot(fig) | |