Upload catalog.py with huggingface_hub
Browse files- catalog.py +26 -23
catalog.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
from pathlib import Path
|
|
|
|
| 4 |
import requests
|
| 5 |
-
import json
|
| 6 |
-
from .artifact import Artifact, Artifactory
|
| 7 |
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
COLLECTION_SEPARATOR =
|
| 10 |
-
PATHS_SEP =
|
| 11 |
|
| 12 |
|
| 13 |
class Catalog(Artifactory):
|
|
@@ -18,9 +19,14 @@ class Catalog(Artifactory):
|
|
| 18 |
try:
|
| 19 |
import unitxt
|
| 20 |
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
except ImportError:
|
| 23 |
-
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class LocalCatalog(Catalog):
|
|
@@ -28,7 +34,7 @@ class LocalCatalog(Catalog):
|
|
| 28 |
location: str = default_catalog_path
|
| 29 |
|
| 30 |
def path(self, artifact_identifier: str):
|
| 31 |
-
assert artifact_identifier.strip(),
|
| 32 |
parts = artifact_identifier.split(COLLECTION_SEPARATOR)
|
| 33 |
parts[-1] = parts[-1] + ".json"
|
| 34 |
return os.path.join(self.location, *parts)
|
|
@@ -50,8 +56,6 @@ class LocalCatalog(Catalog):
|
|
| 50 |
return False
|
| 51 |
return os.path.exists(path) and os.path.isfile(path)
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
def save_artifact(self, artifact: Artifact, artifact_identifier: str, overwrite: bool = False):
|
| 56 |
assert isinstance(artifact, Artifact), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
|
| 57 |
if not overwrite:
|
|
@@ -61,6 +65,7 @@ class LocalCatalog(Catalog):
|
|
| 61 |
path = self.path(artifact_identifier)
|
| 62 |
os.makedirs(Path(path).parent.absolute(), exist_ok=True)
|
| 63 |
artifact.save(path)
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
class GithubCatalog(LocalCatalog):
|
|
@@ -68,38 +73,36 @@ class GithubCatalog(LocalCatalog):
|
|
| 68 |
repo = "unitxt"
|
| 69 |
repo_dir = "src/unitxt/catalog"
|
| 70 |
user = "IBM"
|
| 71 |
-
|
| 72 |
-
|
| 73 |
def prepare(self):
|
| 74 |
-
|
| 75 |
-
|
|
|
|
| 76 |
def load(self, artifact_identifier: str):
|
| 77 |
url = self.path(artifact_identifier)
|
| 78 |
response = requests.get(url)
|
| 79 |
data = response.json()
|
| 80 |
return Artifact.from_dict(data)
|
| 81 |
-
|
| 82 |
def __contains__(self, artifact_identifier: str):
|
| 83 |
url = self.path(artifact_identifier)
|
| 84 |
response = requests.head(url)
|
| 85 |
return response.status_code == 200
|
| 86 |
-
|
| 87 |
-
|
| 88 |
|
| 89 |
|
| 90 |
def verify_legal_catalog_name(name):
|
| 91 |
-
assert re.match(
|
| 92 |
-
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
-
def add_to_catalog(
|
| 96 |
-
|
|
|
|
| 97 |
if catalog is None:
|
| 98 |
if catalog_path is None:
|
| 99 |
catalog_path = default_catalog_path
|
| 100 |
catalog = LocalCatalog(location=catalog_path)
|
| 101 |
verify_legal_catalog_name(name)
|
| 102 |
-
catalog.save_artifact(artifact, name, overwrite=overwrite)
|
| 103 |
# verify name
|
| 104 |
-
|
| 105 |
-
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
from pathlib import Path
|
| 4 |
+
|
| 5 |
import requests
|
|
|
|
|
|
|
| 6 |
|
| 7 |
+
from ._version import get_current_version
|
| 8 |
+
from .artifact import Artifact, Artifactory
|
| 9 |
|
| 10 |
+
COLLECTION_SEPARATOR = "."
|
| 11 |
+
PATHS_SEP = ":"
|
| 12 |
|
| 13 |
|
| 14 |
class Catalog(Artifactory):
|
|
|
|
| 19 |
try:
|
| 20 |
import unitxt
|
| 21 |
|
| 22 |
+
if unitxt.__file__:
|
| 23 |
+
lib_dir = os.path.dirname(unitxt.__file__)
|
| 24 |
+
else:
|
| 25 |
+
lib_dir = os.path.dirname(__file__)
|
| 26 |
except ImportError:
|
| 27 |
+
lib_dir = os.path.dirname(__file__)
|
| 28 |
+
|
| 29 |
+
default_catalog_path = os.path.join(lib_dir, "catalog")
|
| 30 |
|
| 31 |
|
| 32 |
class LocalCatalog(Catalog):
|
|
|
|
| 34 |
location: str = default_catalog_path
|
| 35 |
|
| 36 |
def path(self, artifact_identifier: str):
|
| 37 |
+
assert artifact_identifier.strip(), "artifact_identifier should not be an empty string."
|
| 38 |
parts = artifact_identifier.split(COLLECTION_SEPARATOR)
|
| 39 |
parts[-1] = parts[-1] + ".json"
|
| 40 |
return os.path.join(self.location, *parts)
|
|
|
|
| 56 |
return False
|
| 57 |
return os.path.exists(path) and os.path.isfile(path)
|
| 58 |
|
|
|
|
|
|
|
| 59 |
def save_artifact(self, artifact: Artifact, artifact_identifier: str, overwrite: bool = False):
|
| 60 |
assert isinstance(artifact, Artifact), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
|
| 61 |
if not overwrite:
|
|
|
|
| 65 |
path = self.path(artifact_identifier)
|
| 66 |
os.makedirs(Path(path).parent.absolute(), exist_ok=True)
|
| 67 |
artifact.save(path)
|
| 68 |
+
print(f"Artifact {artifact_identifier} saved to {path}")
|
| 69 |
|
| 70 |
|
| 71 |
class GithubCatalog(LocalCatalog):
|
|
|
|
| 73 |
repo = "unitxt"
|
| 74 |
repo_dir = "src/unitxt/catalog"
|
| 75 |
user = "IBM"
|
| 76 |
+
|
|
|
|
| 77 |
def prepare(self):
|
| 78 |
+
tag = get_current_version().split("+")[0]
|
| 79 |
+
self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}"
|
| 80 |
+
|
| 81 |
def load(self, artifact_identifier: str):
|
| 82 |
url = self.path(artifact_identifier)
|
| 83 |
response = requests.get(url)
|
| 84 |
data = response.json()
|
| 85 |
return Artifact.from_dict(data)
|
| 86 |
+
|
| 87 |
def __contains__(self, artifact_identifier: str):
|
| 88 |
url = self.path(artifact_identifier)
|
| 89 |
response = requests.head(url)
|
| 90 |
return response.status_code == 200
|
|
|
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
def verify_legal_catalog_name(name):
|
| 94 |
+
assert re.match(
|
| 95 |
+
r"^[\w" + COLLECTION_SEPARATOR + "]+$", name
|
| 96 |
+
), 'Catalog name should be alphanumeric, ":" should specify dirs (instead of "/").'
|
| 97 |
|
| 98 |
|
| 99 |
+
def add_to_catalog(
|
| 100 |
+
artifact: Artifact, name: str, catalog: Catalog = None, overwrite: bool = False, catalog_path: str = None
|
| 101 |
+
):
|
| 102 |
if catalog is None:
|
| 103 |
if catalog_path is None:
|
| 104 |
catalog_path = default_catalog_path
|
| 105 |
catalog = LocalCatalog(location=catalog_path)
|
| 106 |
verify_legal_catalog_name(name)
|
| 107 |
+
catalog.save_artifact(artifact, name, overwrite=overwrite) # remove collection (its actually the dir).
|
| 108 |
# verify name
|
|
|
|
|
|