rshakked commited on
Commit
92846e1
·
1 Parent(s): fea5367

fix: remove .to(device) from dataset to avoid pin_memory error in Trainer

Browse files
Files changed (1) hide show
  1. train_abuse_model.py +4 -3
train_abuse_model.py CHANGED
@@ -41,13 +41,14 @@ class AbuseDataset(Dataset):
41
 
42
  def __len__(self):
43
  return len(self.labels)
44
-
45
  def __getitem__(self, idx):
46
- item = {key: torch.tensor(val[idx]).to(device) for key, val in self.encodings.items()}
47
- item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float).to(device)
48
  return item
49
 
50
 
 
51
  # Convert label values to soft scores: "yes" = 1.0, "plausibly" = 0.5, others = 0.0
52
  def label_row_soft(row):
53
  labels = []
 
41
 
42
  def __len__(self):
43
  return len(self.labels)
44
+
45
  def __getitem__(self, idx):
46
+ item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
47
+ item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
48
  return item
49
 
50
 
51
+
52
  # Convert label values to soft scores: "yes" = 1.0, "plausibly" = 0.5, others = 0.0
53
  def label_row_soft(row):
54
  labels = []