Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import joblib | |
from sklearn.preprocessing import LabelEncoder | |
from transformers import BertTokenizer | |
# Load model to CPU | |
loaded_model = torch.load('WSP_Model_CPU.pkl', map_location=torch.device('cpu')) | |
tokenizer = BertTokenizer.from_pretrained('bert-base-german-cased') | |
# Load labels | |
label_encoder = joblib.load('WSP_label_encoder.joblib') | |
# Create text input field for input phrase | |
input_phrase = st.text_input("Input search text") | |
# Add a slider to select the probability | |
confidence_threshold = st.slider("Confidence Threshold (%)", 0, 100, 50) | |
if input_phrase: | |
tokenized_input = tokenizer(input_phrase.lower(), truncation=True, padding=True, return_tensors='pt') | |
with torch.no_grad(): | |
outputs = loaded_model(**tokenized_input) | |
# Get predicted class and it probability | |
predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
predicted_confidence = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item() * 100 | |
# If probability equal or more than selected probability - show predicted result, else show user input | |
if predicted_confidence >= confidence_threshold: | |
predicted_category = label_encoder.inverse_transform([predicted_class])[0] | |
st.text(f"Output: {predicted_category}") | |
else: | |
st.text(f"Output: {input_phrase}") |