{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "rToK0Tku8PPn" }, "source": [ "## makemore: becoming a backprop ninja" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8sFElPqq8PPp" }, "outputs": [], "source": [ "# there no change change in the first several cells from last lecture" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ChBbac4y8PPq" }, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F\n", "import matplotlib.pyplot as plt # for making figures\n", "%matplotlib inline" ] }, { "cell_type": "code", "source": [ "# download the names.txt file from github\n", "!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt" ], "metadata": { "id": "x6GhEWW18aCS" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "klmu3ZG08PPr" }, "outputs": [], "source": [ "# read in all the words\n", "words = open('names.txt', 'r').read().splitlines()\n", "print(len(words))\n", "print(max(len(w) for w in words))\n", "print(words[:8])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "BCQomLE_8PPs" }, "outputs": [], "source": [ "# build the vocabulary of characters and mappings to/from integers\n", "chars = sorted(list(set(''.join(words))))\n", "stoi = {s:i+1 for i,s in enumerate(chars)}\n", "stoi['.'] = 0\n", "itos = {i:s for s,i in stoi.items()}\n", "vocab_size = len(itos)\n", "print(itos)\n", "print(vocab_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V_zt2QHr8PPs" }, "outputs": [], "source": [ "# build the dataset\n", "block_size = 3 # context length: how many characters do we take to predict the next one?\n", "\n", "def build_dataset(words):\n", " X, Y = [], []\n", "\n", " for w in words:\n", " context = [0] * block_size\n", " for ch in w + '.':\n", " ix = stoi[ch]\n", " X.append(context)\n", " Y.append(ix)\n", " context = context[1:] + [ix] # crop and append\n", "\n", " X = torch.tensor(X)\n", " Y = torch.tensor(Y)\n", " print(X.shape, Y.shape)\n", " return X, Y\n", "\n", "import random\n", "random.seed(42)\n", "random.shuffle(words)\n", "n1 = int(0.8*len(words))\n", "n2 = int(0.9*len(words))\n", "\n", "Xtr, Ytr = build_dataset(words[:n1]) # 80%\n", "Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n", "Xte, Yte = build_dataset(words[n2:]) # 10%" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eg20-vsg8PPt" }, "outputs": [], "source": [ "# ok biolerplate done, now we get to the action:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MJPU8HT08PPu" }, "outputs": [], "source": [ "# utility function we will use later when comparing manual gradients to PyTorch gradients\n", "def cmp(s, dt, t):\n", " ex = torch.all(dt == t.grad).item()\n", " app = torch.allclose(dt, t.grad)\n", " maxdiff = (dt - t.grad).abs().max().item()\n", " print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZlFLjQyT8PPu" }, "outputs": [], "source": [ "n_embd = 10 # the dimensionality of the character embedding vectors\n", "n_hidden = 64 # the number of neurons in the hidden layer of the MLP\n", "\n", "g = torch.Generator().manual_seed(2147483647) # for reproducibility\n", "C = torch.randn((vocab_size, n_embd), generator=g)\n", "# Layer 1\n", "W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n", "b1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN\n", "# Layer 2\n", "W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n", "b2 = torch.randn(vocab_size, generator=g) * 0.1\n", "# BatchNorm parameters\n", "bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n", "bnbias = torch.randn((1, n_hidden))*0.1\n", "\n", "# Note: I am initializating many of these parameters in non-standard ways\n", "# because sometimes initializating with e.g. all zeros could mask an incorrect\n", "# implementation of the backward pass.\n", "\n", "parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n", "print(sum(p.nelement() for p in parameters)) # number of parameters in total\n", "for p in parameters:\n", " p.requires_grad = True" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QY-y96Y48PPv" }, "outputs": [], "source": [ "batch_size = 32\n", "n = batch_size # a shorter variable also, for convenience\n", "# construct a minibatch\n", "ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n", "Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8ofj1s6d8PPv" }, "outputs": [], "source": [ "# forward pass, \"chunkated\" into smaller steps that are possible to backward one at a time\n", "\n", "emb = C[Xb] # embed the characters into vectors\n", "embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n", "# Linear layer 1\n", "hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n", "# BatchNorm layer\n", "bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n", "bndiff = hprebn - bnmeani\n", "bndiff2 = bndiff**2\n", "bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n", "bnvar_inv = (bnvar + 1e-5)**-0.5\n", "bnraw = bndiff * bnvar_inv\n", "hpreact = bngain * bnraw + bnbias\n", "# Non-linearity\n", "h = torch.tanh(hpreact) # hidden layer\n", "# Linear layer 2\n", "logits = h @ W2 + b2 # output layer\n", "# cross entropy loss (same as F.cross_entropy(logits, Yb))\n", "logit_maxes = logits.max(1, keepdim=True).values\n", "norm_logits = logits - logit_maxes # subtract max for numerical stability\n", "counts = norm_logits.exp()\n", "counts_sum = counts.sum(1, keepdims=True)\n", "counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n", "probs = counts * counts_sum_inv\n", "logprobs = probs.log()\n", "loss = -logprobs[range(n), Yb].mean()\n", "\n", "# PyTorch backward pass\n", "for p in parameters:\n", " p.grad = None\n", "for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way\n", " norm_logits, logit_maxes, logits, h, hpreact, bnraw,\n", " bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,\n", " embcat, emb]:\n", " t.retain_grad()\n", "loss.backward()\n", "loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mO-8aqxK8PPw" }, "outputs": [], "source": [ "# Exercise 1: backprop through the whole thing manually,\n", "# backpropagating through exactly all of the variables\n", "# as they are defined in the forward pass above, one by one\n", "\n", "# -----------------\n", "# YOUR CODE HERE :)\n", "# -----------------\n", "\n", "# cmp('logprobs', dlogprobs, logprobs)\n", "# cmp('probs', dprobs, probs)\n", "# cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)\n", "# cmp('counts_sum', dcounts_sum, counts_sum)\n", "# cmp('counts', dcounts, counts)\n", "# cmp('norm_logits', dnorm_logits, norm_logits)\n", "# cmp('logit_maxes', dlogit_maxes, logit_maxes)\n", "# cmp('logits', dlogits, logits)\n", "# cmp('h', dh, h)\n", "# cmp('W2', dW2, W2)\n", "# cmp('b2', db2, b2)\n", "# cmp('hpreact', dhpreact, hpreact)\n", "# cmp('bngain', dbngain, bngain)\n", "# cmp('bnbias', dbnbias, bnbias)\n", "# cmp('bnraw', dbnraw, bnraw)\n", "# cmp('bnvar_inv', dbnvar_inv, bnvar_inv)\n", "# cmp('bnvar', dbnvar, bnvar)\n", "# cmp('bndiff2', dbndiff2, bndiff2)\n", "# cmp('bndiff', dbndiff, bndiff)\n", "# cmp('bnmeani', dbnmeani, bnmeani)\n", "# cmp('hprebn', dhprebn, hprebn)\n", "# cmp('embcat', dembcat, embcat)\n", "# cmp('W1', dW1, W1)\n", "# cmp('b1', db1, b1)\n", "# cmp('emb', demb, emb)\n", "# cmp('C', dC, C)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ebLtYji_8PPw" }, "outputs": [], "source": [ "# Exercise 2: backprop through cross_entropy but all in one go\n", "# to complete this challenge look at the mathematical expression of the loss,\n", "# take the derivative, simplify the expression, and just write it out\n", "\n", "# forward pass\n", "\n", "# before:\n", "# logit_maxes = logits.max(1, keepdim=True).values\n", "# norm_logits = logits - logit_maxes # subtract max for numerical stability\n", "# counts = norm_logits.exp()\n", "# counts_sum = counts.sum(1, keepdims=True)\n", "# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n", "# probs = counts * counts_sum_inv\n", "# logprobs = probs.log()\n", "# loss = -logprobs[range(n), Yb].mean()\n", "\n", "# now:\n", "loss_fast = F.cross_entropy(logits, Yb)\n", "print(loss_fast.item(), 'diff:', (loss_fast - loss).item())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-gCXbB4C8PPx" }, "outputs": [], "source": [ "# backward pass\n", "\n", "# -----------------\n", "# YOUR CODE HERE :)\n", "dlogits = None # TODO. my solution is 3 lines\n", "# -----------------\n", "\n", "#cmp('logits', dlogits, logits) # I can only get approximate to be true, my maxdiff is 6e-9" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hd-MkhB68PPy" }, "outputs": [], "source": [ "# Exercise 3: backprop through batchnorm but all in one go\n", "# to complete this challenge look at the mathematical expression of the output of batchnorm,\n", "# take the derivative w.r.t. its input, simplify the expression, and just write it out\n", "# BatchNorm paper: https://arxiv.org/abs/1502.03167\n", "\n", "# forward pass\n", "\n", "# before:\n", "# bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n", "# bndiff = hprebn - bnmeani\n", "# bndiff2 = bndiff**2\n", "# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n", "# bnvar_inv = (bnvar + 1e-5)**-0.5\n", "# bnraw = bndiff * bnvar_inv\n", "# hpreact = bngain * bnraw + bnbias\n", "\n", "# now:\n", "hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias\n", "print('max diff:', (hpreact_fast - hpreact).abs().max())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "POdeZSKT8PPy" }, "outputs": [], "source": [ "# backward pass\n", "\n", "# before we had:\n", "# dbnraw = bngain * dhpreact\n", "# dbndiff = bnvar_inv * dbnraw\n", "# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n", "# dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv\n", "# dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar\n", "# dbndiff += (2*bndiff) * dbndiff2\n", "# dhprebn = dbndiff.clone()\n", "# dbnmeani = (-dbndiff).sum(0)\n", "# dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)\n", "\n", "# calculate dhprebn given dhpreact (i.e. backprop through the batchnorm)\n", "# (you'll also need to use some of the variables from the forward pass up above)\n", "\n", "# -----------------\n", "# YOUR CODE HERE :)\n", "dhprebn = None # TODO. my solution is 1 (long) line\n", "# -----------------\n", "\n", "cmp('hprebn', dhprebn, hprebn) # I can only get approximate to be true, my maxdiff is 9e-10" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wPy8DhqB8PPz" }, "outputs": [], "source": [ "# Exercise 4: putting it all together!\n", "# Train the MLP neural net with your own backward pass\n", "\n", "# init\n", "n_embd = 10 # the dimensionality of the character embedding vectors\n", "n_hidden = 200 # the number of neurons in the hidden layer of the MLP\n", "\n", "g = torch.Generator().manual_seed(2147483647) # for reproducibility\n", "C = torch.randn((vocab_size, n_embd), generator=g)\n", "# Layer 1\n", "W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n", "b1 = torch.randn(n_hidden, generator=g) * 0.1\n", "# Layer 2\n", "W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n", "b2 = torch.randn(vocab_size, generator=g) * 0.1\n", "# BatchNorm parameters\n", "bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n", "bnbias = torch.randn((1, n_hidden))*0.1\n", "\n", "parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n", "print(sum(p.nelement() for p in parameters)) # number of parameters in total\n", "for p in parameters:\n", " p.requires_grad = True\n", "\n", "# same optimization as last time\n", "max_steps = 200000\n", "batch_size = 32\n", "n = batch_size # convenience\n", "lossi = []\n", "\n", "# use this context manager for efficiency once your backward pass is written (TODO)\n", "#with torch.no_grad():\n", "\n", "# kick off optimization\n", "for i in range(max_steps):\n", "\n", " # minibatch construct\n", " ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n", " Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y\n", "\n", " # forward pass\n", " emb = C[Xb] # embed the characters into vectors\n", " embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n", " # Linear layer\n", " hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n", " # BatchNorm layer\n", " # -------------------------------------------------------------\n", " bnmean = hprebn.mean(0, keepdim=True)\n", " bnvar = hprebn.var(0, keepdim=True, unbiased=True)\n", " bnvar_inv = (bnvar + 1e-5)**-0.5\n", " bnraw = (hprebn - bnmean) * bnvar_inv\n", " hpreact = bngain * bnraw + bnbias\n", " # -------------------------------------------------------------\n", " # Non-linearity\n", " h = torch.tanh(hpreact) # hidden layer\n", " logits = h @ W2 + b2 # output layer\n", " loss = F.cross_entropy(logits, Yb) # loss function\n", "\n", " # backward pass\n", " for p in parameters:\n", " p.grad = None\n", " loss.backward() # use this for correctness comparisons, delete it later!\n", "\n", " # manual backprop! #swole_doge_meme\n", " # -----------------\n", " # YOUR CODE HERE :)\n", " dC, dW1, db1, dW2, db2, dbngain, dbnbias = None, None, None, None, None, None, None\n", " grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]\n", " # -----------------\n", "\n", " # update\n", " lr = 0.1 if i < 100000 else 0.01 # step learning rate decay\n", " for p, grad in zip(parameters, grads):\n", " p.data += -lr * p.grad # old way of cheems doge (using PyTorch grad from .backward())\n", " #p.data += -lr * grad # new way of swole doge TODO: enable\n", "\n", " # track stats\n", " if i % 10000 == 0: # print every once in a while\n", " print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')\n", " lossi.append(loss.log10().item())\n", "\n", " if i >= 100: # TODO: delete early breaking when you're ready to train the full net\n", " break" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZEpI0hMW8PPz" }, "outputs": [], "source": [ "# useful for checking your gradients\n", "# for p,g in zip(parameters, grads):\n", "# cmp(str(tuple(p.shape)), g, p)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "KImLWNoh8PP0" }, "outputs": [], "source": [ "# calibrate the batch norm at the end of training\n", "\n", "with torch.no_grad():\n", " # pass the training set through\n", " emb = C[Xtr]\n", " embcat = emb.view(emb.shape[0], -1)\n", " hpreact = embcat @ W1 + b1\n", " # measure the mean/std over the entire training set\n", " bnmean = hpreact.mean(0, keepdim=True)\n", " bnvar = hpreact.var(0, keepdim=True, unbiased=True)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6aFnP_Zc8PP0" }, "outputs": [], "source": [ "# evaluate train and val loss\n", "\n", "@torch.no_grad() # this decorator disables gradient tracking\n", "def split_loss(split):\n", " x,y = {\n", " 'train': (Xtr, Ytr),\n", " 'val': (Xdev, Ydev),\n", " 'test': (Xte, Yte),\n", " }[split]\n", " emb = C[x] # (N, block_size, n_embd)\n", " embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)\n", " hpreact = embcat @ W1 + b1\n", " hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias\n", " h = torch.tanh(hpreact) # (N, n_hidden)\n", " logits = h @ W2 + b2 # (N, vocab_size)\n", " loss = F.cross_entropy(logits, y)\n", " print(split, loss.item())\n", "\n", "split_loss('train')\n", "split_loss('val')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "esWqmhyj8PP1" }, "outputs": [], "source": [ "# I achieved:\n", "# train 2.0718822479248047\n", "# val 2.1162495613098145" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xHeQNv3s8PP1" }, "outputs": [], "source": [ "# sample from the model\n", "g = torch.Generator().manual_seed(2147483647 + 10)\n", "\n", "for _ in range(20):\n", "\n", " out = []\n", " context = [0] * block_size # initialize with all ...\n", " while True:\n", " # forward pass\n", " emb = C[torch.tensor([context])] # (1,block_size,d)\n", " embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)\n", " hpreact = embcat @ W1 + b1\n", " hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias\n", " h = torch.tanh(hpreact) # (N, n_hidden)\n", " logits = h @ W2 + b2 # (N, vocab_size)\n", " # sample\n", " probs = F.softmax(logits, dim=1)\n", " ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n", " context = context[1:] + [ix]\n", " out.append(ix)\n", " if ix == 0:\n", " break\n", "\n", " print(''.join(itos[i] for i in out))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 0 }