ford442 commited on
Commit
b67c0d5
·
verified ·
1 Parent(s): 122eb50

Update audiocraft/models/musicgen.py

Browse files
Files changed (1) hide show
  1. audiocraft/models/musicgen.py +2 -2
audiocraft/models/musicgen.py CHANGED
@@ -100,7 +100,7 @@ class MusicGen:
100
  return self.compression_model.channels
101
 
102
  @staticmethod
103
- def get_pretrained(name: str = 'facebook/musicgen-melody', device=None, torch_dtype=torch.float32):
104
  """Return pretrained model, we provide four models:
105
  - facebook/musicgen-small (300M), text to music,
106
  # see: https://huggingface.co/facebook/musicgen-small
@@ -129,7 +129,7 @@ class MusicGen:
129
  f"Please use full pre-trained id instead: facebook/musicgen-{name}")
130
  name = _HF_MODEL_CHECKPOINTS_MAP[name]
131
 
132
- lm = load_lm_model(name, device=device, torch_dtype=torch_dtype)
133
  compression_model = load_compression_model(name, device=device)
134
  if 'self_wav' in lm.condition_provider.conditioners:
135
  lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
 
100
  return self.compression_model.channels
101
 
102
  @staticmethod
103
+ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None, depth='float32'):
104
  """Return pretrained model, we provide four models:
105
  - facebook/musicgen-small (300M), text to music,
106
  # see: https://huggingface.co/facebook/musicgen-small
 
129
  f"Please use full pre-trained id instead: facebook/musicgen-{name}")
130
  name = _HF_MODEL_CHECKPOINTS_MAP[name]
131
 
132
+ lm = load_lm_model(name, device=device, depth=depth)
133
  compression_model = load_compression_model(name, device=device)
134
  if 'self_wav' in lm.condition_provider.conditioners:
135
  lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True