Spaces:
Running
Running
lym0302
commited on
Commit
·
46ab663
1
Parent(s):
146e231
device->cpu
Browse files
mmaudio/ext/rotary_embeddings.py
CHANGED
@@ -16,7 +16,8 @@ def compute_rope_rotations(length: int,
|
|
16 |
device: Union[torch.device, str] = 'cpu') -> Tensor:
|
17 |
assert dim % 2 == 0
|
18 |
|
19 |
-
with torch.amp.autocast(device_type='cuda', enabled=False):
|
|
|
20 |
pos = torch.arange(length, dtype=torch.float32, device=device)
|
21 |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
22 |
freqs *= freq_scaling
|
@@ -27,8 +28,9 @@ def compute_rope_rotations(length: int,
|
|
27 |
return rot
|
28 |
|
29 |
|
30 |
-
def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
|
31 |
-
with torch.amp.autocast(device_type='cuda', enabled=False):
|
|
|
32 |
_x = x.float()
|
33 |
_x = _x.view(*_x.shape[:-1], -1, 1, 2)
|
34 |
x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
|
|
|
16 |
device: Union[torch.device, str] = 'cpu') -> Tensor:
|
17 |
assert dim % 2 == 0
|
18 |
|
19 |
+
# with torch.amp.autocast(device_type='cuda', enabled=False):
|
20 |
+
with torch.amp.autocast(device_type=device, enabled=False):
|
21 |
pos = torch.arange(length, dtype=torch.float32, device=device)
|
22 |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
23 |
freqs *= freq_scaling
|
|
|
28 |
return rot
|
29 |
|
30 |
|
31 |
+
def apply_rope(x: Tensor, rot: Tensor, device: Union[torch.device, str] = 'cpu') -> tuple[Tensor, Tensor]:
|
32 |
+
# with torch.amp.autocast(device_type='cuda', enabled=False):
|
33 |
+
with torch.amp.autocast(device_type=device, enabled=False):
|
34 |
_x = x.float()
|
35 |
_x = _x.view(*_x.shape[:-1], -1, 1, 2)
|
36 |
x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
|
mmaudio/ext/synchformer/synchformer.py
CHANGED
@@ -41,14 +41,14 @@ class Synchformer(nn.Module):
|
|
41 |
return super().load_state_dict(sd, strict)
|
42 |
|
43 |
|
44 |
-
if __name__ == "__main__":
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
|
53 |
# extract and save the state dict only
|
54 |
# sd = torch.load('./ext_weights/sync_model_audioset.pt')['model']
|
|
|
41 |
return super().load_state_dict(sd, strict)
|
42 |
|
43 |
|
44 |
+
# if __name__ == "__main__":
|
45 |
+
# model = Synchformer().cuda().eval()
|
46 |
+
# sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True)
|
47 |
+
# model.load_state_dict(sd)
|
48 |
+
|
49 |
+
# vid = torch.randn(2, 7, 16, 3, 224, 224).cuda()
|
50 |
+
# features = model.extract_vfeats(vid, for_loop=False).detach().cpu()
|
51 |
+
# print(features.shape)
|
52 |
|
53 |
# extract and save the state dict only
|
54 |
# sd = torch.load('./ext_weights/sync_model_audioset.pt')['model']
|