File size: 1,976 Bytes
			
			5238467 9d7284e 5238467  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62  | 
								"""Utility for loading the models from HF."""
from pathlib import Path
import typing as tp
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
import torch
from audiocraft.models import builders, MusicGen
MODEL_CHECKPOINTS_MAP = {
    "small": "facebook/musicgen-small",
    "medium": "facebook/musicgen-medium",
    "large": "facebook/musicgen-large",
    "melody": "facebook/musicgen-melody",
}
def _get_state_dict(file_or_url: tp.Union[Path, str],
                    filename="state_dict.bin", device='cpu'):
    # Return the state dict either from a file or url
    print("loading", file_or_url, filename)
    file_or_url = str(file_or_url)
    assert isinstance(file_or_url, str)
    return torch.load(
        hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device)
def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
    pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin")
    cfg = OmegaConf.create(pkg['xp.cfg'])
    cfg.device = str(device)
    model = builders.get_compression_model(cfg)
    model.load_state_dict(pkg['best_state'])
    model.eval()
    model.cfg = cfg
    return model
def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
    pkg = _get_state_dict(file_or_url)
    cfg = OmegaConf.create(pkg['xp.cfg'])
    cfg.device = str(device)
    if cfg.device == 'cpu':
        cfg.transformer_lm.memory_efficient = False
        cfg.transformer_lm.custom = True
        cfg.dtype = 'float32'
    else:
        cfg.dtype = 'float16'
    model = builders.get_lm_model(cfg)
    model.load_state_dict(pkg['best_state'])
    model.eval()
    model.cfg = cfg
    return model
def get_pretrained(name: str = 'small', device='cuda'):
    model_id = MODEL_CHECKPOINTS_MAP[name]
    compression_model = load_compression_model(model_id, device=device)
    lm = load_lm_model(model_id, device=device)
    return MusicGen(name, compression_model, lm)
 |