chore(debug): move device check to top of script for clearer startup logs
Browse files- train_abuse_model.py +4 -4
train_abuse_model.py
CHANGED
@@ -8,10 +8,6 @@ import numpy as np
|
|
8 |
import torch
|
9 |
from torch.utils.data import Dataset
|
10 |
|
11 |
-
print("torch.cuda.is_available():", torch.cuda.is_available())
|
12 |
-
print("Using device:", device)
|
13 |
-
print("PyTorch version:", torch.__version__)
|
14 |
-
|
15 |
from sklearn.model_selection import train_test_split
|
16 |
from sklearn.metrics import classification_report, precision_recall_fscore_support
|
17 |
|
@@ -34,8 +30,12 @@ print("Transformers version:", transformers.__version__)
|
|
34 |
|
35 |
# Check for GPU availability
|
36 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
37 |
|
38 |
# Custom Dataset class
|
|
|
39 |
class AbuseDataset(Dataset):
|
40 |
def __init__(self, texts, labels):
|
41 |
self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
|
|
|
8 |
import torch
|
9 |
from torch.utils.data import Dataset
|
10 |
|
|
|
|
|
|
|
|
|
11 |
from sklearn.model_selection import train_test_split
|
12 |
from sklearn.metrics import classification_report, precision_recall_fscore_support
|
13 |
|
|
|
30 |
|
31 |
# Check for GPU availability
|
32 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33 |
+
print("torch.cuda.is_available():", torch.cuda.is_available())
|
34 |
+
print("Using device:", device)
|
35 |
+
print("PyTorch version:", torch.__version__)
|
36 |
|
37 |
# Custom Dataset class
|
38 |
+
|
39 |
class AbuseDataset(Dataset):
|
40 |
def __init__(self, texts, labels):
|
41 |
self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
|