Spaces:
Running
Running
| from dataclasses import dataclass | |
| from tqdm import tqdm | |
| from graphgen.models.text.text_pair import TextPair | |
| class RewardEvaluator: | |
| """ | |
| Reward Model Evaluator. | |
| OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好 | |
| """ | |
| reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2" | |
| max_length: int = 2560 | |
| results: list[float] = None | |
| def __post_init__(self): | |
| import torch | |
| self.num_gpus = torch.cuda.device_count() | |
| def process_chunk(rank, pairs, reward_name, max_length, return_dict): | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| device = f'cuda:{rank}' | |
| torch.cuda.set_device(rank) | |
| rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name) | |
| tokenizer = AutoTokenizer.from_pretrained(reward_name) | |
| rank_model.to(device) | |
| rank_model.eval() | |
| results = [] | |
| with torch.no_grad(): | |
| for pair in tqdm(pairs): | |
| inputs = tokenizer( | |
| pair.question, | |
| pair.answer, | |
| return_tensors="pt", | |
| max_length=max_length, | |
| truncation=True | |
| ) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| score = rank_model(**inputs).logits[0].item() | |
| results.append(score) | |
| return_dict[rank] = results | |
| def evaluate(self, pairs: list[TextPair]) -> list[float]: | |
| import torch.multiprocessing as mp | |
| chunk_size = len(pairs) // self.num_gpus | |
| chunks = [] | |
| for i in range(self.num_gpus): | |
| start = i * chunk_size | |
| end = start + chunk_size | |
| if i == self.num_gpus - 1: | |
| end = len(pairs) | |
| chunks.append(pairs[start:end]) | |
| # multi-process | |
| manager = mp.Manager() | |
| return_dict = manager.dict() | |
| processes = [] | |
| for rank, chunk in enumerate(chunks): | |
| p = mp.Process( | |
| target=self.process_chunk, | |
| args=(rank, chunk, self.reward_name, self.max_length, return_dict) | |
| ) | |
| p.start() | |
| processes.append(p) | |
| for p in processes: | |
| p.join() | |
| # 合并结果 | |
| results = [] | |
| for rank in range(len(chunks)): | |
| results.extend(return_dict[rank]) | |
| for p in processes: | |
| if p.is_alive(): | |
| p.terminate() | |
| p.join() | |
| return results | |
| def get_average_score(self, pairs: list[TextPair]) -> float: | |
| """ | |
| Get the average score of a batch of texts. | |
| """ | |
| results = self.evaluate(pairs) | |
| self.results = results | |
| return sum(self.results) / len(pairs) | |
| def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]: | |
| """ | |
| Get the min and max score of a batch of texts. | |
| """ | |
| if self.results is None: | |
| self.get_average_score(pairs) | |
| return min(self.results), max(self.results) | |