rshakked commited on
Commit
285a433
·
1 Parent(s): d4ac0ac

chore(debug): move device check to top of script for clearer startup logs

Browse files
Files changed (1) hide show
  1. train_abuse_model.py +4 -4
train_abuse_model.py CHANGED
@@ -8,10 +8,6 @@ import numpy as np
8
  import torch
9
  from torch.utils.data import Dataset
10
 
11
- print("torch.cuda.is_available():", torch.cuda.is_available())
12
- print("Using device:", device)
13
- print("PyTorch version:", torch.__version__)
14
-
15
  from sklearn.model_selection import train_test_split
16
  from sklearn.metrics import classification_report, precision_recall_fscore_support
17
 
@@ -34,8 +30,12 @@ print("Transformers version:", transformers.__version__)
34
 
35
  # Check for GPU availability
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
37
 
38
  # Custom Dataset class
 
39
  class AbuseDataset(Dataset):
40
  def __init__(self, texts, labels):
41
  self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)
 
8
  import torch
9
  from torch.utils.data import Dataset
10
 
 
 
 
 
11
  from sklearn.model_selection import train_test_split
12
  from sklearn.metrics import classification_report, precision_recall_fscore_support
13
 
 
30
 
31
  # Check for GPU availability
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ print("torch.cuda.is_available():", torch.cuda.is_available())
34
+ print("Using device:", device)
35
+ print("PyTorch version:", torch.__version__)
36
 
37
  # Custom Dataset class
38
+
39
  class AbuseDataset(Dataset):
40
  def __init__(self, texts, labels):
41
  self.encodings = tokenizer(texts, truncation=True, padding=True, max_length=512)