rshakked commited on
Commit
fedc8f2
·
1 Parent(s): 285a433

fix: reduce batch size and enable gradient checkpointing to prevent GPU OOM crashes

Browse files
Files changed (1) hide show
  1. train_abuse_model.py +5 -8
train_abuse_model.py CHANGED
@@ -185,11 +185,8 @@ model = AutoModelForSequenceClassification.from_pretrained(
185
  problem_type="multi_label_classification"
186
  ).to(device) # Move model to GPU
187
 
188
- # # Optional: Freeze base model layers (only train classifier head)
189
- # freeze_base = False
190
- # if freeze_base:
191
- # for name, param in model.bert.named_parameters():
192
- # param.requires_grad = False
193
 
194
  # Freeze bottom 6 layers of DeBERTa encoder
195
  for name, param in model.named_parameters():
@@ -215,12 +212,12 @@ test_dataset = AbuseDataset(test_texts, test_labels)
215
  training_args = TrainingArguments(
216
  output_dir="./results",
217
  num_train_epochs=3,
218
- per_device_train_batch_size=8,
219
- per_device_eval_batch_size=8,
220
  evaluation_strategy="epoch",
221
  save_strategy="epoch",
222
  logging_dir="./logs",
223
- logging_steps=10,
224
  )
225
 
226
  # Train using HuggingFace Trainer
 
185
  problem_type="multi_label_classification"
186
  ).to(device) # Move model to GPU
187
 
188
+ # gradient checkpointing helps cut memory use:
189
+ model.gradient_checkpointing_enable()
 
 
 
190
 
191
  # Freeze bottom 6 layers of DeBERTa encoder
192
  for name, param in model.named_parameters():
 
212
  training_args = TrainingArguments(
213
  output_dir="./results",
214
  num_train_epochs=3,
215
+ per_device_train_batch_size=4,
216
+ per_device_eval_batch_size=4,
217
  evaluation_strategy="epoch",
218
  save_strategy="epoch",
219
  logging_dir="./logs",
220
+ logging_steps=100,
221
  )
222
 
223
  # Train using HuggingFace Trainer