ddovidovich
Add comments to code
db38025
import streamlit as st
import torch
import joblib
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer
# Load model to CPU
loaded_model = torch.load('WSP_Model_CPU.pkl', map_location=torch.device('cpu'))
tokenizer = BertTokenizer.from_pretrained('bert-base-german-cased')
# Load labels
label_encoder = joblib.load('WSP_label_encoder.joblib')
# Create text input field for input phrase
input_phrase = st.text_input("Input search text")
# Add a slider to select the probability
confidence_threshold = st.slider("Confidence Threshold (%)", 0, 100, 50)
if input_phrase:
tokenized_input = tokenizer(input_phrase.lower(), truncation=True, padding=True, return_tensors='pt')
with torch.no_grad():
outputs = loaded_model(**tokenized_input)
# Get predicted class and it probability
predicted_class = torch.argmax(outputs.logits, dim=1).item()
predicted_confidence = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item() * 100
# If probability equal or more than selected probability - show predicted result, else show user input
if predicted_confidence >= confidence_threshold:
predicted_category = label_encoder.inverse_transform([predicted_class])[0]
st.text(f"Output: {predicted_category}")
else:
st.text(f"Output: {input_phrase}")