File size: 96,765 Bytes
f0e5081 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 |
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Transformer."""
from contextlib import nullcontext
import math
import numpy as np
import torch
import torch.nn.functional as F
from typing import Optional
from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
from .module import MegatronModule
from .rmsnorm import RMSNorm
from megatron.core import parallel_state, tensor_parallel, mpu
from megatron.core.enums import ModelType
from megatron.model import LayerNorm
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
import deepspeed
from deepspeed.moe.layer import MoE
from deepspeed.accelerator import get_accelerator
try:
from deepspeed.sequence.layer import DistributedAttention
dist_attn_supported = True
except ImportError:
dist_attn_supported = False
try:
from einops import rearrange
except ImportError:
rearrange = None
try:
# FlashAttention (1.x)
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.flash_attn_triton import flash_attn_func
except ImportError:
flash_attn_unpadded_func = None
flash_attn_func = None
try:
# FlashAttention-2
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
FlashAttentionBuilder = get_accelerator().get_op_builder("FlashAttentionBuilder")
flash_attn_builder = None
try:
flash_attn_builder = FlashAttentionBuilder().load()
except (TypeError, ValueError):
flash_attn_builder = None
try:
from apex.normalization import MixedFusedRMSNorm
except ImportError:
MixedFusedRMSNorm = RMSNorm
try:
import habana_frameworks.torch.hpu as hthpu
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
hthpu = None
FusedSDPA = None
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
# hidden_state: [s, b, h]
shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
output = hidden_state.div(keep_prob) * random_tensor
return output
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def __init__(self, config, moe=False, enable_expert_tensor_parallelism=False):
super(ParallelMLP, self).__init__()
args = get_args()
self.add_bias = config.add_bias_linear
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=self.add_bias,
gather_output=False,
skip_bias_add=True,
moe=moe,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism
)
self.bias_gelu_fusion = False
self.activation_func = None
self.swiglu = args.swiglu
if args.openai_gelu:
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.swiglu:
def swiglu(x):
x = torch.chunk(x, 2, dim=-1)
return F.silu(x[0]) * x[1]
self.activation_func = swiglu
elif args.squared_relu:
def squared_relu(x):
return torch.pow(F.relu(x), 2)
self.activation_func = squared_relu
else:
self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=self.add_bias,
input_is_parallel=True,
moe=moe,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism
)
def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
assert self.add_bias is True
# DeepSpeed FLOPS profiler temporarily substitues functions like F.gelu to calculate the throughput
assert hasattr(self, "__flops__") or self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, config):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(config.hidden_size, args.num_experts_switch)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts_switch):
self.experts.append(ParallelMLP(config))
def forward(self, hidden_states):
# hidden_states: [s, b, h]
s = hidden_states.size(0)
b = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [s, b, h] to [s*b, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
max_ind = max_ind.view(-1) # [s*b]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(s, b, h)
output_bias_total = output_bias_total.view(s, b, h)
return output_total, output_bias_total
class CoreAttention(MegatronModule):
def __init__(self, layer_number, config,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
self.fp16 = config.fp16
self.bf16 = config.bf16
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number+1)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = config.sequence_parallel
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
seq_parallel_world_size = 1
if parallel_state.sequence_parallel_is_initialized():
seq_parallel_world_size = parallel_state.get_sequence_parallel_world_size()
world_size = seq_parallel_world_size if seq_parallel_world_size > 1 else parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
config.num_attention_heads, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
config.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(config.attention_dropout) if config.attention_dropout != 0 else None
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
if self.attention_dropout is not None:
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class FlashSelfAttention(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_func is not None or flash_attn_varlen_func is not None or flash_attn_builder is not None, \
('Please install FlashAttention first, e.g., with pip install flash-attn or implement your own flash attention')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
# Use FlashAttention-2 when args.use_flash_attn_v2 is True
args = get_args()
self.flash_attn_func = flash_attn_varlen_func if args.use_flash_attn_v2 else flash_attn_unpadded_func
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v)))
assert all((get_accelerator().on_accelerator(i) for i in (q, k, v)))
# if get_accelerator().device_name() == 'cuda':
# assert all((i.is_cuda for i in (q,k,v)))
# else:
# assert all((i.is_xpu for i in (q,k,v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
if get_accelerator().device_name() == 'cuda':
# goes for cuda device
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
device=q.device)
else:
# goes for other device
q, k, v = [rearrange(x, 'b s h d -> b h s d').contiguous() for x in [q, k, v]]
if self.training:
# during training q,k,v always have same seqlen
assert seqlen_k == seqlen_q
is_causal = self.causal
cu_seqlens_k = cu_seqlens_q if get_accelerator().device_name() == 'cuda' else None
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device) if get_accelerator().device_name() == 'cuda' else None
dropout_p = 0
output = self.flash_attn_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal
) if get_accelerator().device_name() == 'cuda' else flash_attn_builder.flash_attn_func(
q, k, v, self.dropout_p, self.softmax_scale, is_causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) if get_accelerator().device_name() == 'cuda' else rearrange(
output, 'b h s d -> b s h d').contiguous()
return output
class FlashSelfAttentionTriton(torch.nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_func is not None, ('Triton version of FlashAttention is not installed.')
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, q, k, v):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert q.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda
q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (q, k, v)]
output = flash_attn_func(q, k, v, None, self.causal)
output = rearrange(output, 'b s h d -> s b (h d)').contiguous()
return output
class HabanaFlashSelfAttention(MegatronModule):
def __init__(self, config, attn_mask_type=AttnMaskType.padding):
super(HabanaFlashSelfAttention, self).__init__()
assert hthpu is not None, "Failed to import hthpu"
assert FusedSDPA is not None, "Failed to import FusedSDPA"
self.attn_mask_type = attn_mask_type
self.attention_dropout_p = config.attention_dropout
self.use_fused_sdpa = config.use_fused_sdpa
self.use_fused_sdpa_with_recompute = config.use_fused_sdpa_with_recompute
# Per attention head and per partition values.
seq_parallel_world_size = 1
if parallel_state.sequence_parallel_is_initialized():
seq_parallel_world_size = parallel_state.get_sequence_parallel_world_size()
world_size = seq_parallel_world_size if seq_parallel_world_size > 1 else parallel_state.get_tensor_model_parallel_world_size()
projection_size = config.kv_channels * config.num_attention_heads
self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
# [sq, b, np, hn] -> [b, np, sq, hn]
q, k, v = [x.transpose(0, 1).transpose(1, 2) for x in [query_layer, key_layer, value_layer]]
causal = True
scale = None
attn_mask = None
with hthpu.sdp_kernel(enable_recompute=self.use_fused_sdpa_with_recompute):
context_layer = FusedSDPA.apply(q, k, v, attn_mask, self.attention_dropout_p, causal, scale)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, config, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = config.params_dtype
self.sequence_parallel = config.sequence_parallel
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
# TODO - Remove self.attention_dropout usage when SW-172239 is solved
self.attention_dropout = config.attention_dropout
self.use_gqa = (self.num_attention_heads != self.num_key_value_heads)
self.use_flash_attn = (args.use_flash_attn_v1 or args.use_flash_attn_triton or args.use_flash_attn_v2) \
and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal
self.use_flash_attn_triton = args.use_flash_attn_triton
self.use_fused_sdpa = config.use_fused_sdpa
if self.use_flash_attn:
global flash_attn_builder
try:
flash_attn_builder = FlashAttentionBuilder().load()
except TypeError:
flash_attn_builder = None
if args.use_flash_attn_v1:
assert flash_attn_unpadded_func != None or flash_attn_builder != None, ("Cannot import FlashAttention v1 "
"and Cannot find FlashAttention Builder")
if args.use_flash_attn_v2:
assert flash_attn_varlen_func != None, "Cannot import FlashAttention v2 "
if args.use_flash_attn_triton:
assert flash_attn_func != None, "Cannot import FlashAttention triton "
assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports '
'self-attention for now')
assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only '
'supports causal mask for now')
if rearrange is None:
raise ImportError('einops is not installed, please install with pip install einops')
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
config.num_attention_heads, world_size)
# Per GQA head and per partition values
self.num_key_value_heads_per_partition = core.utils.divide(
config.num_key_value_heads, world_size)
self.num_key_value_groups = core.utils.divide(
config.num_attention_heads, config.num_key_value_heads)
kv_projection_size = config.kv_channels * config.num_key_value_heads
assert self.hidden_size_per_attention_head == core.utils.divide(
kv_projection_size, config.num_key_value_heads)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
self.key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
2 * projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
# Currently FlashAttention only works with causal mask
if self.use_flash_attn_triton:
local_attn = FlashSelfAttentionTriton(causal=True, attention_dropout=args.attention_dropout)
elif self.use_flash_attn:
local_attn = FlashSelfAttention(causal=True, attention_dropout=config.attention_dropout)
elif self.use_fused_sdpa:
local_attn = HabanaFlashSelfAttention(config, self.attn_mask_type)
else:
local_attn = CoreAttention(self.layer_number, config, self.attn_mask_type)
self.enable_ds_sequence_parallel = parallel_state.get_sequence_parallel_world_size() > 1 \
or args.force_ds_sequence_parallel
if self.enable_ds_sequence_parallel:
assert dist_attn_supported, 'Distributed attention is not supported in this DeepSpeed version'
assert args.num_attention_heads % parallel_state.get_sequence_parallel_world_size() == 0
self.dist_attn = DistributedAttention(local_attn, parallel_state.get_sequence_parallel_group())
else:
if self.use_flash_attn:
self.core_attention_flash = local_attn
else:
self.core_attention = local_attn
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=args.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True)
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask,
rotary_pos_emb=None):
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(query_layer, key_layer,
value_layer, attention_mask)
return output_
q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \
else rotary_pos_emb
hidden_states = tensor_parallel.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask,
q_pos_emb, k_pos_emb)
return hidden_states
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=get_accelerator().current_device_name())
def repeat_kv(self, hidden_states, n_rep):
slen, batch, num_key_value_heads_per_partition, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
elif num_key_value_heads_per_partition == 1:
# If no of KV heads is 1 then just perform expand operation
# instead of unsqueeze, expand and reshape to match query states.
return hidden_states.expand(slen, batch, n_rep, head_dim)
else:
hidden_states = hidden_states[:, :, :, None, :].expand(
slen, batch, num_key_value_heads_per_partition, n_rep, head_dim)
return hidden_states.reshape(slen, batch,
num_key_value_heads_per_partition * n_rep,
head_dim)
def split_tensor(self, mixed_x_layer):
query_layer, key_layer, value_layer = torch.split(mixed_x_layer, [self.num_key_value_groups, 1, 1], dim=-2)
query_layer = query_layer.reshape(mixed_x_layer.shape[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
key_layer = torch.squeeze(key_layer, -2)
value_layer = torch.squeeze(value_layer, -2)
return query_layer, key_layer, value_layer
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, ((nq + 2 * nkv) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, ((nq + 2 * nkv) * hn)] --> [sq, b, nkv, (nq // nkv + 2), hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(-1, (self.num_key_value_groups + 2),
self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, nkv, (nq // nkv + 2), hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = self.split_tensor(mixed_x_layer)
# Repeat kv ; fused SPDA internally handles it
# TODO - Remove self.attention_dropout check when SW-172239 is solved
if self.use_gqa and (not self.use_fused_sdpa or self.attention_dropout != 0):
key_layer = self.repeat_kv(key_layer, self.num_key_value_groups)
value_layer = self.repeat_kv(value_layer,
self.num_key_value_groups)
else:
assert not self.use_gqa, 'GQA + cross-attn not tested yet'
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb,) * 2)
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ==================================
# core attention computation
# ==================================
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
if self.enable_ds_sequence_parallel:
if self.use_flash_attn:
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
context_layer = self.dist_attn(query_layer, key_layer, value_layer)
if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
else:
context_layer = self.dist_attn(query_layer, key_layer, value_layer, attention_mask)
else:
if self.use_flash_attn:
if not self.use_flash_attn_triton:
query_layer, key_layer, value_layer = [rearrange(x, 's b ... -> b s ...').contiguous()
for x in (query_layer, key_layer, value_layer)]
if self.sequence_parallel:
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
else:
with tensor_parallel.get_cuda_rng_tracker().fork():
context_layer = self.core_attention_flash(query_layer, key_layer, value_layer)
if not self.use_flash_attn_triton:
context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous()
else:
if self.checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer, key_layer, value_layer, attention_mask)
else:
context_layer = self.core_attention(
query_layer, key_layer, value_layer, attention_mask)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
if bias is not None:
x = x + bias
if prob == 0:
out = x
else:
out = torch.nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor,
bias: Optional[torch.Tensor],
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: Optional[torch.Tensor],
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def __init__(self, config,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0., num_experts=1):
# retriever=None):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= config.apply_residual_connection_post_layernorm
self.bf16 = config.bf16
self.fp32_residual_connection = config.fp32_residual_connection
# Layernorm on the input data.
if args.normalization == 'layernorm':
if get_accelerator().device_name() in ['cuda', 'hpu']:
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.input_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.input_layernorm = MixedFusedRMSNorm(config.hidden_size,
config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
# Self attention.
self.self_attention = ParallelAttention(
config,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = config.hidden_dropout
self.bias_dropout_fusion = config.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Layernorm on the attention output
if args.normalization == 'layernorm':
if get_accelerator().device_name() in ['cuda', 'hpu']:
self.post_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.post_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.post_attention_layernorm = MixedFusedRMSNorm(config.hidden_size,
config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
# Cross attention.
if self.layer_type in (LayerType.decoder,
LayerType.retro_decoder,
LayerType.retro_decoder_with_retriever,
LayerType.retro_encoder):
self.inter_attention = ParallelAttention(
config,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
if args.normalization == 'layernorm':
self.post_inter_attention_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.post_inter_attention_layernorm = MixedFusedRMSNorm(config.hidden_size,
config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
# MLP
self.num_experts = num_experts
if args.num_experts_switch is not None:
self.mlp = SwitchMLP(config) # Megatron-LM's MoE
else:
if self.num_experts <= 1: # dense, not MoE
self.mlp = ParallelMLP(config)
else: # DeepSpeed's MoE
enable_expert_tensor_parallelism = args.enable_expert_tensor_parallelism
self.mlp = MoE(args.hidden_size,
ParallelMLP(config,
moe=True,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism),
num_experts=self.num_experts,
ep_size=args.moe_expert_parallel_size,
k=args.topk,
use_residual=(args.mlp_type == 'residual'),
capacity_factor=args.moe_train_capacity_factor,
eval_capacity_factor=args.moe_eval_capacity_factor,
min_capacity=args.moe_min_capacity,
drop_tokens=args.moe_token_dropping, use_tutel=args.use_tutel,
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad
if args.retro_add_retriever:
retro_args = get_retro_args()
self.retro_num_neighbors = args.retro_num_neighbors
self.retro_chunk_length = retro_args.retro_gpt_chunk_length
self.retro_retrieved_length = retro_args.retro_gpt_retrieved_length
# Retriever (bi-directional transformer with cross attention)
if layer_type == LayerType.retro_decoder_with_retriever:
self.retriever = ParallelTransformer(
config=config,
model_type=ModelType.retro_encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True,
post_process=False,
)
self._retriever_key = 'retriever'
else:
self.retriever = None
def default_decoder_cross_attention(self,
encoder_output,
enc_dec_attn_mask,
layernorm_input,
layernorm_output,
bias_dropout_add_func):
'''Cross attention for a standard encoder-decoder model.'''
# Attention.
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
# Bias-dropout-add.
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias,
residual,
self.hidden_dropout)
# Layer norm.
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
return layernorm_input, layernorm_output
def retro_encoder_cross_attention(self,
retriever_output,
layernorm_input,
layernorm_output,
bias_dropout_add_func):
"""Cross attention for Retro encoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns, bs, d = layernorm_output.shape # [r, bs * l * k, d]
# Divide sequence dimension into chunks.
chunked_outputs = layernorm_output.reshape(self.retro_retrieved_length,
-1,
self.retro_num_neighbors,
d)
chunked_outputs_before_layer_norm = \
layernorm_input.reshape(self.retro_retrieved_length, -1,
self.retro_num_neighbors, d) # [r, bs*l, k, d]
# Per-chunk attention.
layernorm_inputs = []
layernorm_outputs = []
for k in range(self.retro_num_neighbors):
# Attention.
chunked_output = chunked_outputs[:,:,k].contiguous()
attention_output, attention_bias = \
self.inter_attention(
chunked_output, # Q (neighbor embedding)
None,
encoder_output=retriever_output) # K, V (hidden act)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = chunked_output
else:
residual = chunked_outputs_before_layer_norm[:,:,k]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
None if attention_bias is None else attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
layernorm_inputs.append(layernorm_input)
# Layer norm.
layernorm_output = \
self.post_inter_attention_layernorm(layernorm_input)
layernorm_outputs.append(layernorm_output)
# Concatenate layer norms.
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input = \
torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d)
layernorm_output = \
torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d)
return layernorm_input, layernorm_output
def retro_decoder_cross_attention(self,
retriever_input,
retriever_output,
retriever_attn_mask,
layernorm_input,
layernorm_output,
inference_params,
bias_dropout_add_func):
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns, bs, d = layernorm_output.shape
l = int(np.ceil(ns / self.retro_chunk_length))
# Retrieve neighbors.
if self.layer_type == LayerType.retro_decoder_with_retriever:
first_ns = ns % self.retro_chunk_length
if first_ns > 0:
raise Exception("test this case.")
first_chunk, rest_chunk = \
layernorm_output[:first_ns], layernorm_output[first_ns:]
first_chunk = torch.nn.functional.pad(
first_chunk,
(0, 0, 0, 0, 0, self.retro_chunk_length - first_ns),
'constant',
0)
chunked_output = \
torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
else:
chunked_output = layernorm_output # [l * m, bs, d]
chunked_output = chunked_output \
.reshape(l, self.retro_chunk_length, bs, d) \
.permute(1, 2, 0, 3) \
.reshape(self.retro_chunk_length, bs * l, d) \
.contiguous()
# Get Encoder Output
retriever_output = self.retriever(
hidden_states=retriever_input,
attention_mask=retriever_attn_mask,
retriever_output=chunked_output,
retriever_attn_mask=retriever_attn_mask,
inference_params=inference_params) # [r, k * bs * l , d]
retriever_output = retriever_output.reshape(
self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d]
# Chunks.
pad = (ns - 1) % self.retro_chunk_length
attending_chunks = layernorm_output[pad:]
padded_chunks = torch.nn.functional.pad(
attending_chunks,
(0, 0, 0, 0, 0, self.retro_chunk_length - 1),
'constant', 0)
padded_chunked_output = padded_chunks \
.reshape(l, self.retro_chunk_length, bs, d) \
.permute(1, 2, 0, 3)
padded_chunked_output = padded_chunked_output.reshape(
self.retro_chunk_length, bs * l, d).contiguous()
# Encoder output.
attention_output, attention_bias = \
self.inter_attention(padded_chunked_output,
None,
encoder_output=retriever_output)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
None if attention_bias is None else attention_bias.expand_as(attention_output),
torch.zeros_like(attention_output),
self.hidden_dropout)
layernorm_input = layernorm_input \
.reshape(self.retro_chunk_length, bs, l, d) \
.permute(2, 0, 1, 3) # [l, m, bs, d]
layernorm_input = layernorm_input.reshape(self.retro_chunk_length * l, bs, d)
layernorm_input = torch.nn.functional.pad(
layernorm_input,
(0, 0, 0, 0, pad, 0),
'constant', 0)[:ns] # [ns, b, d]
layernorm_input = layernorm_input + residual
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
return retriever_output, layernorm_input, layernorm_output
def forward(self, hidden_states, attention_mask=None,
encoder_output=None, enc_dec_attn_mask=None,
retriever_input=None,
retriever_output=None,
retriever_attn_mask=None,
inference_params=None,
rotary_pos_emb=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias,
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# Cross attention.
if self.layer_type == LayerType.encoder:
pass
elif self.layer_type == LayerType.decoder:
layernorm_input, layernorm_output = \
self.default_decoder_cross_attention(
encoder_output,
enc_dec_attn_mask,
layernorm_input,
layernorm_output,
bias_dropout_add_func)
elif self.layer_type == LayerType.retro_encoder:
layernorm_input, layernorm_output = \
self.retro_encoder_cross_attention(
retriever_output,
layernorm_input,
layernorm_output,
bias_dropout_add_func)
elif self.layer_type in (LayerType.retro_decoder,
LayerType.retro_decoder_with_retriever):
retriever_output, layernorm_input, layernorm_output = \
self.retro_decoder_cross_attention(
retriever_input,
retriever_output,
retriever_attn_mask,
layernorm_input,
layernorm_output,
inference_params,
bias_dropout_add_func)
else:
raise Exception("Unsupported layer type, '%s'." %
self.layer_type.name)
# MLP.
moe_loss = torch.tensor(0.0, device=layernorm_output.device, dtype=layernorm_output.dtype)
mlp_bias = torch.tensor(0.0, device=layernorm_output.device, dtype=layernorm_output.dtype)
if self.num_experts == 1:
mlp_output, mlp_bias = self.mlp(layernorm_output)
else:
mlp_output, moe_loss, _ = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
if mlp_bias is not None:
mlp_bias = mlp_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler():
output = bias_dropout_add_func(
mlp_output,
mlp_bias,
residual,
self.hidden_dropout)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = core.utils.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad,
keep_graph = True)
else:
if mlp_bias is not None:
mlp_output = mlp_output + mlp_bias
out = torch.nn.functional.dropout(mlp_output,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
if self.layer_type == LayerType.retro_decoder_with_retriever:
return output, retriever_output, moe_loss
else:
return output, moe_loss
class ParallelTransformerLayerPipe(ParallelTransformerLayer):
"""Extends ParallelTransformerLayer to forward attention_mask through the pipeline.
Forward has two usages that affect attention mask communication:
1) forward((input, attn_mask) , **kwargs) -> (output, mask)
When the attention mask is provided as the second positional
argument, typical pipeline behavior is used and both the output
*and* mask are returned in a tuple. This tuple is then forwarded
to the next stage in the pipeline.
This version is useful if masks are dynamic.
2) forward(input, **kwargs) -> output
When the mask is static over all samples, it is advantageous to
cache the mask and avoid communicating it.
If no mask is provided, the module will query `self._args.attn_mask`
for the mask and only return `super().forward(...)`
"""
def forward(self, inputs, **kwargs):
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
if not hasattr(self, '_args'):
self._args = get_args()
rotary_pos_emb = self._args.rotary_pos_emb if self._args.use_rotary_position_embeddings else None
if torch.is_tensor(inputs) or len(inputs) == 1:
# No attention mask forwarded, search for args.attn_mask
hidden_states, attention_mask = inputs, self._args.attn_mask
# HACK: currently MoE model does not support pipeline parallel, so
# here we just ignore the moe_loss returned by forward()
return super().forward(hidden_states, attention_mask, **kwargs, rotary_pos_emb=rotary_pos_emb)[0]
elif len(inputs) == 2:
# Attention mask is an activation.
hidden_states, attention_mask = inputs[0], inputs[1]
# HACK: currently MoE model does not support pipeline parallel, so
# here we just ignore the moe_loss returned by forward()
return super().forward(*inputs, **kwargs, rotary_pos_emb=rotary_pos_emb)[0], attention_mask
else:
raise RuntimeError('Received more inputs than understood.')
class NoopTransformerLayer(MegatronModule):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when a standalone embedding layer
is used (i.e., args.standalone_embedding_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
"""
def __init__(self, layer_number):
super().__init__()
self.layer_number = layer_number
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
return hidden_states.clone()
def _get_num_layers(args, model_type, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank."""
is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder)
if model_type == ModelType.retro_encoder:
num_layers = args.retro_encoder_layers
elif parallel_state.get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if parallel_state.is_pipeline_stage_before_split():
num_layers = (
0
if args.standalone_embedding_stage
and parallel_state.get_pipeline_model_parallel_rank() == 0 else
args.encoder_num_layers // num_ranks_in_encoder
)
else:
num_layers = args.decoder_num_layers // num_ranks_in_decoder
else:
assert args.num_layers == args.encoder_num_layers
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
and parallel_state.get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size
)
else:
if not is_decoder:
num_layers = args.encoder_num_layers
else:
num_layers = args.decoder_num_layers
return num_layers
def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
layer_number):
args = get_args()
if args.retro_add_retriever and layer_number in retro_layer_numbers:
if model_type == ModelType.retro_decoder:
return LayerType.retro_decoder_with_retriever \
if layer_number == retro_layer_numbers[0] \
else LayerType.retro_decoder
elif model_type == ModelType.retro_encoder:
return LayerType.retro_encoder
else:
raise Exception("Unsupported model type, '%s'." % model_type)
else:
return default_layer_type
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, config,
model_type, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
pre_process=True,
post_process=True,
drop_path_rate=0.0,
num_experts=[1]):
super(ParallelTransformer, self).__init__()
args = get_args()
self.layer_type = layer_type
self.model_type = model_type
self.bf16 = config.bf16
self.fp32_residual_connection = config.fp32_residual_connection
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
self.transformer_impl = args.transformer_impl
self.retro_add_retriever = args.retro_add_retriever
self.ds_inference = args.ds_inference
self.device_name = get_accelerator().device_name()
# Store activation checkpoiting flag.
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
self.recompute_granularity = config.recompute_granularity
self.recompute_method = config.recompute_method
self.recompute_num_layers = config.recompute_num_layers
self.distribute_saved_activations = \
config.distribute_saved_activations and not config.sequence_parallel
self.sequence_parallel = config.sequence_parallel
self.use_fp8 = False
# Transformer Engine Init.
self.transformer_engine_rope_available = False
if self.transformer_impl == 'transformer_engine' and self.device_name == 'cuda':
global transformer_engine
import transformer_engine
from importlib.metadata import version
from pkg_resources import packaging
te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("0.10.0"):
self.transformer_engine_rope_available = True
del version, packaging
self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
self.fp8_recipe = None
self.fp8_group = None
if self.use_fp8:
self.fp8_group = parallel_state.get_data_parallel_group()
if args.fp8_e4m3:
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif args.fp8_hybrid:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=args.fp8_margin,
interval=args.fp8_interval,
fp8_format=fp8_format,
amax_history_len=args.fp8_amax_history_len,
amax_compute_algo=args.fp8_amax_compute_algo,
override_linear_precision=(False, False, not args.fp8_wgrad),
)
self.num_microbatches_in_previous_step = -1
self.microbatch_count = 0
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
# Number of layers.
self.num_layers = _get_num_layers(args, model_type,
layer_type==LayerType.decoder)
self.drop_path_rates = [
rate.item() for rate in
torch.linspace(0, self.drop_path_rate, config.num_layers)]
self.retro_layer_numbers = None
if model_type == ModelType.retro_decoder:
retro_layer_start = 6 if config.num_layers <= 15 else 9
self.retro_layer_numbers = \
np.arange(retro_layer_start, args.num_layers + 1, 3).tolist()
if model_type == ModelType.retro_encoder:
self.retro_layer_numbers = [1]
# Transformer layers.
if args.retro_add_retriever:
assert self.recompute_granularity != 'full', \
"Full recompute not supported for Retro."
assert args.transformer_impl == 'local', \
"Transformer engine does not support Retro layers."
def build_layer(layer_number, n_e):
if args.transformer_impl == 'local' or self.device_name == 'hpu':
current_layer_type = _get_layer_type(
model_type, layer_type, self.retro_layer_numbers,
layer_number)
return ParallelTransformerLayer(
config,
layer_number,
layer_type=current_layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1],
num_experts=n_e)
else:
assert config.num_attention_heads == config.num_key_value_heads, \
'Transformer_engine does not support GQA'
return transformer_engine.pytorch.TransformerLayer(
config.hidden_size,
config.ffn_hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.layernorm_epsilon,
hidden_dropout=config.hidden_dropout,
attention_dropout=config.attention_dropout,
init_method=config.init_method,
output_layer_init_method=config.output_layer_init_method,
layer_number=layer_number,
kv_channels=config.kv_channels,
self_attn_mask_type=self_attn_mask_type.name,
tp_group=parallel_state.get_tensor_model_parallel_group(),
get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
apply_query_key_layer_scaling=config.apply_query_key_layer_scaling,
attention_softmax_in_fp32=config.attention_softmax_in_fp32,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
sequence_parallel=config.sequence_parallel,
params_dtype=config.params_dtype,
apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=self.drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=True)
if config.virtual_pipeline_model_parallel_size is not None:
assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
config.num_layers // config.virtual_pipeline_model_parallel_size) + \
(parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
parallel_state.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
assert len(num_experts) == 1 or len(num_experts) == args.num_layers // args.expert_interval, \
'num_experts must be either a single value or a list of the same length as the number of MoE layers'
# Create the list of MoE experts
if len(num_experts) == 1:
num_experts = num_experts * (args.num_layers // args.expert_interval)
# Build the layers
self.layers = []
for i in range(self.num_layers):
layer_num = i + 1 + offset
if layer_num % args.expert_interval == 0:
n_e = num_experts[(layer_num-1) // args.expert_interval]
else:
n_e = 1
self.layers.append(build_layer(layer_num, n_e))
self.layers = torch.nn.ModuleList(self.layers)
# Update dropout rate for Retro encoder.
if model_type == ModelType.retro_encoder:
for layer in self.layers:
if layer.self_attention.use_flash_attn:
layer.self_attention.core_attention_flash.dropout_p = \
torch.nn.Dropout(args.retro_encoder_attention_dropout)
else:
layer.self_attention.core_attention.attention_dropout.p =\
args.retro_encoder_attention_dropout
layer.hidden_dropout = args.retro_encoder_hidden_dropout
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
if args.normalization == 'layernorm':
if get_accelerator().device_name() in ['cuda', 'hpu']:
self.final_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p,
mem_efficient_ln=args.mem_efficient_ln)
else:
self.final_layernorm = LayerNorm(
config.hidden_size,
eps=config.layernorm_epsilon)
else:
self.final_layernorm = MixedFusedRMSNorm(config.hidden_size,
config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
rotary_pos_emb, is_first_microbatch):
args = get_args()
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*args, **kwargs):
x_, *args = args
moe_losses = []
for index in range(start, end):
layer = self._get_layer(index)
output = layer(x_, *args, **kwargs)
if isinstance(output, tuple):
x_, moe_loss = output
else:
x_ = output
moe_loss = torch.tensor(0.0, device=x_.device, dtype=x_.dtype, requires_grad=True)
moe_losses.append(moe_loss)
return (x_, *moe_losses)
return custom_forward
if args.deepspeed and args.deepspeed_activation_checkpointing:
moe_losses = []
# Make sure memory is freed.
tensor_parallel.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.num_layers:
hidden_states, *local_moe_losses = tensor_parallel.checkpoint(
custom(l, l + self.checkpoint_num_layers), False,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
moe_losses.extend(local_moe_losses)
l += self.checkpoint_num_layers
return hidden_states, moe_losses
else:
moe_losses = []
te_forward_kwargs = {}
if self.transformer_impl == 'transformer_engine':
te_forward_kwargs['is_first_microbatch'] = is_first_microbatch
if self.transformer_engine_rope_available:
te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and
# checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states, *local_moe_losses = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, **te_forward_kwargs)
else:
hidden_states, *local_moe_losses = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
moe_losses.extend(local_moe_losses)
l += self.recompute_num_layers
elif self.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.recompute_num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states, *local_moe_losses = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, **te_forward_kwargs)
else:
hidden_states, *local_moe_losses = tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
else:
if self.transformer_impl == 'transformer_engine':
hidden_states, *local_moe_losses = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, **te_forward_kwargs)
else:
hidden_states, *local_moe_losses = custom(l, l + 1)(
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
moe_losses.extend(local_moe_losses)
else:
raise ValueError("Invalid activation recompute method.")
return hidden_states, moe_losses
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
retriever_input=None,
retriever_output=None,
retriever_attn_mask=None,
inference_params=None,
rotary_pos_emb=None):
# hidden_states: [s, b, h]
# Checks.
if inference_params:
assert self.recompute_granularity is None, \
'inference does not work with activation checkpointing'
# TODO: Below old DeepSpeed code are commented because it's unsure whether
# it is still relevant.
# # Reza's note: DeepSpeed inference does not support transposes
# if not self.ds_inference:
# if self.pre_process:
# # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# # If the input flag for fp32 residual connection is set, convert for float.
# if self.fp32_residual_connection:
# hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# # Otherwise, leave it as is.
# else:
# hidden_states = hidden_states.transpose(0, 1).contiguous()
# else:
# # See set_input_tensor()
# hidden_states = self.input_tensor
# if encoder_output is not None:
# encoder_output = encoder_output.transpose(0, 1).contiguous()
if not self.pre_process:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = core.utils.make_viewless_tensor(
hidden_states,
requires_grad=True,
keep_graph=True,
)
# RNG context.
if self.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
# Forward layers.
with rng_context:
# The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast
with transformer_engine.pytorch.fp8_autocast(
enabled=self.use_fp8,
fp8_recipe=self.fp8_recipe,
fp8_group=self.fp8_group
) if self.use_fp8 else nullcontext():
# Determine if the current iteration is first microbatch
if self.num_microbatches_in_previous_step != get_num_microbatches():
self.microbatch_count = 0 # Reset count on new batch size rampup interval
self.num_microbatches_in_previous_step = get_num_microbatches()
is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0
# Forward pass.
moe_losses = []
if self.checkpoint_activations:
hidden_states, moe_losses = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
rotary_pos_emb,
is_first_microbatch)
elif self.recompute_granularity == 'full':
hidden_states, moe_losses = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
rotary_pos_emb,
is_first_microbatch)
else:
forward_kwargs = {
'encoder_output': encoder_output,
'enc_dec_attn_mask': enc_dec_attn_mask,
'inference_params': inference_params,
}
if self.transformer_impl == 'transformer_engine':
forward_kwargs['is_first_microbatch'] = is_first_microbatch
forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention
if self.transformer_engine_rope_available:
forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
else:
forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
forward_kwargs['retriever_input'] = retriever_input
forward_kwargs['retriever_output'] = retriever_output
forward_kwargs['retriever_attn_mask'] = retriever_attn_mask
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
**forward_kwargs)
# First Retro decoder layer returns both hidden_states
# and retriever_output. Make retriever_output available
# to subsequence Retro layers.
if isinstance(hidden_states, tuple):
assert (len(hidden_states) == 2 or len(hidden_states) == 3)
if len(hidden_states) == 2:
if not self.ds_inference:
hidden_states, moe_loss = hidden_states
moe_losses.append(moe_loss)
else:
forward_kwargs["retriever_output"] = hidden_states[1]
if not self.ds_inference:
hidden_states, _, moe_loss = hidden_states
moe_losses.append(moe_loss)
# Skip counter update for eval and activation checkpointing
if torch.is_grad_enabled() and self.training:
self.microbatch_count += 1
# Final layer norm.
if self.post_process and self.post_layer_norm:
# TODO: Below old DeepSpeed code are commented because it's unsure whether
# it is still relevant.
# if not self.ds_inference:
# # Reverting data format change [s b h] --> [b s h].
# hidden_states = hidden_states.transpose(0, 1).contiguous()
hidden_states = self.final_layernorm(hidden_states)
return (hidden_states, *moe_losses)
class LMHeadPipe(MegatronModule):
"""
Arguments:
vocab_size: size of vocabulary.
hidden_size: hidden size
gather_output: wether output logits being gathered or not.
init_method: init method for weight initialization
config:
"""
def __init__(self, hidden_size, vocab_size, config):
args = get_args()
super(LMHeadPipe, self).__init__()
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=hidden_size,
output_size=vocab_size,
bias=False,
config=config,
init_method=config.init_method,)
def forward(self, inputs, **kwargs):
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
if isinstance(inputs, tuple):
hidden_states = inputs[0]
else:
hidden_states = inputs
if not hasattr(self, '_args'):
self._args = get_args()
if hasattr(self._args, 'attn_mask'):
attention_mask = None
else:
attention_mask = inputs[1]
logits, _ = self.lm_head(hidden_states)
# If cmd args has attn_mask, we don't forward it as an activation.
if hasattr(self._args, 'attn_mask'):
return logits
else:
return logits, attention_mask
|