petil777 commited on
Commit
74a5e35
·
1 Parent(s): 1e9783d

Upload 2 files

Browse files
Files changed (2) hide show
  1. dist.py +80 -0
  2. layers.py +256 -0
dist.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from datetime import timedelta
5
+
6
+ # Tensor Parallelism settings
7
+ RANK = int(os.getenv("RANK", "0"))
8
+ WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
9
+
10
+ # CUDA memory fraction
11
+ MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
12
+
13
+
14
+ class FakeBarrier:
15
+ def wait(self):
16
+ pass
17
+
18
+
19
+ class FakeGroup:
20
+ def __init__(self, rank, size):
21
+ self._rank = rank
22
+ self._size = size
23
+
24
+ def allreduce(self, *args, **kwargs):
25
+ return FakeBarrier()
26
+
27
+ def allgather(self, inputs, local_tensor, **kwargs):
28
+ assert (
29
+ len(inputs[0]) == len(local_tensor) == 1
30
+ ), f"{len(inputs[0])} != {len(local_tensor)} != 1, and the FakeGroup is supposed to join on simple tensors"
31
+ for input_ in inputs:
32
+ input_[0].data = local_tensor[0].data
33
+ return FakeBarrier()
34
+
35
+ def barrier(self, *args, **kwargs):
36
+ return FakeBarrier()
37
+
38
+ def size(self):
39
+ return self._size
40
+
41
+ def rank(self):
42
+ return self._rank
43
+
44
+
45
+ def initialize_torch_distributed():
46
+ if torch.cuda.is_available():
47
+ from torch.distributed import ProcessGroupNCCL
48
+
49
+ # Set the device id.
50
+ assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
51
+ device = RANK % torch.cuda.device_count()
52
+ torch.cuda.set_device(device)
53
+ torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
54
+ backend = "nccl"
55
+ options = ProcessGroupNCCL.Options()
56
+ options.is_high_priority_stream = True
57
+ options._timeout = timedelta(seconds=60)
58
+ else:
59
+ backend = "gloo"
60
+ options = None
61
+
62
+ if WORLD_SIZE == 1:
63
+ return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
64
+ else:
65
+ if os.getenv("DEBUG", None) == "1":
66
+ return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
67
+
68
+ if not torch.distributed.is_initialized():
69
+ # Call the init process.
70
+ torch.distributed.init_process_group(
71
+ backend=backend,
72
+ world_size=WORLD_SIZE,
73
+ rank=RANK,
74
+ timeout=timedelta(seconds=60),
75
+ pg_options=options,
76
+ )
77
+ else:
78
+ print("torch.distributed is already initialized.")
79
+
80
+ return torch.distributed.group.WORLD, RANK, WORLD_SIZE
layers.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy and modify https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/layers.py
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.distributed
6
+ from accelerate import init_empty_weights
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+
11
+ # Monkey patching
12
+ @classmethod
13
+ def load_layer_norm(cls, prefix, weights, eps):
14
+ weight = weights.get_tensor(f"{prefix}.weight")
15
+ bias = weights.get_tensor(f"{prefix}.bias")
16
+ with init_empty_weights():
17
+ ln = cls(weight.shape, eps=eps)
18
+
19
+ ln.weight = nn.Parameter(weight)
20
+ ln.bias = nn.Parameter(bias)
21
+ return ln
22
+
23
+
24
+ @classmethod
25
+ def load_layer_norm_no_bias(cls, prefix, weights, eps):
26
+ weight = weights.get_tensor(f"{prefix}.weight")
27
+ with init_empty_weights():
28
+ ln = cls(weight.shape, eps=eps)
29
+
30
+ ln.weight = nn.Parameter(weight)
31
+ ln.bias = None
32
+ return ln
33
+
34
+
35
+ torch.nn.LayerNorm.load = load_layer_norm
36
+ torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
37
+
38
+
39
+ class FastLinear(nn.Module):
40
+ def __init__(
41
+ self,
42
+ weight,
43
+ bias,
44
+ ) -> None:
45
+ super().__init__()
46
+ self.weight = nn.Parameter(weight)
47
+ if bias is not None:
48
+ self.bias = nn.Parameter(bias)
49
+ else:
50
+ self.bias = None
51
+
52
+ @classmethod
53
+ def load(cls, config, prefix: str, weights, bias: bool):
54
+ weight = weights.get_tensor(f"{prefix}.weight")
55
+ if bias:
56
+ bias = weights.get_tensor(f"{prefix}.bias")
57
+ else:
58
+ bias = None
59
+ return cls(weight, bias)
60
+
61
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
62
+ return F.linear(input, self.weight, self.bias)
63
+
64
+
65
+ def get_linear(weight, bias):
66
+ linear = FastLinear(weight, bias)
67
+ return linear
68
+
69
+
70
+ class SuperLayer(nn.Module):
71
+ def __init__(self, linear):
72
+ super().__init__()
73
+ self.linear = linear
74
+
75
+ def forward(self, x):
76
+ return self.linear.forward(x)
77
+
78
+
79
+ class TensorParallelHead(SuperLayer):
80
+ def __init__(self, linear, process_group, should_gather: bool):
81
+ super().__init__(linear)
82
+ self.process_group = process_group
83
+ self.should_gather = should_gather
84
+
85
+ @staticmethod
86
+ def load(config, prefix: str, weights):
87
+ if weights.process_group.size() > 1:
88
+ try:
89
+ weight = weights.get_sharded(f"{prefix}.weight", dim=0)
90
+ should_gather = True
91
+ except AssertionError:
92
+ # If the vocab size is not divisible by number of shards
93
+ # just load the entire thing.
94
+ weight = weights.get_tensor(f"{prefix}.weight")
95
+ should_gather = False
96
+ else:
97
+ weight = weights.get_tensor(f"{prefix}.weight")
98
+ should_gather = False
99
+
100
+ return TensorParallelHead(
101
+ get_linear(weight, bias=None),
102
+ process_group=weights.process_group,
103
+ should_gather=should_gather,
104
+ )
105
+
106
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
107
+ if not self.should_gather:
108
+ return super().forward(input)
109
+
110
+ world_size = self.process_group.size()
111
+ if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
112
+ out_dim = self.linear.weight.shape[0]
113
+
114
+ if input.shape[0] == 1:
115
+ world_out = input.new_empty(1, out_dim * world_size)
116
+ local_out = input.new_empty(1, out_dim)
117
+ gather_input = local_out
118
+ else:
119
+ world_out = input.new_empty(out_dim * world_size, input.shape[0])
120
+ gather_input = input.new_empty(out_dim, input.shape[0])
121
+ local_out = gather_input.T
122
+
123
+ torch.mm(input, self.linear.weight.T, out=local_out)
124
+
125
+ torch.distributed.all_gather_into_tensor(world_out, gather_input, group=self.process_group)
126
+
127
+ if input.shape[0] == 1:
128
+ return world_out
129
+ return world_out.T
130
+
131
+ output = super().forward(input)
132
+ world_output = [torch.empty_like(output) for _ in range(self.process_group.size())]
133
+ torch.distributed.all_gather(world_output, output, group=self.process_group)
134
+ world_output = torch.cat(world_output, dim=-1)
135
+ return world_output
136
+
137
+
138
+ class TensorParallelColumnLinear(SuperLayer):
139
+ @classmethod
140
+ def load(cls, config, prefix: str, weights, bias: bool):
141
+ return cls.load_multi(config, [prefix], weights, bias, dim=0)
142
+
143
+ @classmethod
144
+ def load_multi(cls, config, prefixes: List[str], weights, bias: bool, dim: int):
145
+ weight = weights.get_multi_weights_col(prefixes, dim=dim, quantize=config.quantize)
146
+
147
+ if bias:
148
+ b = [weights.get_sharded(f"{p}.bias", dim=0) for p in prefixes]
149
+ bias = torch.cat(b, dim=dim)
150
+ else:
151
+ bias = None
152
+ linear = get_linear(weight, bias)
153
+ return cls(linear)
154
+
155
+
156
+ class TensorParallelRowLinear(SuperLayer):
157
+ def __init__(self, linear, process_group):
158
+ super().__init__(linear)
159
+ self.process_group = process_group
160
+
161
+ @classmethod
162
+ def load(cls, config, prefix: str, weights, bias: bool):
163
+ weight = weights.get_multi_weights_row(prefix, quantize=config.quantize)
164
+
165
+ if bias and weights.process_group.rank() == 0:
166
+ # Rank is only on the first rank process
167
+ bias = weights.get_tensor(f"{prefix}.bias")
168
+ else:
169
+ bias = None
170
+ return cls(
171
+ get_linear(weight, bias),
172
+ process_group=weights.process_group,
173
+ )
174
+
175
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
176
+ out = super().forward(input)
177
+ if self.process_group.size() > 1:
178
+ torch.distributed.all_reduce(out, group=self.process_group)
179
+ return out
180
+
181
+
182
+ class TensorParallelEmbedding(nn.Module):
183
+ def __init__(self, prefix: str, weights, reduce=True):
184
+ super().__init__()
185
+ weight = weights.get_partial_sharded(f"{prefix}.weight", dim=0)
186
+ num_embeddings = weights.get_shape(f"{prefix}.weight")[0]
187
+
188
+ process_group = weights.process_group
189
+
190
+ world_size = process_group.size()
191
+ rank = process_group.rank()
192
+
193
+ block_size = num_embeddings // world_size
194
+ self.min_id = rank * block_size
195
+ self.max_id = min(num_embeddings, (rank + 1) * block_size)
196
+ self.null_idx = block_size
197
+ self.process_group = weights.process_group
198
+ self.reduce = reduce
199
+
200
+ """Additional 0 entry used for masking"""
201
+ self.weight = nn.Parameter(F.pad(weight, (0, 0, 0, 1)))
202
+
203
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
204
+ # default all out of bounds values to `self.null_idx` that will then be mapped to 0
205
+ # translate for [0, self.max_id - self.min_id[
206
+ input = torch.where(
207
+ (self.min_id > input) | (input >= self.max_id),
208
+ self.null_idx,
209
+ input - self.min_id,
210
+ )
211
+ out = torch.nn.functional.embedding(input, self.weight)
212
+ if self.reduce and self.process_group.size() > 1:
213
+ torch.distributed.all_reduce(out, group=self.process_group)
214
+ return out
215
+
216
+
217
+ try:
218
+ import dropout_layer_norm
219
+
220
+ class FastLayerNorm(nn.LayerNorm):
221
+ def forward(self, hidden_states, residual=None):
222
+ if hidden_states.shape[-1] > 8192:
223
+ if residual is not None:
224
+ hidden_states += residual
225
+ residual = hidden_states
226
+
227
+ return super(FastLayerNorm, self).forward(hidden_states), residual
228
+ else:
229
+ (
230
+ normed_hidden_states,
231
+ residual,
232
+ *rest,
233
+ ) = dropout_layer_norm.dropout_add_ln_fwd(
234
+ hidden_states,
235
+ residual,
236
+ self.weight,
237
+ self.bias,
238
+ None,
239
+ None,
240
+ None,
241
+ None,
242
+ 0.0,
243
+ self.eps,
244
+ 1.0,
245
+ 0,
246
+ None,
247
+ False,
248
+ False,
249
+ )
250
+ if residual is None:
251
+ residual = hidden_states
252
+
253
+ return normed_hidden_states, residual
254
+
255
+ except ImportError:
256
+ pass