Spaces:
Sleeping
Sleeping
File size: 2,337 Bytes
fdb1a12 23f87a3 cf54d78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import streamlit as st
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
# Constants
MAX_LENGTH = 100
TOKENIZER_PATH = "tokenizer.json" # Assuming the tokenizer is saved separately.
# Load pre-trained model
@st.cache_resource
def load_trained_model():
return load_model("deep_learning_model.h5")
# Load tokenizer
@st.cache_resource
def load_tokenizer():
import json
from tensorflow.keras.preprocessing.text import tokenizer_from_json
with open(TOKENIZER_PATH, "r") as f:
tokenizer_data = json.load(f)
return tokenizer_from_json(tokenizer_data)
# Preprocessing function
def preprocess_prompt(prompt, tokenizer, max_length):
sequence = tokenizer.texts_to_sequences([prompt])
padded_sequence = pad_sequences(sequence, maxlen=max_length)
return padded_sequence
# Predict function
def detect_prompt(prompt, model, tokenizer, max_length):
processed_prompt = preprocess_prompt(prompt, tokenizer, max_length)
prediction = model.predict(processed_prompt)[0][0]
class_label = "Malicious" if prediction >= 0.5 else "Valid"
confidence_score = prediction * 100 if prediction >= 0.5 else (1 - prediction) * 100
return class_label, confidence_score
# Streamlit App
st.title("Prompt Injection Detection App")
st.write("Detect and prevent prompt injection attacks using a deep learning model.")
# Load model and tokenizer
model = load_trained_model()
tokenizer = load_tokenizer()
# Input Section
user_input = st.text_area("Enter a prompt to test:", "")
if st.button("Detect"):
if user_input:
label, confidence = detect_prompt(user_input, model, tokenizer, MAX_LENGTH)
st.write(f"**Predicted Class:** {label}")
st.write(f"**Confidence Score:** {confidence:.2f}%")
else:
st.warning("Please enter a prompt to test.")
import os
if st.button("Train Model"):
os.system("python train_model.py")
st.success("Model training complete. Saved as deep_learning_model.h5")
if not os.path.exists("deep_learning_model.h5"):
st.info("Training the model for the first time...")
os.system("python train_model.py")
st.success("Model trained successfully and saved as deep_learning_model.h5")
|