Update model.py
Browse files
model.py
CHANGED
@@ -43,7 +43,8 @@ class SALMONN(nn.Module):
|
|
43 |
lora_dropout=0.1,
|
44 |
second_per_frame=0.333333,
|
45 |
second_stride=0.333333,
|
46 |
-
low_resource=False
|
|
|
47 |
):
|
48 |
|
49 |
super().__init__()
|
@@ -115,7 +116,7 @@ class SALMONN(nn.Module):
|
|
115 |
|
116 |
# load ckpt
|
117 |
ckpt_dict = torch.load(ckpt)['model']
|
118 |
-
self.load_state_dict(ckpt_dict, strict=False)
|
119 |
|
120 |
def generate(
|
121 |
self,
|
|
|
43 |
lora_dropout=0.1,
|
44 |
second_per_frame=0.333333,
|
45 |
second_stride=0.333333,
|
46 |
+
low_resource=False,
|
47 |
+
device=None,
|
48 |
):
|
49 |
|
50 |
super().__init__()
|
|
|
116 |
|
117 |
# load ckpt
|
118 |
ckpt_dict = torch.load(ckpt)['model']
|
119 |
+
self.load_state_dict(ckpt_dict, strict=False, map_location=device)
|
120 |
|
121 |
def generate(
|
122 |
self,
|