aiqcamp commited on
Commit
eb6df04
·
verified ·
1 Parent(s): 223feb1

Delete utils.py

Browse files
Files changed (1) hide show
  1. utils.py +0 -371
utils.py DELETED
@@ -1,371 +0,0 @@
1
- import os
2
- import glob
3
- import sys
4
- import argparse
5
- import logging
6
- import json
7
- import subprocess
8
- import traceback
9
-
10
- import librosa
11
- import numpy as np
12
- from scipy.io.wavfile import read
13
- import torch
14
- import logging
15
-
16
- logging.getLogger("numba").setLevel(logging.ERROR)
17
- logging.getLogger("matplotlib").setLevel(logging.ERROR)
18
-
19
- MATPLOTLIB_FLAG = False
20
-
21
- logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
22
- logger = logging
23
-
24
-
25
- def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
26
- assert os.path.isfile(checkpoint_path)
27
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
28
- iteration = checkpoint_dict["iteration"]
29
- learning_rate = checkpoint_dict["learning_rate"]
30
- if (
31
- optimizer is not None
32
- and not skip_optimizer
33
- and checkpoint_dict["optimizer"] is not None
34
- ):
35
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
36
- saved_state_dict = checkpoint_dict["model"]
37
- if hasattr(model, "module"):
38
- state_dict = model.module.state_dict()
39
- else:
40
- state_dict = model.state_dict()
41
- new_state_dict = {}
42
- for k, v in state_dict.items():
43
- try:
44
- # assert "quantizer" not in k
45
- # print("load", k)
46
- new_state_dict[k] = saved_state_dict[k]
47
- assert saved_state_dict[k].shape == v.shape, (
48
- saved_state_dict[k].shape,
49
- v.shape,
50
- )
51
- except:
52
- traceback.print_exc()
53
- print(
54
- "error, %s is not in the checkpoint" % k
55
- ) # shape不对也会,比如text_embedding当cleaner修改时
56
- new_state_dict[k] = v
57
- if hasattr(model, "module"):
58
- model.module.load_state_dict(new_state_dict)
59
- else:
60
- model.load_state_dict(new_state_dict)
61
- print("load ")
62
- logger.info(
63
- "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
64
- )
65
- return model, optimizer, learning_rate, iteration
66
-
67
- from time import time as ttime
68
- import shutil
69
- def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
70
- dir=os.path.dirname(path)
71
- name=os.path.basename(path)
72
- tmp_path="%s.pth"%(ttime())
73
- torch.save(fea,tmp_path)
74
- shutil.move(tmp_path,"%s/%s"%(dir,name))
75
-
76
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
77
- logger.info(
78
- "Saving model and optimizer state at iteration {} to {}".format(
79
- iteration, checkpoint_path
80
- )
81
- )
82
- if hasattr(model, "module"):
83
- state_dict = model.module.state_dict()
84
- else:
85
- state_dict = model.state_dict()
86
- # torch.save(
87
- my_save(
88
- {
89
- "model": state_dict,
90
- "iteration": iteration,
91
- "optimizer": optimizer.state_dict(),
92
- "learning_rate": learning_rate,
93
- },
94
- checkpoint_path,
95
- )
96
-
97
-
98
- def summarize(
99
- writer,
100
- global_step,
101
- scalars={},
102
- histograms={},
103
- images={},
104
- audios={},
105
- audio_sampling_rate=22050,
106
- ):
107
- for k, v in scalars.items():
108
- writer.add_scalar(k, v, global_step)
109
- for k, v in histograms.items():
110
- writer.add_histogram(k, v, global_step)
111
- for k, v in images.items():
112
- writer.add_image(k, v, global_step, dataformats="HWC")
113
- for k, v in audios.items():
114
- writer.add_audio(k, v, global_step, audio_sampling_rate)
115
-
116
-
117
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
118
- f_list = glob.glob(os.path.join(dir_path, regex))
119
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
120
- x = f_list[-1]
121
- print(x)
122
- return x
123
-
124
-
125
- def plot_spectrogram_to_numpy(spectrogram):
126
- global MATPLOTLIB_FLAG
127
- if not MATPLOTLIB_FLAG:
128
- import matplotlib
129
-
130
- matplotlib.use("Agg")
131
- MATPLOTLIB_FLAG = True
132
- mpl_logger = logging.getLogger("matplotlib")
133
- mpl_logger.setLevel(logging.WARNING)
134
- import matplotlib.pylab as plt
135
- import numpy as np
136
-
137
- fig, ax = plt.subplots(figsize=(10, 2))
138
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
139
- plt.colorbar(im, ax=ax)
140
- plt.xlabel("Frames")
141
- plt.ylabel("Channels")
142
- plt.tight_layout()
143
-
144
- fig.canvas.draw()
145
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
146
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
147
- plt.close()
148
- return data
149
-
150
-
151
- def plot_alignment_to_numpy(alignment, info=None):
152
- global MATPLOTLIB_FLAG
153
- if not MATPLOTLIB_FLAG:
154
- import matplotlib
155
-
156
- matplotlib.use("Agg")
157
- MATPLOTLIB_FLAG = True
158
- mpl_logger = logging.getLogger("matplotlib")
159
- mpl_logger.setLevel(logging.WARNING)
160
- import matplotlib.pylab as plt
161
- import numpy as np
162
-
163
- fig, ax = plt.subplots(figsize=(6, 4))
164
- im = ax.imshow(
165
- alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
166
- )
167
- fig.colorbar(im, ax=ax)
168
- xlabel = "Decoder timestep"
169
- if info is not None:
170
- xlabel += "\n\n" + info
171
- plt.xlabel(xlabel)
172
- plt.ylabel("Encoder timestep")
173
- plt.tight_layout()
174
-
175
- fig.canvas.draw()
176
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
177
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
178
- plt.close()
179
- return data
180
-
181
-
182
- def load_wav_to_torch(full_path):
183
- data, sampling_rate = librosa.load(full_path, sr=None)
184
- return torch.FloatTensor(data), sampling_rate
185
-
186
-
187
- def load_filepaths_and_text(filename, split="|"):
188
- with open(filename, encoding="utf-8") as f:
189
- filepaths_and_text = [line.strip().split(split) for line in f]
190
- return filepaths_and_text
191
-
192
-
193
- def get_hparams(init=True, stage=1):
194
- parser = argparse.ArgumentParser()
195
- parser.add_argument(
196
- "-c",
197
- "--config",
198
- type=str,
199
- default="./configs/s2.json",
200
- help="JSON file for configuration",
201
- )
202
- parser.add_argument(
203
- "-p", "--pretrain", type=str, required=False, default=None, help="pretrain dir"
204
- )
205
- parser.add_argument(
206
- "-rs",
207
- "--resume_step",
208
- type=int,
209
- required=False,
210
- default=None,
211
- help="resume step",
212
- )
213
- # parser.add_argument('-e', '--exp_dir', type=str, required=False,default=None,help='experiment directory')
214
- # parser.add_argument('-g', '--pretrained_s2G', type=str, required=False,default=None,help='pretrained sovits gererator weights')
215
- # parser.add_argument('-d', '--pretrained_s2D', type=str, required=False,default=None,help='pretrained sovits discriminator weights')
216
-
217
- args = parser.parse_args()
218
-
219
- config_path = args.config
220
- with open(config_path, "r") as f:
221
- data = f.read()
222
- config = json.loads(data)
223
-
224
- hparams = HParams(**config)
225
- hparams.pretrain = args.pretrain
226
- hparams.resume_step = args.resume_step
227
- # hparams.data.exp_dir = args.exp_dir
228
- if stage == 1:
229
- model_dir = hparams.s1_ckpt_dir
230
- else:
231
- model_dir = hparams.s2_ckpt_dir
232
- config_save_path = os.path.join(model_dir, "config.json")
233
-
234
- if not os.path.exists(model_dir):
235
- os.makedirs(model_dir)
236
-
237
- with open(config_save_path, "w") as f:
238
- f.write(data)
239
- return hparams
240
-
241
-
242
- def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
243
- """Freeing up space by deleting saved ckpts
244
-
245
- Arguments:
246
- path_to_models -- Path to the model directory
247
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
248
- sort_by_time -- True -> chronologically delete ckpts
249
- False -> lexicographically delete ckpts
250
- """
251
- import re
252
-
253
- ckpts_files = [
254
- f
255
- for f in os.listdir(path_to_models)
256
- if os.path.isfile(os.path.join(path_to_models, f))
257
- ]
258
- name_key = lambda _f: int(re.compile("._(\d+)\.pth").match(_f).group(1))
259
- time_key = lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))
260
- sort_key = time_key if sort_by_time else name_key
261
- x_sorted = lambda _x: sorted(
262
- [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
263
- key=sort_key,
264
- )
265
- to_del = [
266
- os.path.join(path_to_models, fn)
267
- for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
268
- ]
269
- del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
270
- del_routine = lambda x: [os.remove(x), del_info(x)]
271
- rs = [del_routine(fn) for fn in to_del]
272
-
273
-
274
- def get_hparams_from_dir(model_dir):
275
- config_save_path = os.path.join(model_dir, "config.json")
276
- with open(config_save_path, "r") as f:
277
- data = f.read()
278
- config = json.loads(data)
279
-
280
- hparams = HParams(**config)
281
- hparams.model_dir = model_dir
282
- return hparams
283
-
284
-
285
- def get_hparams_from_file(config_path):
286
- with open(config_path, "r") as f:
287
- data = f.read()
288
- config = json.loads(data)
289
-
290
- hparams = HParams(**config)
291
- return hparams
292
-
293
-
294
- def check_git_hash(model_dir):
295
- source_dir = os.path.dirname(os.path.realpath(__file__))
296
- if not os.path.exists(os.path.join(source_dir, ".git")):
297
- logger.warn(
298
- "{} is not a git repository, therefore hash value comparison will be ignored.".format(
299
- source_dir
300
- )
301
- )
302
- return
303
-
304
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
305
-
306
- path = os.path.join(model_dir, "githash")
307
- if os.path.exists(path):
308
- saved_hash = open(path).read()
309
- if saved_hash != cur_hash:
310
- logger.warn(
311
- "git hash values are different. {}(saved) != {}(current)".format(
312
- saved_hash[:8], cur_hash[:8]
313
- )
314
- )
315
- else:
316
- open(path, "w").write(cur_hash)
317
-
318
-
319
- def get_logger(model_dir, filename="train.log"):
320
- global logger
321
- logger = logging.getLogger(os.path.basename(model_dir))
322
- logger.setLevel(logging.DEBUG)
323
-
324
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
325
- if not os.path.exists(model_dir):
326
- os.makedirs(model_dir)
327
- h = logging.FileHandler(os.path.join(model_dir, filename))
328
- h.setLevel(logging.DEBUG)
329
- h.setFormatter(formatter)
330
- logger.addHandler(h)
331
- return logger
332
-
333
-
334
- class HParams:
335
- def __init__(self, **kwargs):
336
- for k, v in kwargs.items():
337
- if type(v) == dict:
338
- v = HParams(**v)
339
- self[k] = v
340
-
341
- def keys(self):
342
- return self.__dict__.keys()
343
-
344
- def items(self):
345
- return self.__dict__.items()
346
-
347
- def values(self):
348
- return self.__dict__.values()
349
-
350
- def __len__(self):
351
- return len(self.__dict__)
352
-
353
- def __getitem__(self, key):
354
- return getattr(self, key)
355
-
356
- def __setitem__(self, key, value):
357
- return setattr(self, key, value)
358
-
359
- def __contains__(self, key):
360
- return key in self.__dict__
361
-
362
- def __repr__(self):
363
- return self.__dict__.__repr__()
364
-
365
-
366
- if __name__ == "__main__":
367
- print(
368
- load_wav_to_torch(
369
- "/home/fish/wenetspeech/dataset_vq/Y0000022499_wHFSeHEx9CM/S00261.flac"
370
- )
371
- )