Update model.py
Browse files
model.py
CHANGED
@@ -351,7 +351,7 @@ class rotary(nn.Module):
|
|
351 |
batch, head, ctx, head_dim = x.shape
|
352 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
353 |
|
354 |
-
f0 =
|
355 |
if f0 is not None and f0.dim() == 2:
|
356 |
if f0.shape[0] == 1:
|
357 |
f0 = f0.squeeze(0)
|
|
|
351 |
batch, head, ctx, head_dim = x.shape
|
352 |
t = torch.arange(ctx, device=device, dtype=dtype)
|
353 |
|
354 |
+
f0 = enc.get("f0") if enc is not None else None
|
355 |
if f0 is not None and f0.dim() == 2:
|
356 |
if f0.shape[0] == 1:
|
357 |
f0 = f0.squeeze(0)
|