
**Best-of-n sampling as an alternative to RLHF**

This notebook compares reward-model scores of prompt based responses from 
1. a base model (`gpt2-imdb`)
2. `RLHF` tuned model based on this base-model 
3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model

Import dependencies

In [None]:
%pip install transformers trl

In [2]:
import torch
import pandas as pd

from transformers import pipeline, AutoTokenizer
from datasets import load_dataset

from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

device = "cuda" if torch.cuda.is_available() else "cpu"

Various constants

In [3]:
ref_model_name = "lvwerra/gpt2-imdb"
model_name = "lvwerra/gpt2-imdb-pos-v2"
reward_model = "lvwerra/distilbert-imdb"

N_BEST_OF = 4

Models and tokenizers

In [4]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)

ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)

reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)

tokenizer = AutoTokenizer.from_pretrained(ref_model_name)

tokenizer.pad_token = tokenizer.eos_token

# cuda-ize models
model.to(device)
ref_model.to(device)



AutoModelForCausalLMWithValueHead(
  (pretrained_model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2SdpaAttention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): 

Dataset building

In [5]:
def build_dataset(
    tokenizer,
    dataset_name="stanfordnlp/imdb",
    input_min_text_length=2,
    input_max_text_length=8,
):
    # load imdb with datasets
    ds = load_dataset(dataset_name, split="train")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds


dataset = build_dataset(tokenizer)

In [6]:
gen_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}
sent_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}

In [7]:
output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)

#### get a batch from the dataset
bs = 16
output_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
output_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

# :: [Resp]
response_tensors_ref, response_tensors = [], []
# :: [[Resp]]
response_tensors_best_of = []


Generation using various models

In [8]:
for i in range(bs):
    gen_len = output_length_sampler()

    query = torch.tensor(query_tensors[i])

    output = ref_model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_ref.append(tokenizer.decode(output))

    output = model.generate(
        query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors.append(tokenizer.decode(output))

    # generating copies of the same query for the Best-of-n sampling
    queries = query.repeat((N_BEST_OF, 1))
    output = ref_model.generate(
        queries.to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()
    response_tensors_best_of.append(tokenizer.batch_decode(output))

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Scoring

In [9]:
scores_ref = [
    output[0]["score"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)
]
scores = [output[0]["score"] for output in reward_pipe(response_tensors, **sent_kwargs)]
scores_best_of = []
for i, response in enumerate(response_tensors_best_of):
    # base_score = scores_ref[i]
    scores_best_of.append(
        torch.tensor(
            [output[0]["score"] for output in reward_pipe(response, **sent_kwargs)]
        )
    )

In [10]:
output_data["response (ref)"] = response_tensors_ref
output_data["scores (ref)"] = scores_ref
output_data["response (RLHF)"] = response_tensors
output_data["scores (RLHF)"] = scores
output_data["response (best_of)"] = [
    response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)
]
output_data["scores (best_of)"] = [a.max().item() for a in scores_best_of]


# store results in a dataframe
df_results = pd.DataFrame(output_data)
df_results

Unnamed: 0,query,response (ref),scores (ref),response (RLHF),scores (RLHF),response (best_of),scores (best_of)
0,This movie,"This movie should have read some books, and",1.411889,This movie has plenty of extraordinary feature...,2.735337,"This movie was unexpectedly funny and funny, you",2.405301
1,OK where do i begin?,OK where do i begin? *** Acting is decent (not...,1.55538,OK where do i begin? For all of you who are no...,0.019694,OK where do i begin? i just wanted to add some...,0.622912
2,I watched,I watched one can compare themselves upon view...,1.38012,I watched it because of its excellent cast. Th...,2.498309,I watched the trial trial for teaches us a goo...,2.057187
3,It's been 19 years since Gordon,It's been 19 years since Gordon finally left c...,1.554914,It's been 19 years since Gordon Tree has becom...,1.632266,It's been 19 years since Gordon Clarke put me ...,2.783458
4,Just kidding,Just kidding; I know a lot,-0.069533,"Just kidding ""Third World Snopes",0.944632,"Just kidding, I didn't even",1.945202
5,shakespeare's plays have a way,shakespeare's plays have a way of weaving into...,1.656927,shakespeare's plays have a way. It's the look ...,1.444803,shakespeare's plays have a way of getting back...,1.834373
6,This movie is wonderful. What,This movie is wonderful. What could have been ...,2.749068,This movie is wonderful. What someone likes ab...,2.75951,"This movie is wonderful. What a different look,",2.695312
7,I loved,I loved this film. <br /><,2.576181,"I loved it, and I really loved Audrey",2.578412,I loved this film. Reading reviews of it,2.751773
8,A superb and,A superb and very cool drama. The novel is,2.910374,A superb and super fun movie that removes all the,2.783201,A superb and most finely acted role that I will,2.894923
9,I remember,I remember.Very poor execution but good movies,0.923775,I remember when Shelter saw some girls on TV,0.825408,I remember thinking to myself how SOMEONE who,1.634163
