|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import tempfile |
|
import unittest |
|
from unittest.mock import patch |
|
|
|
import torch |
|
from datasets import load_dataset |
|
from parameterized import parameterized |
|
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer |
|
from transformers.testing_utils import require_peft |
|
from transformers.utils import is_peft_available |
|
|
|
from trl import GRPOConfig, GRPOTrainer |
|
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict |
|
|
|
from .testing_utils import require_vllm |
|
|
|
|
|
if is_peft_available(): |
|
from peft import LoraConfig, PeftModel |
|
|
|
|
|
class SplitTensorDictTester(unittest.TestCase): |
|
def test_split_equal_chunks(self): |
|
x = torch.arange(12).reshape(6, 2) |
|
y = torch.arange(6).reshape(6, 1) |
|
tensor_dict = {"x": x, "y": y} |
|
|
|
result = split_tensor_dict(tensor_dict, 3) |
|
|
|
expected_x_chunks = torch.chunk(x, 3, dim=0) |
|
expected_y_chunks = torch.chunk(y, 3, dim=0) |
|
self.assertEqual(len(result), 3) |
|
for i in range(3): |
|
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) |
|
self.assertTrue(torch.equal(result[i]["y"], expected_y_chunks[i])) |
|
|
|
def test_with_none_tensor(self): |
|
x = torch.arange(12).reshape(6, 2) |
|
tensor_dict = {"x": x, "y": None} |
|
|
|
result = split_tensor_dict(tensor_dict, 2) |
|
|
|
expected_x_chunks = torch.chunk(x, 2, dim=0) |
|
self.assertEqual(len(result), 2) |
|
for i in range(2): |
|
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) |
|
self.assertIsNone(result[i]["y"]) |
|
|
|
|
|
class ShuffleTensorDictTester(unittest.TestCase): |
|
def test_shuffle_preserves_shape(self): |
|
x = torch.arange(6).reshape(3, 2) |
|
y = torch.arange(3).reshape(3, 1) |
|
tensor_dict = {"x": x.clone(), "y": y.clone()} |
|
|
|
shuffled = shuffle_tensor_dict(tensor_dict) |
|
|
|
self.assertEqual(shuffled["x"].shape, x.shape) |
|
self.assertEqual(shuffled["y"].shape, y.shape) |
|
|
|
def test_shuffle_consistent_across_tensors(self): |
|
|
|
x = torch.tensor([[10, 11], [20, 21], [30, 31]]) |
|
y = torch.tensor([[1], [2], [3]]) |
|
tensor_dict = {"x": x.clone(), "y": y.clone()} |
|
|
|
shuffled = shuffle_tensor_dict(tensor_dict) |
|
|
|
|
|
for i in range(3): |
|
x_row = shuffled["x"][i] |
|
y_val = shuffled["y"][i].item() |
|
|
|
if torch.equal(x_row, torch.tensor([10, 11])): |
|
self.assertEqual(y_val, 1) |
|
elif torch.equal(x_row, torch.tensor([20, 21])): |
|
self.assertEqual(y_val, 2) |
|
elif torch.equal(x_row, torch.tensor([30, 31])): |
|
self.assertEqual(y_val, 3) |
|
else: |
|
self.fail("Unexpected x row in shuffled output.") |
|
|
|
def test_none_tensor_remains_none(self): |
|
x = torch.arange(6).reshape(3, 2) |
|
tensor_dict = {"x": x.clone(), "y": None} |
|
|
|
shuffled = shuffle_tensor_dict(tensor_dict) |
|
|
|
self.assertIsNone(shuffled["y"]) |
|
self.assertEqual(shuffled["x"].shape, x.shape) |
|
|
|
|
|
class RepeatRandomSamplerTester(unittest.TestCase): |
|
def test_sampler(self): |
|
dataset = ["a", "b", "c", "d", "e", "f", "g"] |
|
sampler = RepeatSampler(dataset, mini_repeat_count=2) |
|
|
|
sampled = list(sampler) |
|
|
|
assert len(sampled) == 2 * len(dataset) |
|
|
|
assert set(sampled) == set(range(len(dataset))) |
|
|
|
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2)) |
|
|
|
def test_sampler_no_shuffle(self): |
|
dataset = ["a", "b", "c", "d", "e", "f", "g"] |
|
sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False) |
|
sampled = list(sampler) |
|
expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] |
|
self.assertEqual(sampled, expected) |
|
|
|
def test_sampler_no_repeat(self): |
|
dataset = ["a", "b", "c", "d", "e", "f", "g"] |
|
sampler = RepeatSampler(dataset, mini_repeat_count=1) |
|
|
|
sampled = list(sampler) |
|
|
|
assert len(sampled) == len(dataset) |
|
|
|
assert set(sampled) == set(range(len(dataset))) |
|
|
|
def test_sampler_with_batch_size(self): |
|
dataset = ["a", "b", "c", "d", "e", "f", "g", "h"] |
|
sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2) |
|
|
|
sampled = list(sampler) |
|
|
|
assert len(sampled) == 2 * len(dataset) |
|
|
|
assert set(sampled) == set(range(len(dataset))) |
|
|
|
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4)) |
|
|
|
def test_sampler_with_batch_size_and_drop(self): |
|
dataset = ["a", "b", "c", "d", "e", "f", "g"] |
|
sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2) |
|
|
|
sampled = list(sampler) |
|
|
|
assert len(sampled) == 2 * ( |
|
len(dataset) - 1 |
|
) |
|
|
|
assert set(sampled).issubset(set(range(len(dataset)))) |
|
|
|
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4)) |
|
|
|
def test_sampler_with_mini_repeat_count_and_batch_size_1(self): |
|
dataset = ["a", "b", "c", "d", "e", "f", "g"] |
|
sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2) |
|
|
|
|
|
sampled = list(sampler) |
|
|
|
assert len(sampled) == 4 * (len(dataset) - 1) |
|
|
|
assert set(sampled).issubset(set(range(len(dataset)))) |
|
|
|
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2)) |
|
|
|
assert sampled[0:6] == sampled[6:12] |
|
assert sampled[12:18] == sampled[18:24] |
|
|
|
def test_sampler_with_mini_repeat_count_and_batch_size_2(self): |
|
dataset = ["a", "b", "c", "d", "e", "f", "g"] |
|
sampler = RepeatSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2) |
|
|
|
|
|
|
|
sampled = list(sampler) |
|
|
|
assert len(sampled) == 6 * (len(dataset) - 1) |
|
|
|
assert set(sampled).issubset(set(range(len(dataset)))) |
|
|
|
assert all(sampled[i] == sampled[i + 1] == sampled[i + 2] for i in range(0, len(sampled), 3)) |
|
|
|
assert sampled[0:6] == sampled[6:12] |
|
assert sampled[12:18] == sampled[18:24] |
|
assert sampled[24:30] == sampled[30:36] |
|
|
|
def test_sampler_with_mini_repeat_count_and_batch_size_3(self): |
|
dataset = ["a", "b", "c", "d", "e", "f", "g"] |
|
sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3) |
|
|
|
|
|
|
|
sampled = list(sampler) |
|
|
|
assert len(sampled) == 6 * (len(dataset) - 1) |
|
|
|
assert set(sampled).issubset(set(range(len(dataset)))) |
|
|
|
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2)) |
|
|
|
assert sampled[0:4] == sampled[4:8] == sampled[8:12] |
|
assert sampled[12:16] == sampled[16:20] == sampled[20:24] |
|
assert sampled[24:28] == sampled[28:32] == sampled[32:36] |
|
|
|
|
|
class GRPOTrainerTester(unittest.TestCase): |
|
def test_init_minimal(self): |
|
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
train_dataset=dataset, |
|
) |
|
|
|
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) |
|
def test_training(self, config_name): |
|
dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
@parameterized.expand([("bnpo",), ("dr_grpo",)]) |
|
def test_training_loss_types(self, loss_type): |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=32, |
|
loss_type=loss_type, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_with_eval(self): |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
per_device_train_batch_size=3, |
|
per_device_eval_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
eval_strategy="steps", |
|
eval_steps=2, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset["train"], |
|
eval_dataset=dataset["test"], |
|
) |
|
|
|
trainer.train() |
|
|
|
def test_training_multiple_iterations(self): |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
num_iterations=2, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
@require_peft |
|
def test_training_peft(self): |
|
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") |
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model=model, |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
peft_config=LoraConfig(), |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
if n in base_param_names: |
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") |
|
elif "base_layer" not in n: |
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
@require_peft |
|
def test_training_peft_with_gradient_checkpointing(self): |
|
"""Test that training works with PEFT and gradient checkpointing enabled.""" |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
torch_dtype=torch.float32, |
|
use_cache=False, |
|
) |
|
|
|
lora_config = LoraConfig( |
|
r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none" |
|
) |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
gradient_checkpointing=True, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model=model, |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
peft_config=lora_config, |
|
) |
|
|
|
|
|
self.assertIsInstance(trainer.model, PeftModel) |
|
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
if "lora" in n.lower(): |
|
self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.") |
|
else: |
|
self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.") |
|
|
|
def test_training_different_reward_model(self): |
|
|
|
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") |
|
reward_model_id = "trl-internal-testing/tiny-LlamaForSequenceClassification-3.2" |
|
reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id) |
|
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_id) |
|
|
|
|
|
|
|
|
|
reward_tokenizer.pad_token = "<|finetune_right_pad_id|>" |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs=reward_model, |
|
args=training_args, |
|
train_dataset=dataset, |
|
reward_processing_classes=reward_tokenizer, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_reward_func_standard(self): |
|
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
def reward_func(completions, **kwargs): |
|
"""Reward function that rewards longer completions.""" |
|
return [float(len(completion)) for completion in completions] |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs=reward_func, |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_reward_func_conversational(self): |
|
|
|
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") |
|
|
|
def reward_func(completions, **kwargs): |
|
"""Reward function that gives higher scores to longer completion content.""" |
|
completion_contents = [completion[0]["content"] for completion in completions] |
|
return [float(len(content)) for content in completion_contents] |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs=reward_func, |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_multiple_reward_funcs(self): |
|
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
def reward_func1(completions, **kwargs): |
|
"""Reward function that rewards longer completions.""" |
|
return [float(len(completion)) for completion in completions] |
|
|
|
def reward_func2(completions, **kwargs): |
|
"""Reward function that rewards completions with more unique letters.""" |
|
return [float(len(set(completion))) for completion in completions] |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs=[reward_func1, reward_func2], |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_multiple_reward_funcs_with_None_output(self): |
|
"""Test that a valid math reward function is processed correctly while the code reward function returns None.""" |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
def applicable_reward_func(completions, **kwargs): |
|
"""A reward function that rewards longer completions.""" |
|
return [float(len(completion)) for completion in completions] |
|
|
|
def non_applicable_reward_func(completions, **kwargs): |
|
"""A reward function that returns None for all inputs, as it is not applicable to this sample.""" |
|
return [None] * len(completions) |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
|
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs=[ |
|
applicable_reward_func, |
|
non_applicable_reward_func, |
|
], |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = { |
|
n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad |
|
} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_multiple_reward_funcs_with_weights(self): |
|
"""Test that GRPOTrainer can handle multiple reward functions with weights.""" |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
def reward_func1(completions, **kwargs): |
|
"""Reward function that rewards longer completions.""" |
|
return [float(len(completion)) for completion in completions] |
|
|
|
def reward_func2(completions, **kwargs): |
|
"""Reward function that rewards completions with more unique letters.""" |
|
return [float(len(set(completion))) for completion in completions] |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
reward_weights=[0.7, 0.3], |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs=[reward_func1, reward_func2], |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1]) |
|
self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1]) |
|
self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1]) |
|
self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_multiple_mixed_reward_funcs(self): |
|
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
def reward_func(completions, **kwargs): |
|
"""Reward function that rewards longer completions.""" |
|
return [float(len(completion)) for completion in completions] |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs=[reward_func, "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"], |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_reward_func_additional_column(self): |
|
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
|
|
some_values = list(range(len(dataset))) |
|
dataset = dataset.add_column("some_values", some_values) |
|
|
|
def reward_func(completions, some_values, **kwargs): |
|
"""Reward function that rewards completions with lengths closer to the values in some_values.""" |
|
return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)] |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs=reward_func, |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
@require_vllm |
|
@unittest.skip("We should add a mock for the vLLM server.") |
|
def test_training_vllm(self): |
|
"""Test that training works with vLLM for generation.""" |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
use_vllm=True, |
|
) |
|
trainer = GRPOTrainer( |
|
model="Qwen/Qwen2.5-0.5B-Instruct", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_with_sync_ref_model(self): |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
sync_ref_model=True, |
|
ref_model_sync_steps=2, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_beta_non_zero(self): |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
beta=0.1, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
@unittest.skip("We should add a mock for the vLLM server.") |
|
@require_peft |
|
@require_vllm |
|
def test_training_vllm_and_peft(self): |
|
"""Test that training works with vLLM for generation.""" |
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") |
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
use_vllm=True, |
|
) |
|
lora_config = LoraConfig( |
|
target_modules="all-linear", |
|
|
|
modules_to_save=["embed_tokens", "lm_head"], |
|
) |
|
trainer = GRPOTrainer( |
|
model=model, |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
peft_config=lora_config, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
if n in base_param_names: |
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") |
|
elif "base_layer" not in n and "original_module" not in n: |
|
|
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
@require_vllm |
|
@unittest.skip("We should add a mock for the vLLM server.") |
|
def test_training_vllm_guided_decoding(self): |
|
"""Test that training works with vLLM for generation with guided decoding.""" |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
use_vllm=True, |
|
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>", |
|
) |
|
trainer = GRPOTrainer( |
|
model="Qwen/Qwen2.5-0.5B-Instruct", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_with_additional_generation_kwargs(self): |
|
"""Test that training works with additional generation kwargs.""" |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
top_p=0.9, |
|
top_k=10, |
|
min_p=0.01, |
|
repetition_penalty=1.1, |
|
) |
|
|
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
@require_vllm |
|
@unittest.skip("We should add a mock for the vLLM server.") |
|
def test_training_vllm_with_additional_generation_kwargs(self): |
|
"""Test that training works with vLLM and additional generation kwargs.""" |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
report_to="none", |
|
use_vllm=True, |
|
top_p=0.9, |
|
top_k=10, |
|
min_p=0.01, |
|
repetition_penalty=1.1, |
|
) |
|
|
|
trainer = GRPOTrainer( |
|
model="Qwen/Qwen2.5-0.5B-Instruct", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_no_scale_rewards(self): |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
scale_rewards=False, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
@patch("transformers.generation.utils.GenerationMixin.generate") |
|
def test_training_with_mask_truncated_completions(self, mock_generate): |
|
"""Test that training works with mask_truncated_completions=True parameter.""" |
|
|
|
|
|
|
|
|
|
def fake_generate(prompt_ids, **kwargs): |
|
|
|
completions_ids = torch.tensor( |
|
[ |
|
[1, 2, 3, 4, 5, 6, 7, 8], |
|
[9, 10, 11, 151645, 151643, 151643, 151643, 151643], |
|
[12, 13, 14, 15, 16, 17, 18, 151645], |
|
], |
|
device=prompt_ids.device, |
|
) |
|
return torch.cat([prompt_ids, completions_ids], dim=1) |
|
|
|
mock_generate.side_effect = fake_generate |
|
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
mask_truncated_completions=True, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_with_mask_truncated_completions_all_masked(self): |
|
""" |
|
Test that when all generated completions are truncated (i.e., none contain an EOS token), and |
|
mask_truncated_completions=True, the model receives no effective learning signal and therefore does not update |
|
its parameters. |
|
|
|
Here, we don't mock the generate method, be we rely on the fact that the model the probability of generating |
|
the EOS token is extremely low, so all generated completions are truncated. |
|
""" |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
mask_truncated_completions=True, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.") |
|
|
|
def test_training_num_generations_larger_than_batch_size(self): |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
max_completion_length=8, |
|
num_generations=6, |
|
gradient_accumulation_steps=2, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_delta_clipping(self): |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
delta=2.0, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|
|
def test_training_multiple_dataloader_workers(self): |
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
training_args = GRPOConfig( |
|
output_dir=tmp_dir, |
|
learning_rate=0.1, |
|
per_device_train_batch_size=3, |
|
num_generations=3, |
|
max_completion_length=8, |
|
dataloader_num_workers=2, |
|
report_to="none", |
|
) |
|
trainer = GRPOTrainer( |
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", |
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", |
|
args=training_args, |
|
train_dataset=dataset, |
|
) |
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} |
|
|
|
trainer.train() |
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) |
|
|
|
|
|
for n, param in previous_trainable_params.items(): |
|
new_param = trainer.model.get_parameter(n) |
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") |
|
|