Spaces:
Sleeping
Sleeping
File size: 105,328 Bytes
306b4ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 |
# Copyright (c) 2024, Tri Dao, Albert Gu.
"""We want triton==2.1.0 or 2.2.0 for this
"""
import math
from packaging import version
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from einops import rearrange, repeat
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
def init_to_zero(names):
return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None]
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
],
key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
)
@triton.jit
def _chunk_scan_fwd_kernel(
# Pointers to matrices
cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr,
# Matrix dimensions
chunk_size, hdim, dstate,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
stride_D_head,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
IS_TRITON_22: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Without the if (pid_c > -1), with Triton 2.1.0, I get
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
# With Triton 2.2.0, this works
if IS_TRITON_22 or pid_c > -1:
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)
prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)
if not HAS_SEQ_IDX:
scale_m = tl.exp(dA_cs_m)
else:
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc = tl.dot(C, prev_states) * scale_m[:, None]
else:
for k in range(0, dstate, BLOCK_SIZE_K):
C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0)
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc += tl.dot(C, prev_states)
C_ptrs += BLOCK_SIZE_K
prev_states_ptrs += BLOCK_SIZE_K
acc *= scale_m[:, None]
offs_k = tl.arange(0, BLOCK_SIZE_K)
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
for k in range(0, K_MAX, BLOCK_SIZE_K):
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :]))
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
mask = offs_m[:, None] >= k + offs_k[None, :]
cb = tl.where(mask, cb, 0.0)
cb = cb.to(x_ptr.dtype.element_ty)
x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0)
acc += tl.dot(cb, x)
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_D:
if D_HAS_HDIM:
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
acc += x_residual * D
if HAS_Z:
out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])
tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])
z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)
acc *= z * tl.sigmoid(z)
out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim)
tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
@triton.autotune(
configs=[
# triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4),
# triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8),
],
key=['chunk_size', 'hdim', 'dstate'],
)
@triton.jit
def _chunk_scan_fwd_kernel_wip(
# Pointers to matrices
cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr,
# Matrix dimensions
chunk_size, hdim, dstate,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate,
stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate,
stride_D_head,
# Meta-parameters
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_n = tl.program_id(axis=0)
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head
prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
offs_m = tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE)
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate)
B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate)
prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k)
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0)
# if pid_c == 0:
# if pid_b == 0:
# if pid_h == 0:
# tl.device_print("", prev_states)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
# dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
# scale_m = tl.exp(dA_cs_m)
# C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0)
# acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]
# cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32)
# cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))
# dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
# cb *= dt_m
# mask = offs_m[:, None] >= offs_m[None, :]
# cb = tl.where(mask, cb, 0.0)
# cb = cb.to(x_ptr.dtype.element_ty)
# x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0)
# acc += tl.dot(cb, x)
# if HAS_D:
# if D_HAS_HDIM:
# D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
# else:
# D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
# acc += x.to(tl.float32) * D
# tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M):
start_m = tl.multiple_of(start_m, BLOCK_SIZE_M)
dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1)
if not HAS_SEQ_IDX:
scale_m = tl.exp(dA_cs_m)
else:
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0)
acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None]
# cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32)
# cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :]))
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32)
# cb *= dt_m
# mask = offs_m[:, None] >= offs_m[None, :]
# cb = tl.where(mask, cb, 0.0)
# cb = cb.to(x_ptr.dtype.element_ty)
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0)
# acc += tl.dot(cb, x)
if HAS_D:
if D_HAS_HDIM:
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
acc += x.to(tl.float32) * D
# if HAS_Z:
# out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
# out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :])
# tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim))
# z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
# z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :])
# z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32)
# acc *= z * tl.sigmoid(z)
tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim))
# TODO: this is not correct, and quite a bit slower
if start_m + BLOCK_SIZE_M < chunk_size_limit:
# B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32)
B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0)
dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32)
# TODO: seq_idx
scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m
# B *= scale
B = B.to(x_ptr.dtype.element_ty)
tmp = tl.dot(B, x)
prev_states += tmp.to(prev_states.dtype)
C_ptrs += BLOCK_SIZE_M * stride_C_seqlen
B_ptrs += BLOCK_SIZE_M * stride_B_seqlen
cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k
x_ptrs += BLOCK_SIZE_M * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_M * stride_dt_csize
out_ptrs += BLOCK_SIZE_M * stride_out_seqlen
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32}),
triton.Config({'BLOCK_SIZE_M': 64}),
triton.Config({'BLOCK_SIZE_M': 128}),
triton.Config({'BLOCK_SIZE_M': 256}),
],
key=["chunk_size", "hdim"],
)
@triton.jit
def _chunk_scan_bwd_dz_kernel(
# Pointers to matrices
dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim,
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_D_head,
stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim,
stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim,
stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim,
stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
# Meta-parameters
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_DDACS: tl.constexpr,
RECOMPUTE_OUTPUT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head
out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head
dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head
if RECOMPUTE_OUTPUT:
outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head
if HAS_DDACS:
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
if HAS_D:
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_N)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim)
out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim)
dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim)
if RECOMPUTE_OUTPUT:
outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim)
if HAS_D:
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
if D_HAS_HDIM:
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
z_sigmoid = tl.sigmoid(z)
if RECOMPUTE_OUTPUT:
outz = out * z * z_sigmoid
tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid))
tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
dout *= z * z_sigmoid
tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
if HAS_D:
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
if D_HAS_HDIM:
dD = tl.sum(dout * x, axis=0)
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
dD = tl.sum(dout * x)
tl.store(dD_ptr, dD)
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
out -= x * D
if HAS_DDACS:
ddA_cs = tl.sum(dout * out, axis=1)
tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),
],
key=['hdim', 'dstate', 'chunk_size'],
)
@triton.jit
def _chunk_scan_bwd_dstates_kernel(
# Pointers to matrices
dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr,
# Matrix dimensions
hdim, dstate, chunk_size,
batch, seqlen, nchunks, nheads_ngroups_ratio,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate,
stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
# Meta-parameters
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen)
c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
if HAS_SEQ_IDX:
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale_k = tl.exp(dA_cs_k)
else:
seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1)
scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0)
dout = (dout * scale_k).to(dout_ptr.dtype.element_ty)
c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0)
acc += tl.dot(dout, c)
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
c_ptrs += BLOCK_SIZE_K * stride_c_seqlen
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
if HAS_SEQ_IDX:
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
out = acc.to(dprev_states_ptr.dtype.element_ty)
dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate)
tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate))
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
],
key=['chunk_size', 'dstate', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_dc_kernel(
# Pointers to matrices
dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,
dc_ptr, ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, dstate, hdim,
batch, seqlen, nheads, nheads_per_program, ngroups,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,
stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
# Meta-parameters
HAS_DDA_CS: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_sg = tl.program_id(axis=2)
pid_s = pid_sg // ngroups
pid_g = pid_sg - pid_s * ngroups
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head
dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split
prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
if HAS_DDA_CS:
C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
if HAS_DDA_CS:
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if HAS_DDA_CS:
c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
for h in range(nheads_iter):
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
prev_states = prev_states.to(dout_ptrs.dtype.element_ty)
dc = tl.dot(dout, prev_states)
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp(dA_cs_m)
else:
scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
dc *= scale[:, None]
if HAS_DDA_CS:
ddA_cs = tl.sum(dc * c, axis=1)
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
acc += dc
dout_ptrs += stride_dout_head
prev_states_ptrs += stride_prev_states_head
dA_cumsum_ptrs += stride_dA_cs_head
if HAS_DDA_CS:
ddA_cumsum_ptrs += stride_ddA_cs_head
# if HAS_SEQ_IDX:
# seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
# seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
# acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate)
tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate))
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])),
],
key=['chunk_size', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_dx_kernel(
# Pointers to matrices
x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr,
dx_ptr, ddt_ptr, # dD_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k,
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_D_head,
stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim,
stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
# stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize,
# Meta-parameters
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
# if HAS_D:
# dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k)
dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Idk why limiting K_MAX gives wrong results, is it a Triton bug?
# K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
K_MAX = chunk_size_limit
for k in range(0, K_MAX, BLOCK_SIZE_K):
# For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0)
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
# This will cause NaN in acc, and hence NaN in dx and ddt.
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
cb = tl.where(mask, cb, 0.0)
cb = cb.to(dout_ptr.dtype.element_ty)
acc += tl.dot(cb, dout)
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dt_ptrs = dt_ptr + offs_m * stride_dt_csize
dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
dx = acc * dt_m[:, None]
dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head
dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim)
if HAS_D:
dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
if D_HAS_HDIM:
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
dx += dout_res * D
tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim))
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
ddt = tl.sum(acc * x, axis=1)
ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize
tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size)
# if HAS_D:
# dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim)
# dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32)
# dD = tl.sum(x * dout, axis=0)
# tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N)
# Disabling HAS_DDA_CS for now since it's much slower
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),
],
key=['chunk_size', 'hdim'],
)
# @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)})
# @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32})
@triton.jit
def _chunk_scan_bwd_dcb_kernel(
# Pointers to matrices
x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr,
dcb_ptr, ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen, nheads, nheads_per_program, ngroups,
# Strides
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,
# Meta-parameters
HAS_DDA_CS: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_sg = tl.program_id(axis=2)
pid_s = pid_sg // ngroups
pid_g = pid_sg - pid_s * ngroups
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head
if HAS_DDA_CS:
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_n * stride_dt_csize
if HAS_DDA_CS:
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split
dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)
tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
return
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if HAS_DDA_CS:
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)
nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program)
for h in range(nheads_iter):
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
dcb = tl.dot(dout, x)
dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
dcb *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32)
dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
if HAS_DDA_CS:
tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet")
ddA_cs = dcb * cb
mask = offs_m[:, None] >= offs_n[None, :] + 1
ddA_cs = tl.where(mask, ddA_cs, 0.0)
ddA_cs = tl.cumsum(ddA_cs, axis=1)
ddA_cs = tl.where(mask, ddA_cs, 0.0)
ddA_cs = tl.sum(ddA_cs, axis=0)
tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)
tl.store(ddA_cumsum_ptr, 0.0)
acc += dcb
dout_ptrs += stride_dout_head
x_ptrs += stride_x_head
dt_ptrs += stride_dt_head
dA_cumsum_ptr += stride_dA_cs_head
if HAS_DDA_CS:
ddA_cumsum_ptr += stride_ddA_cs_head
ddA_cumsum_ptrs += stride_ddA_cs_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_SEQ_IDX:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
mask = offs_m[:, None] >= offs_n[None, :]
acc = tl.where(mask, acc, 0.0)
dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split
dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n)
tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))
# Not numerically stable and should not be used. Leaving here for reference.
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32}),
triton.Config({'BLOCK_SIZE_M': 64}),
triton.Config({'BLOCK_SIZE_M': 128}),
triton.Config({'BLOCK_SIZE_M': 256}),
],
key=["chunk_size", "hdim"],
)
@triton.jit
def _chunk_scan_bwd_ddAcs_unstable_kernel(
# Pointers to matrices
dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr,
ddA_cumsum_ptr, dD_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize,
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_D_head,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim,
# Meta-parameters
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
SUBTRACT_DDTDT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
if HAS_D:
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_N)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim)
out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim)
if HAS_D:
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim)
if D_HAS_HDIM:
dD_ptrs = dD_ptr + offs_n * stride_dD_hdim
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
if HAS_D:
x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32)
if D_HAS_HDIM:
dD = tl.sum(dout * x, axis=0)
tl.store(dD_ptrs, dD, mask=offs_n < hdim)
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32)
else:
dD = tl.sum(dout * x)
tl.store(dD_ptr, dD)
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
out -= x * D
ddA_cs = tl.sum(dout * out, axis=1)
if SUBTRACT_DDTDT:
dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
ddA_cs -= dt * ddt
tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size)
@triton.autotune(
configs=[
# triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4),
# triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8),
],
key=['chunk_size', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_ddAcs_stable_kernel_old(
# Pointers to matrices
x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,
ddAcs_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_n * stride_dt_csize
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
# Doing a matmul loop with cumsum later on will cause Triton to crash
# Instead we do just one big matmul
# acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# for k in range(0, hdim, BLOCK_SIZE_K):
# dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)
# x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)
# acc += tl.dot(dout, x)
# dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim
# x_ptrs += BLOCK_SIZE_K * stride_x_hdim
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
acc = tl.dot(dout, x)
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32)
acc *= cb
dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
acc *= dt_n
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32)
acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
mask = offs_m[:, None] >= offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
acc = tl.cumsum(acc, axis=1)
acc = tl.where(mask, acc, 0.0)
ddA_cs = tl.sum(acc, axis=0)
ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n
tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1)
tl.store(ddAcs_ptr, 0.0)
# offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64)
# offs_k = tl.arange(0, BLOCK_SIZE_K)
# dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
# x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
# dt_ptrs = dt_ptr + offs_n * stride_dt_csize
# cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
# chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
# chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M)
# rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
# dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
# dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
# ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m
# ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n
# for n in range(0, chunk_size_limit_n, 64):
# x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0)
# acc = tl.dot(dout, x)
# cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32)
# acc *= cb
# dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)
# acc *= dt_n
# dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32)
# acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
# mask = offs_m[:, None] >= offs_n[None, :] + 1 + n
# acc = tl.where(mask, acc, 0.0)
# acc = tl.cumsum(acc, axis=1)
# acc = tl.where(mask, acc, 0.0)
# ddA_cs = tl.sum(acc, axis=0)
# tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n)
# # tl.store(ddAcs_ptr, 0.0)
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4),
],
key=['chunk_size', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_ddAcs_stable_kernel(
# Pointers to matrices
x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr,
ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, hdim,
batch, seqlen, nheads_ngroups_ratio,
# Strides
stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim,
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_n * stride_dt_csize
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n)
ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n
tl.store(ddA_cumsum_ptr, 0.0)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32)
# Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower
lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M
# lo, hi = 0, chunk_size
for start_n in range(lo, hi, BLOCK_SIZE_N):
start_n = tl.multiple_of(start_n, BLOCK_SIZE_N)
# Doing a matmul loop with cumsum later on will cause Triton to crash
# Instead we do just one big matmul
# acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# for k in range(0, hdim, BLOCK_SIZE_K):
# dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0)
# x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0)
# acc += tl.dot(dout, x)
# dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim
# x_ptrs += BLOCK_SIZE_K * stride_x_hdim
# x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0)
x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0)
acc = tl.dot(dout, x)
dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
acc *= dt_n
# If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j]
cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32)
acc *= cb
dA_cs_n = tl.load(dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32)
acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :])
mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1
acc = tl.where(mask, acc, 0.0)
rowsum_new = rowsum + tl.sum(acc, axis=1)
acc = rowsum[:, None] + tl.cumsum(acc, axis=1)
rowsum = rowsum_new
acc = tl.where(mask, acc, 0.0)
ddA_cs = tl.sum(acc, axis=0)
tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1)
x_ptrs += BLOCK_SIZE_N * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_N * stride_dt_csize
cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n
ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n
# Need to zero out the rest, since we'll be summing the rows together
for start_n in range(hi, chunk_size, BLOCK_SIZE_N):
tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1)
ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])),
],
key=['chunk_size', 'dstate', 'hdim'],
)
@triton.jit
def _chunk_scan_bwd_ddAcs_prev_kernel(
# Pointers to matrices
dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr,
ddA_cumsum_ptr,
# Matrix dimensions
chunk_size, dstate, hdim,
batch, seqlen, nchunks, nheads_ngroups_ratio,
# Strides
stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim,
stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate,
stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate,
stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
stride_seq_idx_batch, stride_seq_idx_seqlen,
stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize,
# Meta-parameters
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head
prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head
C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head
ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim)
prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim)
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0)
prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0)
prev_states = prev_states.to(dout_ptrs.dtype.element_ty)
acc = tl.dot(dout, prev_states)
c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32)
ddA_cs = tl.sum(acc * c, axis=1)
dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp(dA_cs_m)
if HAS_SEQ_IDX:
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)
scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
ddA_cs *= scale
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize
tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size)
def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
# Allocates output.
out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
if z is not None:
out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
assert out_x.stride() == out.stride()
else:
out_x = None
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
batch * nchunks, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))
if z is not None else (0, 0, 0, 0))
_chunk_scan_fwd_kernel[grid](
cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D,
chunk_size, headdim, dstate,
batch, seqlen, nheads // ngroups,
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
z_strides[0], z_strides[1], z_strides[2], z_strides[3],
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
D.stride(0) if D is not None else 0,
True,
D is not None,
D.dim() == 2 if D is not None else True,
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
HAS_Z=z is not None,
HAS_SEQ_IDX=seq_idx is not None,
IS_TRITON_22=TRITON_22,
)
return out, out_x
def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
assert B.shape == C.shape
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
# Allocates output.
out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
if z is not None:
out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype)
assert out_x.stride() == out.stride()
else:
out_x = None
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3))
if z is not None else (0, 0, 0, 0))
_chunk_scan_fwd_kernel_wip[grid](
cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D,
chunk_size, headdim, dstate,
batch, seqlen, nheads // ngroups,
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
z_strides[0], z_strides[1], z_strides[2], z_strides[3],
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
B.stride(0), B.stride(1), B.stride(2), B.stride(3),
states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4),
D.stride(0) if D is not None else 0,
D is not None,
D.dim() == 2 if D is not None else True,
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
BLOCK_SIZE_M=128,
HAS_Z=z is not None,
HAS_SEQ_IDX=seq_idx is not None,
)
return out, out_x
def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False):
batch, seqlen, nheads, headdim = x.shape
assert z.shape == x.shape
assert out.shape == x.shape
assert dout.shape == out.shape
nchunks = math.ceil(seqlen / chunk_size)
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
assert D.stride(-1) == 1
if has_ddAcs:
ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32)
if D is not None:
BLOCK_SIZE_min = 32
dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
else:
dD = None
if dz is not None:
assert dz.shape == z.shape
else:
dz = torch.empty_like(z)
if recompute_output:
outz = torch.empty_like(x)
dout_x = torch.empty_like(dout)
dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
if D is not None else (0, 0, 0, 0, 0))
grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_dz_kernel[grid_dz](
dout, out, z, x, D, outz if recompute_output else None,
dz, dout_x, dD, ddA_cumsum if has_ddAcs else None,
chunk_size, headdim,
batch, seqlen,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
z.stride(0), z.stride(1), z.stride(2), z.stride(3),
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
D.stride(0) if D is not None else 0,
*((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)),
dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3),
dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3),
dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
*((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3))
if has_ddAcs else (0, 0, 0, 0)),
D is not None,
D.dim() == 2 if D is not None else True,
has_ddAcs,
BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),
RECOMPUTE_OUTPUT=recompute_output,
)
if D is not None:
BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
if D.dim() == 1:
dD = rearrange(dD, "h 1 -> h")
return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD)
return return_vals if not recompute_output else (*return_vals, outz)
def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None):
batch, seqlen, nheads, headdim = dout.shape
_, _, nchunks, chunk_size = dA_cumsum.shape
_, _, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
dtype = C.dtype if dtype is None else dtype
dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype)
grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
batch * nchunks, nheads)
with torch.cuda.device(C.device.index):
_chunk_scan_bwd_dstates_kernel[grid_dstates](
dout, C, dprev_states, dA_cumsum, seq_idx,
headdim, dstate, chunk_size,
batch, seqlen, nchunks, nheads // ngroups,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
HAS_SEQ_IDX=seq_idx is not None,
)
return dprev_states
def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1):
batch, nchunks, nheads, headdim, dstate = prev_states.shape
_, seqlen, _, _ = dout.shape
_, _, _, chunk_size = dA_cumsum.shape
assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert dout.shape == (batch, seqlen, nheads, headdim)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if C is not None:
assert C.shape == (batch, seqlen, ngroups, dstate)
C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3))
ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3))
else:
C_strides = (0, 0, 0, 0)
ddA_cumsum_prev = None
ddA_cumsum_prev_strides = (0, 0, 0, 0)
nheads_ngroups_ratio = nheads // ngroups
sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count
nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32)
grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
batch * nchunks, nsplits * ngroups)
with torch.cuda.device(dout.device.index):
_chunk_scan_bwd_dc_kernel[grid_dc](
dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev,
chunk_size, dstate, headdim,
batch, seqlen, nheads, nheads_per_program, ngroups,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),
*C_strides,
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4),
*ddA_cumsum_prev_strides,
HAS_DDA_CS=ddA_cumsum_prev is not None,
HAS_SEQ_IDX=seq_idx is not None,
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
)
dC = dC.sum(2)
return dC if C is None else (dC, ddA_cumsum_prev)
def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
assert dout.shape == x.shape
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if CB is not None:
assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4))
BLOCK_SIZE_M_min = 16
ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
chunk_size, device=x.device, dtype=torch.float32)
ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4))
else:
CB_strides = (0, 0, 0, 0, 0)
ddA_cumsum = None
ddA_cumsum_strides = (0, 0, 0, 0, 0)
nheads_ngroups_ratio = nheads // ngroups
sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1)
nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program)
dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32)
grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),
batch * nchunks, nsplits * ngroups)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_dcb_kernel[grid_dcb](
x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum,
chunk_size, headdim,
batch, seqlen, nheads, nheads_per_program, ngroups,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
*CB_strides,
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5),
*ddA_cumsum_strides,
HAS_DDA_CS=ddA_cumsum is not None,
HAS_SEQ_IDX=seq_idx is not None,
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
)
dcb = dcb.sum(2)
if ddA_cumsum is not None:
BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
return dcb if CB is None else (dcb, ddA_cumsum)
def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
ngroups = cb.shape[2]
assert nheads % ngroups == 0
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
assert dout.shape == x.shape
# if D is not None:
# BLOCK_SIZE_M_min = 32
# dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32)
# else:
# dD = None
dx = torch.empty_like(x)
ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']),
batch * nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_dx_kernel[grid_dx](
x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD,
chunk_size, headdim,
batch, seqlen, nheads // ngroups,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2),
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
D.stride(0) if D is not None else 0,
dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3),
ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
# dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0,
D is not None,
D.dim() == 2 if D is not None else True,
)
# if D is not None:
# BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"]
# n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
# dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
return dx, ddt.to(dtype=dt.dtype)
def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True):
"""Not numerically stable and should not be used. Leaving here for reference.
"""
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert ddt.shape == dt.shape
assert out.shape == x.shape
assert dout.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
ddA_cumsum = torch.empty_like(dt)
grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
if D is not None: # Triton gives wrong results if we write to the same location
BLOCK_SIZE_min = 32
dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads,
headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32)
else:
dD = None
dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4))
if D is not None else (0, 0, 0, 0, 0))
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs](
dout, out, dt, ddt, x, D, ddA_cumsum, dD,
chunk_size, headdim,
batch, seqlen,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
out.stride(0), out.stride(1), out.stride(2), out.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3),
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
D.stride(0) if D is not None else 0,
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3),
dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4],
D is not None,
D.dim() == 2 if D is not None else True,
subtract_ddtdt,
BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16),
)
if D is not None:
BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual
dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype)
if D.dim() == 1:
dD = rearrange(dD, "h 1 -> h")
return ddA_cumsum, dD
def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dout.shape == x.shape
assert dA_cumsum.shape == dt.shape
ngroups = cb.shape[2]
assert nheads % ngroups == 0
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
BLOCK_SIZE_M_min = 16
ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
chunk_size, device=x.device, dtype=torch.float32)
grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs](
x, dout, dt, dA_cumsum, cb, ddA_cumsum,
chunk_size, headdim,
batch, seqlen, nheads // ngroups,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16),
)
BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
return ddA_cumsum
def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dout.shape == x.shape
assert dA_cumsum.shape == dt.shape
ngroups = cb.shape[2]
assert nheads % ngroups == 0
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
BLOCK_SIZE_M_min = 32
ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min),
chunk_size, device=x.device, dtype=torch.float32)
grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs](
x, dout, dt, dA_cumsum, cb, ddA_cumsum,
chunk_size, headdim,
batch, seqlen, nheads // ngroups,
x.stride(0), x.stride(1), x.stride(2), x.stride(3),
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4),
ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4),
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
)
BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"]
n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual
ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3)
return ddA_cumsum
def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None):
batch, nchunks, nheads, headdim, dstate = prev_states.shape
_, seqlen, _, _ = dout.shape
_, _, _, chunk_size = dA_cumsum.shape
assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert dout.shape == (batch, seqlen, nheads, headdim)
ngroups = C.shape[2]
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32)
grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']),
batch * nchunks, nheads)
with torch.cuda.device(dout.device.index):
_chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs](
dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev,
chunk_size, dstate, headdim,
batch, seqlen, nchunks, nheads // ngroups,
dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3),
prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4),
C.stride(0), C.stride(1), C.stride(2), C.stride(3),
dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3),
HAS_SEQ_IDX=seq_idx is not None,
BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16),
)
return ddA_cumsum_prev
class ChunkScanFn(torch.autograd.Function):
@staticmethod
def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
# Check constraints.
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert B.shape == (batch, seqlen, ngroups, dstate)
_, _, nchunks, chunk_size = dt.shape
assert seqlen == nchunks * chunk_size
assert C.shape == B.shape
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous
x = x.contiguous()
if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
CB = _bmm_chunk_fwd(C, B, chunk_size)
out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z)
ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z)
return out
@staticmethod
def backward(ctx, dout):
if dout.stride(-1) != 1:
dout = dout.contiguous()
out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = B.shape
assert dout.shape == (batch, seqlen, nheads, headdim)
if z is not None:
dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D)
else:
dz = None
dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype)
dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups)
dC = dC.to(C.dtype)
dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups)
dCB = dCB.to(CB.dtype)
dB = _bmm_chunk_bwd(C, dCB)
dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC)
dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D)
# Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D.
# ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt
if z is not None:
ddA_cumsum -= ddt * dt
else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz
ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D)
ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype)
return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz
def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
"""
prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1.
Argument:
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
x: (batch, seqlen, nheads, headdim)
dt: (batch, nheads, nchunks, chunk_size)
dA_cumsum: (batch, nheads, nchunks, chunk_size)
prev_states: (batch, nchunks, nheads, headdim, dstate)
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
Return:
out: (batch, seqlen, nheads, headdim)
"""
return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z)
def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None):
"""
Argument:
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
x: (batch, seqlen, nheads, headdim)
dt: (batch, nheads, nchunks, chunk_size)
dA_cumsum: (batch, nheads, nchunks, chunk_size)
prev_states: (batch, nchunks, nheads, headdim, dstate)
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
Return:
out: (batch, seqlen, nheads, headdim)
"""
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert B.shape == (batch, seqlen, ngroups, dstate)
_, _, nchunks, chunk_size = dt.shape
assert seqlen == nchunks * chunk_size
assert C.shape == B.shape
B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups)
C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups)
CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
rearrange(B, "b (c s) h n -> b c s h n", c=nchunks))
# (batch, nheads, nchunks, chunksize, chunksize)
dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :]
decay = torch.exp(dt_segment_sum)
scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s")
causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0)
scores_decay = scores_decay.masked_fill(~causal_mask, 0)
out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype),
rearrange(x, "b (c s) h p -> b c s h p", c=nchunks))
state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1"))
out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks),
prev_states.to(C.dtype)) * state_decay_out
out = out + out_prev
out = rearrange(out, "b c l h p -> b (c l) h p")
if D is not None:
if D.dim() == 1:
D = rearrange(D, "h -> h 1")
out = out + x * D
return out if z is None else out * F.silu(z)
|