Sin2pi commited on
Commit
8b434a5
·
verified ·
1 Parent(s): 8e82b27

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -2
model.py CHANGED
@@ -392,7 +392,7 @@ class rotary(nn.Module):
392
  else:
393
  f0 = f0.view(-1)
394
 
395
- if f0 is not None and layer == "encoder": #rethink this
396
  f0_mean = f0.mean()
397
  theta = f0_mean + self.theta
398
  else:
@@ -404,7 +404,7 @@ class rotary(nn.Module):
404
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
405
 
406
  freqs = t[:, None] * freqs[None, :]
407
- if self.radii and f0 is not None and layer == "encoder": #this too
408
  radius = f0.to(device, dtype)
409
  L = radius.shape[0]
410
  if L != ctx:
 
392
  else:
393
  f0 = f0.view(-1)
394
 
395
+ if f0 is not None:
396
  f0_mean = f0.mean()
397
  theta = f0_mean + self.theta
398
  else:
 
404
  print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
405
 
406
  freqs = t[:, None] * freqs[None, :]
407
+ if self.radii and f0 is not None and layer == "encoder":
408
  radius = f0.to(device, dtype)
409
  L = radius.shape[0]
410
  if L != ctx: