Sin2pi commited on
Commit
7520b9d
·
verified ·
1 Parent(s): c05d8b0

Update model_simple.py

Browse files
Files changed (1) hide show
  1. model_simple.py +24 -51
model_simple.py CHANGED
@@ -19,14 +19,6 @@ dtype = torch.float32
19
  warnings.filterwarnings("ignore")
20
  logging.basicConfig(level=logging.ERROR)
21
 
22
- PATH = 'E:/hf'
23
- os.environ['HF_HOME'] = PATH
24
- os.environ['HF_DATASETS_CACHE'] = PATH
25
- os.environ['TORCH_HOME'] = PATH
26
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
27
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
28
-
29
-
30
  @dataclass
31
  class Dimensions:
32
  vocab: int
@@ -43,7 +35,8 @@ class rotary(nn.Module):
43
  self.dims = dims
44
  self.head = head
45
  self.head_dim = dims // head
46
- self.theta = nn.Parameter((torch.tensor(36000, device=device, dtype=dtype)), requires_grad=True)
 
47
 
48
  def forward(self, x=None) -> Tensor:
49
  freqs = (self.theta / 220.0) * 700 * (
@@ -68,35 +61,33 @@ class rotary(nn.Module):
68
  return torch.cat([x1.type_as(x), x2], dim=-1)
69
 
70
  class MultiheadA(nn.Module):
71
-
72
  def __init__(self, dims: int, head: int):
73
  super(MultiheadA, self).__init__()
74
-
75
  self.dims = dims
76
  self.head = head
77
  self.head_dim = dims // head
78
-
79
  self.q = nn.Linear(dims, dims).to(device, dtype)
80
  self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
81
  self.v = nn.Linear(dims, dims).to(device, dtype)
82
  self.o = nn.Linear(dims, dims).to(device, dtype)
83
  self.rope = rotary(dims=dims, head=head)
84
-
 
85
  def forward(self, x: Tensor, xa = None, mask = None):
86
  scale = (self.dims // self.head) ** -0.25
87
- q = self.q(x)
88
- k = self.k(x if xa is None else xa)
89
- v = self.v(x if xa is None else xa)
90
- batch, ctx, dims = q.shape
91
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
92
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
93
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
 
 
94
  q = self.rope.apply_rotary(q, (self.rope(q.shape[2]))) # type: ignore
95
  k = self.rope.apply_rotary(k, (self.rope(k.shape[2]))) # type: ignore
96
- a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and ctx > 1)
97
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
98
- qk = None
99
- return self.o(out), qk
100
 
101
  class t_gate(nn.Module):
102
  def __init__(self, dims, num_types=4):
@@ -123,17 +114,16 @@ class Residual(nn.Module):
123
  self.head = head
124
  self.ctx = ctx
125
  self.head_dim = dims // head
126
-
127
- self.blend = nn.Parameter(torch.tensor(0.5))
128
  act_fn = get_activation(act)
 
129
  self.attn = MultiheadA(dims, head)
130
  mlp = dims * 4
131
  self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
132
  self.t_gate = t_gate(dims=dims, num_types=4*2)
133
 
134
- self.lna = RMSNorm(dims)
135
- self.lnb = RMSNorm(dims)
136
- self.lnc = RMSNorm(dims)
137
 
138
  def forward(self, x, xa=None, mask=None) -> Tensor:
139
  x = x + self.attn(self.lna(x), xa=None, mask=mask)[0]
@@ -165,17 +155,16 @@ class processor(nn.Module):
165
  self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
166
 
167
  # pitch
168
- # self.encoder = nn.Sequential(
169
- # Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
170
- # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
171
- # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
172
-
173
-
174
  self.encoder = nn.Sequential(
175
- Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
176
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
177
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
178
 
 
 
 
 
 
179
  self.bA = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
180
  self.bB = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
181
 
@@ -188,27 +177,11 @@ class processor(nn.Module):
188
 
189
  xa = self.encoder(xa).permute(0, 2, 1)
190
  xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 36000).to(device, dtype)
191
-
192
  for b in chain(self.bA or []):
193
  xa = b(x=xa, xa=None, mask=None)
194
-
195
  for b in chain(self.bB or []):
196
  x = b(x=x, xa=None, mask=self.mask)
197
- xc = b(x, xa=xa, mask=None)
198
- if sequential:
199
- x = xc
200
- else:
201
- a = torch.sigmoid(self.blend)
202
- x = a * xc + (1 - a) * x
203
-
204
- # for b in chain(self.bB or []):
205
- # xd = b(x=torch.cat([x, xa], dim=1), xa=None, mask=None)
206
- # xm = b(x=xd[:, :x.shape[1]], xa=xd[:, x.shape[1]:], mask=None)
207
- # if sequential:
208
- # x = xm
209
- # else:
210
- # a = torch.sigmoid(self.blend)
211
- # x = a * x + (1 - a) * xm
212
 
213
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
214
  x = self.norm(x)
@@ -320,10 +293,10 @@ def main():
320
 
321
  extract_args = {
322
  "waveform": False,
323
- "spec": True,
324
  "f0": False,
325
  "f0t": False,
326
- "pitch": False,
327
  "harmonics": False,
328
  "aperiodics": False,
329
  "phase_mod": False,
 
19
  warnings.filterwarnings("ignore")
20
  logging.basicConfig(level=logging.ERROR)
21
 
 
 
 
 
 
 
 
 
22
  @dataclass
23
  class Dimensions:
24
  vocab: int
 
35
  self.dims = dims
36
  self.head = head
37
  self.head_dim = dims // head
38
+ self.theta = nn.Parameter((torch.tensor(36000, device=device, dtype=dtype)), requires_grad=True)
39
+ self.twotwenty = nn.Parameter((torch.tensor(220, device=device, dtype=dtype)), requires_grad=True)
40
 
41
  def forward(self, x=None) -> Tensor:
42
  freqs = (self.theta / 220.0) * 700 * (
 
61
  return torch.cat([x1.type_as(x), x2], dim=-1)
62
 
63
  class MultiheadA(nn.Module):
 
64
  def __init__(self, dims: int, head: int):
65
  super(MultiheadA, self).__init__()
 
66
  self.dims = dims
67
  self.head = head
68
  self.head_dim = dims // head
 
69
  self.q = nn.Linear(dims, dims).to(device, dtype)
70
  self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
71
  self.v = nn.Linear(dims, dims).to(device, dtype)
72
  self.o = nn.Linear(dims, dims).to(device, dtype)
73
  self.rope = rotary(dims=dims, head=head)
74
+ self.lnq = nn.LayerNorm(self.head_dim, bias = False)
75
+ self.lnx = nn.LayerNorm(dims, bias = False)
76
  def forward(self, x: Tensor, xa = None, mask = None):
77
  scale = (self.dims // self.head) ** -0.25
78
+ q = self.q(self.lnx(x))
79
+ k = self.k(self.lnx(x if xa is None else xa))
80
+ v = self.v(self.lnx(x if xa is None else xa))
 
81
  q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
82
  k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
83
  v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
84
+ q = self.lnq(q)
85
+ k = self.lnq(k)
86
  q = self.rope.apply_rotary(q, (self.rope(q.shape[2]))) # type: ignore
87
  k = self.rope.apply_rotary(k, (self.rope(k.shape[2]))) # type: ignore
88
+ a = scaled_dot_product_attention(q, k, v, is_causal=mask is not None and q.shape[1] > 1)
89
  out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
90
+ return self.o(out)
 
91
 
92
  class t_gate(nn.Module):
93
  def __init__(self, dims, num_types=4):
 
114
  self.head = head
115
  self.ctx = ctx
116
  self.head_dim = dims // head
 
 
117
  act_fn = get_activation(act)
118
+ self.blend = nn.Parameter(torch.tensor(0.5))
119
  self.attn = MultiheadA(dims, head)
120
  mlp = dims * 4
121
  self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
122
  self.t_gate = t_gate(dims=dims, num_types=4*2)
123
 
124
+ self.lna = nn.LayerNorm(dims, bias = False)
125
+ self.lnb = nn.LayerNorm(dims, bias = False)
126
+ self.lnc = nn.LayerNorm(dims, bias = False)
127
 
128
  def forward(self, x, xa=None, mask=None) -> Tensor:
129
  x = x + self.attn(self.lna(x), xa=None, mask=mask)[0]
 
155
  self.positional_sin = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
156
 
157
  # pitch
 
 
 
 
 
 
158
  self.encoder = nn.Sequential(
159
+ Conv1d(1, dims, kernel_size=3, stride=1, padding=1), act_fn,
160
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
161
  Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
162
 
163
+ # self.encoder = nn.Sequential(
164
+ # Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
165
+ # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
166
+ # Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
167
+
168
  self.bA = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
169
  self.bB = nn.ModuleList([Residual(ctx=ctx, dims=dims, head=head, act=act_fn) for _ in range(layer)])
170
 
 
177
 
178
  xa = self.encoder(xa).permute(0, 2, 1)
179
  xa = xa + self.positional_sin(xa.shape[1], xa.shape[-1], 36000).to(device, dtype)
 
180
  for b in chain(self.bA or []):
181
  xa = b(x=xa, xa=None, mask=None)
 
182
  for b in chain(self.bB or []):
183
  x = b(x=x, xa=None, mask=self.mask)
184
+ x = b(x, xa=xa, mask=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  x = nn.functional.dropout(x, p=self.dropout, training=self.training)
187
  x = self.norm(x)
 
293
 
294
  extract_args = {
295
  "waveform": False,
296
+ "spec": False,
297
  "f0": False,
298
  "f0t": False,
299
+ "pitch": True,
300
  "harmonics": False,
301
  "aperiodics": False,
302
  "phase_mod": False,