blitz1809 commited on
Commit
e43130f
·
1 Parent(s): c616409

fix: switch to bf16 for HF Jobs (newer transformers + bnb prefers bf16, no GradScaler needed)

Browse files
configs/colab_demo.yaml CHANGED
@@ -9,7 +9,7 @@ per_generation:
9
  model:
10
  name: "Qwen/Qwen2.5-3B-Instruct"
11
  load_in_4bit: true
12
- bnb_4bit_compute_dtype: "float16"
13
 
14
  lora:
15
  r: 16
 
9
  model:
10
  name: "Qwen/Qwen2.5-3B-Instruct"
11
  load_in_4bit: true
12
+ bnb_4bit_compute_dtype: "bfloat16"
13
 
14
  lora:
15
  r: 16
configs/l4_training.yaml CHANGED
@@ -9,7 +9,7 @@ per_generation:
9
  model:
10
  name: "Qwen/Qwen2.5-3B-Instruct"
11
  load_in_4bit: true
12
- bnb_4bit_compute_dtype: "float16"
13
 
14
  lora:
15
  r: 16
 
9
  model:
10
  name: "Qwen/Qwen2.5-3B-Instruct"
11
  load_in_4bit: true
12
+ bnb_4bit_compute_dtype: "bfloat16"
13
 
14
  lora:
15
  r: 16
training/train_attacker.py CHANGED
@@ -160,7 +160,7 @@ def train_attacker(
160
  model_name,
161
  quantization_config=bnb_config,
162
  device_map="auto",
163
- dtype=torch.float16,
164
  )
165
  model = prepare_model_for_kbit_training(model)
166
 
@@ -198,7 +198,7 @@ def train_attacker(
198
  num_generations=tr["rollouts_per_episode"],
199
  temperature=tr["temperature"],
200
  top_p=tr["top_p"],
201
- fp16=True,
202
  )
203
 
204
  if opponent is not None:
 
160
  model_name,
161
  quantization_config=bnb_config,
162
  device_map="auto",
163
+ dtype=torch.bfloat16,
164
  )
165
  model = prepare_model_for_kbit_training(model)
166
 
 
198
  num_generations=tr["rollouts_per_episode"],
199
  temperature=tr["temperature"],
200
  top_p=tr["top_p"],
201
+ bf16=True,
202
  )
203
 
204
  if opponent is not None:
training/train_defender.py CHANGED
@@ -179,7 +179,7 @@ def train_defender(
179
  model_name,
180
  quantization_config=bnb_config,
181
  device_map="auto",
182
- dtype=torch.float16,
183
  )
184
  model = prepare_model_for_kbit_training(model)
185
 
@@ -217,7 +217,7 @@ def train_defender(
217
  num_generations=tr["rollouts_per_episode"],
218
  temperature=tr["temperature"],
219
  top_p=tr["top_p"],
220
- fp16=True,
221
  )
222
 
223
  reward_fn = make_reward_function(task_id=cfg["env"]["task_id"])
 
179
  model_name,
180
  quantization_config=bnb_config,
181
  device_map="auto",
182
+ dtype=torch.bfloat16,
183
  )
184
  model = prepare_model_for_kbit_training(model)
185
 
 
217
  num_generations=tr["rollouts_per_episode"],
218
  temperature=tr["temperature"],
219
  top_p=tr["top_p"],
220
+ bf16=True,
221
  )
222
 
223
  reward_fn = make_reward_function(task_id=cfg["env"]["task_id"])