Spaces:
Running
on
Zero
Running
on
Zero
import json | |
import logging | |
import os | |
import pathlib | |
import re | |
from copy import deepcopy | |
from pathlib import Path | |
from typing import Optional, Tuple, Union, Dict, Any | |
import torch | |
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] | |
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs | |
def _natural_key(string_): | |
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] | |
def _rescan_model_configs(): | |
global _MODEL_CONFIGS | |
config_ext = (".json",) | |
config_files = [] | |
for config_path in _MODEL_CONFIG_PATHS: | |
if config_path.is_file() and config_path.suffix in config_ext: | |
config_files.append(config_path) | |
elif config_path.is_dir(): | |
for ext in config_ext: | |
config_files.extend(config_path.glob(f"*{ext}")) | |
for cf in config_files: | |
with open(cf, "r", encoding="utf8") as f: | |
model_cfg = json.load(f) | |
if all(a in model_cfg for a in ("embed_dim", "vision_cfg", "text_cfg")): | |
_MODEL_CONFIGS[cf.stem] = model_cfg | |
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))) | |
_rescan_model_configs() # initial populate of model config registry | |
def list_models(): | |
"""enumerate available model architectures based on config files""" | |
return list(_MODEL_CONFIGS.keys()) | |
def add_model_config(path): | |
"""add model config path or file and update registry""" | |
if not isinstance(path, Path): | |
path = Path(path) | |
_MODEL_CONFIG_PATHS.append(path) | |
_rescan_model_configs() | |
def get_model_config(model_name): | |
if model_name in _MODEL_CONFIGS: | |
return deepcopy(_MODEL_CONFIGS[model_name]) | |
else: | |
return None | |