|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
import unittest |
|
|
|
import torch |
|
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available |
|
from transformers.testing_utils import torch_device |
|
from transformers.utils import is_rich_available |
|
|
|
from trl import BaseBinaryJudge, BasePairwiseJudge |
|
from trl.import_utils import ( |
|
is_diffusers_available, |
|
is_joblib_available, |
|
is_llm_blender_available, |
|
is_mergekit_available, |
|
is_vllm_available, |
|
) |
|
|
|
|
|
|
|
|
|
def require_bitsandbytes(test_case): |
|
""" |
|
Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available. |
|
""" |
|
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) |
|
|
|
|
|
def require_comet(test_case): |
|
""" |
|
Decorator marking a test that requires Comet. Skips the test if Comet is not available. |
|
""" |
|
return unittest.skipUnless(is_comet_available(), "test requires comet_ml")(test_case) |
|
|
|
|
|
def require_diffusers(test_case): |
|
""" |
|
Decorator marking a test that requires diffusers. Skips the test if diffusers is not available. |
|
""" |
|
return unittest.skipUnless(is_diffusers_available(), "test requires diffusers")(test_case) |
|
|
|
|
|
def require_llm_blender(test_case): |
|
""" |
|
Decorator marking a test that requires llm-blender. Skips the test if llm-blender is not available. |
|
""" |
|
return unittest.skipUnless(is_llm_blender_available(), "test requires llm-blender")(test_case) |
|
|
|
|
|
def require_mergekit(test_case): |
|
""" |
|
Decorator marking a test that requires mergekit. Skips the test if mergekit is not available. |
|
""" |
|
return unittest.skipUnless(is_mergekit_available(), "test requires mergekit")(test_case) |
|
|
|
|
|
def require_rich(test_case): |
|
""" |
|
Decorator marking a test that requires rich. Skips the test if rich is not available. |
|
""" |
|
return unittest.skipUnless(is_rich_available(), "test requires rich")(test_case) |
|
|
|
|
|
def require_sklearn(test_case): |
|
""" |
|
Decorator marking a test that requires sklearn. Skips the test if sklearn is not available. |
|
""" |
|
return unittest.skipUnless(is_sklearn_available() and is_joblib_available(), "test requires sklearn")(test_case) |
|
|
|
|
|
def require_vllm(test_case): |
|
""" |
|
Decorator marking a test that requires vllm. Skips the test if vllm is not available. |
|
""" |
|
return unittest.skipUnless(is_vllm_available(), "test requires vllm")(test_case) |
|
|
|
|
|
def require_no_wandb(test_case): |
|
""" |
|
Decorator marking a test that requires no wandb. Skips the test if wandb is available. |
|
""" |
|
return unittest.skipUnless(not is_wandb_available(), "test requires no wandb")(test_case) |
|
|
|
|
|
def require_3_accelerators(test_case): |
|
""" |
|
Decorator marking a test that requires at least 3 accelerators. Skips the test if 3 accelerators are not available. |
|
""" |
|
torch_accelerator_module = getattr(torch, torch_device, torch.cuda) |
|
return unittest.skipUnless( |
|
torch_accelerator_module.device_count() > 3, f"test requires at least 3 {torch_device}s" |
|
)(test_case) |
|
|
|
|
|
class RandomBinaryJudge(BaseBinaryJudge): |
|
""" |
|
Random binary judge, for testing purposes. |
|
""" |
|
|
|
def judge(self, prompts, completions, gold_completions=None, shuffle_order=True): |
|
return [random.choice([0, 1, -1]) for _ in range(len(prompts))] |
|
|
|
|
|
class RandomPairwiseJudge(BasePairwiseJudge): |
|
""" |
|
Random pairwise judge, for testing purposes. |
|
""" |
|
|
|
def judge(self, prompts, completions, shuffle_order=True, return_scores=False): |
|
if not return_scores: |
|
return [random.randint(0, len(completion) - 1) for completion in completions] |
|
else: |
|
return [random.random() for _ in range(len(prompts))] |
|
|