Sin2pi commited on
Commit
2bc1697
·
verified ·
1 Parent(s): 71bcdd6

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -10
model.py CHANGED
@@ -197,6 +197,7 @@ class rotary(nn.Module):
197
  if f0 is not None:
198
  f0_mean = f0.mean()
199
  f0_theta = (f0_mean**2) * self.pitch_scale
 
200
  inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
201
  else:
202
  inv_freq = self.inv_freq
@@ -206,20 +207,13 @@ class rotary(nn.Module):
206
  if f0 is not None:
207
  f0 = f0[0]
208
  seq_len = x
209
- f0 = torch.tensor(f0, device=x.device if isinstance(x, torch.Tensor) else self.device)
210
  f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
211
- max_f0 = torch.max(f0)
212
-
213
- if max_f0 > 0:
214
- radius = f0 / max_f0
215
- else:
216
- radius = torch.ones_like(f0)
217
-
218
  freqs = torch.polar(radius, freqs)
219
  else:
220
  freqs = torch.polar(torch.ones_like(freqs), freqs)
221
- freqs = freqs.unsqueeze(0)
222
- # print(f"radius, {radius}")
223
 
224
  if "rotary" in self.debug:
225
  if f0 is not None:
 
197
  if f0 is not None:
198
  f0_mean = f0.mean()
199
  f0_theta = (f0_mean**2) * self.pitch_scale
200
+ #f0_theta = f0_mean * (f0_mean / self.theta) * self.theta * self.pitch_scale
201
  inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
202
  else:
203
  inv_freq = self.inv_freq
 
207
  if f0 is not None:
208
  f0 = f0[0]
209
  seq_len = x
210
+ f0 = torch.tensor(f0, device=device if isinstance(x, torch.Tensor) else device)
211
  f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
212
+ radius = 1.0 / (f0 + 1)
 
 
 
 
 
 
213
  freqs = torch.polar(radius, freqs)
214
  else:
215
  freqs = torch.polar(torch.ones_like(freqs), freqs)
216
+ freqs = freqs.unsqueeze(0)
 
217
 
218
  if "rotary" in self.debug:
219
  if f0 is not None: