Sin2pi commited on
Commit
87df44f
·
verified ·
1 Parent(s): be63dd3

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +36 -38
model.py CHANGED
@@ -130,38 +130,32 @@ def sinusoids(length, channels, max_timescale=10000):
130
 
131
  class rotary(nn.Module):
132
  _seen = set()
133
- def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=True, variable_radius=False,
134
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
135
  super().__init__()
136
- self.use_pbias = False
137
-
138
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
139
  self.device = device
140
- dtype = torch.float32
141
  self.dtype = dtype
142
  self.debug = debug
143
  self._counter = 0
144
- self.dims = dims
 
145
  self.max_ctx = max_ctx
146
  self.variable_radius = variable_radius
147
 
148
- self.inv_freq = nn.Parameter(
149
- 1.0 / (19000 ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)),
150
  requires_grad=learned_freq)
151
- self.theta = nn.Parameter(
152
- torch.tensor(float(theta)), requires_grad=learned_theta)
153
- self.min_theta = nn.Parameter(
154
- torch.tensor(600.0), requires_grad=learned_theta)
155
- self.max_theta = nn.Parameter(
156
- torch.tensor(2400.0), requires_grad=learned_theta)
157
-
158
- self.pitch_scale = nn.Parameter(torch.tensor(1.0),
159
- requires_grad=learned_pitch)
160
 
161
  if variable_radius:
162
- self.radius = nn.Parameter(
163
- torch.ones(dims // 2),
164
- requires_grad=learned_radius)
165
 
166
  def get_pitch_bias(self, f0):
167
  if f0 is None:
@@ -189,31 +183,38 @@ class rotary(nn.Module):
189
  rotary.get_sim = get_sim
190
  rotary.fwd_sim = fwd_sim
191
 
192
- def forward(self, x = None, f0=None) -> Tensor:
 
 
 
 
 
 
193
  if isinstance(x, int):
194
  t = torch.arange(x, device=self.device).float()
195
  else:
196
  t = x.float().to(self.inv_freq.device)
197
-
198
  if f0 is not None:
199
-
200
  f0_mean = f0.mean()
201
- perceptual_factor = torch.log(1 + f0_mean / 700.0) / torch.log(torch.tensor(1 + 300.0 / 700.0))
202
- min_theta, max_theta = 800.0, 10000.0
203
- f0_theta = self.theta + perceptual_factor * (max_theta - min_theta)
204
- inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
205
  else:
206
  inv_freq = self.inv_freq
207
  freqs = torch.einsum('i,j->ij', t, inv_freq)
208
-
209
  freqs = freqs.float()
210
  if self.variable_radius:
211
- radius = F.softplus(self.radius)
212
- freqs = torch.polar(radius.unsqueeze(0).expand_as(freqs), freqs)
213
- else:
214
- freqs = torch.polar(torch.ones_like(freqs), freqs)
 
 
 
 
 
 
215
  freqs = freqs.unsqueeze(0)
216
-
217
  if "rotary" in self.debug:
218
  if f0 is not None:
219
  key = f"{self._counter}_{f0_theta:.2f}"
@@ -221,13 +222,12 @@ class rotary(nn.Module):
221
  if not hasattr(self, '_prev_f0_theta'):
222
  self._prev_f0_theta = f0_theta
223
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
224
- elif abs(self._prev_f0_theta - f0_theta) > 200.0:
225
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
226
  self._prev_f0_theta = f0_theta
227
  rotary._seen.add(key)
228
  self._counter += 1
229
-
230
- return freqs
231
 
232
  @staticmethod
233
  def apply_rotary(x, freqs):
@@ -240,13 +240,11 @@ class rotary(nn.Module):
240
  x1 = x1 * freqs
241
  x1 = torch.view_as_real(x1).flatten(-2)
242
  return torch.cat([x1.type_as(x), x2], dim=-1)
243
-
244
  else:
245
  x1 = x[..., :freqs.shape[-1]*2]
246
  x2 = x[..., freqs.shape[-1]*2:]
247
 
248
  if x.ndim == 2:
249
-
250
  x1 = x1.unsqueeze(0)
251
  x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
252
  x1 = torch.view_as_complex(x1)
@@ -260,7 +258,7 @@ class rotary(nn.Module):
260
  x1 = x1 * freqs
261
  x1 = torch.view_as_real(x1).flatten(-2)
262
  return torch.cat([x1.type_as(x), x2], dim=-1)
263
-
264
  class SliceAttention(nn.Module):
265
  def __init__(self, dims, heads, dropout=0.0):
266
  super().__init__()
 
130
 
131
  class rotary(nn.Module):
132
  _seen = set()
133
+ def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, variable_radius=False,
134
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
135
  super().__init__()
136
+
137
+ self.dims = dims
138
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
139
+ dtype = torch.float32
140
  self.device = device
 
141
  self.dtype = dtype
142
  self.debug = debug
143
  self._counter = 0
144
+
145
+ self.use_pbias = False
146
  self.max_ctx = max_ctx
147
  self.variable_radius = variable_radius
148
 
149
+ self.inv_freq = nn.Parameter(1.0 / (theta ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)),
 
150
  requires_grad=learned_freq)
