fix: switch to bf16 for HF Jobs (newer transformers + bnb prefers bf16, no GradScaler needed)
Browse files- configs/colab_demo.yaml +1 -1
- configs/l4_training.yaml +1 -1
- training/train_attacker.py +2 -2
- training/train_defender.py +2 -2
configs/colab_demo.yaml
CHANGED
|
@@ -9,7 +9,7 @@ per_generation:
|
|
| 9 |
model:
|
| 10 |
name: "Qwen/Qwen2.5-3B-Instruct"
|
| 11 |
load_in_4bit: true
|
| 12 |
-
bnb_4bit_compute_dtype: "
|
| 13 |
|
| 14 |
lora:
|
| 15 |
r: 16
|
|
|
|
| 9 |
model:
|
| 10 |
name: "Qwen/Qwen2.5-3B-Instruct"
|
| 11 |
load_in_4bit: true
|
| 12 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 13 |
|
| 14 |
lora:
|
| 15 |
r: 16
|
configs/l4_training.yaml
CHANGED
|
@@ -9,7 +9,7 @@ per_generation:
|
|
| 9 |
model:
|
| 10 |
name: "Qwen/Qwen2.5-3B-Instruct"
|
| 11 |
load_in_4bit: true
|
| 12 |
-
bnb_4bit_compute_dtype: "
|
| 13 |
|
| 14 |
lora:
|
| 15 |
r: 16
|
|
|
|
| 9 |
model:
|
| 10 |
name: "Qwen/Qwen2.5-3B-Instruct"
|
| 11 |
load_in_4bit: true
|
| 12 |
+
bnb_4bit_compute_dtype: "bfloat16"
|
| 13 |
|
| 14 |
lora:
|
| 15 |
r: 16
|
training/train_attacker.py
CHANGED
|
@@ -160,7 +160,7 @@ def train_attacker(
|
|
| 160 |
model_name,
|
| 161 |
quantization_config=bnb_config,
|
| 162 |
device_map="auto",
|
| 163 |
-
dtype=torch.
|
| 164 |
)
|
| 165 |
model = prepare_model_for_kbit_training(model)
|
| 166 |
|
|
@@ -198,7 +198,7 @@ def train_attacker(
|
|
| 198 |
num_generations=tr["rollouts_per_episode"],
|
| 199 |
temperature=tr["temperature"],
|
| 200 |
top_p=tr["top_p"],
|
| 201 |
-
|
| 202 |
)
|
| 203 |
|
| 204 |
if opponent is not None:
|
|
|
|
| 160 |
model_name,
|
| 161 |
quantization_config=bnb_config,
|
| 162 |
device_map="auto",
|
| 163 |
+
dtype=torch.bfloat16,
|
| 164 |
)
|
| 165 |
model = prepare_model_for_kbit_training(model)
|
| 166 |
|
|
|
|
| 198 |
num_generations=tr["rollouts_per_episode"],
|
| 199 |
temperature=tr["temperature"],
|
| 200 |
top_p=tr["top_p"],
|
| 201 |
+
bf16=True,
|
| 202 |
)
|
| 203 |
|
| 204 |
if opponent is not None:
|
training/train_defender.py
CHANGED
|
@@ -179,7 +179,7 @@ def train_defender(
|
|
| 179 |
model_name,
|
| 180 |
quantization_config=bnb_config,
|
| 181 |
device_map="auto",
|
| 182 |
-
dtype=torch.
|
| 183 |
)
|
| 184 |
model = prepare_model_for_kbit_training(model)
|
| 185 |
|
|
@@ -217,7 +217,7 @@ def train_defender(
|
|
| 217 |
num_generations=tr["rollouts_per_episode"],
|
| 218 |
temperature=tr["temperature"],
|
| 219 |
top_p=tr["top_p"],
|
| 220 |
-
|
| 221 |
)
|
| 222 |
|
| 223 |
reward_fn = make_reward_function(task_id=cfg["env"]["task_id"])
|
|
|
|
| 179 |
model_name,
|
| 180 |
quantization_config=bnb_config,
|
| 181 |
device_map="auto",
|
| 182 |
+
dtype=torch.bfloat16,
|
| 183 |
)
|
| 184 |
model = prepare_model_for_kbit_training(model)
|
| 185 |
|
|
|
|
| 217 |
num_generations=tr["rollouts_per_episode"],
|
| 218 |
temperature=tr["temperature"],
|
| 219 |
top_p=tr["top_p"],
|
| 220 |
+
bf16=True,
|
| 221 |
)
|
| 222 |
|
| 223 |
reward_fn = make_reward_function(task_id=cfg["env"]["task_id"])
|