Update audiocraft/models/musicgen.py
Browse files
audiocraft/models/musicgen.py
CHANGED
@@ -100,7 +100,7 @@ class MusicGen:
|
|
100 |
return self.compression_model.channels
|
101 |
|
102 |
@staticmethod
|
103 |
-
def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
|
104 |
"""Return pretrained model, we provide four models:
|
105 |
- facebook/musicgen-small (300M), text to music,
|
106 |
# see: https://huggingface.co/facebook/musicgen-small
|
@@ -129,7 +129,7 @@ class MusicGen:
|
|
129 |
f"Please use full pre-trained id instead: facebook/musicgen-{name}")
|
130 |
name = _HF_MODEL_CHECKPOINTS_MAP[name]
|
131 |
|
132 |
-
lm = load_lm_model(name, device=device)
|
133 |
compression_model = load_compression_model(name, device=device)
|
134 |
if 'self_wav' in lm.condition_provider.conditioners:
|
135 |
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
|
|
100 |
return self.compression_model.channels
|
101 |
|
102 |
@staticmethod
|
103 |
+
def get_pretrained(name: str = 'facebook/musicgen-melody', device=None, torch_dtype=torch.float32):
|
104 |
"""Return pretrained model, we provide four models:
|
105 |
- facebook/musicgen-small (300M), text to music,
|
106 |
# see: https://huggingface.co/facebook/musicgen-small
|
|
|
129 |
f"Please use full pre-trained id instead: facebook/musicgen-{name}")
|
130 |
name = _HF_MODEL_CHECKPOINTS_MAP[name]
|
131 |
|
132 |
+
lm = load_lm_model(name, device=device, torch_dtype=torch_dtype)
|
133 |
compression_model = load_compression_model(name, device=device)
|
134 |
if 'self_wav' in lm.condition_provider.conditioners:
|
135 |
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|