Bredvige commited on
Commit
e64573e
·
verified ·
1 Parent(s): 52e9757

Upload 4 files

Browse files
libs/jit/__init__.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import pickle
3
+ import time
4
+ import torch
5
+ from tqdm import tqdm
6
+ from collections import OrderedDict
7
+
8
+
9
+ def load_inputs(path, device, is_half=False):
10
+ parm = torch.load(path, map_location=torch.device("cpu"))
11
+ for key in parm.keys():
12
+ parm[key] = parm[key].to(device)
13
+ if is_half and parm[key].dtype == torch.float32:
14
+ parm[key] = parm[key].half()
15
+ elif not is_half and parm[key].dtype == torch.float16:
16
+ parm[key] = parm[key].float()
17
+ return parm
18
+
19
+
20
+ def benchmark(
21
+ model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False
22
+ ):
23
+ parm = load_inputs(inputs_path, device, is_half)
24
+ total_ts = 0.0
25
+ bar = tqdm(range(epoch))
26
+ for i in bar:
27
+ start_time = time.perf_counter()
28
+ o = model(**parm)
29
+ total_ts += time.perf_counter() - start_time
30
+ print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}")
31
+
32
+
33
+ def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False):
34
+ benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half)
35
+
36
+
37
+ def to_jit_model(
38
+ model_path,
39
+ model_type: str,
40
+ mode: str = "trace",
41
+ inputs_path: str = None,
42
+ device=torch.device("cpu"),
43
+ is_half=False,
44
+ ):
45
+ model = None
46
+ if model_type.lower() == "synthesizer":
47
+ from .get_synthesizer import get_synthesizer
48
+
49
+ model, _ = get_synthesizer(model_path, device)
50
+ model.forward = model.infer
51
+ elif model_type.lower() == "rmvpe":
52
+ from .get_rmvpe import get_rmvpe
53
+
54
+ model = get_rmvpe(model_path, device)
55
+ elif model_type.lower() == "hubert":
56
+ from .get_hubert import get_hubert_model
57
+
58
+ model = get_hubert_model(model_path, device)
59
+ model.forward = model.infer
60
+ else:
61
+ raise ValueError(f"No model type named {model_type}")
62
+ model = model.eval()
63
+ model = model.half() if is_half else model.float()
64
+ if mode == "trace":
65
+ assert not inputs_path
66
+ inputs = load_inputs(inputs_path, device, is_half)
67
+ model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
68
+ elif mode == "script":
69
+ model_jit = torch.jit.script(model)
70
+ model_jit.to(device)
71
+ model_jit = model_jit.half() if is_half else model_jit.float()
72
+ # model = model.half() if is_half else model.float()
73
+ return (model, model_jit)
74
+
75
+
76
+ def export(
77
+ model: torch.nn.Module,
78
+ mode: str = "trace",
79
+ inputs: dict = None,
80
+ device=torch.device("cpu"),
81
+ is_half: bool = False,
82
+ ) -> dict:
83
+ model = model.half() if is_half else model.float()
84
+ model.eval()
85
+ if mode == "trace":
86
+ assert inputs is not None
87
+ model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
88
+ elif mode == "script":
89
+ model_jit = torch.jit.script(model)
90
+ model_jit.to(device)
91
+ model_jit = model_jit.half() if is_half else model_jit.float()
92
+ buffer = BytesIO()
93
+ # model_jit=model_jit.cpu()
94
+ torch.jit.save(model_jit, buffer)
95
+ del model_jit
96
+ cpt = OrderedDict()
97
+ cpt["model"] = buffer.getvalue()
98
+ cpt["is_half"] = is_half
99
+ return cpt
100
+
101
+
102
+ def load(path: str):
103
+ with open(path, "rb") as f:
104
+ return pickle.load(f)
105
+
106
+
107
+ def save(ckpt: dict, save_path: str):
108
+ with open(save_path, "wb") as f:
109
+ pickle.dump(ckpt, f)
110
+
111
+
112
+ def rmvpe_jit_export(
113
+ model_path: str,
114
+ mode: str = "script",
115
+ inputs_path: str = None,
116
+ save_path: str = None,
117
+ device=torch.device("cpu"),
118
+ is_half=False,
119
+ ):
120
+ if not save_path:
121
+ save_path = model_path.rstrip(".pth")
122
+ save_path += ".half.jit" if is_half else ".jit"
123
+ if "cuda" in str(device) and ":" not in str(device):
124
+ device = torch.device("cuda:0")
125
+ from .get_rmvpe import get_rmvpe
126
+
127
+ model = get_rmvpe(model_path, device)
128
+ inputs = None
129
+ if mode == "trace":
130
+ inputs = load_inputs(inputs_path, device, is_half)
131
+ ckpt = export(model, mode, inputs, device, is_half)
132
+ ckpt["device"] = str(device)
133
+ save(ckpt, save_path)
134
+ return ckpt
135
+
136
+
137
+ def synthesizer_jit_export(
138
+ model_path: str,
139
+ mode: str = "script",
140
+ inputs_path: str = None,
141
+ save_path: str = None,
142
+ device=torch.device("cpu"),
143
+ is_half=False,
144
+ ):
145
+ if not save_path:
146
+ save_path = model_path.rstrip(".pth")
147
+ save_path += ".half.jit" if is_half else ".jit"
148
+ if "cuda" in str(device) and ":" not in str(device):
149
+ device = torch.device("cuda:0")
150
+ from .get_synthesizer import get_synthesizer
151
+
152
+ model, cpt = get_synthesizer(model_path, device)
153
+ assert isinstance(cpt, dict)
154
+ model.forward = model.infer
155
+ inputs = None
156
+ if mode == "trace":
157
+ inputs = load_inputs(inputs_path, device, is_half)
158
+ ckpt = export(model, mode, inputs, device, is_half)
159
+ cpt.pop("weight")
160
+ cpt["model"] = ckpt["model"]
161
+ cpt["device"] = device
162
+ save(cpt, save_path)
163
+ return cpt
libs/jit/get_hubert.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Optional, Tuple
4
+ from fairseq.checkpoint_utils import load_model_ensemble_and_task
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ # from fairseq.data.data_utils import compute_mask_indices
10
+ from fairseq.utils import index_put
11
+
12
+
13
+ # @torch.jit.script
14
+ def pad_to_multiple(x, multiple, dim=-1, value=0):
15
+ # Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
16
+ if x is None:
17
+ return None, 0
18
+ tsz = x.size(dim)
19
+ m = tsz / multiple
20
+ remainder = math.ceil(m) * multiple - tsz
21
+ if int(tsz % multiple) == 0:
22
+ return x, 0
23
+ pad_offset = (0,) * (-1 - dim) * 2
24
+
25
+ return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
26
+
27
+
28
+ def extract_features(
29
+ self,
30
+ x,
31
+ padding_mask=None,
32
+ tgt_layer=None,
33
+ min_layer=0,
34
+ ):
35
+ if padding_mask is not None:
36
+ x = index_put(x, padding_mask, 0)
37
+
38
+ x_conv = self.pos_conv(x.transpose(1, 2))
39
+ x_conv = x_conv.transpose(1, 2)
40
+ x = x + x_conv
41
+
42
+ if not self.layer_norm_first:
43
+ x = self.layer_norm(x)
44
+
45
+ # pad to the sequence length dimension
46
+ x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
47
+ if pad_length > 0 and padding_mask is None:
48
+ padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
49
+ padding_mask[:, -pad_length:] = True
50
+ else:
51
+ padding_mask, _ = pad_to_multiple(
52
+ padding_mask, self.required_seq_len_multiple, dim=-1, value=True
53
+ )
54
+ x = F.dropout(x, p=self.dropout, training=self.training)
55
+
56
+ # B x T x C -> T x B x C
57
+ x = x.transpose(0, 1)
58
+
59
+ layer_results = []
60
+ r = None
61
+ for i, layer in enumerate(self.layers):
62
+ dropout_probability = np.random.random() if self.layerdrop > 0 else 1
63
+ if not self.training or (dropout_probability > self.layerdrop):
64
+ x, (z, lr) = layer(
65
+ x, self_attn_padding_mask=padding_mask, need_weights=False
66
+ )
67
+ if i >= min_layer:
68
+ layer_results.append((x, z, lr))
69
+ if i == tgt_layer:
70
+ r = x
71
+ break
72
+
73
+ if r is not None:
74
+ x = r
75
+
76
+ # T x B x C -> B x T x C
77
+ x = x.transpose(0, 1)
78
+
79
+ # undo paddding
80
+ if pad_length > 0:
81
+ x = x[:, :-pad_length]
82
+
83
+ def undo_pad(a, b, c):
84
+ return (
85
+ a[:-pad_length],
86
+ b[:-pad_length] if b is not None else b,
87
+ c[:-pad_length],
88
+ )
89
+
90
+ layer_results = [undo_pad(*u) for u in layer_results]
91
+
92
+ return x, layer_results
93
+
94
+
95
+ def compute_mask_indices(
96
+ shape: Tuple[int, int],
97
+ padding_mask: Optional[torch.Tensor],
98
+ mask_prob: float,
99
+ mask_length: int,
100
+ mask_type: str = "static",
101
+ mask_other: float = 0.0,
102
+ min_masks: int = 0,
103
+ no_overlap: bool = False,
104
+ min_space: int = 0,
105
+ require_same_masks: bool = True,
106
+ mask_dropout: float = 0.0,
107
+ ) -> torch.Tensor:
108
+ """
109
+ Computes random mask spans for a given shape
110
+
111
+ Args:
112
+ shape: the the shape for which to compute masks.
113
+ should be of size 2 where first element is batch size and 2nd is timesteps
114
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
115
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
116
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
117
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
118
+ mask_type: how to compute mask lengths
119
+ static = fixed size
120
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
121
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
122
+ poisson = sample from possion distribution with lambda = mask length
123
+ min_masks: minimum number of masked spans
124
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
125
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
126
+ require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
127
+ mask_dropout: randomly dropout this percentage of masks in each example
128
+ """
129
+
130
+ bsz, all_sz = shape
131
+ mask = torch.full((bsz, all_sz), False)
132
+
133
+ all_num_mask = int(
134
+ # add a random number for probabilistic rounding
135
+ mask_prob * all_sz / float(mask_length)
136
+ + torch.rand([1]).item()
137
+ )
138
+
139
+ all_num_mask = max(min_masks, all_num_mask)
140
+
141
+ mask_idcs = []
142
+ for i in range(bsz):
143
+ if padding_mask is not None:
144
+ sz = all_sz - padding_mask[i].long().sum().item()
145
+ num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
146
+ num_mask = max(min_masks, num_mask)
147
+ else:
148
+ sz = all_sz
149
+ num_mask = all_num_mask
150
+
151
+ if mask_type == "static":
152
+ lengths = torch.full([num_mask], mask_length)
153
+ elif mask_type == "uniform":
154
+ lengths = torch.randint(mask_other, mask_length * 2 + 1, size=[num_mask])
155
+ elif mask_type == "normal":
156
+ lengths = torch.normal(mask_length, mask_other, size=[num_mask])
157
+ lengths = [max(1, int(round(x))) for x in lengths]
158
+ else:
159
+ raise Exception("unknown mask selection " + mask_type)
160
+
161
+ if sum(lengths) == 0:
162
+ lengths[0] = min(mask_length, sz - 1)
163
+
164
+ if no_overlap:
165
+ mask_idc = []
166
+
167
+ def arrange(s, e, length, keep_length):
168
+ span_start = torch.randint(low=s, high=e - length, size=[1]).item()
169
+ mask_idc.extend(span_start + i for i in range(length))
170
+
171
+ new_parts = []
172
+ if span_start - s - min_space >= keep_length:
173
+ new_parts.append((s, span_start - min_space + 1))
174
+ if e - span_start - length - min_space > keep_length:
175
+ new_parts.append((span_start + length + min_space, e))
176
+ return new_parts
177
+
178
+ parts = [(0, sz)]
179
+ min_length = min(lengths)
180
+ for length in sorted(lengths, reverse=True):
181
+ t = [e - s if e - s >= length + min_space else 0 for s, e in parts]
182
+ lens = torch.asarray(t, dtype=torch.int)
183
+ l_sum = torch.sum(lens)
184
+ if l_sum == 0:
185
+ break
186
+ probs = lens / torch.sum(lens)
187
+ c = torch.multinomial(probs.float(), len(parts)).item()
188
+ s, e = parts.pop(c)
189
+ parts.extend(arrange(s, e, length, min_length))
190
+ mask_idc = torch.asarray(mask_idc)
191
+ else:
192
+ min_len = min(lengths)
193
+ if sz - min_len <= num_mask:
194
+ min_len = sz - num_mask - 1
195
+ mask_idc = torch.asarray(
196
+ random.sample([i for i in range(sz - min_len)], num_mask)
197
+ )
198
+ mask_idc = torch.asarray(
199
+ [
200
+ mask_idc[j] + offset
201
+ for j in range(len(mask_idc))
202
+ for offset in range(lengths[j])
203
+ ]
204
+ )
205
+
206
+ mask_idcs.append(torch.unique(mask_idc[mask_idc < sz]))
207
+
208
+ min_len = min([len(m) for m in mask_idcs])
209
+ for i, mask_idc in enumerate(mask_idcs):
210
+ if isinstance(mask_idc, torch.Tensor):
211
+ mask_idc = torch.asarray(mask_idc, dtype=torch.float)
212
+ if len(mask_idc) > min_len and require_same_masks:
213
+ mask_idc = torch.asarray(
214
+ random.sample([i for i in range(mask_idc)], min_len)
215
+ )
216
+ if mask_dropout > 0:
217
+ num_holes = int(round(len(mask_idc) * mask_dropout))
218
+ mask_idc = torch.asarray(
219
+ random.sample([i for i in range(mask_idc)], len(mask_idc) - num_holes)
220
+ )
221
+
222
+ mask[i, mask_idc.int()] = True
223
+
224
+ return mask
225
+
226
+
227
+ def apply_mask(self, x, padding_mask, target_list):
228
+ B, T, C = x.shape
229
+ torch.zeros_like(x)
230
+ if self.mask_prob > 0:
231
+ mask_indices = compute_mask_indices(
232
+ (B, T),
233
+ padding_mask,
234
+ self.mask_prob,
235
+ self.mask_length,
236
+ self.mask_selection,
237
+ self.mask_other,
238
+ min_masks=2,
239
+ no_overlap=self.no_mask_overlap,
240
+ min_space=self.mask_min_space,
241
+ )
242
+ mask_indices = mask_indices.to(x.device)
243
+ x[mask_indices] = self.mask_emb
244
+ else:
245
+ mask_indices = None
246
+
247
+ if self.mask_channel_prob > 0:
248
+ mask_channel_indices = compute_mask_indices(
249
+ (B, C),
250
+ None,
251
+ self.mask_channel_prob,
252
+ self.mask_channel_length,
253
+ self.mask_channel_selection,
254
+ self.mask_channel_other,
255
+ no_overlap=self.no_mask_channel_overlap,
256
+ min_space=self.mask_channel_min_space,
257
+ )
258
+ mask_channel_indices = (
259
+ mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1)
260
+ )
261
+ x[mask_channel_indices] = 0
262
+
263
+ return x, mask_indices
264
+
265
+
266
+ def get_hubert_model(
267
+ model_path="assets/hubert/hubert_base.pt", device=torch.device("cpu")
268
+ ):
269
+ models, _, _ = load_model_ensemble_and_task(
270
+ [model_path],
271
+ suffix="",
272
+ )
273
+ hubert_model = models[0]
274
+ hubert_model = hubert_model.to(device)
275
+
276
+ def _apply_mask(x, padding_mask, target_list):
277
+ return apply_mask(hubert_model, x, padding_mask, target_list)
278
+
279
+ hubert_model.apply_mask = _apply_mask
280
+
281
+ def _extract_features(
282
+ x,
283
+ padding_mask=None,
284
+ tgt_layer=None,
285
+ min_layer=0,
286
+ ):
287
+ return extract_features(
288
+ hubert_model.encoder,
289
+ x,
290
+ padding_mask=padding_mask,
291
+ tgt_layer=tgt_layer,
292
+ min_layer=min_layer,
293
+ )
294
+
295
+ hubert_model.encoder.extract_features = _extract_features
296
+
297
+ hubert_model._forward = hubert_model.forward
298
+
299
+ def hubert_extract_features(
300
+ self,
301
+ source: torch.Tensor,
302
+ padding_mask: Optional[torch.Tensor] = None,
303
+ mask: bool = False,
304
+ ret_conv: bool = False,
305
+ output_layer: Optional[int] = None,
306
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
307
+ res = self._forward(
308
+ source,
309
+ padding_mask=padding_mask,
310
+ mask=mask,
311
+ features_only=True,
312
+ output_layer=output_layer,
313
+ )
314
+ feature = res["features"] if ret_conv else res["x"]
315
+ return feature, res["padding_mask"]
316
+
317
+ def _hubert_extract_features(
318
+ source: torch.Tensor,
319
+ padding_mask: Optional[torch.Tensor] = None,
320
+ mask: bool = False,
321
+ ret_conv: bool = False,
322
+ output_layer: Optional[int] = None,
323
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
324
+ return hubert_extract_features(
325
+ hubert_model, source, padding_mask, mask, ret_conv, output_layer
326
+ )
327
+
328
+ hubert_model.extract_features = _hubert_extract_features
329
+
330
+ def infer(source, padding_mask, output_layer: torch.Tensor):
331
+ output_layer = output_layer.item()
332
+ logits = hubert_model.extract_features(
333
+ source=source, padding_mask=padding_mask, output_layer=output_layer
334
+ )
335
+ feats = hubert_model.final_proj(logits[0]) if output_layer == 9 else logits[0]
336
+ return feats
337
+
338
+ hubert_model.infer = infer
339
+ # hubert_model.forward=infer
340
+ # hubert_model.forward
341
+
342
+ return hubert_model
libs/jit/get_rmvpe.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu")):
5
+ from infer.lib.rmvpe import E2E
6
+
7
+ model = E2E(4, 1, (2, 2))
8
+ ckpt = torch.load(model_path, map_location=device)
9
+ model.load_state_dict(ckpt)
10
+ model.eval()
11
+ model = model.to(device)
12
+ return model
libs/jit/get_synthesizer.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def get_synthesizer(pth_path, device=torch.device("cpu")):
5
+ from infer.lib.infer_pack.models import (
6
+ SynthesizerTrnMs256NSFsid,
7
+ SynthesizerTrnMs256NSFsid_nono,
8
+ SynthesizerTrnMs768NSFsid,
9
+ SynthesizerTrnMs768NSFsid_nono,
10
+ )
11
+
12
+ cpt = torch.load(pth_path, map_location=torch.device("cpu"))
13
+ # tgt_sr = cpt["config"][-1]
14
+ cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
15
+ if_f0 = cpt.get("f0", 1)
16
+ version = cpt.get("version", "v1")
17
+ if version == "v1":
18
+ if if_f0 == 1:
19
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=False)
20
+ else:
21
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
22
+ elif version == "v2":
23
+ if if_f0 == 1:
24
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=False)
25
+ else:
26
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
27
+ del net_g.enc_q
28
+ # net_g.forward = net_g.infer
29
+ # ckpt = {}
30
+ # ckpt["config"] = cpt["config"]
31
+ # ckpt["f0"] = if_f0
32
+ # ckpt["version"] = version
33
+ # ckpt["info"] = cpt.get("info", "0epoch")
34
+ net_g.load_state_dict(cpt["weight"], strict=False)
35
+ net_g = net_g.float()
36
+ net_g.eval().to(device)
37
+ net_g.remove_weight_norm()
38
+ return net_g, cpt