lym0302 commited on
Commit
46ab663
·
1 Parent(s): 146e231

device->cpu

Browse files
mmaudio/ext/rotary_embeddings.py CHANGED
@@ -16,7 +16,8 @@ def compute_rope_rotations(length: int,
16
  device: Union[torch.device, str] = 'cpu') -> Tensor:
17
  assert dim % 2 == 0
18
 
19
- with torch.amp.autocast(device_type='cuda', enabled=False):
 
20
  pos = torch.arange(length, dtype=torch.float32, device=device)
21
  freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
22
  freqs *= freq_scaling
@@ -27,8 +28,9 @@ def compute_rope_rotations(length: int,
27
  return rot
28
 
29
 
30
- def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
31
- with torch.amp.autocast(device_type='cuda', enabled=False):
 
32
  _x = x.float()
33
  _x = _x.view(*_x.shape[:-1], -1, 1, 2)
34
  x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
 
16
  device: Union[torch.device, str] = 'cpu') -> Tensor:
17
  assert dim % 2 == 0
18
 
19
+ # with torch.amp.autocast(device_type='cuda', enabled=False):
20
+ with torch.amp.autocast(device_type=device, enabled=False):
21
  pos = torch.arange(length, dtype=torch.float32, device=device)
22
  freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
23
  freqs *= freq_scaling
 
28
  return rot
29
 
30
 
31
+ def apply_rope(x: Tensor, rot: Tensor, device: Union[torch.device, str] = 'cpu') -> tuple[Tensor, Tensor]:
32
+ # with torch.amp.autocast(device_type='cuda', enabled=False):
33
+ with torch.amp.autocast(device_type=device, enabled=False):
34
  _x = x.float()
35
  _x = _x.view(*_x.shape[:-1], -1, 1, 2)
36
  x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
mmaudio/ext/synchformer/synchformer.py CHANGED
@@ -41,14 +41,14 @@ class Synchformer(nn.Module):
41
  return super().load_state_dict(sd, strict)
42
 
43
 
44
- if __name__ == "__main__":
45
- model = Synchformer().cuda().eval()
46
- sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True)
47
- model.load_state_dict(sd)
48
-
49
- vid = torch.randn(2, 7, 16, 3, 224, 224).cuda()
50
- features = model.extract_vfeats(vid, for_loop=False).detach().cpu()
51
- print(features.shape)
52
 
53
  # extract and save the state dict only
54
  # sd = torch.load('./ext_weights/sync_model_audioset.pt')['model']
 
41
  return super().load_state_dict(sd, strict)
42
 
43
 
44
+ # if __name__ == "__main__":
45
+ # model = Synchformer().cuda().eval()
46
+ # sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True)
47
+ # model.load_state_dict(sd)
48
+
49
+ # vid = torch.randn(2, 7, 16, 3, 224, 224).cuda()
50
+ # features = model.extract_vfeats(vid, for_loop=False).detach().cpu()
51
+ # print(features.shape)
52
 
53
  # extract and save the state dict only
54
  # sd = torch.load('./ext_weights/sync_model_audioset.pt')['model']