Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -10,20 +10,23 @@ LABEL_MAP = {
|
|
10 |
"LABEL_2": "Good"
|
11 |
}
|
12 |
|
13 |
-
# Load model and tokenizer
|
14 |
@st.cache_resource
|
15 |
def load_model():
|
|
|
|
|
16 |
tokenizer = AutoTokenizer.from_pretrained("mjpsm/check-ins-classifier")
|
17 |
model = AutoModelForSequenceClassification.from_pretrained("mjpsm/check-ins-classifier")
|
18 |
-
|
|
|
19 |
|
20 |
-
tokenizer, model = load_model()
|
21 |
|
22 |
st.title("Check-In Classifier")
|
23 |
st.write("Enter your check-in so I can see if it's **Good**, **Mediocre**, or **Bad**.")
|
24 |
|
25 |
# User input
|
26 |
-
user_input = st.text_area("π¬ Your Check-In Message:"
|
27 |
|
28 |
if st.button("π Analyze"):
|
29 |
if user_input.strip() == "":
|
@@ -32,6 +35,9 @@ if st.button("π Analyze"):
|
|
32 |
# Tokenize input
|
33 |
inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
|
34 |
|
|
|
|
|
|
|
35 |
# Run inference
|
36 |
with torch.no_grad():
|
37 |
outputs = model(**inputs)
|
@@ -53,3 +59,6 @@ if st.button("π Analyze"):
|
|
53 |
label_key = model.config.id2label.get(idx, f"LABEL_{idx}")
|
54 |
label_name = LABEL_MAP.get(label_key, label_key)
|
55 |
st.write(f"{label_name}: {prob:.2%}")
|
|
|
|
|
|
|
|
10 |
"LABEL_2": "Good"
|
11 |
}
|
12 |
|
13 |
+
# Load model and tokenizer (force CPU usage)
|
14 |
@st.cache_resource
|
15 |
def load_model():
|
16 |
+
# Check for CUDA (GPU) availability
|
17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
tokenizer = AutoTokenizer.from_pretrained("mjpsm/check-ins-classifier")
|
19 |
model = AutoModelForSequenceClassification.from_pretrained("mjpsm/check-ins-classifier")
|
20 |
+
model.to(device) # Move model to the available device
|
21 |
+
return tokenizer, model, device
|
22 |
|
23 |
+
tokenizer, model, device = load_model()
|
24 |
|
25 |
st.title("Check-In Classifier")
|
26 |
st.write("Enter your check-in so I can see if it's **Good**, **Mediocre**, or **Bad**.")
|
27 |
|
28 |
# User input
|
29 |
+
user_input = st.text_area("π¬ Your Check-In Message:")
|
30 |
|
31 |
if st.button("π Analyze"):
|
32 |
if user_input.strip() == "":
|
|
|
35 |
# Tokenize input
|
36 |
inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
|
37 |
|
38 |
+
# Move input tensors to the same device as the model
|
39 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
40 |
+
|
41 |
# Run inference
|
42 |
with torch.no_grad():
|
43 |
outputs = model(**inputs)
|
|
|
59 |
label_key = model.config.id2label.get(idx, f"LABEL_{idx}")
|
60 |
label_name = LABEL_MAP.get(label_key, label_key)
|
61 |
st.write(f"{label_name}: {prob:.2%}")
|
62 |
+
|
63 |
+
|
64 |
+
|