File size: 14,878 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
import os
import random
import unittest

import numpy as np
import torch
from torch.nn.functional import scaled_dot_product_attention

from finetrainers.models.attention_dispatch import (
    AttentionProvider,
    _AttentionProviderRegistry,
    _set_context_parallel_options,
    attention_dispatch,
    attention_provider,
    flash_attn_flash_attention,
    native_cudnn_attention,
    native_efficient_attention,
    native_flash_attention,
)
from finetrainers.parallel.ptd import _EquipartitionSharder


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def get_world_size():
    if torch.distributed.is_initialized():
        return torch.distributed.get_world_size()
    return int(os.environ.get("WORLD_SIZE", 1))


class AttentionDispatchTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        set_seed(0)

    def test_forward(self):
        if not torch.cuda.is_available():
            self.skipTest("CUDA is not available")
        cuda_capability = torch.cuda.get_device_capability()

        query, key, value = self._create_dummy_inputs()

        all_providers = [
            (AttentionProvider._NATIVE_MATH, 0),
            (AttentionProvider.NATIVE, 5e-3),
            (AttentionProvider.FLASH, 5e-3),
            (AttentionProvider.FLASH_VARLEN, 5e-3),
            (AttentionProvider.FLEX, 2e-2),
            (AttentionProvider._NATIVE_CUDNN, 5e-3),
            (AttentionProvider._NATIVE_EFFICIENT, 5e-3),
            (AttentionProvider._NATIVE_FLASH, 5e-3),
            (AttentionProvider.SAGE, 1e-1),
            (AttentionProvider.SAGE_VARLEN, 2e-0),
            (AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, 2e-0),  # TODO: look into the high difference threshold
            (AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, 2e-0),
            (AttentionProvider.XFORMERS, 5e-3),
        ]

        if cuda_capability >= (8, 9):
            all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, 2e-0))
        if cuda_capability >= (9, 0):
            all_providers.append((AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA_SM90, 2e-0))

        ref_output = None
        for i, (provider, threshold) in enumerate(all_providers):
            try:
                output = self._check_forward_pass(provider, query, key, value)
                if i == 0:
                    ref_output = output.detach().clone()
                else:
                    self.assertTrue(
                        torch.allclose(output, ref_output, atol=threshold), f"Forward pass mismatch for {provider}"
                    )
            except Exception as e:
                print(f"Warning: Forward pass test failed for {provider} with error: {e}")

    def test_backward(self):
        if not torch.cuda.is_available():
            self.skipTest("CUDA is not available")

        query, key, value = self._create_dummy_inputs()

        selected_providers = [
            AttentionProvider.FLASH,
            AttentionProvider.FLASH_VARLEN,
            AttentionProvider.FLEX,
            AttentionProvider.NATIVE,
            AttentionProvider.XFORMERS,
        ]

        ref_output = None
        for i, provider in enumerate(selected_providers):
            try:
                output = self._check_backward_pass(provider, query, key, value)
                if i == 0:
                    ref_output = output.detach().clone()
                else:
                    if provider == AttentionProvider.FLEX:
                        threshold = 1e-2
                    else:
                        threshold = 1e-3
                    self.assertTrue(
                        torch.allclose(output, ref_output, atol=threshold), f"Backward pass mismatch for {provider}"
                    )
            except Exception as e:
                print(f"Warning: Backward pass test failed for {provider} with error: {e}")

    def _create_dummy_inputs(
        self, batch_size=2, num_heads=8, seq_len=256, head_dim=64, dtype=torch.bfloat16, device="cuda"
    ):
        torch.manual_seed(0)
        query = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
        key = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
        value = torch.randn(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)
        return query, key, value

    def _check_forward_pass(self, provider: AttentionProvider, query, key, value):
        kwargs = {}
        if provider == AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA:
            kwargs["pv_accum_dtype"] = "fp32"
        with attention_provider(provider):
            output = attention_dispatch(query, key, value, attention_kwargs=kwargs)
        self.assertIsNotNone(output)
        self.assertEqual(output.shape, query.shape)
        return output

    def _check_backward_pass(self, provider: AttentionProvider, query, key, value):
        query.requires_grad_(True)
        key.requires_grad_(True)
        value.requires_grad_(True)

        with attention_provider(provider):
            output = attention_dispatch(query, key, value)
            loss = output.mean()
            loss.backward()

        self.assertTrue(query.grad is not None)
        self.assertTrue(key.grad is not None)
        self.assertTrue(value.grad is not None)

        query.grad.zero_()
        key.grad.zero_()
        value.grad.zero_()
        return output


class RingAttentionTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        torch.distributed.init_process_group(backend="nccl")
        rank, world_size = torch.distributed.get_rank(), torch.distributed.get_world_size()

        cls.rank = rank
        cls.world_size = world_size
        torch.cuda.set_device(rank)
        cls.mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,))

        set_seed(0)
        cls.batch_size = 2
        cls.num_heads = 8
        cls.seq_len = 256
        cls.head_dim = 64
        cls.dtype = torch.bfloat16
        cls.device = "cuda"

        _AttentionProviderRegistry._set_context_parallel(
            mesh=cls.mesh, convert_to_fp32=True, rotate_method="allgather"
        )
        _set_context_parallel_options(is_causal=False)

        cls.full_query = torch.randn(
            cls.batch_size,
            cls.num_heads,
            cls.seq_len * cls.world_size,
            cls.head_dim,
            dtype=cls.dtype,
            device=cls.device,
            requires_grad=True,
        )
        cls.full_key = torch.randn(
            cls.batch_size,
            cls.num_heads,
            cls.seq_len * cls.world_size,
            cls.head_dim,
            dtype=cls.dtype,
            device=cls.device,
            requires_grad=True,
        )
        cls.full_value = torch.randn(
            cls.batch_size,
            cls.num_heads,
            cls.seq_len * cls.world_size,
            cls.head_dim,
            dtype=cls.dtype,
            device=cls.device,
            requires_grad=True,
        )

        # Ensure all ranks have the same data
        with torch.no_grad():
            torch.distributed.broadcast(cls.full_query, src=0)
            torch.distributed.broadcast(cls.full_key, src=0)
            torch.distributed.broadcast(cls.full_value, src=0)

        with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
            reference_output = scaled_dot_product_attention(cls.full_query, cls.full_key, cls.full_value)

        cls.reference_output = reference_output.detach().clone()
        reference_output.sum().backward()

        cls.query, cls.key, cls.value = (
            _EquipartitionSharder.shard(x, dim=2, mesh=cls.mesh).detach().clone()
            for x in (cls.full_query, cls.full_key, cls.full_value)
        )

    @classmethod
    def tearDownClass(cls):
        torch.distributed.destroy_process_group()

    def _test_forward_native_cudnn_attention(self, atol: float = 1e-3):
        output = native_cudnn_attention(self.query, self.key, self.value)
        output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
        self.assertEqual(output.shape, self.reference_output.shape)
        self.assertTrue(torch.allclose(output, self.reference_output, atol=atol))

    def _test_forward_native_efficient_attention(self, atol: float = 1e-3):
        output = native_efficient_attention(self.query, self.key, self.value)
        output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
        self.assertEqual(output.shape, self.reference_output.shape)
        self.assertTrue(torch.allclose(output, self.reference_output, atol=atol))

    def _test_forward_native_flash_attention(self, atol: float = 1e-3):
        output = native_flash_attention(self.query, self.key, self.value)
        output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
        self.assertEqual(output.shape, self.reference_output.shape)
        self.assertTrue(torch.allclose(output, self.reference_output, atol=atol))

    def _test_forward_flash_attn_flash_attention(self, atol: float = 1e-3):
        output = flash_attn_flash_attention(self.query, self.key, self.value)
        output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
        self.assertEqual(output.shape, self.reference_output.shape)
        self.assertTrue(torch.allclose(output, self.reference_output, atol=atol))

    def _test_backward_native_cudnn_attention(self, atol: float = 1e-3):
        query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value))
        query.requires_grad = True
        key.requires_grad = True
        value.requires_grad = True
        output = native_cudnn_attention(query, key, value)
        output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
        output.sum().backward()
        with torch.no_grad():
            q_g, k_g, v_g = (
                _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh)
                for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad)
            )
        self.assertTrue(torch.allclose(query.grad, q_g, atol=atol))
        self.assertTrue(torch.allclose(key.grad, k_g, atol=atol))
        self.assertTrue(torch.allclose(value.grad, v_g, atol=atol))

    def _test_backward_native_efficient_attention(self, atol: float = 1e-3):
        query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value))
        query.requires_grad = True
        key.requires_grad = True
        value.requires_grad = True
        output = native_efficient_attention(query, key, value)
        output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
        output.sum().backward()
        with torch.no_grad():
            q_g, k_g, v_g = (
                _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh)
                for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad)
            )
        self.assertTrue(torch.allclose(query.grad, q_g, atol=atol))
        self.assertTrue(torch.allclose(key.grad, k_g, atol=atol))
        self.assertTrue(torch.allclose(value.grad, v_g, atol=atol))

    def _test_backward_native_flash_attention(self, atol: float = 1e-3):
        query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value))
        query.requires_grad = True
        key.requires_grad = True
        value.requires_grad = True
        output = native_flash_attention(query, key, value)
        output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
        output.sum().backward()
        with torch.no_grad():
            q_g, k_g, v_g = (
                _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh)
                for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad)
            )
        self.assertTrue(torch.allclose(query.grad, q_g, atol=atol))
        self.assertTrue(torch.allclose(key.grad, k_g, atol=atol))
        self.assertTrue(torch.allclose(value.grad, v_g, atol=atol))

    def _test_backward_flash_attn_flash_attention(self, atol: float = 1e-3):
        query, key, value = (x.detach().clone() for x in (self.query, self.key, self.value))
        query.requires_grad = True
        key.requires_grad = True
        value.requires_grad = True
        output = flash_attn_flash_attention(query, key, value)
        output = _EquipartitionSharder.unshard(output, dim=2, mesh=self.mesh)
        output.sum().backward()
        with torch.no_grad():
            q_g, k_g, v_g = (
                _EquipartitionSharder.shard(x, dim=2, mesh=self.mesh)
                for x in (self.full_query.grad, self.full_key.grad, self.full_value.grad)
            )
        self.assertTrue(torch.allclose(query.grad, q_g, atol=atol))
        self.assertTrue(torch.allclose(key.grad, k_g, atol=atol))
        self.assertTrue(torch.allclose(value.grad, v_g, atol=atol))


