File size: 1,153 Bytes
d184887
 
 
 
9b9207f
 
 
 
d9a9e61
9b9207f
d184887
 
 
c94eabd
 
 
d184887
 
 
 
 
 
 
c94eabd
d184887
c94eabd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import streamlit as st
import torch
import joblib

from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer

loaded_model = torch.load('WSP_Model_CPU.pkl', map_location=torch.device('cpu'))
tokenizer = BertTokenizer.from_pretrained('bert-base-german-cased')
label_encoder = joblib.load('WSP_label_encoder.joblib')

input_phrase = st.text_input("Input search text")

# Добавляем слайдер для выбора вероятности
confidence_threshold = st.slider("Confidence Threshold (%)", 0, 100, 50)

if input_phrase:
    tokenized_input = tokenizer(input_phrase, truncation=True, padding=True, return_tensors='pt')
    
    with torch.no_grad():
        outputs = loaded_model(**tokenized_input)
    
    predicted_class = torch.argmax(outputs.logits, dim=1).item()
    predicted_confidence = torch.softmax(outputs.logits, dim=1)[0][predicted_class].item() * 100
    
    if predicted_confidence >= confidence_threshold:
        predicted_category = label_encoder.inverse_transform([predicted_class])[0]
        st.text(f"Output: {predicted_category}")
    else:
        st.text(f"Output: {input_phrase}")