{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "WQpNapZNWuXP" }, "source": [ "\n", "**Best-of-n sampling as an alternative to RLHF**\n", "\n", "This notebook compares reward-model scores of prompt based responses from \n", "1. a base model (`gpt2-imdb`)\n", "2. `RLHF` tuned model based on this base-model \n", "3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n", "\n", "Import dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vDA6qayz692w" }, "outputs": [], "source": [ "%pip install transformers trl" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "M1s_iNm773hM" }, "outputs": [], "source": [ "import torch\n", "import pandas as pd\n", "\n", "from transformers import pipeline, AutoTokenizer\n", "from datasets import load_dataset\n", "\n", "from trl import AutoModelForCausalLMWithValueHead\n", "from trl.core import LengthSampler\n", "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "markdown", "metadata": { "id": "Y7hyrIrO8tcY" }, "source": [ "Various constants" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "MqS3OM6Q8x6g" }, "outputs": [], "source": [ "ref_model_name = \"lvwerra/gpt2-imdb\"\n", "model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n", "reward_model = \"lvwerra/distilbert-imdb\"\n", "\n", "N_BEST_OF = 4" ] }, { "cell_type": "markdown", "metadata": { "id": "c1YcXeElg6or" }, "source": [ "Models and tokenizers" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "b855NrL181Hh" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/kashif/Github/transformers/src/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", " warnings.warn(\n" ] }, { "data": { "text/plain": [ "AutoModelForCausalLMWithValueHead(\n", " (pretrained_model): GPT2LMHeadModel(\n", " (transformer): GPT2Model(\n", " (wte): Embedding(50257, 768)\n", " (wpe): Embedding(1024, 768)\n", " (drop): Dropout(p=0.1, inplace=False)\n", " (h): ModuleList(\n", " (0-11): 12 x GPT2Block(\n", " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (attn): GPT2SdpaAttention(\n", " (c_attn): Conv1D(nf=2304, nx=768)\n", " (c_proj): Conv1D(nf=768, nx=768)\n", " (attn_dropout): Dropout(p=0.1, inplace=False)\n", " (resid_dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " (mlp): GPT2MLP(\n", " (c_fc): Conv1D(nf=3072, nx=768)\n", " (c_proj): Conv1D(nf=768, nx=3072)\n", " (act): NewGELUActivation()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " )\n", " )\n", " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", " )\n", " (v_head): ValueHead(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (summary): Linear(in_features=768, out_features=1, bias=True)\n", " (flatten): Flatten(start_dim=1, end_dim=-1)\n", " )\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n", "\n", "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n", "\n", "reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n", "\n", "tokenizer.pad_token = tokenizer.eos_token\n", "\n", "# cuda-ize models\n", "model.to(device)\n", "ref_model.to(device)" ] }, { "cell_type": "markdown", "metadata": { "id": "Z1Cz0gCFhZYJ" }, "source": [ "Dataset building" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "LqLVEp5p_8XM" }, "outputs": [], "source": [ "def build_dataset(\n", " tokenizer,\n", " dataset_name=\"stanfordnlp/imdb\",\n", " input_min_text_length=2,\n", " input_max_text_length=8,\n", "):\n", " # load imdb with datasets\n", " ds = load_dataset(dataset_name, split=\"train\")\n", " ds = ds.rename_columns({\"text\": \"review\"})\n", " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", "\n", " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", "\n", " def tokenize(sample):\n", " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", " return sample\n", "\n", " ds = ds.map(tokenize, batched=False)\n", " ds.set_format(type=\"torch\")\n", " return ds\n", "\n", "\n", "dataset = build_dataset(tokenizer)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "AqA2McjMAxNw" }, "outputs": [], "source": [ "gen_kwargs = {\n", " \"min_length\": -1,\n", " \"top_k\": 0.0,\n", " \"top_p\": 1.0,\n", " \"do_sample\": True,\n", " \"pad_token_id\": tokenizer.eos_token_id,\n", "}\n", "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "L_q4qs35AxcR" }, "outputs": [], "source": [ "output_min_length = 4\n", "output_max_length = 16\n", "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", "\n", "#### get a batch from the dataset\n", "bs = 16\n", "output_data = dict()\n", "dataset.set_format(\"pandas\")\n", "df_batch = dataset[:].sample(bs)\n", "output_data[\"query\"] = df_batch[\"query\"].tolist()\n", "query_tensors = df_batch[\"input_ids\"].tolist()\n", "\n", "# :: [Resp]\n", "response_tensors_ref, response_tensors = [], []\n", "# :: [[Resp]]\n", "response_tensors_best_of = []" ] }, { "cell_type": "markdown", "metadata": { "id": "QVfpyHnZBLKY" }, "source": [ "\n", "Generation using various models" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "-imZ7uEFBNbw" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" ] } ], "source": [ "for i in range(bs):\n", " gen_len = output_length_sampler()\n", "\n", " query = torch.tensor(query_tensors[i])\n", "\n", " output = ref_model.generate(\n", " query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", " ).squeeze()\n", " response_tensors_ref.append(tokenizer.decode(output))\n", "\n", " output = model.generate(\n", " query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", " ).squeeze()\n", " response_tensors.append(tokenizer.decode(output))\n", "\n", " # generating copies of the same query for the Best-of-n sampling\n", " queries = query.repeat((N_BEST_OF, 1))\n", " output = ref_model.generate(\n", " queries.to(device), max_new_tokens=gen_len, **gen_kwargs\n", " ).squeeze()\n", " response_tensors_best_of.append(tokenizer.batch_decode(output))" ] }, { "cell_type": "markdown", "metadata": { "id": "Jp5FC0Y5h_Sf" }, "source": [ "Scoring" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "PyDbbAQ0F_h7" }, "outputs": [], "source": [ "scores_ref = [\n", " output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)\n", "]\n", "scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n", "scores_best_of = []\n", "for i, response in enumerate(response_tensors_best_of):\n", " # base_score = scores_ref[i]\n", " scores_best_of.append(\n", " torch.tensor(\n", " [output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]\n", " )\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 682 }, "id": "nA1GDNJEiGm-", "outputId": "1389c686-0751-4304-dea2-b71fd68748e1" }, "outputs": [ { "data": { "text/html": [ "
\n", " | query | \n", "response (ref) | \n", "scores (ref) | \n", "response (RLHF) | \n", "scores (RLHF) | \n", "response (best_of) | \n", "scores (best_of) | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "This movie | \n", "This movie should have read some books, and | \n", "1.411889 | \n", "This movie has plenty of extraordinary feature... | \n", "2.735337 | \n", "This movie was unexpectedly funny and funny, you | \n", "2.405301 | \n", "
1 | \n", "OK where do i begin? | \n", "OK where do i begin? *** Acting is decent (not... | \n", "1.555380 | \n", "OK where do i begin? For all of you who are no... | \n", "0.019694 | \n", "OK where do i begin? i just wanted to add some... | \n", "0.622912 | \n", "
2 | \n", "I watched | \n", "I watched one can compare themselves upon view... | \n", "1.380120 | \n", "I watched it because of its excellent cast. Th... | \n", "2.498309 | \n", "I watched the trial trial for teaches us a goo... | \n", "2.057187 | \n", "
3 | \n", "It's been 19 years since Gordon | \n", "It's been 19 years since Gordon finally left c... | \n", "1.554914 | \n", "It's been 19 years since Gordon Tree has becom... | \n", "1.632266 | \n", "It's been 19 years since Gordon Clarke put me ... | \n", "2.783458 | \n", "
4 | \n", "Just kidding | \n", "Just kidding; I know a lot | \n", "-0.069533 | \n", "Just kidding \"Third World Snopes | \n", "0.944632 | \n", "Just kidding, I didn't even | \n", "1.945202 | \n", "
5 | \n", "shakespeare's plays have a way | \n", "shakespeare's plays have a way of weaving into... | \n", "1.656927 | \n", "shakespeare's plays have a way. It's the look ... | \n", "1.444803 | \n", "shakespeare's plays have a way of getting back... | \n", "1.834373 | \n", "
6 | \n", "This movie is wonderful. What | \n", "This movie is wonderful. What could have been ... | \n", "2.749068 | \n", "This movie is wonderful. What someone likes ab... | \n", "2.759510 | \n", "This movie is wonderful. What a different look, | \n", "2.695312 | \n", "
7 | \n", "I loved | \n", "I loved this film. <br />< | \n", "2.576181 | \n", "I loved it, and I really loved Audrey | \n", "2.578412 | \n", "I loved this film. Reading reviews of it | \n", "2.751773 | \n", "
8 | \n", "A superb and | \n", "A superb and very cool drama. The novel is | \n", "2.910374 | \n", "A superb and super fun movie that removes all the | \n", "2.783201 | \n", "A superb and most finely acted role that I will | \n", "2.894923 | \n", "
9 | \n", "I remember | \n", "I remember.Very poor execution but good movies | \n", "0.923775 | \n", "I remember when Shelter saw some girls on TV | \n", "0.825408 | \n", "I remember thinking to myself how SOMEONE who | \n", "1.634163 | \n", "
10 | \n", "This su*k | \n", "This su*k camel down your kidd | \n", "1.605957 | \n", "This su*k Dress! I loved it | \n", "2.345865 | \n", "This su*k like a roll of crap | \n", "2.422874 | \n", "
11 | \n", "One Stink | \n", "One Stink Act...<br /><br | \n", "1.456476 | \n", "One Stinkl was a great actor, particularly | \n", "1.782818 | \n", "One Stink?: Invisible of Saint Barbara, poor | \n", "1.667756 | \n", "
12 | \n", "I pulled down a VHS | \n", "I pulled down a VHS copy and watched it with m... | \n", "0.756151 | \n", "I pulled down a VHS looking a good looking, and a | \n", "-0.008258 | \n", "I pulled down a VHS copy the other day and all I | \n", "0.992919 | \n", "
13 | \n", "For some | \n", "For some alone no more Buddy Trumbull would ha... | \n", "0.790762 | \n", "For some enthraled time, the film will impress... | \n", "2.455694 | \n", "For some reason, a bomb crashed on the rear of... | \n", "0.857423 | \n", "
14 | \n", "This one features all | \n", "This one features all the good elements of spi... | \n", "1.452079 | \n", "This one features all kinds of wit and humor r... | \n", "2.743043 | \n", "This one features all the best Birdprogram sup... | \n", "2.343950 | \n", "
15 | \n", "Somehow a woman working with | \n", "Somehow a woman working with Jim Wynorski prof... | \n", "0.242172 | \n", "Somehow a woman working with her daughter play... | \n", "0.092226 | \n", "Somehow a woman working with an overweight ins... | \n", "1.415525 | \n", "