Update tasks/text.py
Browse files- tasks/text.py +5 -2
tasks/text.py
CHANGED
|
@@ -63,6 +63,9 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 63 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 64 |
import torch
|
| 65 |
from torch.utils.data import DataLoader, TensorDataset
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# Load model and tokenizer from Hugging Face Hub
|
| 68 |
MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
|
|
@@ -71,7 +74,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 71 |
MAX_LENGTH = 365
|
| 72 |
|
| 73 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
|
| 74 |
-
|
| 75 |
model.eval() # Set to evaluation mode
|
| 76 |
|
| 77 |
|
|
@@ -85,7 +88,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 85 |
predictions = []
|
| 86 |
with torch.no_grad():
|
| 87 |
for batch in test_loader:
|
| 88 |
-
|
| 89 |
outputs = model(input_ids, attention_mask=attention_mask)
|
| 90 |
preds = torch.argmax(outputs.logits, dim=1)
|
| 91 |
predictions.extend(preds.cpu().numpy())
|
|
|
|
| 63 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 64 |
import torch
|
| 65 |
from torch.utils.data import DataLoader, TensorDataset
|
| 66 |
+
|
| 67 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 68 |
+
print(f"Using device: {device}")
|
| 69 |
|
| 70 |
# Load model and tokenizer from Hugging Face Hub
|
| 71 |
MODEL_REPO = "ClimateDebunk/FineTunedDistilBert4SeqClass"
|
|
|
|
| 74 |
MAX_LENGTH = 365
|
| 75 |
|
| 76 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO)
|
| 77 |
+
model.to(device)
|
| 78 |
model.eval() # Set to evaluation mode
|
| 79 |
|
| 80 |
|
|
|
|
| 88 |
predictions = []
|
| 89 |
with torch.no_grad():
|
| 90 |
for batch in test_loader:
|
| 91 |
+
input_ids, attention_mask, labels = [x.to(device) for x in batch]
|
| 92 |
outputs = model(input_ids, attention_mask=attention_mask)
|
| 93 |
preds = torch.argmax(outputs.logits, dim=1)
|
| 94 |
predictions.extend(preds.cpu().numpy())
|