Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -134,6 +134,17 @@ with torch.no_grad():
|
|
134 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
135 |
predictions = torch.argmax(logits, dim=2)
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
# debug result
|
138 |
dubug_result = predictions #class_weights
|
139 |
|
|
|
134 |
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens
|
135 |
predictions = torch.argmax(logits, dim=2)
|
136 |
|
137 |
+
# Define labels
|
138 |
+
id2label = {
|
139 |
+
0: "No binding site",
|
140 |
+
1: "Binding site"
|
141 |
+
}
|
142 |
+
|
143 |
+
# Print the predicted labels for each token
|
144 |
+
for token, prediction in zip(tokens, predictions[0].numpy()):
|
145 |
+
if token not in ['<pad>', '<cls>', '<eos>']:
|
146 |
+
print((token, id2label[prediction]))
|
147 |
+
|
148 |
# debug result
|
149 |
dubug_result = predictions #class_weights
|
150 |
|