mjpsm commited on
Commit
2517ed4
Β·
verified Β·
1 Parent(s): 150d94b

Create check-in

Browse files
Files changed (1) hide show
  1. check-in +55 -0
check-in ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ # Map model labels to human-readable labels
7
+ LABEL_MAP = {
8
+ "LABEL_0": "Bad",
9
+ "LABEL_1": "Mediocre",
10
+ "LABEL_2": "Good"
11
+ }
12
+
13
+ # Load model and tokenizer
14
+ @st.cache_resource
15
+ def load_model():
16
+ tokenizer = AutoTokenizer.from_pretrained("mjpsm/check-ins-classifier")
17
+ model = AutoModelForSequenceClassification.from_pretrained("mjpsm/check-ins-classifier")
18
+ return tokenizer, model
19
+
20
+ tokenizer, model = load_model()
21
+
22
+ st.title("Check-In Classifier")
23
+ st.write("Enter your check-in so I can see if it's **Good**, **Mediocre**, or **Bad**.")
24
+
25
+ # User input
26
+ user_input = st.text_area("πŸ’¬ Your Check-In Message:", height=50)
27
+
28
+ if st.button("πŸ” Analyze"):
29
+ if user_input.strip() == "":
30
+ st.warning("Please enter some text first!")
31
+ else:
32
+ # Tokenize input
33
+ inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
34
+
35
+ # Run inference
36
+ with torch.no_grad():
37
+ outputs = model(**inputs)
38
+
39
+ logits = outputs.logits
40
+ probs = F.softmax(logits, dim=1)
41
+
42
+ # Get prediction
43
+ predicted_class = torch.argmax(probs, dim=1).item()
44
+ label_key = model.config.id2label[predicted_class]
45
+ human_label = LABEL_MAP.get(label_key, label_key)
46
+ confidence = torch.max(probs).item()
47
+
48
+ st.success(f"🧾 Prediction: **{human_label}** (Confidence: {confidence:.2%})")
49
+
50
+ # Show all class probabilities with human-readable labels
51
+ st.subheader("πŸ“Š Class Probabilities:")
52
+ for idx, prob in enumerate(probs[0]):
53
+ label_key = model.config.id2label.get(idx, f"LABEL_{idx}")
54
+ label_name = LABEL_MAP.get(label_key, label_key)
55
+ st.write(f"{label_name}: {prob:.2%}")