chore(debug): print tensor devices and CUDA availability for troubleshooting
Browse files- train_abuse_model.py +12 -1
train_abuse_model.py
CHANGED
@@ -8,6 +8,8 @@ import numpy as np
|
|
8 |
import torch
|
9 |
from torch.utils.data import Dataset
|
10 |
|
|
|
|
|
11 |
print("PyTorch version:", torch.__version__)
|
12 |
|
13 |
from sklearn.model_selection import train_test_split
|
@@ -46,7 +48,10 @@ class AbuseDataset(Dataset):
|
|
46 |
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
47 |
item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
|
48 |
return item
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
# Convert label values to soft scores: "yes" = 1.0, "plausibly" = 0.5, others = 0.0
|
@@ -226,6 +231,12 @@ trainer = Trainer(
|
|
226 |
eval_dataset=val_dataset
|
227 |
)
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
# Start training!
|
230 |
trainer.train()
|
231 |
|
|
|
8 |
import torch
|
9 |
from torch.utils.data import Dataset
|
10 |
|
11 |
+
print("torch.cuda.is_available():", torch.cuda.is_available())
|
12 |
+
print("Using device:", device)
|
13 |
print("PyTorch version:", torch.__version__)
|
14 |
|
15 |
from sklearn.model_selection import train_test_split
|
|
|
48 |
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
49 |
item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
|
50 |
return item
|
51 |
+
def __getitem__(self, idx):
|
52 |
+
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
53 |
+
item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
|
54 |
+
return item
|
55 |
|
56 |
|
57 |
# Convert label values to soft scores: "yes" = 1.0, "plausibly" = 0.5, others = 0.0
|
|
|
231 |
eval_dataset=val_dataset
|
232 |
)
|
233 |
|
234 |
+
# This checks if any tensor is on GPU too early.
|
235 |
+
print("🧪 Sample device check from train_dataset:")
|
236 |
+
sample = train_dataset[0]
|
237 |
+
for k, v in sample.items():
|
238 |
+
print(f"{k}: {v.device}")
|
239 |
+
|
240 |
# Start training!
|
241 |
trainer.train()
|
242 |
|