nyarunyarunya commited on
Commit
e48545f
·
verified ·
1 Parent(s): 24e33d6

Create communications.py

Browse files
Files changed (1) hide show
  1. wan/modules/communications.py +516 -0
wan/modules/communications.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import os
4
+
5
+ # DeepSpeed Team
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from fastvideo.utils.parallel_states import nccl_info
10
+ from typing import Any, Tuple
11
+ from torch import Tensor
12
+ from torch.nn import Module
13
+
14
+
15
+ def broadcast(input_: torch.Tensor):
16
+ src = nccl_info.group_id * nccl_info.sp_size
17
+ dist.broadcast(input_, src=src, group=nccl_info.group)
18
+
19
+
20
+ def _all_to_all_4D(
21
+ input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None
22
+ ) -> torch.tensor:
23
+ """
24
+ all-to-all for QKV
25
+
26
+ Args:
27
+ input (torch.tensor): a tensor sharded along dim scatter dim
28
+ scatter_idx (int): default 1
29
+ gather_idx (int): default 2
30
+ group : torch process group
31
+
32
+ Returns:
33
+ torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
34
+ """
35
+ assert (
36
+ input.dim() == 4
37
+ ), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
38
+
39
+ seq_world_size = dist.get_world_size(group)
40
+
41
+ if scatter_idx == 2 and gather_idx == 1:
42
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
43
+ bs, shard_seqlen, hc, hs = input.shape
44
+ seqlen = shard_seqlen * seq_world_size
45
+ shard_hc = hc // seq_world_size
46
+
47
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
48
+ # (bs, seqlen/P, hc, hs) -reshape-> (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)-> (P, seq_len/P, bs, hc/P, hs)
49
+ input_t = (
50
+ input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs)
51
+ .transpose(0, 2)
52
+ .contiguous()
53
+ )
54
+
55
+ output = torch.empty_like(input_t)
56
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
57
+ # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
58
+ if seq_world_size > 1:
59
+ dist.all_to_all_single(output, input_t, group=group)
60
+ torch.cuda.synchronize()
61
+ else:
62
+ output = input_t
63
+ # if scattering the seq-dim, transpose the heads back to the original dimension
64
+ output = output.reshape(seqlen, bs, shard_hc, hs)
65
+
66
+ # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
67
+ output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
68
+
69
+ return output
70
+
71
+ elif scatter_idx == 1 and gather_idx == 2:
72
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
73
+ bs, seqlen, shard_hc, hs = input.shape
74
+ hc = shard_hc * seq_world_size
75
+ shard_seqlen = seqlen // seq_world_size
76
+ seq_world_size = dist.get_world_size(group)
77
+
78
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
79
+ # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)-> (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
80
+ input_t = (
81
+ input.reshape(bs, seq_world_size, shard_seqlen, shard_hc, hs)
82
+ .transpose(0, 3)
83
+ .transpose(0, 1)
84
+ .contiguous()
85
+ .reshape(seq_world_size, shard_hc, shard_seqlen, bs, hs)
86
+ )
87
+
88
+ output = torch.empty_like(input_t)
89
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
90
+ # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
91
+ if seq_world_size > 1:
92
+ dist.all_to_all_single(output, input_t, group=group)
93
+ torch.cuda.synchronize()
94
+ else:
95
+ output = input_t
96
+
97
+ # if scattering the seq-dim, transpose the heads back to the original dimension
98
+ output = output.reshape(hc, shard_seqlen, bs, hs)
99
+
100
+ # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
101
+ output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)
102
+
103
+ return output
104
+ else:
105
+ raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
106
+
107
+
108
+ class SeqAllToAll4D(torch.autograd.Function):
109
+ @staticmethod
110
+ def forward(
111
+ ctx: Any,
112
+ group: dist.ProcessGroup,
113
+ input: Tensor,
114
+ scatter_idx: int,
115
+ gather_idx: int,
116
+ ) -> Tensor:
117
+ ctx.group = group
118
+ ctx.scatter_idx = scatter_idx
119
+ ctx.gather_idx = gather_idx
120
+
121
+ return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
122
+
123
+ @staticmethod
124
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
125
+ return (
126
+ None,
127
+ SeqAllToAll4D.apply(
128
+ ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx
129
+ ),
130
+ None,
131
+ None,
132
+ )
133
+
134
+
135
+ def all_to_all_4D(
136
+ input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1,
137
+ ):
138
+ return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim)
139
+
140
+
141
+ def _all_to_all(
142
+ input_: torch.Tensor,
143
+ world_size: int,
144
+ group: dist.ProcessGroup,
145
+ scatter_dim: int,
146
+ gather_dim: int,
147
+ ):
148
+ input_list = [
149
+ t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)
150
+ ]
151
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
152
+ dist.all_to_all(output_list, input_list, group=group)
153
+ return torch.cat(output_list, dim=gather_dim).contiguous()
154
+
155
+
156
+ class _AllToAll(torch.autograd.Function):
157
+ """All-to-all communication.
158
+
159
+ Args:
160
+ input_: input matrix
161
+ process_group: communication group
162
+ scatter_dim: scatter dimension
163
+ gather_dim: gather dimension
164
+ """
165
+
166
+ @staticmethod
167
+ def forward(ctx, input_, process_group, scatter_dim, gather_dim):
168
+ ctx.process_group = process_group
169
+ ctx.scatter_dim = scatter_dim
170
+ ctx.gather_dim = gather_dim
171
+ ctx.world_size = dist.get_world_size(process_group)
172
+ output = _all_to_all(
173
+ input_, ctx.world_size, process_group, scatter_dim, gather_dim
174
+ )
175
+ return output
176
+
177
+ @staticmethod
178
+ def backward(ctx, grad_output):
179
+ grad_output = _all_to_all(
180
+ grad_output,
181
+ ctx.world_size,
182
+ ctx.process_group,
183
+ ctx.gather_dim,
184
+ ctx.scatter_dim,
185
+ )
186
+ return (
187
+ grad_output,
188
+ None,
189
+ None,
190
+ None,
191
+ )
192
+
193
+
194
+ def all_to_all(
195
+ input_: torch.Tensor, scatter_dim: int = 2, gather_dim: int = 1,
196
+ ):
197
+ return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
198
+
199
+
200
+ class _AllGather(torch.autograd.Function):
201
+ """All-gather communication with autograd support.
202
+
203
+ Args:
204
+ input_: input tensor
205
+ dim: dimension along which to concatenate
206
+ """
207
+
208
+ @staticmethod
209
+ def forward(ctx, input_, dim):
210
+ ctx.dim = dim
211
+ world_size = nccl_info.sp_size
212
+ group = nccl_info.group
213
+ input_size = list(input_.size())
214
+
215
+ ctx.input_size = input_size[dim]
216
+
217
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
218
+ input_ = input_.contiguous()
219
+ dist.all_gather(tensor_list, input_, group=group)
220
+
221
+ output = torch.cat(tensor_list, dim=dim)
222
+ return output
223
+
224
+ @staticmethod
225
+ def backward(ctx, grad_output):
226
+ world_size = nccl_info.sp_size
227
+ rank = nccl_info.rank_within_group
228
+ dim = ctx.dim
229
+ input_size = ctx.input_size
230
+
231
+ sizes = [input_size] * world_size
232
+
233
+ grad_input_list = torch.split(grad_output, sizes, dim=dim)
234
+ grad_input = grad_input_list[rank]
235
+
236
+ return grad_input, None
237
+
238
+
239
+ def all_gather(input_: torch.Tensor, dim: int = 1):
240
+ """Performs an all-gather operation on the input tensor along the specified dimension.
241
+
242
+ Args:
243
+ input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
244
+ dim (int, optional): Dimension along which to concatenate. Defaults to 1.
245
+
246
+ Returns:
247
+ torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
248
+ """
249
+ return _AllGather.apply(input_, dim)
250
+
251
+
252
+ def prepare_sequence_parallel_data(
253
+ hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask
254
+ ):###not use fastvideo default sp data
255
+ return (
256
+ hidden_states,
257
+ encoder_hidden_states,
258
+ attention_mask,
259
+ encoder_attention_mask,
260
+ )
261
+ if nccl_info.sp_size == 1:
262
+ return (
263
+ hidden_states,
264
+ encoder_hidden_states,
265
+ attention_mask,
266
+ encoder_attention_mask,
267
+ )
268
+
269
+ def prepare(
270
+ hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask
271
+ ):
272
+ hidden_states = all_to_all(hidden_states, scatter_dim=2, gather_dim=0)
273
+ encoder_hidden_states = all_to_all(
274
+ encoder_hidden_states, scatter_dim=1, gather_dim=0
275
+ )
276
+ attention_mask = all_to_all(attention_mask, scatter_dim=1, gather_dim=0)
277
+ encoder_attention_mask = all_to_all(
278
+ encoder_attention_mask, scatter_dim=1, gather_dim=0
279
+ )
280
+ return (
281
+ hidden_states,
282
+ encoder_hidden_states,
283
+ attention_mask,
284
+ encoder_attention_mask,
285
+ )
286
+
287
+ sp_size = nccl_info.sp_size
288
+ # frame = hidden_states.shape[2]
289
+ # print(2333333,frame)#13
290
+ # assert frame % sp_size == 0, "frame should be a multiple of sp_size"
291
+
292
+ (
293
+ hidden_states,
294
+ encoder_hidden_states,
295
+ attention_mask,
296
+ encoder_attention_mask,
297
+ ) = prepare(
298
+ hidden_states,
299
+ encoder_hidden_states.repeat(1, sp_size, 1),
300
+ attention_mask.repeat(1, sp_size, 1, 1),
301
+ encoder_attention_mask.repeat(1, sp_size),
302
+ )
303
+
304
+ return hidden_states, encoder_hidden_states, attention_mask, encoder_attention_mask
305
+
306
+
307
+ def sp_parallel_dataloader_wrapper(
308
+ dataloader, device, train_batch_size, sp_size, train_sp_batch_size
309
+ ):
310
+ while True:
311
+ for data_item in dataloader:
312
+ latents, cond, attn_mask, cond_mask = data_item
313
+ latents = latents.to(device)
314
+ cond = cond.to(device)
315
+ attn_mask = attn_mask.to(device)
316
+ cond_mask = cond_mask.to(device)
317
+ frame = latents.shape[2]
318
+ if frame == 1:
319
+ yield latents, cond, attn_mask, cond_mask
320
+ else:
321
+ latents, cond, attn_mask, cond_mask = prepare_sequence_parallel_data(
322
+ latents, cond, attn_mask, cond_mask
323
+ )
324
+ assert (
325
+ train_batch_size * sp_size >= train_sp_batch_size
326
+ ), "train_batch_size * sp_size should be greater than train_sp_batch_size"
327
+ for iter in range(train_batch_size * sp_size // train_sp_batch_size):
328
+ st_idx = iter * train_sp_batch_size
329
+ ed_idx = (iter + 1) * train_sp_batch_size
330
+ encoder_hidden_states = cond[st_idx:ed_idx]
331
+ attention_mask = attn_mask[st_idx:ed_idx]
332
+ encoder_attention_mask = cond_mask[st_idx:ed_idx]
333
+ yield (
334
+ latents[st_idx:ed_idx],
335
+ encoder_hidden_states,
336
+ attention_mask,
337
+ encoder_attention_mask,
338
+ )
339
+
340
+
341
+
342
+ def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
343
+ # skip if only one rank involved
344
+ world_size = dist.get_world_size(pg)
345
+ rank = dist.get_rank(pg)
346
+ if world_size == 1:
347
+ return input_
348
+
349
+ if pad > 0:
350
+ pad_size = list(input_.shape)
351
+ pad_size[dim] = pad
352
+ input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
353
+
354
+ dim_size = input_.size(dim)
355
+ assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})"
356
+
357
+ tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
358
+ output = tensor_list[rank].contiguous()
359
+ # if output.grad!=None:####must be None...
360
+ # print(1111111,output.grad)
361
+ return output
362
+
363
+
364
+ def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
365
+ # skip if only one rank involved
366
+ input_ = input_.contiguous()
367
+ world_size = dist.get_world_size(pg)
368
+ dist.get_rank(pg)
369
+
370
+ if world_size == 1:
371
+ return input_
372
+
373
+ # all gather
374
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
375
+ assert input_.device.type == "cuda"
376
+ torch.distributed.all_gather(tensor_list, input_, group=pg)
377
+
378
+ # concat
379
+ output = torch.cat(tensor_list, dim=dim)
380
+
381
+ if pad > 0:
382
+ output = output.narrow(dim, 0, output.size(dim) - pad)
383
+
384
+ return output
385
+
386
+
387
+ class _GatherForwardSplitBackward(torch.autograd.Function):
388
+ """
389
+ Gather the input sequence.
390
+
391
+ Args:
392
+ input_: input matrix.
393
+ process_group: process group.
394
+ dim: dimension
395
+ """
396
+
397
+ @staticmethod
398
+ def symbolic(graph, input_):
399
+ return _gather_sequence_func(input_)
400
+
401
+ @staticmethod
402
+ def forward(ctx, input_, process_group, dim, grad_scale, pad):
403
+ ctx.process_group = process_group
404
+ ctx.dim = dim
405
+ ctx.grad_scale = grad_scale
406
+ ctx.pad = pad
407
+ return _gather_sequence_func(input_, process_group, dim, pad)
408
+
409
+ @staticmethod
410
+ def backward(ctx, grad_output):
411
+ if ctx.grad_scale == "up":
412
+ grad_output = grad_output * dist.get_world_size(ctx.process_group)
413
+ elif ctx.grad_scale == "down":
414
+ grad_output = grad_output / dist.get_world_size(ctx.process_group)
415
+
416
+ return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
417
+
418
+
419
+
420
+ class _SplitForwardGatherBackward(torch.autograd.Function):
421
+ """
422
+ Split sequence.
423
+
424
+ Args:
425
+ input_: input matrix.
426
+ process_group: parallel mode.
427
+ dim: dimension
428
+ """
429
+
430
+ @staticmethod
431
+ def symbolic(graph, input_):
432
+ return _split_sequence_func(input_)
433
+
434
+ @staticmethod
435
+ def forward(ctx, input_, process_group, dim, grad_scale, pad):
436
+ ctx.process_group = process_group
437
+ ctx.dim = dim
438
+ ctx.grad_scale = grad_scale
439
+ ctx.pad = pad
440
+ return _split_sequence_func(input_, process_group, dim, pad)
441
+
442
+ @staticmethod
443
+ def backward(ctx, grad_output):
444
+ if ctx.grad_scale == "up":
445
+ grad_output = grad_output * dist.get_world_size(ctx.process_group)
446
+ elif ctx.grad_scale == "down":
447
+ grad_output = grad_output / dist.get_world_size(ctx.process_group)
448
+ return _gather_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
449
+
450
+
451
+ # def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
452
+ # return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
453
+ # def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
454
+ # return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
455
+
456
+ # if_print=0
457
+ def split_sequence(input_, dim, grad_scale=1.0, pad=0):
458
+ # global if_print
459
+ # if if_print==0:
460
+ # # print(123232323, int(os.getenv("RANK", "0")), nccl_info.group)
461
+ # print(123232323, int(os.getenv("RANK", "0")), dist.get_rank(nccl_info.group),dist.get_world_size(nccl_info.group))
462
+ # if_print=1
463
+ process_group=nccl_info.group
464
+ return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
465
+ def gather_sequence(input_, dim, grad_scale=1.0, pad=0):
466
+ process_group=nccl_info.group
467
+ # print(process_group)
468
+ return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
469
+
470
+ import torch
471
+ import torch.distributed as dist
472
+ import torch.nn.functional as F
473
+ from einops import rearrange
474
+ from torch import Tensor
475
+ from torch.distributed import ProcessGroup
476
+
477
+ def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
478
+ input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
479
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
480
+ dist.all_to_all(output_list, input_list, group=group)
481
+ return torch.cat(output_list, dim=gather_dim).contiguous()
482
+
483
+
484
+ class _AllToAll1(torch.autograd.Function):
485
+ """All-to-all communication.
486
+
487
+ Args:
488
+ input_: input matrix
489
+ process_group: communication group
490
+ scatter_dim: scatter dimension
491
+ gather_dim: gather dimension
492
+ """
493
+
494
+ @staticmethod
495
+ def forward(ctx, input_, process_group, scatter_dim, gather_dim):
496
+ ctx.process_group = process_group
497
+ ctx.scatter_dim = scatter_dim
498
+ ctx.gather_dim = gather_dim
499
+ world_size = dist.get_world_size(process_group)
500
+
501
+ return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
502
+
503
+ @staticmethod
504
+ def backward(ctx, *grad_output):
505
+ process_group = ctx.process_group
506
+ scatter_dim = ctx.gather_dim
507
+ gather_dim = ctx.scatter_dim
508
+ return_grad = _AllToAll1.apply(*grad_output, process_group, scatter_dim, gather_dim)
509
+ return (return_grad, None, None, None)
510
+
511
+
512
+ # def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
513
+ # return _AllToAll1.apply(input_, process_group, scatter_dim, gather_dim)
514
+ def all_to_all_comm(input_,scatter_dim=2, gather_dim=1):
515
+ process_group=nccl_info.group
516
+ return _AllToAll1.apply(input_, process_group, scatter_dim, gather_dim)