class RingAttentionCPTesterMixin:
    def test_forward_native_cudnn_attention(self):
        self._test_forward_native_cudnn_attention(atol=1e-2)

    def test_forward_native_efficient_attention(self):
        self._test_forward_native_efficient_attention(atol=1e-2)

    def test_forward_native_flash_attention(self):
        self._test_forward_native_flash_attention(atol=1e-2)

    def test_forward_flash_attn_flash_attention(self):
        self._test_forward_flash_attn_flash_attention(atol=1e-2)

    def test_backward_native_cudnn_attention(self):
        atol = 1e-2 * self.world_size  # TODO: make bounds more strict
        self._test_backward_native_cudnn_attention(atol=atol)

    def test_backward_native_efficient_attention(self):
        atol = 1e-2 * self.world_size  # TODO: make bounds more strict
        self._test_backward_native_efficient_attention(atol=atol)

    def test_backward_native_flash_attention(self):
        atol = 1e-2 * self.world_size  # TODO: make bounds more strict
        self._test_backward_native_flash_attention(atol=atol)

    @unittest.skip(
        """query diff: 0.298828125, key diff: 2.09375, value diff: 0.68359375; Needs further investigation"""
    )
    def test_backward_flash_attn_flash_attention(self):
        # Seems to require much higher bound for some reason
        atol = 1.5e-1 * self.world_size  # TODO: make bounds more strict
        self._test_backward_flash_attn_flash_attention(atol=atol)


@unittest.skipIf(
    not torch.cuda.is_available() or get_world_size() != 2, "CUDA is not available or world size is not 2"
)
class RingAttentionCP2Test(RingAttentionTest, RingAttentionCPTesterMixin):
    pass


@unittest.skipIf(
    not torch.cuda.is_available() or get_world_size() != 4, "CUDA is not available or world size is not 4"
)
class RingAttentionCP4Test(RingAttentionTest, RingAttentionCPTesterMixin):
    pass