File size: 23,036 Bytes
a74ead5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "rToK0Tku8PPn"
},
"source": [
"## makemore: becoming a backprop ninja"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8sFElPqq8PPp"
},
"outputs": [],
"source": [
"# there no change change in the first several cells from last lecture"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ChBbac4y8PPq"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt # for making figures\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"source": [
"# download the names.txt file from github\n",
"!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt"
],
"metadata": {
"id": "x6GhEWW18aCS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "klmu3ZG08PPr"
},
"outputs": [],
"source": [
"# read in all the words\n",
"words = open('names.txt', 'r').read().splitlines()\n",
"print(len(words))\n",
"print(max(len(w) for w in words))\n",
"print(words[:8])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BCQomLE_8PPs"
},
"outputs": [],
"source": [
"# build the vocabulary of characters and mappings to/from integers\n",
"chars = sorted(list(set(''.join(words))))\n",
"stoi = {s:i+1 for i,s in enumerate(chars)}\n",
"stoi['.'] = 0\n",
"itos = {i:s for s,i in stoi.items()}\n",
"vocab_size = len(itos)\n",
"print(itos)\n",
"print(vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "V_zt2QHr8PPs"
},
"outputs": [],
"source": [
"# build the dataset\n",
"block_size = 3 # context length: how many characters do we take to predict the next one?\n",
"\n",
"def build_dataset(words):\n",
" X, Y = [], []\n",
"\n",
" for w in words:\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" context = context[1:] + [ix] # crop and append\n",
"\n",
" X = torch.tensor(X)\n",
" Y = torch.tensor(Y)\n",
" print(X.shape, Y.shape)\n",
" return X, Y\n",
"\n",
"import random\n",
"random.seed(42)\n",
"random.shuffle(words)\n",
"n1 = int(0.8*len(words))\n",
"n2 = int(0.9*len(words))\n",
"\n",
"Xtr, Ytr = build_dataset(words[:n1]) # 80%\n",
"Xdev, Ydev = build_dataset(words[n1:n2]) # 10%\n",
"Xte, Yte = build_dataset(words[n2:]) # 10%"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eg20-vsg8PPt"
},
"outputs": [],
"source": [
"# ok biolerplate done, now we get to the action:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "MJPU8HT08PPu"
},
"outputs": [],
"source": [
"# utility function we will use later when comparing manual gradients to PyTorch gradients\n",
"def cmp(s, dt, t):\n",
" ex = torch.all(dt == t.grad).item()\n",
" app = torch.allclose(dt, t.grad)\n",
" maxdiff = (dt - t.grad).abs().max().item()\n",
" print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZlFLjQyT8PPu"
},
"outputs": [],
"source": [
"n_embd = 10 # the dimensionality of the character embedding vectors\n",
"n_hidden = 64 # the number of neurons in the hidden layer of the MLP\n",
"\n",
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n",
"C = torch.randn((vocab_size, n_embd), generator=g)\n",
"# Layer 1\n",
"W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n",
"b1 = torch.randn(n_hidden, generator=g) * 0.1 # using b1 just for fun, it's useless because of BN\n",
"# Layer 2\n",
"W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n",
"b2 = torch.randn(vocab_size, generator=g) * 0.1\n",
"# BatchNorm parameters\n",
"bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n",
"bnbias = torch.randn((1, n_hidden))*0.1\n",
"\n",
"# Note: I am initializating many of these parameters in non-standard ways\n",
"# because sometimes initializating with e.g. all zeros could mask an incorrect\n",
"# implementation of the backward pass.\n",
"\n",
"parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n",
"print(sum(p.nelement() for p in parameters)) # number of parameters in total\n",
"for p in parameters:\n",
" p.requires_grad = True"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QY-y96Y48PPv"
},
"outputs": [],
"source": [
"batch_size = 32\n",
"n = batch_size # a shorter variable also, for convenience\n",
"# construct a minibatch\n",
"ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n",
"Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8ofj1s6d8PPv"
},
"outputs": [],
"source": [
"# forward pass, \"chunkated\" into smaller steps that are possible to backward one at a time\n",
"\n",
"emb = C[Xb] # embed the characters into vectors\n",
"embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n",
"# Linear layer 1\n",
"hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n",
"# BatchNorm layer\n",
"bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n",
"bndiff = hprebn - bnmeani\n",
"bndiff2 = bndiff**2\n",
"bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n",
"bnvar_inv = (bnvar + 1e-5)**-0.5\n",
"bnraw = bndiff * bnvar_inv\n",
"hpreact = bngain * bnraw + bnbias\n",
"# Non-linearity\n",
"h = torch.tanh(hpreact) # hidden layer\n",
"# Linear layer 2\n",
"logits = h @ W2 + b2 # output layer\n",
"# cross entropy loss (same as F.cross_entropy(logits, Yb))\n",
"logit_maxes = logits.max(1, keepdim=True).values\n",
"norm_logits = logits - logit_maxes # subtract max for numerical stability\n",
"counts = norm_logits.exp()\n",
"counts_sum = counts.sum(1, keepdims=True)\n",
"counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n",
"probs = counts * counts_sum_inv\n",
"logprobs = probs.log()\n",
"loss = -logprobs[range(n), Yb].mean()\n",
"\n",
"# PyTorch backward pass\n",
"for p in parameters:\n",
" p.grad = None\n",
"for t in [logprobs, probs, counts, counts_sum, counts_sum_inv, # afaik there is no cleaner way\n",
" norm_logits, logit_maxes, logits, h, hpreact, bnraw,\n",
" bnvar_inv, bnvar, bndiff2, bndiff, hprebn, bnmeani,\n",
" embcat, emb]:\n",
" t.retain_grad()\n",
"loss.backward()\n",
"loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "mO-8aqxK8PPw"
},
"outputs": [],
"source": [
"# Exercise 1: backprop through the whole thing manually,\n",
"# backpropagating through exactly all of the variables\n",
"# as they are defined in the forward pass above, one by one\n",
"\n",
"# -----------------\n",
"# YOUR CODE HERE :)\n",
"# -----------------\n",
"\n",
"# cmp('logprobs', dlogprobs, logprobs)\n",
"# cmp('probs', dprobs, probs)\n",
"# cmp('counts_sum_inv', dcounts_sum_inv, counts_sum_inv)\n",
"# cmp('counts_sum', dcounts_sum, counts_sum)\n",
"# cmp('counts', dcounts, counts)\n",
"# cmp('norm_logits', dnorm_logits, norm_logits)\n",
"# cmp('logit_maxes', dlogit_maxes, logit_maxes)\n",
"# cmp('logits', dlogits, logits)\n",
"# cmp('h', dh, h)\n",
"# cmp('W2', dW2, W2)\n",
"# cmp('b2', db2, b2)\n",
"# cmp('hpreact', dhpreact, hpreact)\n",
"# cmp('bngain', dbngain, bngain)\n",
"# cmp('bnbias', dbnbias, bnbias)\n",
"# cmp('bnraw', dbnraw, bnraw)\n",
"# cmp('bnvar_inv', dbnvar_inv, bnvar_inv)\n",
"# cmp('bnvar', dbnvar, bnvar)\n",
"# cmp('bndiff2', dbndiff2, bndiff2)\n",
"# cmp('bndiff', dbndiff, bndiff)\n",
"# cmp('bnmeani', dbnmeani, bnmeani)\n",
"# cmp('hprebn', dhprebn, hprebn)\n",
"# cmp('embcat', dembcat, embcat)\n",
"# cmp('W1', dW1, W1)\n",
"# cmp('b1', db1, b1)\n",
"# cmp('emb', demb, emb)\n",
"# cmp('C', dC, C)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ebLtYji_8PPw"
},
"outputs": [],
"source": [
"# Exercise 2: backprop through cross_entropy but all in one go\n",
"# to complete this challenge look at the mathematical expression of the loss,\n",
"# take the derivative, simplify the expression, and just write it out\n",
"\n",
"# forward pass\n",
"\n",
"# before:\n",
"# logit_maxes = logits.max(1, keepdim=True).values\n",
"# norm_logits = logits - logit_maxes # subtract max for numerical stability\n",
"# counts = norm_logits.exp()\n",
"# counts_sum = counts.sum(1, keepdims=True)\n",
"# counts_sum_inv = counts_sum**-1 # if I use (1.0 / counts_sum) instead then I can't get backprop to be bit exact...\n",
"# probs = counts * counts_sum_inv\n",
"# logprobs = probs.log()\n",
"# loss = -logprobs[range(n), Yb].mean()\n",
"\n",
"# now:\n",
"loss_fast = F.cross_entropy(logits, Yb)\n",
"print(loss_fast.item(), 'diff:', (loss_fast - loss).item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-gCXbB4C8PPx"
},
"outputs": [],
"source": [
"# backward pass\n",
"\n",
"# -----------------\n",
"# YOUR CODE HERE :)\n",
"dlogits = None # TODO. my solution is 3 lines\n",
"# -----------------\n",
"\n",
"#cmp('logits', dlogits, logits) # I can only get approximate to be true, my maxdiff is 6e-9"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hd-MkhB68PPy"
},
"outputs": [],
"source": [
"# Exercise 3: backprop through batchnorm but all in one go\n",
"# to complete this challenge look at the mathematical expression of the output of batchnorm,\n",
"# take the derivative w.r.t. its input, simplify the expression, and just write it out\n",
"# BatchNorm paper: https://arxiv.org/abs/1502.03167\n",
"\n",
"# forward pass\n",
"\n",
"# before:\n",
"# bnmeani = 1/n*hprebn.sum(0, keepdim=True)\n",
"# bndiff = hprebn - bnmeani\n",
"# bndiff2 = bndiff**2\n",
"# bnvar = 1/(n-1)*(bndiff2).sum(0, keepdim=True) # note: Bessel's correction (dividing by n-1, not n)\n",
"# bnvar_inv = (bnvar + 1e-5)**-0.5\n",
"# bnraw = bndiff * bnvar_inv\n",
"# hpreact = bngain * bnraw + bnbias\n",
"\n",
"# now:\n",
"hpreact_fast = bngain * (hprebn - hprebn.mean(0, keepdim=True)) / torch.sqrt(hprebn.var(0, keepdim=True, unbiased=True) + 1e-5) + bnbias\n",
"print('max diff:', (hpreact_fast - hpreact).abs().max())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "POdeZSKT8PPy"
},
"outputs": [],
"source": [
"# backward pass\n",
"\n",
"# before we had:\n",
"# dbnraw = bngain * dhpreact\n",
"# dbndiff = bnvar_inv * dbnraw\n",
"# dbnvar_inv = (bndiff * dbnraw).sum(0, keepdim=True)\n",
"# dbnvar = (-0.5*(bnvar + 1e-5)**-1.5) * dbnvar_inv\n",
"# dbndiff2 = (1.0/(n-1))*torch.ones_like(bndiff2) * dbnvar\n",
"# dbndiff += (2*bndiff) * dbndiff2\n",
"# dhprebn = dbndiff.clone()\n",
"# dbnmeani = (-dbndiff).sum(0)\n",
"# dhprebn += 1.0/n * (torch.ones_like(hprebn) * dbnmeani)\n",
"\n",
"# calculate dhprebn given dhpreact (i.e. backprop through the batchnorm)\n",
"# (you'll also need to use some of the variables from the forward pass up above)\n",
"\n",
"# -----------------\n",
"# YOUR CODE HERE :)\n",
"dhprebn = None # TODO. my solution is 1 (long) line\n",
"# -----------------\n",
"\n",
"cmp('hprebn', dhprebn, hprebn) # I can only get approximate to be true, my maxdiff is 9e-10"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wPy8DhqB8PPz"
},
"outputs": [],
"source": [
"# Exercise 4: putting it all together!\n",
"# Train the MLP neural net with your own backward pass\n",
"\n",
"# init\n",
"n_embd = 10 # the dimensionality of the character embedding vectors\n",
"n_hidden = 200 # the number of neurons in the hidden layer of the MLP\n",
"\n",
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n",
"C = torch.randn((vocab_size, n_embd), generator=g)\n",
"# Layer 1\n",
"W1 = torch.randn((n_embd * block_size, n_hidden), generator=g) * (5/3)/((n_embd * block_size)**0.5)\n",
"b1 = torch.randn(n_hidden, generator=g) * 0.1\n",
"# Layer 2\n",
"W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.1\n",
"b2 = torch.randn(vocab_size, generator=g) * 0.1\n",
"# BatchNorm parameters\n",
"bngain = torch.randn((1, n_hidden))*0.1 + 1.0\n",
"bnbias = torch.randn((1, n_hidden))*0.1\n",
"\n",
"parameters = [C, W1, b1, W2, b2, bngain, bnbias]\n",
"print(sum(p.nelement() for p in parameters)) # number of parameters in total\n",
"for p in parameters:\n",
" p.requires_grad = True\n",
"\n",
"# same optimization as last time\n",
"max_steps = 200000\n",
"batch_size = 32\n",
"n = batch_size # convenience\n",
"lossi = []\n",
"\n",
"# use this context manager for efficiency once your backward pass is written (TODO)\n",
"#with torch.no_grad():\n",
"\n",
"# kick off optimization\n",
"for i in range(max_steps):\n",
"\n",
" # minibatch construct\n",
" ix = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)\n",
" Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y\n",
"\n",
" # forward pass\n",
" emb = C[Xb] # embed the characters into vectors\n",
" embcat = emb.view(emb.shape[0], -1) # concatenate the vectors\n",
" # Linear layer\n",
" hprebn = embcat @ W1 + b1 # hidden layer pre-activation\n",
" # BatchNorm layer\n",
" # -------------------------------------------------------------\n",
" bnmean = hprebn.mean(0, keepdim=True)\n",
" bnvar = hprebn.var(0, keepdim=True, unbiased=True)\n",
" bnvar_inv = (bnvar + 1e-5)**-0.5\n",
" bnraw = (hprebn - bnmean) * bnvar_inv\n",
" hpreact = bngain * bnraw + bnbias\n",
" # -------------------------------------------------------------\n",
" # Non-linearity\n",
" h = torch.tanh(hpreact) # hidden layer\n",
" logits = h @ W2 + b2 # output layer\n",
" loss = F.cross_entropy(logits, Yb) # loss function\n",
"\n",
" # backward pass\n",
" for p in parameters:\n",
" p.grad = None\n",
" loss.backward() # use this for correctness comparisons, delete it later!\n",
"\n",
" # manual backprop! #swole_doge_meme\n",
" # -----------------\n",
" # YOUR CODE HERE :)\n",
" dC, dW1, db1, dW2, db2, dbngain, dbnbias = None, None, None, None, None, None, None\n",
" grads = [dC, dW1, db1, dW2, db2, dbngain, dbnbias]\n",
" # -----------------\n",
"\n",
" # update\n",
" lr = 0.1 if i < 100000 else 0.01 # step learning rate decay\n",
" for p, grad in zip(parameters, grads):\n",
" p.data += -lr * p.grad # old way of cheems doge (using PyTorch grad from .backward())\n",
" #p.data += -lr * grad # new way of swole doge TODO: enable\n",
"\n",
" # track stats\n",
" if i % 10000 == 0: # print every once in a while\n",
" print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')\n",
" lossi.append(loss.log10().item())\n",
"\n",
" if i >= 100: # TODO: delete early breaking when you're ready to train the full net\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZEpI0hMW8PPz"
},
"outputs": [],
"source": [
"# useful for checking your gradients\n",
"# for p,g in zip(parameters, grads):\n",
"# cmp(str(tuple(p.shape)), g, p)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KImLWNoh8PP0"
},
"outputs": [],
"source": [
"# calibrate the batch norm at the end of training\n",
"\n",
"with torch.no_grad():\n",
" # pass the training set through\n",
" emb = C[Xtr]\n",
" embcat = emb.view(emb.shape[0], -1)\n",
" hpreact = embcat @ W1 + b1\n",
" # measure the mean/std over the entire training set\n",
" bnmean = hpreact.mean(0, keepdim=True)\n",
" bnvar = hpreact.var(0, keepdim=True, unbiased=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6aFnP_Zc8PP0"
},
"outputs": [],
"source": [
"# evaluate train and val loss\n",
"\n",
"@torch.no_grad() # this decorator disables gradient tracking\n",
"def split_loss(split):\n",
" x,y = {\n",
" 'train': (Xtr, Ytr),\n",
" 'val': (Xdev, Ydev),\n",
" 'test': (Xte, Yte),\n",
" }[split]\n",
" emb = C[x] # (N, block_size, n_embd)\n",
" embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)\n",
" hpreact = embcat @ W1 + b1\n",
" hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias\n",
" h = torch.tanh(hpreact) # (N, n_hidden)\n",
" logits = h @ W2 + b2 # (N, vocab_size)\n",
" loss = F.cross_entropy(logits, y)\n",
" print(split, loss.item())\n",
"\n",
"split_loss('train')\n",
"split_loss('val')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "esWqmhyj8PP1"
},
"outputs": [],
"source": [
"# I achieved:\n",
"# train 2.0718822479248047\n",
"# val 2.1162495613098145"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xHeQNv3s8PP1"
},
"outputs": [],
"source": [
"# sample from the model\n",
"g = torch.Generator().manual_seed(2147483647 + 10)\n",
"\n",
"for _ in range(20):\n",
"\n",
" out = []\n",
" context = [0] * block_size # initialize with all ...\n",
" while True:\n",
" # forward pass\n",
" emb = C[torch.tensor([context])] # (1,block_size,d)\n",
" embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size * n_embd)\n",
" hpreact = embcat @ W1 + b1\n",
" hpreact = bngain * (hpreact - bnmean) * (bnvar + 1e-5)**-0.5 + bnbias\n",
" h = torch.tanh(hpreact) # (N, n_hidden)\n",
" logits = h @ W2 + b2 # (N, vocab_size)\n",
" # sample\n",
" probs = F.softmax(logits, dim=1)\n",
" ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n",
" context = context[1:] + [ix]\n",
" out.append(ix)\n",
" if ix == 0:\n",
" break\n",
"\n",
" print(''.join(itos[i] for i in out))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"colab": {
"provenance": []
}
},
"nbformat": 4,
"nbformat_minor": 0
} |