lym0302 commited on
Commit
e3d8696
·
1 Parent(s): 46ab663
Files changed (1) hide show
  1. mmaudio/ext/rotary_embeddings.py +4 -4
mmaudio/ext/rotary_embeddings.py CHANGED
@@ -7,7 +7,7 @@ from torch import Tensor
7
  # Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
8
  # Ref: https://github.com/lucidrains/rotary-embedding-torch
9
 
10
-
11
  def compute_rope_rotations(length: int,
12
  dim: int,
13
  theta: int,
@@ -17,7 +17,7 @@ def compute_rope_rotations(length: int,
17
  assert dim % 2 == 0
18
 
19
  # with torch.amp.autocast(device_type='cuda', enabled=False):
20
- with torch.amp.autocast(device_type=device, enabled=False):
21
  pos = torch.arange(length, dtype=torch.float32, device=device)
22
  freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
23
  freqs *= freq_scaling
@@ -28,9 +28,9 @@ def compute_rope_rotations(length: int,
28
  return rot
29
 
30
 
31
- def apply_rope(x: Tensor, rot: Tensor, device: Union[torch.device, str] = 'cpu') -> tuple[Tensor, Tensor]:
32
  # with torch.amp.autocast(device_type='cuda', enabled=False):
33
- with torch.amp.autocast(device_type=device, enabled=False):
34
  _x = x.float()
35
  _x = _x.view(*_x.shape[:-1], -1, 1, 2)
36
  x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
 
7
  # Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
8
  # Ref: https://github.com/lucidrains/rotary-embedding-torch
9
 
10
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
11
  def compute_rope_rotations(length: int,
12
  dim: int,
13
  theta: int,
 
17
  assert dim % 2 == 0
18
 
19
  # with torch.amp.autocast(device_type='cuda', enabled=False):
20
+ with torch.amp.autocast(device_type=DEVICE, enabled=False):
21
  pos = torch.arange(length, dtype=torch.float32, device=device)
22
  freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
23
  freqs *= freq_scaling
 
28
  return rot
29
 
30
 
31
+ def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
32
  # with torch.amp.autocast(device_type='cuda', enabled=False):
33
+ with torch.amp.autocast(device_type=DEVICE, enabled=False):
34
  _x = x.float()
35
  _x = _x.view(*_x.shape[:-1], -1, 1, 2)
36
  x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]