cbensimon HF Staff commited on
Commit
b85bc94
·
verified ·
1 Parent(s): a66ed84

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -2
model.py CHANGED
@@ -43,7 +43,8 @@ class SALMONN(nn.Module):
43
  lora_dropout=0.1,
44
  second_per_frame=0.333333,
45
  second_stride=0.333333,
46
- low_resource=False
 
47
  ):
48
 
49
  super().__init__()
@@ -115,7 +116,7 @@ class SALMONN(nn.Module):
115
 
116
  # load ckpt
117
  ckpt_dict = torch.load(ckpt)['model']
118
- self.load_state_dict(ckpt_dict, strict=False)
119
 
120
  def generate(
121
  self,
 
43
  lora_dropout=0.1,
44
  second_per_frame=0.333333,
45
  second_stride=0.333333,
46
+ low_resource=False,
47
+ device=None,
48
  ):
49
 
50
  super().__init__()
 
116
 
117
  # load ckpt
118
  ckpt_dict = torch.load(ckpt)['model']
119
+ self.load_state_dict(ckpt_dict, strict=False, map_location=device)
120
 
121
  def generate(
122
  self,