{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "WQpNapZNWuXP" }, "source": [ "\n", "**Best-of-n sampling as an alternative to RLHF**\n", "\n", "This notebook compares reward-model scores of prompt based responses from \n", "1. a base model (`gpt2-imdb`)\n", "2. `RLHF` tuned model based on this base-model \n", "3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n", "\n", "Import dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vDA6qayz692w" }, "outputs": [], "source": [ "%pip install transformers trl" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "M1s_iNm773hM" }, "outputs": [], "source": [ "import torch\n", "import pandas as pd\n", "\n", "from transformers import pipeline, AutoTokenizer\n", "from datasets import load_dataset\n", "\n", "from trl import AutoModelForCausalLMWithValueHead\n", "from trl.core import LengthSampler\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "markdown", "metadata": { "id": "Y7hyrIrO8tcY" }, "source": [ "Various constants" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "MqS3OM6Q8x6g" }, "outputs": [], "source": [ "ref_model_name = \"lvwerra/gpt2-imdb\"\n", "model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n", "reward_model = \"lvwerra/distilbert-imdb\"\n", "\n", "N_BEST_OF = 4" ] }, { "cell_type": "markdown", "metadata": { "id": "c1YcXeElg6or" }, "source": [ "Models and tokenizers" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "b855NrL181Hh" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/kashif/Github/transformers/src/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "AutoModelForCausalLMWithValueHead(\n", " (pretrained_model): GPT2LMHeadModel(\n", " (transformer): GPT2Model(\n", " (wte): Embedding(50257, 768)\n", " (wpe): Embedding(1024, 768)\n", " (drop): Dropout(p=0.1, inplace=False)\n", " (h): ModuleList(\n", " (0-11): 12 x GPT2Block(\n", " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (attn): GPT2SdpaAttention(\n", " (c_attn): Conv1D(nf=2304, nx=768)\n", " (c_proj): Conv1D(nf=768, nx=768)\n", " (attn_dropout): Dropout(p=0.1, inplace=False)\n", " (resid_dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (mlp): GPT2MLP(\n", " (c_fc): Conv1D(nf=3072, nx=768)\n", " (c_proj): Conv1D(nf=768, nx=3072)\n", " (act): NewGELUActivation()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", " )\n", " (v_head): ValueHead(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (summary): Linear(in_features=768, out_features=1, bias=True)\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n", " )\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n", "\n", "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n", "\n", "reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n", "\n", "tokenizer.pad_token = tokenizer.eos_token\n", "\n", "# cuda-ize models\n", "model.to(device)\n", "ref_model.to(device)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z1Cz0gCFhZYJ" }, "source": [ "Dataset building" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "LqLVEp5p_8XM" }, "outputs": [], "source": [ "def build_dataset(\n", " tokenizer,\n", " dataset_name=\"stanfordnlp/imdb\",\n", " input_min_text_length=2,\n", " input_max_text_length=8,\n", "):\n", " # load imdb with datasets\n", " ds = load_dataset(dataset_name, split=\"train\")\n", " ds = ds.rename_columns({\"text\": \"review\"})\n", " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", "\n", " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", "\n", " def tokenize(sample):\n", " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", " return sample\n", "\n", " ds = ds.map(tokenize, batched=False)\n", " ds.set_format(type=\"torch\")\n", " return ds\n", "\n", "\n", "dataset = build_dataset(tokenizer)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "AqA2McjMAxNw" }, "outputs": [], "source": [ "gen_kwargs = {\n", " \"min_length\": -1,\n", " \"top_k\": 0.0,\n", " \"top_p\": 1.0,\n", " \"do_sample\": True,\n", " \"pad_token_id\": tokenizer.eos_token_id,\n", "}\n", "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "L_q4qs35AxcR" }, "outputs": [], "source": [ "output_min_length = 4\n", "output_max_length = 16\n", "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", "\n", "#### get a batch from the dataset\n", "bs = 16\n", "output_data = dict()\n", "dataset.set_format(\"pandas\")\n", "df_batch = dataset[:].sample(bs)\n", "output_data[\"query\"] = df_batch[\"query\"].tolist()\n", "query_tensors = df_batch[\"input_ids\"].tolist()\n", "\n", "# :: [Resp]\n", "response_tensors_ref, response_tensors = [], []\n", "# :: [[Resp]]\n", "response_tensors_best_of = []" ] }, { "cell_type": "markdown", "metadata": { "id": "QVfpyHnZBLKY" }, "source": [ "\n", "Generation using various models" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "-imZ7uEFBNbw" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" ] } ], "source": [ "for i in range(bs):\n", " gen_len = output_length_sampler()\n", "\n", " query = torch.tensor(query_tensors[i])\n", "\n", " output = ref_model.generate(\n", " query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", " ).squeeze()\n", " response_tensors_ref.append(tokenizer.decode(output))\n", "\n", " output = model.generate(\n", " query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", " ).squeeze()\n", " response_tensors.append(tokenizer.decode(output))\n", "\n", " # generating copies of the same query for the Best-of-n sampling\n", " queries = query.repeat((N_BEST_OF, 1))\n", " output = ref_model.generate(\n", " queries.to(device), max_new_tokens=gen_len, **gen_kwargs\n", " ).squeeze()\n", " response_tensors_best_of.append(tokenizer.batch_decode(output))" ] }, { "cell_type": "markdown", "metadata": { "id": "Jp5FC0Y5h_Sf" }, "source": [ "Scoring" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "PyDbbAQ0F_h7" }, "outputs": [], "source": [ "scores_ref = [\n", " output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)\n", "]\n", "scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n", "scores_best_of = []\n", "for i, response in enumerate(response_tensors_best_of):\n", " # base_score = scores_ref[i]\n", " scores_best_of.append(\n", " torch.tensor(\n", " [output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]\n", " )\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 682 }, "id": "nA1GDNJEiGm-", "outputId": "1389c686-0751-4304-dea2-b71fd68748e1" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
queryresponse (ref)scores (ref)response (RLHF)scores (RLHF)response (best_of)scores (best_of)
0This movieThis movie should have read some books, and1.411889This movie has plenty of extraordinary feature...2.735337This movie was unexpectedly funny and funny, you2.405301
1OK where do i begin?OK where do i begin? *** Acting is decent (not...1.555380OK where do i begin? For all of you who are no...0.019694OK where do i begin? i just wanted to add some...0.622912
2I watchedI watched one can compare themselves upon view...1.380120I watched it because of its excellent cast. Th...2.498309I watched the trial trial for teaches us a goo...2.057187
3It's been 19 years since GordonIt's been 19 years since Gordon finally left c...1.554914It's been 19 years since Gordon Tree has becom...1.632266It's been 19 years since Gordon Clarke put me ...2.783458
4Just kiddingJust kidding; I know a lot-0.069533Just kidding \"Third World Snopes0.944632Just kidding, I didn't even1.945202
5shakespeare's plays have a wayshakespeare's plays have a way of weaving into...1.656927shakespeare's plays have a way. It's the look ...1.444803shakespeare's plays have a way of getting back...1.834373
6This movie is wonderful. WhatThis movie is wonderful. What could have been ...2.749068This movie is wonderful. What someone likes ab...2.759510This movie is wonderful. What a different look,2.695312
7I lovedI loved this film. <br /><2.576181I loved it, and I really loved Audrey2.578412I loved this film. Reading reviews of it2.751773
8A superb andA superb and very cool drama. The novel is2.910374A superb and super fun movie that removes all the2.783201A superb and most finely acted role that I will2.894923
9I rememberI remember.Very poor execution but good movies0.923775I remember when Shelter saw some girls on TV0.825408I remember thinking to myself how SOMEONE who1.634163
10This su*kThis su*k camel down your kidd1.605957This su*k Dress! I loved it2.345865This su*k like a roll of crap2.422874
11One StinkOne Stink Act...<br /><br1.456476One Stinkl was a great actor, particularly1.782818One Stink?: Invisible of Saint Barbara, poor1.667756
12I pulled down a VHSI pulled down a VHS copy and watched it with m...0.756151I pulled down a VHS looking a good looking, and a-0.008258I pulled down a VHS copy the other day and all I0.992919
13For someFor some alone no more Buddy Trumbull would ha...0.790762For some enthraled time, the film will impress...2.455694For some reason, a bomb crashed on the rear of...0.857423
14This one features allThis one features all the good elements of spi...1.452079This one features all kinds of wit and humor r...2.743043This one features all the best Birdprogram sup...2.343950
15Somehow a woman working withSomehow a woman working with Jim Wynorski prof...0.242172Somehow a woman working with her daughter play...0.092226Somehow a woman working with an overweight ins...1.415525
\n", "
" ], "text/plain": [ " query \\\n", "0 This movie \n", "1 OK where do i begin? \n", "2 I watched \n", "3 It's been 19 years since Gordon \n", "4 Just kidding \n", "5 shakespeare's plays have a way \n", "6 This movie is wonderful. What \n", "7 I loved \n", "8 A superb and \n", "9 I remember \n", "10 This su*k \n", "11 One Stink \n", "12 I pulled down a VHS \n", "13 For some \n", "14 This one features all \n", "15 Somehow a woman working with \n", "\n", " response (ref) scores (ref) \\\n", "0 This movie should have read some books, and 1.411889 \n", "1 OK where do i begin? *** Acting is decent (not... 1.555380 \n", "2 I watched one can compare themselves upon view... 1.380120 \n", "3 It's been 19 years since Gordon finally left c... 1.554914 \n", "4 Just kidding; I know a lot -0.069533 \n", "5 shakespeare's plays have a way of weaving into... 1.656927 \n", "6 This movie is wonderful. What could have been ... 2.749068 \n", "7 I loved this film.
< 2.576181 \n", "8 A superb and very cool drama. The novel is 2.910374 \n", "9 I remember.Very poor execution but good movies 0.923775 \n", "10 This su*k camel down your kidd 1.605957 \n", "11 One Stink Act...