Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import re | |
| import json | |
| import yaml | |
| import logging | |
| from pathlib import Path | |
| from typing import Any, Literal, Optional | |
| from dataclasses import dataclass | |
| from .constants import Keys | |
| import gguf | |
| logger = logging.getLogger("metadata") | |
| class Metadata: | |
| # Authorship Metadata to be written to GGUF KV Store | |
| name: Optional[str] = None | |
| author: Optional[str] = None | |
| version: Optional[str] = None | |
| organization: Optional[str] = None | |
| finetune: Optional[str] = None | |
| basename: Optional[str] = None | |
| description: Optional[str] = None | |
| quantized_by: Optional[str] = None | |
| size_label: Optional[str] = None | |
| url: Optional[str] = None | |
| doi: Optional[str] = None | |
| uuid: Optional[str] = None | |
| repo_url: Optional[str] = None | |
| source_url: Optional[str] = None | |
| source_doi: Optional[str] = None | |
| source_uuid: Optional[str] = None | |
| source_repo_url: Optional[str] = None | |
| license: Optional[str] = None | |
| license_name: Optional[str] = None | |
| license_link: Optional[str] = None | |
| base_models: Optional[list[dict]] = None | |
| tags: Optional[list[str]] = None | |
| languages: Optional[list[str]] = None | |
| datasets: Optional[list[str]] = None | |
| def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata: | |
| # This grabs as many contextual authorship metadata as possible from the model repository | |
| # making any conversion as required to match the gguf kv store metadata format | |
| # as well as giving users the ability to override any authorship metadata that may be incorrect | |
| # Create a new Metadata instance | |
| metadata = Metadata() | |
| model_card = Metadata.load_model_card(model_path) | |
| hf_params = Metadata.load_hf_parameters(model_path) | |
| # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter | |
| # heuristics | |
| metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) | |
| # Metadata Override File Provided | |
| # This is based on LLM_KV_NAMES mapping in llama.cpp | |
| metadata_override = Metadata.load_metadata_override(metadata_override_path) | |
| metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) | |
| metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) | |
| metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) | |
| metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization) | |
| metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune) | |
| metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename) | |
| metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description) | |
| metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by) | |
| metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label) | |
| metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name) | |
| metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link) | |
| metadata.url = metadata_override.get(Keys.General.URL, metadata.url) | |
| metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi) | |
| metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid) | |
| metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url) | |
| metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url) | |
| metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi) | |
| metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid) | |
| metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url) | |
| # Base Models is received here as an array of models | |
| metadata.base_models = metadata_override.get("general.base_models", metadata.base_models) | |
| metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) | |
| metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages) | |
| metadata.datasets = metadata_override.get(Keys.General.DATASETS, metadata.datasets) | |
| # Direct Metadata Override (via direct cli argument) | |
| if model_name is not None: | |
| metadata.name = model_name | |
| return metadata | |
| def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]: | |
| if metadata_override_path is None or not metadata_override_path.is_file(): | |
| return {} | |
| with open(metadata_override_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]: | |
| if model_path is None or not model_path.is_dir(): | |
| return {} | |
| model_card_path = model_path / "README.md" | |
| if not model_card_path.is_file(): | |
| return {} | |
| # The model card metadata is assumed to always be in YAML | |
| # ref: https://github.com/huggingface/transformers/blob/a5c642fe7a1f25d3bdcd76991443ba6ff7ee34b2/src/transformers/modelcard.py#L468-L473 | |
| with open(model_card_path, "r", encoding="utf-8") as f: | |
| if f.readline() == "---\n": | |
| raw = f.read().partition("---\n")[0] | |
| data = yaml.safe_load(raw) | |
| if isinstance(data, dict): | |
| return data | |
| else: | |
| logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict") | |
| return {} | |
| else: | |
| return {} | |
| def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: | |
| if model_path is None or not model_path.is_dir(): | |
| return {} | |
| config_path = model_path / "config.json" | |
| if not config_path.is_file(): | |
| return {} | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| def id_to_title(string): | |
| # Convert capitalization into title form unless acronym or version number | |
| return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()]) | |
| def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]: | |
| # Huggingface often store model id as '<org>/<model name>' | |
| # so let's parse it and apply some heuristics if possible for model name components | |
| if model_id is None: | |
| # model ID missing | |
| return None, None, None, None, None, None | |
| if ' ' in model_id: | |
| # model ID is actually a normal human sentence | |
| # which means its most likely a normal model name only | |
| # not part of the hugging face naming standard, but whatever | |
| return model_id, None, None, None, None, None | |
| if '/' in model_id: | |
| # model ID (huggingface style) | |
| org_component, model_full_name_component = model_id.split('/', 1) | |
| else: | |
| # model ID but missing org components | |
| org_component, model_full_name_component = None, model_id | |
| # Check if we erroneously matched against './' or '../' etc... | |
| if org_component is not None and len(org_component) > 0 and org_component[0] == '.': | |
| org_component = None | |
| name_parts: list[str] = model_full_name_component.split('-') | |
| # Remove empty parts | |
| for i in reversed(range(len(name_parts))): | |
| if len(name_parts[i]) == 0: | |
| del name_parts[i] | |
| name_types: list[ | |
| set[Literal["basename", "size_label", "finetune", "version", "type"]] | |
| ] = [set() for _ in name_parts] | |
| # Annotate the name | |
| for i, part in enumerate(name_parts): | |
| # Version | |
| if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE): | |
| name_types[i].add("version") | |
| # Quant type (should not be there for base models, but still annotated) | |
| elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE): | |
| name_types[i].add("type") | |
| name_parts[i] = part.upper() | |
| # Model size | |
| elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE): | |
| part = part.replace("_", ".") | |
| # Handle weird bloom-7b1 notation | |
| if part[-1].isdecimal(): | |
| part = part[:-2] + "." + part[-1] + part[-2] | |
| # Normalize the size suffixes | |
| if len(part) > 1 and part[-2].isdecimal(): | |
| if part[-1] in "kmbt": | |
| part = part[:-1] + part[-1].upper() | |
| if total_params != 0: | |
| try: | |
| label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1])) | |
| # Only use it as a size label if it's close or bigger than the model size | |
| # Note that LoRA adapters don't necessarily include all layers, | |
| # so this is why bigger label sizes are accepted. | |
| # Do not use the size label when it's smaller than 1/8 of the model size | |
| if (total_params < 0 and label_params < abs(total_params) // 8) or ( | |
| # Check both directions when the current model isn't a LoRA adapter | |
| total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8 | |
| ): | |
| # Likely a context length | |
| name_types[i].add("finetune") | |
| # Lowercase the size when it's a context length | |
| part = part[:-1] + part[-1].lower() | |
| except ValueError: | |
| # Failed to convert the size label to float, use it anyway | |
| pass | |
| if len(name_types[i]) == 0: | |
| name_types[i].add("size_label") | |
| name_parts[i] = part | |
| # Some easy to recognize finetune names | |
| elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE): | |
| if total_params < 0 and part.lower() == "lora": | |
| # ignore redundant "lora" in the finetune part when the output is a lora adapter | |
| name_types[i].add("type") | |
| else: | |
| name_types[i].add("finetune") | |
| # Ignore word-based size labels when there is at least a number-based one present | |
| # TODO: should word-based size labels always be removed instead? | |
| if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n): | |
| for n, t in zip(name_parts, name_types): | |
| if "size_label" in t: | |
| if all(c.isalpha() for c in n): | |
| t.remove("size_label") | |
| at_start = True | |
| # Find the basename through the annotated name | |
| for part, t in zip(name_parts, name_types): | |
| if at_start and ((len(t) == 0 and part[0].isalpha()) or "version" in t): | |
| t.add("basename") | |
| else: | |
| if at_start: | |
| at_start = False | |
| if len(t) == 0: | |
| t.add("finetune") | |
| # Remove the basename annotation from trailing version | |
| for part, t in zip(reversed(name_parts), reversed(name_types)): | |
| if "basename" in t and len(t) > 1: | |
| t.remove("basename") | |
| else: | |
| break | |
| basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None | |
| # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys) | |
| size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None | |
| finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None | |
| # TODO: should the basename version always be excluded? | |
| # NOTE: multiple finetune versions are joined together | |
| version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None | |
| if size_label is None and finetune is None and version is None: | |
| # Too ambiguous, output nothing | |
| basename = None | |
| return model_full_name_component, org_component, basename, finetune, version, size_label | |
| def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata: | |
| # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 | |
| # Model Card Heuristics | |
| ######################## | |
| if model_card is not None: | |
| def use_model_card_metadata(metadata_key: str, model_card_key: str): | |
| if model_card_key in model_card and getattr(metadata, metadata_key, None) is None: | |
| setattr(metadata, metadata_key, model_card.get(model_card_key)) | |
| def use_array_model_card_metadata(metadata_key: str, model_card_key: str): | |
| # Note: Will append rather than replace if already exist | |
| tags_value = model_card.get(model_card_key, None) | |
| if tags_value is None: | |
| return | |
| current_value = getattr(metadata, metadata_key, None) | |
| if current_value is None: | |
| current_value = [] | |
| if isinstance(tags_value, str): | |
| current_value.append(tags_value) | |
| elif isinstance(tags_value, list): | |
| current_value.extend(tags_value) | |
| setattr(metadata, metadata_key, current_value) | |
| # LLAMA.cpp's direct internal convention | |
| # (Definitely not part of hugging face formal/informal standard) | |
| ######################################### | |
| use_model_card_metadata("name", "name") | |
| use_model_card_metadata("author", "author") | |
| use_model_card_metadata("version", "version") | |
| use_model_card_metadata("organization", "organization") | |
| use_model_card_metadata("description", "description") | |
| use_model_card_metadata("finetune", "finetune") | |
| use_model_card_metadata("basename", "basename") | |
| use_model_card_metadata("size_label", "size_label") | |
| use_model_card_metadata("source_url", "url") | |
| use_model_card_metadata("source_doi", "doi") | |
| use_model_card_metadata("source_uuid", "uuid") | |
| use_model_card_metadata("source_repo_url", "repo_url") | |
| # LLAMA.cpp's huggingface style convention | |
| # (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style) | |
| ########################################### | |
| use_model_card_metadata("name", "model_name") | |
| use_model_card_metadata("author", "model_author") | |
| use_model_card_metadata("version", "model_version") | |
| use_model_card_metadata("organization", "model_organization") | |
| use_model_card_metadata("description", "model_description") | |
| use_model_card_metadata("finetune", "model_finetune") | |
| use_model_card_metadata("basename", "model_basename") | |
| use_model_card_metadata("size_label", "model_size_label") | |
| use_model_card_metadata("source_url", "model_url") | |
| use_model_card_metadata("source_doi", "model_doi") | |
| use_model_card_metadata("source_uuid", "model_uuid") | |
| use_model_card_metadata("source_repo_url", "model_repo_url") | |
| # Hugging Face Direct Convention | |
| ################################# | |
| # Not part of huggingface model card standard but notice some model creator using it | |
| # such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF' | |
| use_model_card_metadata("name", "model_name") | |
| use_model_card_metadata("author", "model_creator") | |
| use_model_card_metadata("basename", "model_type") | |
| if "base_model" in model_card: | |
| # This represents the parent models that this is based on | |
| # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges) | |
| # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md | |
| metadata_base_models = [] | |
| base_model_value = model_card.get("base_model", None) | |
| if base_model_value is not None: | |
| if isinstance(base_model_value, str): | |
| metadata_base_models.append(base_model_value) | |
| elif isinstance(base_model_value, list): | |
| metadata_base_models.extend(base_model_value) | |
| if metadata.base_models is None: | |
| metadata.base_models = [] | |
| for model_id in metadata_base_models: | |
| # NOTE: model size of base model is assumed to be similar to the size of the current model | |
| model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) | |
| base_model = {} | |
| if model_full_name_component is not None: | |
| base_model["name"] = Metadata.id_to_title(model_full_name_component) | |
| if org_component is not None: | |
| base_model["organization"] = Metadata.id_to_title(org_component) | |
| if version is not None: | |
| base_model["version"] = version | |
| if org_component is not None and model_full_name_component is not None: | |
| base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}" | |
| metadata.base_models.append(base_model) | |
| use_model_card_metadata("license", "license") | |
| use_model_card_metadata("license_name", "license_name") | |
| use_model_card_metadata("license_link", "license_link") | |
| use_array_model_card_metadata("tags", "tags") | |
| use_array_model_card_metadata("tags", "pipeline_tag") | |
| use_array_model_card_metadata("languages", "languages") | |
| use_array_model_card_metadata("languages", "language") | |
| use_array_model_card_metadata("datasets", "datasets") | |
| use_array_model_card_metadata("datasets", "dataset") | |
| # Hugging Face Parameter Heuristics | |
| #################################### | |
| if hf_params is not None: | |
| hf_name_or_path = hf_params.get("_name_or_path") | |
| if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1: | |
| # Use _name_or_path only if its actually a model name and not some computer path | |
| # e.g. 'meta-llama/Llama-2-7b-hf' | |
| model_id = hf_name_or_path | |
| model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) | |
| if metadata.name is None and model_full_name_component is not None: | |
| metadata.name = Metadata.id_to_title(model_full_name_component) | |
| if metadata.organization is None and org_component is not None: | |
| metadata.organization = Metadata.id_to_title(org_component) | |
| if metadata.basename is None and basename is not None: | |
| metadata.basename = basename | |
| if metadata.finetune is None and finetune is not None: | |
| metadata.finetune = finetune | |
| if metadata.version is None and version is not None: | |
| metadata.version = version | |
| if metadata.size_label is None and size_label is not None: | |
| metadata.size_label = size_label | |
| # Directory Folder Name Fallback Heuristics | |
| ############################################ | |
| if model_path is not None: | |
| model_id = model_path.name | |
| model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) | |
| if metadata.name is None and model_full_name_component is not None: | |
| metadata.name = Metadata.id_to_title(model_full_name_component) | |
| if metadata.organization is None and org_component is not None: | |
| metadata.organization = Metadata.id_to_title(org_component) | |
| if metadata.basename is None and basename is not None: | |
| metadata.basename = basename | |
| if metadata.finetune is None and finetune is not None: | |
| metadata.finetune = finetune | |
| if metadata.version is None and version is not None: | |
| metadata.version = version | |
| if metadata.size_label is None and size_label is not None: | |
| metadata.size_label = size_label | |
| return metadata | |
| def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): | |
| assert self.name is not None | |
| gguf_writer.add_name(self.name) | |
| if self.author is not None: | |
| gguf_writer.add_author(self.author) | |
| if self.version is not None: | |
| gguf_writer.add_version(self.version) | |
| if self.organization is not None: | |
| gguf_writer.add_organization(self.organization) | |
| if self.finetune is not None: | |
| gguf_writer.add_finetune(self.finetune) | |
| if self.basename is not None: | |
| gguf_writer.add_basename(self.basename) | |
| if self.description is not None: | |
| gguf_writer.add_description(self.description) | |
| if self.quantized_by is not None: | |
| gguf_writer.add_quantized_by(self.quantized_by) | |
| if self.size_label is not None: | |
| gguf_writer.add_size_label(self.size_label) | |
| if self.license is not None: | |
| gguf_writer.add_license(self.license) | |
| if self.license_name is not None: | |
| gguf_writer.add_license_name(self.license_name) | |
| if self.license_link is not None: | |
| gguf_writer.add_license_link(self.license_link) | |
| if self.url is not None: | |
| gguf_writer.add_url(self.url) | |
| if self.doi is not None: | |
| gguf_writer.add_doi(self.doi) | |
| if self.uuid is not None: | |
| gguf_writer.add_uuid(self.uuid) | |
| if self.repo_url is not None: | |
| gguf_writer.add_repo_url(self.repo_url) | |
| if self.source_url is not None: | |
| gguf_writer.add_source_url(self.source_url) | |
| if self.source_doi is not None: | |
| gguf_writer.add_source_doi(self.source_doi) | |
| if self.source_uuid is not None: | |
| gguf_writer.add_source_uuid(self.source_uuid) | |
| if self.source_repo_url is not None: | |
| gguf_writer.add_source_repo_url(self.source_repo_url) | |
| if self.base_models is not None: | |
| gguf_writer.add_base_model_count(len(self.base_models)) | |
| for key, base_model_entry in enumerate(self.base_models): | |
| if "name" in base_model_entry: | |
| gguf_writer.add_base_model_name(key, base_model_entry["name"]) | |
| if "author" in base_model_entry: | |
| gguf_writer.add_base_model_author(key, base_model_entry["author"]) | |
| if "version" in base_model_entry: | |
| gguf_writer.add_base_model_version(key, base_model_entry["version"]) | |
| if "organization" in base_model_entry: | |
| gguf_writer.add_base_model_organization(key, base_model_entry["organization"]) | |
| if "url" in base_model_entry: | |
| gguf_writer.add_base_model_url(key, base_model_entry["url"]) | |
| if "doi" in base_model_entry: | |
| gguf_writer.add_base_model_doi(key, base_model_entry["doi"]) | |
| if "uuid" in base_model_entry: | |
| gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"]) | |
| if "repo_url" in base_model_entry: | |
| gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"]) | |
| if self.tags is not None: | |
| gguf_writer.add_tags(self.tags) | |
| if self.languages is not None: | |
| gguf_writer.add_languages(self.languages) | |
| if self.datasets is not None: | |
| gguf_writer.add_datasets(self.datasets) | |