151
+
152
+ self.theta = nn.Parameter(torch.tensor(float(theta)),
153
+ requires_grad=learned_theta)
154
+
155
+ self.pitch_scale = nn.Parameter(torch.tensor(1.0), requires_grad=learned_pitch)
 
 
 
 
156
 
157
  if variable_radius:
158
+ self.radius = nn.Parameter(torch.ones(dims // 2), requires_grad=learned_radius)
 
 
159
 
160
  def get_pitch_bias(self, f0):
161
  if f0 is None:
 
183
  rotary.get_sim = get_sim
184
  rotary.fwd_sim = fwd_sim
185
 
186
+ def align_f0_to_tokens(self, f0, token_length):
187
+ ratio = len(f0) / token_length
188
+ indices = [int(i * ratio) for i in range(token_length)]
189
+ indices = [min(i, len(f0) - 1) for i in indices]
190
+ return f0[indices]
191
+
192
+ def forward(self, x=None, f0=None, stage=None) -> Tensor:
193
  if isinstance(x, int):
194
  t = torch.arange(x, device=self.device).float()
195
  else:
196
  t = x.float().to(self.inv_freq.device)
 
197
  if f0 is not None:
 
198
  f0_mean = f0.mean()
199
+ f0_theta = f0_mean * (f0_mean / self.theta) * self.theta * self.pitch_scale
200
+ inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
 
 
201
  else:
202
  inv_freq = self.inv_freq
203
  freqs = torch.einsum('i,j->ij', t, inv_freq)
 
204
  freqs = freqs.float()
205
  if self.variable_radius:
206
+
207
+ if f0 is not None:
208
+ f0 = f0[0]
209
+ seq_len = x
210
+ f0 = torch.tensor(f0, device=x.device if isinstance(x, torch.Tensor) else device)
211
+ f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
212
+ radius = 1.0 / (f0 + 1)
213
+ freqs = torch.polar(radius, freqs)
214
+ else:
215
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
216
  freqs = freqs.unsqueeze(0)
217
+
218
  if "rotary" in self.debug:
219
  if f0 is not None:
220
  key = f"{self._counter}_{f0_theta:.2f}"
 
222
  if not hasattr(self, '_prev_f0_theta'):
223
  self._prev_f0_theta = f0_theta
224
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
225
+ elif abs(self._prev_f0_theta - f0_theta) > 0.0:
226
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
227
  self._prev_f0_theta = f0_theta
228
  rotary._seen.add(key)
229
  self._counter += 1
230
+ return freqs
 
231
 
232
  @staticmethod
233
  def apply_rotary(x, freqs):
 
240
  x1 = x1 * freqs
241
  x1 = torch.view_as_real(x1).flatten(-2)
242
  return torch.cat([x1.type_as(x), x2], dim=-1)
 
243
  else:
244
  x1 = x[..., :freqs.shape[-1]*2]
245
  x2 = x[..., freqs.shape[-1]*2:]
246
 
247
  if x.ndim == 2:
 
248
  x1 = x1.unsqueeze(0)
249
  x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
250
  x1 = torch.view_as_complex(x1)
 
258
  x1 = x1 * freqs
259
  x1 = torch.view_as_real(x1).flatten(-2)
260
  return torch.cat([x1.type_as(x), x2], dim=-1)
261
+
262
  class SliceAttention(nn.Module):
263
  def __init__(self, dims, heads, dropout=0.0):
264
  super().__init__()