Update README.md
Browse files
README.md
CHANGED
@@ -1,100 +1,100 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
language:
|
4 |
-
- en
|
5 |
-
base_model:
|
6 |
-
- meta-llama/Llama-3.2-3B-Instruct
|
7 |
-
tags:
|
8 |
-
- rewardmodel
|
9 |
-
- GRAM
|
10 |
-
- RLHF
|
11 |
-
- reward
|
12 |
-
---
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
28 |
-
|
29 |
-
|
|
30 |
-
|
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
]
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
prompt = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
|
79 |
-
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
"
|
89 |
-
"
|
90 |
-
"
|
91 |
-
"
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
base_model:
|
6 |
+
- meta-llama/Llama-3.2-3B-Instruct
|
7 |
+
tags:
|
8 |
+
- rewardmodel
|
9 |
+
- GRAM
|
10 |
+
- RLHF
|
11 |
+
- reward
|
12 |
+
---
|
13 |
+
# Introduction
|
14 |
+
|
15 |
+
This repository contains the released models for the paper [GRAM: A Generative Foundation Reward Model for Reward Generalization 📝]().
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
This training process is introduced above. Traditionally, these models are trained using labeled data, which can limit their potential. In this study, we propose a new method that combines both labeled and unlabeled data for training reward models. We introduce a generative reward model that first learns from a large amount of unlabeled data and is then fine-tuned with supervised data. Additionally, we demonstrate that using label smoothing during training improves performance by optimizing a regularized ranking loss. This approach bridges generative and discriminative models, offering a new perspective on training reward models. Our model can be easily applied to various tasks without the need for extensive fine-tuning. This means that when aligning LLMs, there is no longer a need to train a reward model from scratch with large amounts of task-specific labeled data. Instead, **you can directly apply our reward model or adapt it to align your LLM based on our [code](https://github.com/wangclnlp/GRAM/tree/main)**.
|
20 |
+
|
21 |
+
This reward model is fine-tuned from [Llama-3.2-3B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct)
|
22 |
+
|
23 |
+
## Evaluation
|
24 |
+
|
25 |
+
We evaluate our reward model on the [JudgeBench](https://huggingface.co/datasets/ScalerLab/JudgeBench), a benchmark for evaluating LLM-as-a-Judge applications, and present the results as follows:
|
26 |
+
|
27 |
+
| Model | Param. | Chat | Code | Math | Safety | Avg. |
|
28 |
+
|:-|-:|:-:|:-:|:-:|:-:|:-:|
|
29 |
+
|[GRAM-Qwen3-14B-RewardBench](https://huggingface.co/wangclnlp/GRAM-Qwen3-14B-RewardModel) |14B|63.0|64.3|89.3|69.1|71.4|
|
30 |
+
|[GRAM-LLaMA3.2-3B-RewardBench](https://huggingface.co/wangclnlp/GRAM-LLaMA3.2-3B-RewardModel) |3B|59.7|64.3|84.0|71.4|69.9|
|
31 |
+
|[GRAM-Qwen3-8B-RewardBench](https://huggingface.co/wangclnlp/GRAM-Qwen3-8B-RewardModel) |8B|62.3|64.3|80.4|64.3|67.8|
|
32 |
+
|nvidia/Llama-3.1-Nemotron-70B-Reward|70B|62.3|72.5|76.8|57.1|67.2|
|
33 |
+
|[GRAM-Qwen3-4B-RewardBench](https://huggingface.co/wangclnlp/GRAM-Qwen3-4B-RewardModel) |4B|59.7|59.2|80.4|64.3|65.9|
|
34 |
+
|[GRAM-Qwen3-1.7B-RewardBench](https://huggingface.co/wangclnlp/GRAM-Qwen3-1.7B-RewardModel) |1.7B|60.4|65.3|78.6|57.1|65.4|
|
35 |
+
|Skywork/Skywork-Reward-Gemma-2-27B-v0.2|27B|59.7|66.3|83.9|50.0|65.0|
|
36 |
+
|Skywork/Skywork-Reward-Llama-3.1-8B-v0.2|8B|59.1|64.3|76.8|50.0|62.6|
|
37 |
+
|internlm/internlm2-20b-reward|20B|62.3|69.4|66.1|50.0|62.0|
|
38 |
+
|
39 |
+
|
40 |
+
## Usage
|
41 |
+
|
42 |
+
You can directly run the GRAM model using the demo provided below. You can also train GRAM using the code available [here](https://github.com/wangclnlp/GRAM/tree/main).
|
43 |
+
|
44 |
+
```python
|
45 |
+
import torch
|
46 |
+
import accelerate
|
47 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
48 |
+
prompt = """Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user\'s instructions and answers the user\'s question better.
|
49 |
+
Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible.
|
50 |
+
Please directly output your final verdict by strictly following this format: "A" if assistant A is better, "B" if assistant B is better.
|
51 |
+
[User Question]
|
52 |
+
{input}
|
53 |
+
[The Start of Assistant A's Answer]
|
54 |
+
{response_a}
|
55 |
+
[The End of Assistant A's Answer]
|
56 |
+
[The Start of Assistant B's Answer]
|
57 |
+
{response_b}
|
58 |
+
[The End of Assistant B's Answer]
|
59 |
+
"""
|
60 |
+
query = "What is the Russian word for frog?"
|
61 |
+
response1 = "The Russian word for frog is \"лягушка\" (pronounced \"lyagushka\")."
|
62 |
+
response2 = "The Russian word for frog is \"жаба\" (pronounced as \"zhaba\"). This word can also be written in Cyrillic as жа́ба. If you're learning Russian, here's a sentence with the word: Меня зовут Иван, и я люблю лезечку на спину жабы, which translates to \"My name is Ivan, and I like sitting on the back of a frog.\" (Keep in mind that in real life, it is best not to disturb or harm frogs.)"
|
63 |
+
model_name_or_path = "gram-open-source/GRAM-Qwen3-1.7B-RewardModel"
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
65 |
+
tokenizer.padding_side = "left"
|
66 |
+
if not tokenizer.pad_token:
|
67 |
+
tokenizer.pad_token = tokenizer.eos_token
|
68 |
+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16, device_map="auto")
|
69 |
+
messages = [
|
70 |
+
[{"role": "user", "content": prompt.format(input=query, response_a=response1, response_b=response2)}],
|
71 |
+
[{"role": "user", "content": prompt.format(input=query, response_a=response2, response_b=response1)}],
|
72 |
+
]
|
73 |
+
# target at response1, response2 respectively
|
74 |
+
target_choices_response1 = ["A", "B"]
|
75 |
+
target_choices_response1_token_ids = torch.tensor([tokenizer(item, add_special_tokens=False).input_ids for item in target_choices_response1], device=model.device)
|
76 |
+
target_choices_response2_token_ids = torch.flip(target_choices_response1_token_ids, dims=(0,))
|
77 |
+
target_choices_token_ids = torch.cat((target_choices_response1_token_ids, target_choices_response2_token_ids), dim=1)
|
78 |
+
prompt = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
|
79 |
+
inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
|
80 |
+
with torch.no_grad():
|
81 |
+
output = model(**inputs)
|
82 |
+
logits = torch.gather(output.logits[..., -1, :], 1, target_choices_token_ids)
|
83 |
+
p = torch.nn.Softmax(dim=0)(logits)
|
84 |
+
score_response1, score_response2 = torch.mean(p, dim=1).tolist()
|
85 |
+
print({
|
86 |
+
"query": query,
|
87 |
+
"response1": response1,
|
88 |
+
"response2": response2,
|
89 |
+
"score_response1": score_response1,
|
90 |
+
"score_response2": score_response2,
|
91 |
+
"response1_is_better": score_response1 > score_response2,
|
92 |
+
})
|
93 |
+
```
|
94 |
+
|
95 |
+
## Citation
|
96 |
+
|
97 |
+
If you find this model helpful for your research, please cite GRAM:
|
98 |
+
```bash
|
99 |
+
bib
|
100 |
+
```
|