Spaces:
Running
Running
lym0302
commited on
Commit
·
e3d8696
1
Parent(s):
46ab663
DEVICE
Browse files
mmaudio/ext/rotary_embeddings.py
CHANGED
@@ -7,7 +7,7 @@ from torch import Tensor
|
|
7 |
# Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
8 |
# Ref: https://github.com/lucidrains/rotary-embedding-torch
|
9 |
|
10 |
-
|
11 |
def compute_rope_rotations(length: int,
|
12 |
dim: int,
|
13 |
theta: int,
|
@@ -17,7 +17,7 @@ def compute_rope_rotations(length: int,
|
|
17 |
assert dim % 2 == 0
|
18 |
|
19 |
# with torch.amp.autocast(device_type='cuda', enabled=False):
|
20 |
-
with torch.amp.autocast(device_type=
|
21 |
pos = torch.arange(length, dtype=torch.float32, device=device)
|
22 |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
23 |
freqs *= freq_scaling
|
@@ -28,9 +28,9 @@ def compute_rope_rotations(length: int,
|
|
28 |
return rot
|
29 |
|
30 |
|
31 |
-
def apply_rope(x: Tensor, rot: Tensor
|
32 |
# with torch.amp.autocast(device_type='cuda', enabled=False):
|
33 |
-
with torch.amp.autocast(device_type=
|
34 |
_x = x.float()
|
35 |
_x = _x.view(*_x.shape[:-1], -1, 1, 2)
|
36 |
x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
|
|
|
7 |
# Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
8 |
# Ref: https://github.com/lucidrains/rotary-embedding-torch
|
9 |
|
10 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
11 |
def compute_rope_rotations(length: int,
|
12 |
dim: int,
|
13 |
theta: int,
|
|
|
17 |
assert dim % 2 == 0
|
18 |
|
19 |
# with torch.amp.autocast(device_type='cuda', enabled=False):
|
20 |
+
with torch.amp.autocast(device_type=DEVICE, enabled=False):
|
21 |
pos = torch.arange(length, dtype=torch.float32, device=device)
|
22 |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
23 |
freqs *= freq_scaling
|
|
|
28 |
return rot
|
29 |
|
30 |
|
31 |
+
def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
|
32 |
# with torch.amp.autocast(device_type='cuda', enabled=False):
|
33 |
+
with torch.amp.autocast(device_type=DEVICE, enabled=False):
|
34 |
_x = x.float()
|
35 |
_x = _x.view(*_x.shape[:-1], -1, 1, 2)
|
36 |
x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
|