Sin2pi commited on
Commit
c924a71
·
verified ·
1 Parent(s): 1813939

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +12 -6
model.py CHANGED
@@ -196,24 +196,30 @@ class rotary(nn.Module):
196
  t = x.float().to(self.inv_freq.device)
197
  if f0 is not None:
198
  f0_mean = f0.mean()
199
- f0_theta = f0_mean * (f0_mean / self.theta) * self.theta * self.pitch_scale
200
  inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
201
  else:
202
  inv_freq = self.inv_freq
203
  freqs = torch.einsum('i,j->ij', t, inv_freq)
204
  freqs = freqs.float()
205
  if self.variable_radius:
206
-
207
  if f0 is not None:
208
  f0 = f0[0]
209
  seq_len = x
210
- f0 = torch.tensor(f0, device=x.device if isinstance(x, torch.Tensor) else device)
211
  f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
212
- radius = 1.0 / (f0 + 1)
 
 
 
 
 
 
213
  freqs = torch.polar(radius, freqs)
214
  else:
215
  freqs = torch.polar(torch.ones_like(freqs), freqs)
216
- freqs = freqs.unsqueeze(0)
 
217
 
218
  if "rotary" in self.debug:
219
  if f0 is not None:
@@ -258,7 +264,7 @@ class rotary(nn.Module):
258
  x1 = x1 * freqs
259
  x1 = torch.view_as_real(x1).flatten(-2)
260
  return torch.cat([x1.type_as(x), x2], dim=-1)
261
-
262
  class SliceAttention(nn.Module):
263
  def __init__(self, dims, heads, dropout=0.0):
264
  super().__init__()
 
196
  t = x.float().to(self.inv_freq.device)
197
  if f0 is not None:
198
  f0_mean = f0.mean()
199
+ f0_theta = (f0_mean**2) * self.pitch_scale
200
  inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
201
  else:
202
  inv_freq = self.inv_freq
203
  freqs = torch.einsum('i,j->ij', t, inv_freq)
204
  freqs = freqs.float()
205
  if self.variable_radius:
 
206
  if f0 is not None:
207
  f0 = f0[0]
208
  seq_len = x
209
+ f0 = torch.tensor(f0, device=x.device if isinstance(x, torch.Tensor) else self.device)
210
  f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
211
+ max_f0 = torch.max(f0)
212
+
213
+ if max_f0 > 0:
214
+ radius = f0 / max_f0
215
+ else:
216
+ radius = torch.ones_like(f0)
217
+
218
  freqs = torch.polar(radius, freqs)
219
  else:
220
  freqs = torch.polar(torch.ones_like(freqs), freqs)
221
+ freqs = freqs.unsqueeze(0)
222
+ print(f"radius, {radius}")
223
 
224
  if "rotary" in self.debug:
225
  if f0 is not None:
 
264
  x1 = x1 * freqs
265
  x1 = torch.view_as_real(x1).flatten(-2)
266
  return torch.cat([x1.type_as(x), x2], dim=-1)
267
+
268
  class SliceAttention(nn.Module):
269
  def __init__(self, dims, heads, dropout=0.0):
270
  super().__init__()