Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tensorflow as tf | |
import numpy as np | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
import pickle | |
from lime import lime_text | |
from lime.lime_text import LimeTextExplainer | |
# Load the model | |
def load_model(filepath): | |
return tf.keras.models.load_model(filepath) | |
# Load tokenizer | |
def load_tokenizer(filepath): | |
with open(filepath, 'rb') as handle: | |
return pickle.load(handle) | |
# Preprocess prompt | |
def preprocess_prompt(prompt, tokenizer, max_length=100): | |
sequence = tokenizer.texts_to_sequences([prompt]) | |
padded_sequence = pad_sequences(sequence, maxlen=max_length) | |
return padded_sequence | |
# Predict prompt class | |
def detect_prompt(prompt, tokenizer, model, max_length=100): | |
processed_prompt = preprocess_prompt(prompt, tokenizer, max_length) | |
prediction = model.predict(processed_prompt)[0][0] | |
class_label = "Malicious" if prediction >= 0.5 else "Valid" | |
confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100 | |
return class_label, confidence_score | |
# LIME explanation | |
def lime_explain(prompt, model, tokenizer, max_length=100): | |
def predict_fn(prompts): | |
sequences = tokenizer.texts_to_sequences(prompts) | |
padded_sequences = pad_sequences(sequences, maxlen=max_length) | |
predictions = model.predict(padded_sequences) | |
return np.hstack([1 - predictions, predictions]) | |
class_names = ["Valid", "Malicious"] | |
explainer = LimeTextExplainer(class_names=class_names) | |
explanation = explainer.explain_instance(prompt, predict_fn, num_features=10) | |
return explanation | |
# Streamlit App | |
st.title("Prompt Injection Detection and Prevention") | |
st.write("Classify prompts as malicious or valid and understand predictions using LIME.") | |
# Model input | |
model_path = st.text_input("Enter the path to your trained model (.h5):") | |
if model_path: | |
try: | |
model = load_model(model_path) | |
st.success("Model Loaded Successfully!") | |
except Exception as e: | |
st.error(f"Error Loading Model: {e}") | |
model = None | |
else: | |
model = None | |
# Tokenizer input | |
tokenizer_path = st.text_input("Enter the path to your tokenizer file (.pickle):") | |
if tokenizer_path: | |
try: | |
tokenizer = load_tokenizer(tokenizer_path) | |
st.success("Tokenizer Loaded Successfully!") | |
except Exception as e: | |
st.error(f"Error Loading Tokenizer: {e}") | |
tokenizer = None | |
else: | |
tokenizer = None | |
# Prompt classification | |
if model and tokenizer: | |
user_prompt = st.text_input("Enter a prompt to classify:") | |
if user_prompt: | |
st.subheader("Model Prediction") | |
try: | |
# Classify the prompt | |
class_label, confidence_score = detect_prompt(user_prompt, tokenizer, model) | |
st.write(f"Predicted Class: **{class_label}**") | |
st.write(f"Confidence Score: **{confidence_score:.2f}%**") | |
# Debugging information | |
st.write("Debugging Information:") | |
st.write(f"Tokenized Sequence: {tokenizer.texts_to_sequences([user_prompt])}") | |
st.write(f"Padded Sequence: {preprocess_prompt(user_prompt, tokenizer)}") | |
st.write(f"Raw Model Output: {model.predict(preprocess_prompt(user_prompt, tokenizer))[0][0]}") | |
# Generate LIME explanation | |
explanation = lime_explain(user_prompt, model, tokenizer) | |
explanation_as_html = explanation.as_html() | |
st.components.v1.html(explanation_as_html, height=500) | |
except Exception as e: | |
st.error(f"Error during prediction: {e}") | |