Update model.py
Browse files
model.py
CHANGED
@@ -392,7 +392,7 @@ class rotary(nn.Module):
|
|
392 |
else:
|
393 |
f0 = f0.view(-1)
|
394 |
|
395 |
-
if f0 is not None
|
396 |
f0_mean = f0.mean()
|
397 |
theta = f0_mean + self.theta
|
398 |
else:
|
@@ -404,7 +404,7 @@ class rotary(nn.Module):
|
|
404 |
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
405 |
|
406 |
freqs = t[:, None] * freqs[None, :]
|
407 |
-
if self.radii and f0 is not None and layer == "encoder":
|
408 |
radius = f0.to(device, dtype)
|
409 |
L = radius.shape[0]
|
410 |
if L != ctx:
|
|
|
392 |
else:
|
393 |
f0 = f0.view(-1)
|
394 |
|
395 |
+
if f0 is not None:
|
396 |
f0_mean = f0.mean()
|
397 |
theta = f0_mean + self.theta
|
398 |
else:
|
|
|
404 |
print(f" [Rotary] {layer}{self.counter} --- [f0] {f0.shape if f0 is not None else None} [Theta] {theta.item():.2f} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
|
405 |
|
406 |
freqs = t[:, None] * freqs[None, :]
|
407 |
+
if self.radii and f0 is not None and layer == "encoder":
|
408 |
radius = f0.to(device, dtype)
|
409 |
L = radius.shape[0]
|
410 |
if L != ctx:
|