Abhishek Thakur
custom snapshot download
1fffe05
import os
from pathlib import Path
from typing import Dict, List, Optional, Union
from huggingface_hub import HfApi
from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE, REPO_TYPES
from huggingface_hub.file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
from huggingface_hub.utils import filter_repo_objects, validate_hf_hub_args
from joblib import Parallel, delayed
@validate_hf_hub_args
def snapshot_download(
repo_id: str,
*,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Optional[Union[Dict, str]] = None,
proxies: Optional[Dict] = None,
etag_timeout: Optional[float] = 10,
resume_download: Optional[bool] = False,
use_auth_token: Optional[Union[bool, str]] = None,
local_files_only: Optional[bool] = False,
allow_regex: Optional[Union[List[str], str]] = None,
ignore_regex: Optional[Union[List[str], str]] = None,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
is useful when you want all files from a repo, because you don't know which
ones you will need a priori. All files are nested inside a folder in order
to keep their actual filename relative to that folder.
An alternative would be to just clone a repo but this would require that the
user always has git and git-lfs installed, and properly configured.
Args:
repo_id (`str`):
A user or an organization name and a repo name separated by a `/`.
revision (`str`, *optional*):
An optional Git revision id which can be a branch name, a tag, or a
commit hash.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if uploading to a dataset or space,
`None` or `"model"` if uploading to a model. Default is `None`.
cache_dir (`str`, `Path`, *optional*):
Path to the folder where cached files are stored.
library_name (`str`, *optional*):
The name of the library to which the object corresponds.
library_version (`str`, *optional*):
The version of the library.
user_agent (`str`, `dict`, *optional*):
The user-agent info in the form of a dictionary or a string.
proxies (`dict`, *optional*):
Dictionary mapping protocol to the URL of the proxy passed to
`requests.request`.
etag_timeout (`float`, *optional*, defaults to `10`):
When fetching ETag, how many seconds to wait for the server to send
data before giving up which is passed to `requests.request`.
resume_download (`bool`, *optional*, defaults to `False):
If `True`, resume a previously interrupted download.
use_auth_token (`str`, `bool`, *optional*):
A token to be used for the download.
- If `True`, the token is read from the HuggingFace config
folder.
- If a string, it's used as the authentication token.
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, avoid downloading the file and return the path to the
local cached file if it exists.
allow_patterns (`List[str]` or `str`, *optional*):
If provided, only files matching at least one pattern are downloaded.
ignore_patterns (`List[str]` or `str`, *optional*):
If provided, files matching any of the patterns are not downloaded.
Returns:
Local folder path (string) of repo snapshot
<Tip>
Raises the following errors:
- [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
if `use_auth_token=True` and the token cannot be found.
- [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
ETag cannot be determined.
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
if some parameter value is invalid
</Tip>
"""
if cache_dir is None:
cache_dir = HUGGINGFACE_HUB_CACHE
if revision is None:
revision = DEFAULT_REVISION
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if repo_type is None:
repo_type = "model"
if repo_type not in REPO_TYPES:
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are:" f" {str(REPO_TYPES)}")
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
# TODO: remove these 4 lines in version 0.12
# Deprecated code to ensure backward compatibility.
if allow_regex is not None:
allow_patterns = allow_regex
if ignore_regex is not None:
ignore_patterns = ignore_regex
# if we have no internet connection we will look for an
# appropriate folder in the cache
# If the specified revision is a commit hash, look inside "snapshots".
# If the specified revision is a branch or tag, look inside "refs".
if local_files_only:
if REGEX_COMMIT_HASH.match(revision):
commit_hash = revision
else:
# retrieve commit_hash from file
ref_path = os.path.join(storage_folder, "refs", revision)
with open(ref_path) as f:
commit_hash = f.read()
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
if os.path.exists(snapshot_folder):
return snapshot_folder
raise ValueError(
"Cannot find an appropriate cached snapshot folder for the specified"
" revision on the local disk and outgoing traffic has been disabled. To"
" enable repo look-ups and downloads online, set 'local_files_only' to"
" False."
)
# if we have internet connection we retrieve the correct folder name from the huggingface api
_api = HfApi()
repo_info = _api.repo_info(
repo_id=repo_id,
repo_type=repo_type,
revision=revision,
use_auth_token=use_auth_token,
)
filtered_repo_files = list(
filter_repo_objects(
items=[f.rfilename for f in repo_info.siblings],
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
)
commit_hash = repo_info.sha
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
# if passed revision is not identical to commit_hash
# then revision has to be a branch name or tag name.
# In that case store a ref.
if revision != commit_hash:
ref_path = os.path.join(storage_folder, "refs", revision)
os.makedirs(os.path.dirname(ref_path), exist_ok=True)
with open(ref_path, "w") as f:
f.write(commit_hash)
# we pass the commit_hash to hf_hub_download
# so no network call happens if we already
# have the file locally.
# for repo_file in tqdm(filtered_repo_files, f"Fetching {len(filtered_repo_files)} files"):
# _ = hf_hub_download(
# repo_id,
# filename=repo_file,
# repo_type=repo_type,
# revision=commit_hash,
# cache_dir=cache_dir,
# library_name=library_name,
# library_version=library_version,
# user_agent=user_agent,
# proxies=proxies,
# etag_timeout=etag_timeout,
# resume_download=resume_download,
# use_auth_token=use_auth_token,
# )
Parallel(n_jobs=10, backend="threading")(
delayed(hf_hub_download)(
repo_id,
filename=repo_file,
repo_type=repo_type,
revision=commit_hash,
cache_dir=cache_dir,
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
proxies=proxies,
etag_timeout=etag_timeout,
resume_download=resume_download,
use_auth_token=use_auth_token,
)
for repo_file in filtered_repo_files
)
return snapshot_folder