Upload 4 files
Browse files- libs/jit/__init__.py +163 -0
- libs/jit/get_hubert.py +342 -0
- libs/jit/get_rmvpe.py +12 -0
- libs/jit/get_synthesizer.py +38 -0
libs/jit/__init__.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
import pickle
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
|
9 |
+
def load_inputs(path, device, is_half=False):
|
10 |
+
parm = torch.load(path, map_location=torch.device("cpu"))
|
11 |
+
for key in parm.keys():
|
12 |
+
parm[key] = parm[key].to(device)
|
13 |
+
if is_half and parm[key].dtype == torch.float32:
|
14 |
+
parm[key] = parm[key].half()
|
15 |
+
elif not is_half and parm[key].dtype == torch.float16:
|
16 |
+
parm[key] = parm[key].float()
|
17 |
+
return parm
|
18 |
+
|
19 |
+
|
20 |
+
def benchmark(
|
21 |
+
model, inputs_path, device=torch.device("cpu"), epoch=1000, is_half=False
|
22 |
+
):
|
23 |
+
parm = load_inputs(inputs_path, device, is_half)
|
24 |
+
total_ts = 0.0
|
25 |
+
bar = tqdm(range(epoch))
|
26 |
+
for i in bar:
|
27 |
+
start_time = time.perf_counter()
|
28 |
+
o = model(**parm)
|
29 |
+
total_ts += time.perf_counter() - start_time
|
30 |
+
print(f"num_epoch: {epoch} | avg time(ms): {(total_ts*1000)/epoch}")
|
31 |
+
|
32 |
+
|
33 |
+
def jit_warm_up(model, inputs_path, device=torch.device("cpu"), epoch=5, is_half=False):
|
34 |
+
benchmark(model, inputs_path, device, epoch=epoch, is_half=is_half)
|
35 |
+
|
36 |
+
|
37 |
+
def to_jit_model(
|
38 |
+
model_path,
|
39 |
+
model_type: str,
|
40 |
+
mode: str = "trace",
|
41 |
+
inputs_path: str = None,
|
42 |
+
device=torch.device("cpu"),
|
43 |
+
is_half=False,
|
44 |
+
):
|
45 |
+
model = None
|
46 |
+
if model_type.lower() == "synthesizer":
|
47 |
+
from .get_synthesizer import get_synthesizer
|
48 |
+
|
49 |
+
model, _ = get_synthesizer(model_path, device)
|
50 |
+
model.forward = model.infer
|
51 |
+
elif model_type.lower() == "rmvpe":
|
52 |
+
from .get_rmvpe import get_rmvpe
|
53 |
+
|
54 |
+
model = get_rmvpe(model_path, device)
|
55 |
+
elif model_type.lower() == "hubert":
|
56 |
+
from .get_hubert import get_hubert_model
|
57 |
+
|
58 |
+
model = get_hubert_model(model_path, device)
|
59 |
+
model.forward = model.infer
|
60 |
+
else:
|
61 |
+
raise ValueError(f"No model type named {model_type}")
|
62 |
+
model = model.eval()
|
63 |
+
model = model.half() if is_half else model.float()
|
64 |
+
if mode == "trace":
|
65 |
+
assert not inputs_path
|
66 |
+
inputs = load_inputs(inputs_path, device, is_half)
|
67 |
+
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
68 |
+
elif mode == "script":
|
69 |
+
model_jit = torch.jit.script(model)
|
70 |
+
model_jit.to(device)
|
71 |
+
model_jit = model_jit.half() if is_half else model_jit.float()
|
72 |
+
# model = model.half() if is_half else model.float()
|
73 |
+
return (model, model_jit)
|
74 |
+
|
75 |
+
|
76 |
+
def export(
|
77 |
+
model: torch.nn.Module,
|
78 |
+
mode: str = "trace",
|
79 |
+
inputs: dict = None,
|
80 |
+
device=torch.device("cpu"),
|
81 |
+
is_half: bool = False,
|
82 |
+
) -> dict:
|
83 |
+
model = model.half() if is_half else model.float()
|
84 |
+
model.eval()
|
85 |
+
if mode == "trace":
|
86 |
+
assert inputs is not None
|
87 |
+
model_jit = torch.jit.trace(model, example_kwarg_inputs=inputs)
|
88 |
+
elif mode == "script":
|
89 |
+
model_jit = torch.jit.script(model)
|
90 |
+
model_jit.to(device)
|
91 |
+
model_jit = model_jit.half() if is_half else model_jit.float()
|
92 |
+
buffer = BytesIO()
|
93 |
+
# model_jit=model_jit.cpu()
|
94 |
+
torch.jit.save(model_jit, buffer)
|
95 |
+
del model_jit
|
96 |
+
cpt = OrderedDict()
|
97 |
+
cpt["model"] = buffer.getvalue()
|
98 |
+
cpt["is_half"] = is_half
|
99 |
+
return cpt
|
100 |
+
|
101 |
+
|
102 |
+
def load(path: str):
|
103 |
+
with open(path, "rb") as f:
|
104 |
+
return pickle.load(f)
|
105 |
+
|
106 |
+
|
107 |
+
def save(ckpt: dict, save_path: str):
|
108 |
+
with open(save_path, "wb") as f:
|
109 |
+
pickle.dump(ckpt, f)
|
110 |
+
|
111 |
+
|
112 |
+
def rmvpe_jit_export(
|
113 |
+
model_path: str,
|
114 |
+
mode: str = "script",
|
115 |
+
inputs_path: str = None,
|
116 |
+
save_path: str = None,
|
117 |
+
device=torch.device("cpu"),
|
118 |
+
is_half=False,
|
119 |
+
):
|
120 |
+
if not save_path:
|
121 |
+
save_path = model_path.rstrip(".pth")
|
122 |
+
save_path += ".half.jit" if is_half else ".jit"
|
123 |
+
if "cuda" in str(device) and ":" not in str(device):
|
124 |
+
device = torch.device("cuda:0")
|
125 |
+
from .get_rmvpe import get_rmvpe
|
126 |
+
|
127 |
+
model = get_rmvpe(model_path, device)
|
128 |
+
inputs = None
|
129 |
+
if mode == "trace":
|
130 |
+
inputs = load_inputs(inputs_path, device, is_half)
|
131 |
+
ckpt = export(model, mode, inputs, device, is_half)
|
132 |
+
ckpt["device"] = str(device)
|
133 |
+
save(ckpt, save_path)
|
134 |
+
return ckpt
|
135 |
+
|
136 |
+
|
137 |
+
def synthesizer_jit_export(
|
138 |
+
model_path: str,
|
139 |
+
mode: str = "script",
|
140 |
+
inputs_path: str = None,
|
141 |
+
save_path: str = None,
|
142 |
+
device=torch.device("cpu"),
|
143 |
+
is_half=False,
|
144 |
+
):
|
145 |
+
if not save_path:
|
146 |
+
save_path = model_path.rstrip(".pth")
|
147 |
+
save_path += ".half.jit" if is_half else ".jit"
|
148 |
+
if "cuda" in str(device) and ":" not in str(device):
|
149 |
+
device = torch.device("cuda:0")
|
150 |
+
from .get_synthesizer import get_synthesizer
|
151 |
+
|
152 |
+
model, cpt = get_synthesizer(model_path, device)
|
153 |
+
assert isinstance(cpt, dict)
|
154 |
+
model.forward = model.infer
|
155 |
+
inputs = None
|
156 |
+
if mode == "trace":
|
157 |
+
inputs = load_inputs(inputs_path, device, is_half)
|
158 |
+
ckpt = export(model, mode, inputs, device, is_half)
|
159 |
+
cpt.pop("weight")
|
160 |
+
cpt["model"] = ckpt["model"]
|
161 |
+
cpt["device"] = device
|
162 |
+
save(cpt, save_path)
|
163 |
+
return cpt
|
libs/jit/get_hubert.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
from fairseq.checkpoint_utils import load_model_ensemble_and_task
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
# from fairseq.data.data_utils import compute_mask_indices
|
10 |
+
from fairseq.utils import index_put
|
11 |
+
|
12 |
+
|
13 |
+
# @torch.jit.script
|
14 |
+
def pad_to_multiple(x, multiple, dim=-1, value=0):
|
15 |
+
# Inspired from https://github.com/lucidrains/local-attention/blob/master/local_attention/local_attention.py#L41
|
16 |
+
if x is None:
|
17 |
+
return None, 0
|
18 |
+
tsz = x.size(dim)
|
19 |
+
m = tsz / multiple
|
20 |
+
remainder = math.ceil(m) * multiple - tsz
|
21 |
+
if int(tsz % multiple) == 0:
|
22 |
+
return x, 0
|
23 |
+
pad_offset = (0,) * (-1 - dim) * 2
|
24 |
+
|
25 |
+
return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder
|
26 |
+
|
27 |
+
|
28 |
+
def extract_features(
|
29 |
+
self,
|
30 |
+
x,
|
31 |
+
padding_mask=None,
|
32 |
+
tgt_layer=None,
|
33 |
+
min_layer=0,
|
34 |
+
):
|
35 |
+
if padding_mask is not None:
|
36 |
+
x = index_put(x, padding_mask, 0)
|
37 |
+
|
38 |
+
x_conv = self.pos_conv(x.transpose(1, 2))
|
39 |
+
x_conv = x_conv.transpose(1, 2)
|
40 |
+
x = x + x_conv
|
41 |
+
|
42 |
+
if not self.layer_norm_first:
|
43 |
+
x = self.layer_norm(x)
|
44 |
+
|
45 |
+
# pad to the sequence length dimension
|
46 |
+
x, pad_length = pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
|
47 |
+
if pad_length > 0 and padding_mask is None:
|
48 |
+
padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
|
49 |
+
padding_mask[:, -pad_length:] = True
|
50 |
+
else:
|
51 |
+
padding_mask, _ = pad_to_multiple(
|
52 |
+
padding_mask, self.required_seq_len_multiple, dim=-1, value=True
|
53 |
+
)
|
54 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
55 |
+
|
56 |
+
# B x T x C -> T x B x C
|
57 |
+
x = x.transpose(0, 1)
|
58 |
+
|
59 |
+
layer_results = []
|
60 |
+
r = None
|
61 |
+
for i, layer in enumerate(self.layers):
|
62 |
+
dropout_probability = np.random.random() if self.layerdrop > 0 else 1
|
63 |
+
if not self.training or (dropout_probability > self.layerdrop):
|
64 |
+
x, (z, lr) = layer(
|
65 |
+
x, self_attn_padding_mask=padding_mask, need_weights=False
|
66 |
+
)
|
67 |
+
if i >= min_layer:
|
68 |
+
layer_results.append((x, z, lr))
|
69 |
+
if i == tgt_layer:
|
70 |
+
r = x
|
71 |
+
break
|
72 |
+
|
73 |
+
if r is not None:
|
74 |
+
x = r
|
75 |
+
|
76 |
+
# T x B x C -> B x T x C
|
77 |
+
x = x.transpose(0, 1)
|
78 |
+
|
79 |
+
# undo paddding
|
80 |
+
if pad_length > 0:
|
81 |
+
x = x[:, :-pad_length]
|
82 |
+
|
83 |
+
def undo_pad(a, b, c):
|
84 |
+
return (
|
85 |
+
a[:-pad_length],
|
86 |
+
b[:-pad_length] if b is not None else b,
|
87 |
+
c[:-pad_length],
|
88 |
+
)
|
89 |
+
|
90 |
+
layer_results = [undo_pad(*u) for u in layer_results]
|
91 |
+
|
92 |
+
return x, layer_results
|
93 |
+
|
94 |
+
|
95 |
+
def compute_mask_indices(
|
96 |
+
shape: Tuple[int, int],
|
97 |
+
padding_mask: Optional[torch.Tensor],
|
98 |
+
mask_prob: float,
|
99 |
+
mask_length: int,
|
100 |
+
mask_type: str = "static",
|
101 |
+
mask_other: float = 0.0,
|
102 |
+
min_masks: int = 0,
|
103 |
+
no_overlap: bool = False,
|
104 |
+
min_space: int = 0,
|
105 |
+
require_same_masks: bool = True,
|
106 |
+
mask_dropout: float = 0.0,
|
107 |
+
) -> torch.Tensor:
|
108 |
+
"""
|
109 |
+
Computes random mask spans for a given shape
|
110 |
+
|
111 |
+
Args:
|
112 |
+
shape: the the shape for which to compute masks.
|
113 |
+
should be of size 2 where first element is batch size and 2nd is timesteps
|
114 |
+
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
115 |
+
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
116 |
+
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
117 |
+
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
118 |
+
mask_type: how to compute mask lengths
|
119 |
+
static = fixed size
|
120 |
+
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
121 |
+
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
122 |
+
poisson = sample from possion distribution with lambda = mask length
|
123 |
+
min_masks: minimum number of masked spans
|
124 |
+
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
125 |
+
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
126 |
+
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
|
127 |
+
mask_dropout: randomly dropout this percentage of masks in each example
|
128 |
+
"""
|
129 |
+
|
130 |
+
bsz, all_sz = shape
|
131 |
+
mask = torch.full((bsz, all_sz), False)
|
132 |
+
|
133 |
+
all_num_mask = int(
|
134 |
+
# add a random number for probabilistic rounding
|
135 |
+
mask_prob * all_sz / float(mask_length)
|
136 |
+
+ torch.rand([1]).item()
|
137 |
+
)
|
138 |
+
|
139 |
+
all_num_mask = max(min_masks, all_num_mask)
|
140 |
+
|
141 |
+
mask_idcs = []
|
142 |
+
for i in range(bsz):
|
143 |
+
if padding_mask is not None:
|
144 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
145 |
+
num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
|
146 |
+
num_mask = max(min_masks, num_mask)
|
147 |
+
else:
|
148 |
+
sz = all_sz
|
149 |
+
num_mask = all_num_mask
|
150 |
+
|
151 |
+
if mask_type == "static":
|
152 |
+
lengths = torch.full([num_mask], mask_length)
|
153 |
+
elif mask_type == "uniform":
|
154 |
+
lengths = torch.randint(mask_other, mask_length * 2 + 1, size=[num_mask])
|
155 |
+
elif mask_type == "normal":
|
156 |
+
lengths = torch.normal(mask_length, mask_other, size=[num_mask])
|
157 |
+
lengths = [max(1, int(round(x))) for x in lengths]
|
158 |
+
else:
|
159 |
+
raise Exception("unknown mask selection " + mask_type)
|
160 |
+
|
161 |
+
if sum(lengths) == 0:
|
162 |
+
lengths[0] = min(mask_length, sz - 1)
|
163 |
+
|
164 |
+
if no_overlap:
|
165 |
+
mask_idc = []
|
166 |
+
|
167 |
+
def arrange(s, e, length, keep_length):
|
168 |
+
span_start = torch.randint(low=s, high=e - length, size=[1]).item()
|
169 |
+
mask_idc.extend(span_start + i for i in range(length))
|
170 |
+
|
171 |
+
new_parts = []
|
172 |
+
if span_start - s - min_space >= keep_length:
|
173 |
+
new_parts.append((s, span_start - min_space + 1))
|
174 |
+
if e - span_start - length - min_space > keep_length:
|
175 |
+
new_parts.append((span_start + length + min_space, e))
|
176 |
+
return new_parts
|
177 |
+
|
178 |
+
parts = [(0, sz)]
|
179 |
+
min_length = min(lengths)
|
180 |
+
for length in sorted(lengths, reverse=True):
|
181 |
+
t = [e - s if e - s >= length + min_space else 0 for s, e in parts]
|
182 |
+
lens = torch.asarray(t, dtype=torch.int)
|
183 |
+
l_sum = torch.sum(lens)
|
184 |
+
if l_sum == 0:
|
185 |
+
break
|
186 |
+
probs = lens / torch.sum(lens)
|
187 |
+
c = torch.multinomial(probs.float(), len(parts)).item()
|
188 |
+
s, e = parts.pop(c)
|
189 |
+
parts.extend(arrange(s, e, length, min_length))
|
190 |
+
mask_idc = torch.asarray(mask_idc)
|
191 |
+
else:
|
192 |
+
min_len = min(lengths)
|
193 |
+
if sz - min_len <= num_mask:
|
194 |
+
min_len = sz - num_mask - 1
|
195 |
+
mask_idc = torch.asarray(
|
196 |
+
random.sample([i for i in range(sz - min_len)], num_mask)
|
197 |
+
)
|
198 |
+
mask_idc = torch.asarray(
|
199 |
+
[
|
200 |
+
mask_idc[j] + offset
|
201 |
+
for j in range(len(mask_idc))
|
202 |
+
for offset in range(lengths[j])
|
203 |
+
]
|
204 |
+
)
|
205 |
+
|
206 |
+
mask_idcs.append(torch.unique(mask_idc[mask_idc < sz]))
|
207 |
+
|
208 |
+
min_len = min([len(m) for m in mask_idcs])
|
209 |
+
for i, mask_idc in enumerate(mask_idcs):
|
210 |
+
if isinstance(mask_idc, torch.Tensor):
|
211 |
+
mask_idc = torch.asarray(mask_idc, dtype=torch.float)
|
212 |
+
if len(mask_idc) > min_len and require_same_masks:
|
213 |
+
mask_idc = torch.asarray(
|
214 |
+
random.sample([i for i in range(mask_idc)], min_len)
|
215 |
+
)
|
216 |
+
if mask_dropout > 0:
|
217 |
+
num_holes = int(round(len(mask_idc) * mask_dropout))
|
218 |
+
mask_idc = torch.asarray(
|
219 |
+
random.sample([i for i in range(mask_idc)], len(mask_idc) - num_holes)
|
220 |
+
)
|
221 |
+
|
222 |
+
mask[i, mask_idc.int()] = True
|
223 |
+
|
224 |
+
return mask
|
225 |
+
|
226 |
+
|
227 |
+
def apply_mask(self, x, padding_mask, target_list):
|
228 |
+
B, T, C = x.shape
|
229 |
+
torch.zeros_like(x)
|
230 |
+
if self.mask_prob > 0:
|
231 |
+
mask_indices = compute_mask_indices(
|
232 |
+
(B, T),
|
233 |
+
padding_mask,
|
234 |
+
self.mask_prob,
|
235 |
+
self.mask_length,
|
236 |
+
self.mask_selection,
|
237 |
+
self.mask_other,
|
238 |
+
min_masks=2,
|
239 |
+
no_overlap=self.no_mask_overlap,
|
240 |
+
min_space=self.mask_min_space,
|
241 |
+
)
|
242 |
+
mask_indices = mask_indices.to(x.device)
|
243 |
+
x[mask_indices] = self.mask_emb
|
244 |
+
else:
|
245 |
+
mask_indices = None
|
246 |
+
|
247 |
+
if self.mask_channel_prob > 0:
|
248 |
+
mask_channel_indices = compute_mask_indices(
|
249 |
+
(B, C),
|
250 |
+
None,
|
251 |
+
self.mask_channel_prob,
|
252 |
+
self.mask_channel_length,
|
253 |
+
self.mask_channel_selection,
|
254 |
+
self.mask_channel_other,
|
255 |
+
no_overlap=self.no_mask_channel_overlap,
|
256 |
+
min_space=self.mask_channel_min_space,
|
257 |
+
)
|
258 |
+
mask_channel_indices = (
|
259 |
+
mask_channel_indices.to(x.device).unsqueeze(1).expand(-1, T, -1)
|
260 |
+
)
|
261 |
+
x[mask_channel_indices] = 0
|
262 |
+
|
263 |
+
return x, mask_indices
|
264 |
+
|
265 |
+
|
266 |
+
def get_hubert_model(
|
267 |
+
model_path="assets/hubert/hubert_base.pt", device=torch.device("cpu")
|
268 |
+
):
|
269 |
+
models, _, _ = load_model_ensemble_and_task(
|
270 |
+
[model_path],
|
271 |
+
suffix="",
|
272 |
+
)
|
273 |
+
hubert_model = models[0]
|
274 |
+
hubert_model = hubert_model.to(device)
|
275 |
+
|
276 |
+
def _apply_mask(x, padding_mask, target_list):
|
277 |
+
return apply_mask(hubert_model, x, padding_mask, target_list)
|
278 |
+
|
279 |
+
hubert_model.apply_mask = _apply_mask
|
280 |
+
|
281 |
+
def _extract_features(
|
282 |
+
x,
|
283 |
+
padding_mask=None,
|
284 |
+
tgt_layer=None,
|
285 |
+
min_layer=0,
|
286 |
+
):
|
287 |
+
return extract_features(
|
288 |
+
hubert_model.encoder,
|
289 |
+
x,
|
290 |
+
padding_mask=padding_mask,
|
291 |
+
tgt_layer=tgt_layer,
|
292 |
+
min_layer=min_layer,
|
293 |
+
)
|
294 |
+
|
295 |
+
hubert_model.encoder.extract_features = _extract_features
|
296 |
+
|
297 |
+
hubert_model._forward = hubert_model.forward
|
298 |
+
|
299 |
+
def hubert_extract_features(
|
300 |
+
self,
|
301 |
+
source: torch.Tensor,
|
302 |
+
padding_mask: Optional[torch.Tensor] = None,
|
303 |
+
mask: bool = False,
|
304 |
+
ret_conv: bool = False,
|
305 |
+
output_layer: Optional[int] = None,
|
306 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
307 |
+
res = self._forward(
|
308 |
+
source,
|
309 |
+
padding_mask=padding_mask,
|
310 |
+
mask=mask,
|
311 |
+
features_only=True,
|
312 |
+
output_layer=output_layer,
|
313 |
+
)
|
314 |
+
feature = res["features"] if ret_conv else res["x"]
|
315 |
+
return feature, res["padding_mask"]
|
316 |
+
|
317 |
+
def _hubert_extract_features(
|
318 |
+
source: torch.Tensor,
|
319 |
+
padding_mask: Optional[torch.Tensor] = None,
|
320 |
+
mask: bool = False,
|
321 |
+
ret_conv: bool = False,
|
322 |
+
output_layer: Optional[int] = None,
|
323 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
324 |
+
return hubert_extract_features(
|
325 |
+
hubert_model, source, padding_mask, mask, ret_conv, output_layer
|
326 |
+
)
|
327 |
+
|
328 |
+
hubert_model.extract_features = _hubert_extract_features
|
329 |
+
|
330 |
+
def infer(source, padding_mask, output_layer: torch.Tensor):
|
331 |
+
output_layer = output_layer.item()
|
332 |
+
logits = hubert_model.extract_features(
|
333 |
+
source=source, padding_mask=padding_mask, output_layer=output_layer
|
334 |
+
)
|
335 |
+
feats = hubert_model.final_proj(logits[0]) if output_layer == 9 else logits[0]
|
336 |
+
return feats
|
337 |
+
|
338 |
+
hubert_model.infer = infer
|
339 |
+
# hubert_model.forward=infer
|
340 |
+
# hubert_model.forward
|
341 |
+
|
342 |
+
return hubert_model
|
libs/jit/get_rmvpe.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def get_rmvpe(model_path="assets/rmvpe/rmvpe.pt", device=torch.device("cpu")):
|
5 |
+
from infer.lib.rmvpe import E2E
|
6 |
+
|
7 |
+
model = E2E(4, 1, (2, 2))
|
8 |
+
ckpt = torch.load(model_path, map_location=device)
|
9 |
+
model.load_state_dict(ckpt)
|
10 |
+
model.eval()
|
11 |
+
model = model.to(device)
|
12 |
+
return model
|
libs/jit/get_synthesizer.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def get_synthesizer(pth_path, device=torch.device("cpu")):
|
5 |
+
from infer.lib.infer_pack.models import (
|
6 |
+
SynthesizerTrnMs256NSFsid,
|
7 |
+
SynthesizerTrnMs256NSFsid_nono,
|
8 |
+
SynthesizerTrnMs768NSFsid,
|
9 |
+
SynthesizerTrnMs768NSFsid_nono,
|
10 |
+
)
|
11 |
+
|
12 |
+
cpt = torch.load(pth_path, map_location=torch.device("cpu"))
|
13 |
+
# tgt_sr = cpt["config"][-1]
|
14 |
+
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
|
15 |
+
if_f0 = cpt.get("f0", 1)
|
16 |
+
version = cpt.get("version", "v1")
|
17 |
+
if version == "v1":
|
18 |
+
if if_f0 == 1:
|
19 |
+
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=False)
|
20 |
+
else:
|
21 |
+
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
22 |
+
elif version == "v2":
|
23 |
+
if if_f0 == 1:
|
24 |
+
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=False)
|
25 |
+
else:
|
26 |
+
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
27 |
+
del net_g.enc_q
|
28 |
+
# net_g.forward = net_g.infer
|
29 |
+
# ckpt = {}
|
30 |
+
# ckpt["config"] = cpt["config"]
|
31 |
+
# ckpt["f0"] = if_f0
|
32 |
+
# ckpt["version"] = version
|
33 |
+
# ckpt["info"] = cpt.get("info", "0epoch")
|
34 |
+
net_g.load_state_dict(cpt["weight"], strict=False)
|
35 |
+
net_g = net_g.float()
|
36 |
+
net_g.eval().to(device)
|
37 |
+
net_g.remove_weight_norm()
|
38 |
+
return net_g, cpt
|