25b3nk commited on
Commit
727746f
·
verified ·
1 Parent(s): a531d01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -13,20 +13,29 @@ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
 
 
16
  # Function to perform inference
17
  def predict(text):
18
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
19
  outputs = model(**inputs)
20
  logits = outputs.logits
21
  probabilities = torch.sigmoid(logits) # Use sigmoid for multi-label classification
 
22
 
23
  # Get predicted labels based on a threshold (e.g., 0.5)
24
- predicted_labels = (probabilities > 0.5).nonzero()[:, 1].tolist()
 
 
 
 
25
 
26
  # Map label IDs back to label names
27
  predicted_labels_names = [model.config.id2label[label_id] for label_id in predicted_labels]
28
-
29
- return predicted_labels_names
 
 
30
 
31
 
32
  # Create the Gradio interface
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model.to(device)
15
 
16
+ prob_thresh = 0.3
17
+
18
  # Function to perform inference
19
  def predict(text):
20
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
21
  outputs = model(**inputs)
22
  logits = outputs.logits
23
  probabilities = torch.sigmoid(logits) # Use sigmoid for multi-label classification
24
+ print(probabilities)
25
 
26
  # Get predicted labels based on a threshold (e.g., 0.5)
27
+ predicted_labels = (probabilities > prob_thresh).nonzero()[:, 1].tolist()
28
+ # positions = (probabilities > 0.5).nonzero(as_tuple=False)
29
+ prob_values = probabilities[probabilities > prob_thresh].tolist()
30
+ print(predicted_labels)
31
+ print(prob_values)
32
 
33
  # Map label IDs back to label names
34
  predicted_labels_names = [model.config.id2label[label_id] for label_id in predicted_labels]
35
+ labels_dict = {model.config.id2label[label_id]: prob for label_id, prob in zip(predicted_labels, prob_values)}
36
+ print(labels_dict)
37
+ # labels_dict = {label: 1/len(predicted_labels_names) for label in predicted_labels_names}
38
+ return labels_dict
39
 
40
 
41
  # Create the Gradio interface