File size: 1,923 Bytes
c69f467
 
d63a5db
58990ee
c69f467
58990ee
 
d63a5db
58990ee
c69f467
58990ee
d63a5db
58990ee
 
 
 
 
 
c69f467
58990ee
c95a6a9
58990ee
c95a6a9
58990ee
 
c95a6a9
58990ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c69f467
d63a5db
 
58990ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import streamlit as st
import tensorflow as tf
import numpy as np
import pickle

# Set page title and header
st.set_page_config(page_title="Prompt Injection Detection and Prevention")
st.title("Prompt Injection Detection and Prevention")
st.subheader("Classify prompts as malicious or valid and understand predictions using LIME.")

# Load the trained model
@st.cache_resource
def load_model(model_path):
    try:
        return tf.keras.models.load_model(model_path)
    except Exception as e:
        st.error(f"Error loading model: {e}")
        return None

# Load the tokenizer
@st.cache_resource
def load_tokenizer(tokenizer_path):
    try:
        with open(tokenizer_path, "rb") as f:
            return pickle.load(f)
    except Exception as e:
        st.error(f"Error loading tokenizer: {e}")
        return None

# Paths to your files (these should be present in your Hugging Face repository)
MODEL_PATH = "model.h5"
TOKENIZER_PATH = "tokenizer.pkl"

# Load model and tokenizer
model = load_model(MODEL_PATH)
tokenizer = load_tokenizer(TOKENIZER_PATH)

if model and tokenizer:
    st.success("Model and tokenizer loaded successfully!")

# User input for prompt classification
st.write("## Classify a Prompt")
user_input = st.text_area("Enter a prompt for classification:")
if st.button("Classify"):
    if user_input:
        # Preprocess the user input
        sequence = tokenizer.texts_to_sequences([user_input])
        padded_sequence = tf.keras.preprocessing.sequence.pad_sequences(sequence, maxlen=50)

        # Make prediction
        prediction = model.predict(padded_sequence)
        label = "Malicious" if prediction[0] > 0.5 else "Valid"
        st.write(f"Prediction: **{label}** (Confidence: {prediction[0][0]:.2f})")
    else:
        st.error("Please enter a prompt for classification.")

# Footer
st.write("---")
st.caption("Developed for detecting and preventing prompt injection attacks.")