wangjin2000 commited on
Commit
4713da5
·
verified ·
1 Parent(s): 30609c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -134,6 +134,17 @@ with torch.no_grad():
134
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
135
  predictions = torch.argmax(logits, dim=2)
136
 
 
 
 
 
 
 
 
 
 
 
 
137
  # debug result
138
  dubug_result = predictions #class_weights
139
 
 
134
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
135
  predictions = torch.argmax(logits, dim=2)
136
 
137
+ # Define labels
138
+ id2label = {
139
+ 0: "No binding site",
140
+ 1: "Binding site"
141
+ }
142
+
143
+ # Print the predicted labels for each token
144
+ for token, prediction in zip(tokens, predictions[0].numpy()):
145
+ if token not in ['<pad>', '<cls>', '<eos>']:
146
+ print((token, id2label[prediction]))
147
+
148
  # debug result
149
  dubug_result = predictions #class_weights
150