Sin2pi commited on
Commit
5219f06
·
verified ·
1 Parent(s): 2bc1697

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +47 -27
model.py CHANGED
@@ -133,29 +133,35 @@ class rotary(nn.Module):
133
  def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, variable_radius=False,
134
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
135
  super().__init__()
136
-
137
- self.dims = dims
138
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
139
- dtype = torch.float32
140
  self.device = device
 
141
  self.dtype = dtype
142
  self.debug = debug
143
  self._counter = 0
144
-
145
- self.use_pbias = False
146
  self.max_ctx = max_ctx
147
  self.variable_radius = variable_radius
148
 
149
- self.inv_freq = nn.Parameter(1.0 / (theta ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)),
 
150
  requires_grad=learned_freq)
151
-
152
- self.theta = nn.Parameter(torch.tensor(float(theta)),
153
- requires_grad=learned_theta)
154
-
155
- self.pitch_scale = nn.Parameter(torch.tensor(1.0), requires_grad=learned_pitch)
 
 
 
 
156
 
157
  if variable_radius:
158
- self.radius = nn.Parameter(torch.ones(dims // 2), requires_grad=learned_radius)
 
 
159
 
160
  def get_pitch_bias(self, f0):
161
  if f0 is None:
@@ -194,27 +200,38 @@ class rotary(nn.Module):
194
  t = torch.arange(x, device=self.device).float()
195
  else:
196
  t = x.float().to(self.inv_freq.device)
 
197
  if f0 is not None:
198
  f0_mean = f0.mean()
199
- f0_theta = (f0_mean**2) * self.pitch_scale
200
- #f0_theta = f0_mean * (f0_mean / self.theta) * self.theta * self.pitch_scale
201
- inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
 
202
  else:
203
  inv_freq = self.inv_freq
204
  freqs = torch.einsum('i,j->ij', t, inv_freq)
 
205
  freqs = freqs.float()
206
  if self.variable_radius:
207
- if f0 is not None:
208
- f0 = f0[0]
209
- seq_len = x
210
- f0 = torch.tensor(f0, device=device if isinstance(x, torch.Tensor) else device)
211
- f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
212
- radius = 1.0 / (f0 + 1)
213
- freqs = torch.polar(radius, freqs)
214
- else:
215
- freqs = torch.polar(torch.ones_like(freqs), freqs)
216
- freqs = freqs.unsqueeze(0)
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  if "rotary" in self.debug:
219
  if f0 is not None:
220
  key = f"{self._counter}_{f0_theta:.2f}"
@@ -222,12 +239,13 @@ class rotary(nn.Module):
222
  if not hasattr(self, '_prev_f0_theta'):
223
  self._prev_f0_theta = f0_theta
224
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
225
- elif abs(self._prev_f0_theta - f0_theta) > 0.0:
226
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
227
  self._prev_f0_theta = f0_theta
228
  rotary._seen.add(key)
229
  self._counter += 1
230
- return freqs
 
231
 
232
  @staticmethod
233
  def apply_rotary(x, freqs):
@@ -240,11 +258,13 @@ class rotary(nn.Module):
240
  x1 = x1 * freqs
241
  x1 = torch.view_as_real(x1).flatten(-2)
242
  return torch.cat([x1.type_as(x), x2], dim=-1)
 
243
  else:
244
  x1 = x[..., :freqs.shape[-1]*2]
245
  x2 = x[..., freqs.shape[-1]*2:]
246
 
247
  if x.ndim == 2:
 
248
  x1 = x1.unsqueeze(0)
249
  x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
250
  x1 = torch.view_as_complex(x1)
 
133
  def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, variable_radius=False,
134
  learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
135
  super().__init__()
136
+ self.use_pbias = False
137
+
138
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
139
  self.device = device
140
+ dtype = torch.float32
141
  self.dtype = dtype
142
  self.debug = debug
143
  self._counter = 0
144
+ self.dims = dims
 
145
  self.max_ctx = max_ctx
146
  self.variable_radius = variable_radius
147
 
148
+ self.inv_freq = nn.Parameter(
149
+ 1.0 / (10000 ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)),
150
  requires_grad=learned_freq)
151
+ self.theta = nn.Parameter(
152
+ torch.tensor(float(theta)), requires_grad=learned_theta)
153
+ self.min_theta = nn.Parameter(
154
+ torch.tensor(600.0), requires_grad=learned_theta)
155
+ self.max_theta = nn.Parameter(
156
+ torch.tensor(2400.0), requires_grad=learned_theta)
157
+
158
+ self.pitch_scale = nn.Parameter(torch.tensor(1.0),
159
+ requires_grad=learned_pitch)
160
 
161
  if variable_radius:
162
+ self.radius = nn.Parameter(
163
+ torch.ones(dims // 2),
164
+ requires_grad=learned_radius)
165
 
166
  def get_pitch_bias(self, f0):
167
  if f0 is None:
 
200
  t = torch.arange(x, device=self.device).float()
201
  else:
202
  t = x.float().to(self.inv_freq.device)
203
+
204
  if f0 is not None:
205
  f0_mean = f0.mean()
206
+ f0_mean = torch.clamp(f0_mean, min=80.0, max=600.0)
207
+ perceptual_factor = torch.log(1 + f0_mean / 700.0) / torch.log(torch.tensor(1 + 300.0 / 700.0))
208
+ f0_theta = self.min_theta + perceptual_factor * (self.max_theta - self.min_theta)
209
+ inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
210
  else:
211
  inv_freq = self.inv_freq
212
  freqs = torch.einsum('i,j->ij', t, inv_freq)
213
+
214
  freqs = freqs.float()
215
  if self.variable_radius:
 
 
 
 
 
 
 
 
 
 
216
 
217
+ # if f0 is not None:
218
+ # f0 = f0[0]
219
+ # seq_len = x
220
+ # f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
221
+ # radius = f0
222
+
223
+ # freqs = torch.polar(radius, freqs)
224
+ # else:
225
+
226
+ # freqs = torch.polar(torch.ones_like(freqs), freqs)
227
+ # freqs = freqs.unsqueeze(0)
228
+
229
+ radius = F.softplus(self.radius)
230
+ freqs = torch.polar(radius.unsqueeze(0).expand_as(freqs), freqs)
231
+ else:
232
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
233
+ freqs = freqs.unsqueeze(0)
234
+
235
  if "rotary" in self.debug:
236
  if f0 is not None:
237
  key = f"{self._counter}_{f0_theta:.2f}"
 
239
  if not hasattr(self, '_prev_f0_theta'):
240
  self._prev_f0_theta = f0_theta
241
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
242
+ elif abs(self._prev_f0_theta - f0_theta) > 200.0:
243
  print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
244
  self._prev_f0_theta = f0_theta
245
  rotary._seen.add(key)
246
  self._counter += 1
247
+
248
+ return freqs
249
 
250
  @staticmethod
251
  def apply_rotary(x, freqs):
 
258
  x1 = x1 * freqs
259
  x1 = torch.view_as_real(x1).flatten(-2)
260
  return torch.cat([x1.type_as(x), x2], dim=-1)
261
+
262
  else:
263
  x1 = x[..., :freqs.shape[-1]*2]
264
  x2 = x[..., freqs.shape[-1]*2:]
265
 
266
  if x.ndim == 2:
267
+
268
  x1 = x1.unsqueeze(0)
269
  x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
270
  x1 = torch.view_as_complex(x1)