File size: 4,782 Bytes
491eded |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import importlib
__attributes = {
'SparseStructureEncoder': 'sparse_structure_vae',
'SparseStructureDecoder': 'sparse_structure_vae',
'SparseStructureFlowModel': 'sparse_structure_flow',
'SLatEncoder': 'structured_latent_vae',
'SLatGaussianDecoder': 'structured_latent_vae',
'SLatRadianceFieldDecoder': 'structured_latent_vae',
'SLatMeshDecoder': 'structured_latent_vae',
'ElasticSLatEncoder': 'structured_latent_vae',
'ElasticSLatGaussianDecoder': 'structured_latent_vae',
'ElasticSLatRadianceFieldDecoder': 'structured_latent_vae',
'ElasticSLatMeshDecoder': 'structured_latent_vae',
'SLatFlowModel': 'structured_latent_flow',
'ElasticSLatFlowModel': 'structured_latent_flow',
}
__submodules = []
__all__ = list(__attributes.keys()) + __submodules
def __getattr__(name):
if name not in globals():
if name in __attributes:
module_name = __attributes[name]
module = importlib.import_module(f".{module_name}", __name__)
globals()[name] = getattr(module, name)
elif name in __submodules:
module = importlib.import_module(f".{name}", __name__)
globals()[name] = module
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
return globals()[name]
def from_pretrained(path: str, **kwargs):
"""
Load a model from a pretrained checkpoint.
Args:
path: The path to the checkpoint. Can be either local path or a Hugging Face model name.
NOTE: config file and model file should take the name f'{path}.json' and f'{path}.safetensors' respectively.
**kwargs: Additional arguments for the model constructor.
"""
import os
import json
from safetensors.torch import load_file
is_local = os.path.exists(f"{path}.json") and os.path.exists(f"{path}.safetensors")
# print(f"is local: {is_local}, path: {path} because {os.path.exists(f'{path}.json')} and {os.path.exists(f'{path}.safetensors')}")
if is_local:
config_file = f"{path}.json"
model_file = f"{path}.safetensors"
else:
from huggingface_hub import hf_hub_download
path_parts = path.split('/')
repo_id = f'{path_parts[0]}/{path_parts[1]}'
model_name = '/'.join(path_parts[2:])
config_file = hf_hub_download(repo_id, f"{model_name}.json")
model_file = hf_hub_download(repo_id, f"{model_name}.safetensors")
with open(config_file, 'r') as f:
config = json.load(f)
# print(f"Config loaded successfully: {config.get('name', 'Name not found in config')}")
if 'name' not in config:
raise ValueError(f"Config file missing required 'name' field")
model_class = config['name']
if model_class.lower() in [k.lower() for k in __attributes.keys()]:
# Try to find case-insensitive match
for k in __attributes.keys():
if k.lower() == model_class.lower():
model_class = k
break
# print(f"Using model class: {model_class}")
try:
model_constructor = __getattr__(model_class)
except AttributeError as e:
print(f"Model lookup failed: {e}")
raise ValueError(f"Model class '{model_class}' not found in available models: {list(__attributes.keys())}")
# print(f"Initializing model with args: {config.get('args', {})}")
model = model_constructor(**config.get('args', {}), **kwargs)
# Load state dict
state_dict = load_file(model_file)
# print(f"State dict loaded successfully from {model_file}")
# Check key compatibility
model_keys = set(model.state_dict().keys())
loaded_keys = set(state_dict.keys())
missing_keys = model_keys - loaded_keys
unexpected_keys = loaded_keys - model_keys
if missing_keys:
print(f"Missing keys in state dict: {missing_keys}")
if unexpected_keys:
print(f"Unexpected keys in state dict: {unexpected_keys}")
# Load state dict with strict=False to allow missing keys
model.load_state_dict(state_dict, strict=False)
return model
# For Pylance
if __name__ == '__main__':
from .sparse_structure_vae import (
SparseStructureEncoder,
SparseStructureDecoder,
)
from .sparse_structure_flow import SparseStructureFlowModel
from .structured_latent_vae import (
SLatEncoder,
SLatGaussianDecoder,
SLatRadianceFieldDecoder,
SLatMeshDecoder,
ElasticSLatEncoder,
ElasticSLatGaussianDecoder,
ElasticSLatRadianceFieldDecoder,
ElasticSLatMeshDecoder,
)
from .structured_latent_flow import (
SLatFlowModel,
ElasticSLatFlowModel,
)
|