Flux9665 commited on
Commit
b5649bf
·
1 Parent(s): e9f478c

Delete Utility/utils.py

Browse files
Files changed (1) hide show
  1. Utility/utils.py +0 -320
Utility/utils.py DELETED
@@ -1,320 +0,0 @@
1
- """
2
- Taken from ESPNet, modified by Florian Lux
3
- """
4
-
5
- import os
6
- from abc import ABC
7
-
8
- import torch
9
-
10
-
11
- def cumsum_durations(durations):
12
- out = [0]
13
- for duration in durations:
14
- out.append(duration + out[-1])
15
- centers = list()
16
- for index, _ in enumerate(out):
17
- if index + 1 < len(out):
18
- centers.append((out[index] + out[index + 1]) / 2)
19
- return out, centers
20
-
21
-
22
- def delete_old_checkpoints(checkpoint_dir, keep=5):
23
- checkpoint_list = list()
24
- for el in os.listdir(checkpoint_dir):
25
- if el.endswith(".pt") and el != "best.pt":
26
- checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
27
- if len(checkpoint_list) <= keep:
28
- return
29
- else:
30
- checkpoint_list.sort(reverse=False)
31
- checkpoints_to_delete = [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:-keep]]
32
- for old_checkpoint in checkpoints_to_delete:
33
- os.remove(os.path.join(old_checkpoint))
34
-
35
-
36
- def get_most_recent_checkpoint(checkpoint_dir, verbose=True):
37
- checkpoint_list = list()
38
- for el in os.listdir(checkpoint_dir):
39
- if el.endswith(".pt") and el != "best.pt":
40
- checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
41
- if len(checkpoint_list) == 0:
42
- print("No previous checkpoints found, cannot reload.")
43
- return None
44
- checkpoint_list.sort(reverse=True)
45
- if verbose:
46
- print("Reloading checkpoint_{}.pt".format(checkpoint_list[0]))
47
- return os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(checkpoint_list[0]))
48
-
49
-
50
- def make_pad_mask(lengths, xs=None, length_dim=-1, device=None):
51
- """
52
- Make mask tensor containing indices of padded part.
53
-
54
- Args:
55
- lengths (LongTensor or List): Batch of lengths (B,).
56
- xs (Tensor, optional): The reference tensor.
57
- If set, masks will be the same shape as this tensor.
58
- length_dim (int, optional): Dimension indicator of the above tensor.
59
- See the example.
60
-
61
- Returns:
62
- Tensor: Mask tensor containing indices of padded part.
63
- dtype=torch.uint8 in PyTorch 1.2-
64
- dtype=torch.bool in PyTorch 1.2+ (including 1.2)
65
-
66
- """
67
- if length_dim == 0:
68
- raise ValueError("length_dim cannot be 0: {}".format(length_dim))
69
-
70
- if not isinstance(lengths, list):
71
- lengths = lengths.tolist()
72
- bs = int(len(lengths))
73
- if xs is None:
74
- maxlen = int(max(lengths))
75
- else:
76
- maxlen = xs.size(length_dim)
77
-
78
- if device is not None:
79
- seq_range = torch.arange(0, maxlen, dtype=torch.int64, device=device)
80
- else:
81
- seq_range = torch.arange(0, maxlen, dtype=torch.int64)
82
- seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
83
- seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
84
- mask = seq_range_expand >= seq_length_expand
85
-
86
- if xs is not None:
87
- assert xs.size(0) == bs, (xs.size(0), bs)
88
-
89
- if length_dim < 0:
90
- length_dim = xs.dim() + length_dim
91
- # ind = (:, None, ..., None, :, , None, ..., None)
92
- ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
93
- mask = mask[ind].expand_as(xs).to(xs.device)
94
- return mask
95
-
96
-
97
- def make_non_pad_mask(lengths, xs=None, length_dim=-1, device=None):
98
- """
99
- Make mask tensor containing indices of non-padded part.
100
-
101
- Args:
102
- lengths (LongTensor or List): Batch of lengths (B,).
103
- xs (Tensor, optional): The reference tensor.
104
- If set, masks will be the same shape as this tensor.
105
- length_dim (int, optional): Dimension indicator of the above tensor.
106
- See the example.
107
-
108
- Returns:
109
- ByteTensor: mask tensor containing indices of padded part.
110
- dtype=torch.uint8 in PyTorch 1.2-
111
- dtype=torch.bool in PyTorch 1.2+ (including 1.2)
112
-
113
- """
114
- return ~make_pad_mask(lengths, xs, length_dim, device=device)
115
-
116
-
117
- def initialize(model, init):
118
- """
119
- Initialize weights of a neural network module.
120
-
121
- Parameters are initialized using the given method or distribution.
122
-
123
- Args:
124
- model: Target.
125
- init: Method of initialization.
126
- """
127
-
128
- # weight init
129
- for p in model.parameters():
130
- if p.dim() > 1:
131
- if init == "xavier_uniform":
132
- torch.nn.init.xavier_uniform_(p.data)
133
- elif init == "xavier_normal":
134
- torch.nn.init.xavier_normal_(p.data)
135
- elif init == "kaiming_uniform":
136
- torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
137
- elif init == "kaiming_normal":
138
- torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
139
- else:
140
- raise ValueError("Unknown initialization: " + init)
141
- # bias init
142
- for p in model.parameters():
143
- if p.dim() == 1:
144
- p.data.zero_()
145
-
146
- # reset some modules with default init
147
- for m in model.modules():
148
- if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)):
149
- m.reset_parameters()
150
-
151
-
152
- def pad_list(xs, pad_value):
153
- """
154
- Perform padding for the list of tensors.
155
-
156
- Args:
157
- xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
158
- pad_value (float): Value for padding.
159
-
160
- Returns:
161
- Tensor: Padded tensor (B, Tmax, `*`).
162
-
163
- """
164
- n_batch = len(xs)
165
- max_len = max(x.size(0) for x in xs)
166
- pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
167
-
168
- for i in range(n_batch):
169
- pad[i, : xs[i].size(0)] = xs[i]
170
-
171
- return pad
172
-
173
-
174
- def subsequent_mask(size, device="cpu", dtype=torch.bool):
175
- """
176
- Create mask for subsequent steps (size, size).
177
-
178
- :param int size: size of mask
179
- :param str device: "cpu" or "cuda" or torch.Tensor.device
180
- :param torch.dtype dtype: result dtype
181
- :rtype
182
- """
183
- ret = torch.ones(size, size, device=device, dtype=dtype)
184
- return torch.tril(ret, out=ret)
185
-
186
-
187
- class ScorerInterface:
188
- """
189
- Scorer interface for beam search.
190
-
191
- The scorer performs scoring of the all tokens in vocabulary.
192
-
193
- Examples:
194
- * Search heuristics
195
- * :class:`espnet.nets.scorers.length_bonus.LengthBonus`
196
- * Decoder networks of the sequence-to-sequence models
197
- * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
198
- * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
199
- * Neural language models
200
- * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
201
- * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
202
- * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
203
-
204
- """
205
-
206
- def init_state(self, x):
207
- """
208
- Get an initial state for decoding (optional).
209
-
210
- Args:
211
- x (torch.Tensor): The encoded feature tensor
212
-
213
- Returns: initial state
214
-
215
- """
216
- return None
217
-
218
- def select_state(self, state, i, new_id=None):
219
- """
220
- Select state with relative ids in the main beam search.
221
-
222
- Args:
223
- state: Decoder state for prefix tokens
224
- i (int): Index to select a state in the main beam search
225
- new_id (int): New label index to select a state if necessary
226
-
227
- Returns:
228
- state: pruned state
229
-
230
- """
231
- return None if state is None else state[i]
232
-
233
- def score(self, y, state, x):
234
- """
235
- Score new token (required).
236
-
237
- Args:
238
- y (torch.Tensor): 1D torch.int64 prefix tokens.
239
- state: Scorer state for prefix tokens
240
- x (torch.Tensor): The encoder feature that generates ys.
241
-
242
- Returns:
243
- tuple[torch.Tensor, Any]: Tuple of
244
- scores for next token that has a shape of `(n_vocab)`
245
- and next state for ys
246
-
247
- """
248
- raise NotImplementedError
249
-
250
- def final_score(self, state):
251
- """
252
- Score eos (optional).
253
-
254
- Args:
255
- state: Scorer state for prefix tokens
256
-
257
- Returns:
258
- float: final score
259
-
260
- """
261
- return 0.0
262
-
263
-
264
- class BatchScorerInterface(ScorerInterface, ABC):
265
-
266
- def batch_init_state(self, x):
267
- """
268
- Get an initial state for decoding (optional).
269
-
270
- Args:
271
- x (torch.Tensor): The encoded feature tensor
272
-
273
- Returns: initial state
274
-
275
- """
276
- return self.init_state(x)
277
-
278
- def batch_score(self, ys, states, xs):
279
- """
280
- Score new token batch (required).
281
-
282
- Args:
283
- ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
284
- states (List[Any]): Scorer states for prefix tokens.
285
- xs (torch.Tensor):
286
- The encoder feature that generates ys (n_batch, xlen, n_feat).
287
-
288
- Returns:
289
- tuple[torch.Tensor, List[Any]]: Tuple of
290
- batchfied scores for next token with shape of `(n_batch, n_vocab)`
291
- and next state list for ys.
292
-
293
- """
294
- scores = list()
295
- outstates = list()
296
- for i, (y, state, x) in enumerate(zip(ys, states, xs)):
297
- score, outstate = self.score(y, state, x)
298
- outstates.append(outstate)
299
- scores.append(score)
300
- scores = torch.cat(scores, 0).view(ys.shape[0], -1)
301
- return scores, outstates
302
-
303
-
304
- def to_device(m, x):
305
- """Send tensor into the device of the module.
306
- Args:
307
- m (torch.nn.Module): Torch module.
308
- x (Tensor): Torch tensor.
309
- Returns:
310
- Tensor: Torch tensor located in the same place as torch module.
311
- """
312
- if isinstance(m, torch.nn.Module):
313
- device = next(m.parameters()).device
314
- elif isinstance(m, torch.Tensor):
315
- device = m.device
316
- else:
317
- raise TypeError(
318
- "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
319
- )
320
- return x.to(device)