Update model_simple.py
Browse files- model_simple.py +24 -51
model_simple.py
CHANGED
@@ -19,14 +19,6 @@ dtype = torch.float32
|
|
19 |
warnings.filterwarnings("ignore")
|
20 |
logging.basicConfig(level=logging.ERROR)
|
21 |
|
22 |
-
PATH = 'E:/hf'
|
23 |
-
os.environ['HF_HOME'] = PATH
|
24 |
-
os.environ['HF_DATASETS_CACHE'] = PATH
|
25 |
-
os.environ['TORCH_HOME'] = PATH
|
26 |
-
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
27 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
28 |
-
|
29 |
-
|
30 |
@dataclass
|
31 |
class Dimensions:
|
32 |
vocab: int
|
@@ -43,7 +35,8 @@ class rotary(nn.Module):
|
|
43 |
self.dims = dims
|
44 |
self.head = head
|
45 |
self.head_dim = dims // head
|
46 |
-
self.theta = nn.Parameter((torch.tensor(36000, device=device, dtype=dtype)), requires_grad=True)
|
|
|
47 |
|
48 |
def forward(self, x=None) -> Tensor:
|
49 |
freqs = (self.theta / 220.0) * 700 * (
|
@@ -68,35 +61,33 @@ class rotary(nn.Module):
|
|
68 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
69 |
|
70 |
class MultiheadA(nn.Module):
|
71 |
-
|
72 |
def __init__(self, dims: int, head: int):
|
73 |
super(MultiheadA, self).__init__()
|
74 |
-
|
75 |
self.dims = dims
|
76 |
self.head = head
|
77 |
self.head_dim = dims // head
|
78 |
-
|
79 |
self.q = nn.Linear(dims, dims).to(device, dtype)
|
80 |
self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
|
81 |
self.v = nn.Linear(dims, dims).to(device, dtype)
|
82 |
self.o = nn.Linear(dims, dims).to(device, dtype)
|
83 |
self.rope = rotary(dims=dims, head=head)
|
84 |
-
|
|
|
85 |
def forward(self, x: Tensor, xa = None, mask = None):
|
86 |
scale = (self.dims // self.head) ** -0.25
|
87 |
-
q = self.q(x)
|
88 |
-
k = self.k(x if xa is None else xa)
|
89 |
-
v = self.v(x if xa is None else xa)
|
90 |
-
batch, ctx, dims = q.shape
|
91 |
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
92 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
93 |
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
|
|
|
|
94 |
q = self.rope.apply_rotary(q, (self.rope(q.shape[2]))) # type: ignore
|
95 |
k = self.rope.apply_rotary(k, (self.rope(k.shape[2]))) # type: ignore
|
96 |
-
a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and
|
97 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
98 |
-
|
99 |
-
return self.o(out), qk
|
100 |
|
101 |
class t_gate(nn.Module):
|
102 |
def __init__(self, dims, num_types=4):
|
@@ -123,17 +114,16 @@ class Residual(nn.Module):
|
|
123 |
self.head = head
|
124 |
self.ctx = ctx
|
125 |
self.head_dim = dims // head
|
126 |
-
|
127 |
-
self.blend = nn.Parameter(torch.tensor(0.5))
|
128 |
act_fn = get_activation(act)
|
|
|
129 |
self.attn = MultiheadA(dims, head)
|
130 |
mlp = dims * 4
|
131 |
self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
|
132 |
self.t_gate = t_gate(dims=dims, num_types=4*2)
|
133 |
|
134 |
-
self.lna =
|
135 |
-
self.lnb =
|
136 |
-
self.lnc =
|
137 |
|
138 |
def forward(self, x, xa=None, mask=None) -> Tensor:
|
139 |
x = x + self.attn(self.lna(x), xa=None, mask=mask)[0]
|
@@ -165,17 +155,16 @@ class processor(nn.Module):
|
|
165 |
self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
166 |
|
167 |
# pitch
|
168 |
-
# self.encoder = nn.Sequential(
|
169 |
-
# Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
170 |
-
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
171 |
-
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
172 |
-
|
173 |
-
|
174 |
self.encoder = nn.Sequential(
|
175 |
-
Conv1d(
|
176 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
177 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
178 |
|
|
|
|
|
|
|
|
|
|
|
179 |
self.bA = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
180 |
self.bB = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
181 |
|
@@ -188,27 +177,11 @@ class processor(nn.Module):
|
|
188 |
|
189 |
xa = self.encoder(xa).permute(0, 2, 1)
|
190 |
xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 36000).to(device, dtype)
|
191 |
-
|
192 |
for b in chain(self.bA or []):
|
193 |
xa = b(x=xa, xa=None, mask=None)
|
194 |
-
|
195 |
for b in chain(self.bB or []):
|
196 |
x = b(x=x, xa=None, mask=self.mask)
|
197 |
-
|
198 |
-
if sequential:
|
199 |
-
x = xc
|
200 |
-
else:
|
201 |
-
a = torch.sigmoid(self.blend)
|
202 |
-
x = a * xc + (1 - a) * x
|
203 |
-
|
204 |
-
# for b in chain(self.bB or []):
|
205 |
-
# xd = b(x=torch.cat([x, xa], dim=1), xa=None, mask=None)
|
206 |
-
# xm = b(x=xd[:, :x.shape[1]], xa=xd[:, x.shape[1]:], mask=None)
|
207 |
-
# if sequential:
|
208 |
-
# x = xm
|
209 |
-
# else:
|
210 |
-
# a = torch.sigmoid(self.blend)
|
211 |
-
# x = a * x + (1 - a) * xm
|
212 |
|
213 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
214 |
x = self.norm(x)
|
@@ -320,10 +293,10 @@ def main():
|
|
320 |
|
321 |
extract_args = {
|
322 |
"waveform": False,
|
323 |
-
"spec":
|
324 |
"f0": False,
|
325 |
"f0t": False,
|
326 |
-
"pitch":
|
327 |
"harmonics": False,
|
328 |
"aperiodics": False,
|
329 |
"phase_mod": False,
|
|
|
19 |
warnings.filterwarnings("ignore")
|
20 |
logging.basicConfig(level=logging.ERROR)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
@dataclass
|
23 |
class Dimensions:
|
24 |
vocab: int
|
|
|
35 |
self.dims = dims
|
36 |
self.head = head
|
37 |
self.head_dim = dims // head
|
38 |
+
self.theta = nn.Parameter((torch.tensor(36000, device=device, dtype=dtype)), requires_grad=True)
|
39 |
+
self.twotwenty = nn.Parameter((torch.tensor(220, device=device, dtype=dtype)), requires_grad=True)
|
40 |
|
41 |
def forward(self, x=None) -> Tensor:
|
42 |
freqs = (self.theta / 220.0) * 700 * (
|
|
|
61 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
62 |
|
63 |
class MultiheadA(nn.Module):
|
|
|
64 |
def __init__(self, dims: int, head: int):
|
65 |
super(MultiheadA, self).__init__()
|
|
|
66 |
self.dims = dims
|
67 |
self.head = head
|
68 |
self.head_dim = dims // head
|
|
|
69 |
self.q = nn.Linear(dims, dims).to(device, dtype)
|
70 |
self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
|
71 |
self.v = nn.Linear(dims, dims).to(device, dtype)
|
72 |
self.o = nn.Linear(dims, dims).to(device, dtype)
|
73 |
self.rope = rotary(dims=dims, head=head)
|
74 |
+
self.lnq = nn.LayerNorm(self.head_dim, bias = False)
|
75 |
+
self.lnx = nn.LayerNorm(dims, bias = False)
|
76 |
def forward(self, x: Tensor, xa = None, mask = None):
|
77 |
scale = (self.dims // self.head) ** -0.25
|
78 |
+
q = self.q(self.lnx(x))
|
79 |
+
k = self.k(self.lnx(x if xa is None else xa))
|
80 |
+
v = self.v(self.lnx(x if xa is None else xa))
|
|
|
81 |
q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
82 |
k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
83 |
v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
|
84 |
+
q = self.lnq(q)
|
85 |
+
k = self.lnq(k)
|
86 |
q = self.rope.apply_rotary(q, (self.rope(q.shape[2]))) # type: ignore
|
87 |
k = self.rope.apply_rotary(k, (self.rope(k.shape[2]))) # type: ignore
|
88 |
+
a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and q.shape[1] > 1)
|
89 |
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
|
90 |
+
return self.o(out)
|
|
|
91 |
|
92 |
class t_gate(nn.Module):
|
93 |
def __init__(self, dims, num_types=4):
|
|
|
114 |
self.head = head
|
115 |
self.ctx = ctx
|
116 |
self.head_dim = dims // head
|
|
|
|
|
117 |
act_fn = get_activation(act)
|
118 |
+
self.blend = nn.Parameter(torch.tensor(0.5))
|
119 |
self.attn = MultiheadA(dims, head)
|
120 |
mlp = dims * 4
|
121 |
self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
|
122 |
self.t_gate = t_gate(dims=dims, num_types=4*2)
|
123 |
|
124 |
+
self.lna = nn.LayerNorm(dims, bias = False)
|
125 |
+
self.lnb = nn.LayerNorm(dims, bias = False)
|
126 |
+
self.lnc = nn.LayerNorm(dims, bias = False)
|
127 |
|
128 |
def forward(self, x, xa=None, mask=None) -> Tensor:
|
129 |
x = x + self.attn(self.lna(x), xa=None, mask=mask)[0]
|
|
|
155 |
self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
|
156 |
|
157 |
# pitch
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
self.encoder = nn.Sequential(
|
159 |
+
Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
160 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
161 |
Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
162 |
|
163 |
+
# self.encoder = nn.Sequential(
|
164 |
+
# Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
165 |
+
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
|
166 |
+
# Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
|
167 |
+
|
168 |
self.bA = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
169 |
self.bB = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
|
170 |
|
|
|
177 |
|
178 |
xa = self.encoder(xa).permute(0, 2, 1)
|
179 |
xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 36000).to(device, dtype)
|
|
|
180 |
for b in chain(self.bA or []):
|
181 |
xa = b(x=xa, xa=None, mask=None)
|
|
|
182 |
for b in chain(self.bB or []):
|
183 |
x = b(x=x, xa=None, mask=self.mask)
|
184 |
+
x = b(x, xa=xa, mask=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
|
187 |
x = self.norm(x)
|
|
|
293 |
|
294 |
extract_args = {
|
295 |
"waveform": False,
|
296 |
+
"spec": False,
|
297 |
"f0": False,
|
298 |
"f0t": False,
|
299 |
+
"pitch": True,
|
300 |
"harmonics": False,
|
301 |
"aperiodics": False,
|
302 |
"phase_mod": False,
|