mjpsm commited on
Commit
9ef5418
Β·
verified Β·
1 Parent(s): edd52f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -10,20 +10,23 @@ LABEL_MAP = {
10
  "LABEL_2": "Good"
11
  }
12
 
13
- # Load model and tokenizer
14
  @st.cache_resource
15
  def load_model():
 
 
16
  tokenizer = AutoTokenizer.from_pretrained("mjpsm/check-ins-classifier")
17
  model = AutoModelForSequenceClassification.from_pretrained("mjpsm/check-ins-classifier")
18
- return tokenizer, model
 
19
 
20
- tokenizer, model = load_model()
21
 
22
  st.title("Check-In Classifier")
23
  st.write("Enter your check-in so I can see if it's **Good**, **Mediocre**, or **Bad**.")
24
 
25
  # User input
26
- user_input = st.text_area("πŸ’¬ Your Check-In Message:", height=50)
27
 
28
  if st.button("πŸ” Analyze"):
29
  if user_input.strip() == "":
@@ -32,6 +35,9 @@ if st.button("πŸ” Analyze"):
32
  # Tokenize input
33
  inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
34
 
 
 
 
35
  # Run inference
36
  with torch.no_grad():
37
  outputs = model(**inputs)
@@ -53,3 +59,6 @@ if st.button("πŸ” Analyze"):
53
  label_key = model.config.id2label.get(idx, f"LABEL_{idx}")
54
  label_name = LABEL_MAP.get(label_key, label_key)
55
  st.write(f"{label_name}: {prob:.2%}")
 
 
 
 
10
  "LABEL_2": "Good"
11
  }
12
 
13
+ # Load model and tokenizer (force CPU usage)
14
  @st.cache_resource
15
  def load_model():
16
+ # Check for CUDA (GPU) availability
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  tokenizer = AutoTokenizer.from_pretrained("mjpsm/check-ins-classifier")
19
  model = AutoModelForSequenceClassification.from_pretrained("mjpsm/check-ins-classifier")
20
+ model.to(device) # Move model to the available device
21
+ return tokenizer, model, device
22
 
23
+ tokenizer, model, device = load_model()
24
 
25
  st.title("Check-In Classifier")
26
  st.write("Enter your check-in so I can see if it's **Good**, **Mediocre**, or **Bad**.")
27
 
28
  # User input
29
+ user_input = st.text_area("πŸ’¬ Your Check-In Message:")
30
 
31
  if st.button("πŸ” Analyze"):
32
  if user_input.strip() == "":
 
35
  # Tokenize input
36
  inputs = tokenizer(user_input, return_tensors="pt", truncation=True, padding=True)
37
 
38
+ # Move input tensors to the same device as the model
39
+ inputs = {key: value.to(device) for key, value in inputs.items()}
40
+
41
  # Run inference
42
  with torch.no_grad():
43
  outputs = model(**inputs)
 
59
  label_key = model.config.id2label.get(idx, f"LABEL_{idx}")
60
  label_name = LABEL_MAP.get(label_key, label_key)
61
  st.write(f"{label_name}: {prob:.2%}")
62
+
63
+
64
+