File size: 3,846 Bytes
06638a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
from typing import Dict, Optional, Union
import numpy as np
import jax.numpy as jnp
from jax import Array
from safetensors import numpy, safe_open
def save(tensors: Dict[str, Array], metadata: Optional[Dict[str, str]] = None) -> bytes:
"""
Saves a dictionary of tensors into raw bytes in safetensors format.
Args:
tensors (`Dict[str, Array]`):
The incoming tensors. Tensors need to be contiguous and dense.
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
Optional text only metadata you might want to save in your header.
For instance it can be useful to specify more about the underlying
tensors. This is purely informative and does not affect tensor loading.
Returns:
`bytes`: The raw bytes representing the format
Example:
```python
from safetensors.flax import save
from jax import numpy as jnp
tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))}
byte_data = save(tensors)
```
"""
np_tensors = _jnp2np(tensors)
return numpy.save(np_tensors, metadata=metadata)
def save_file(
tensors: Dict[str, Array],
filename: Union[str, os.PathLike],
metadata: Optional[Dict[str, str]] = None,
) -> None:
"""
Saves a dictionary of tensors into raw bytes in safetensors format.
Args:
tensors (`Dict[str, Array]`):
The incoming tensors. Tensors need to be contiguous and dense.
filename (`str`, or `os.PathLike`)):
The filename we're saving into.
metadata (`Dict[str, str]`, *optional*, defaults to `None`):
Optional text only metadata you might want to save in your header.
For instance it can be useful to specify more about the underlying
tensors. This is purely informative and does not affect tensor loading.
Returns:
`None`
Example:
```python
from safetensors.flax import save_file
from jax import numpy as jnp
tensors = {"embedding": jnp.zeros((512, 1024)), "attention": jnp.zeros((256, 256))}
save_file(tensors, "model.safetensors")
```
"""
np_tensors = _jnp2np(tensors)
return numpy.save_file(np_tensors, filename, metadata=metadata)
def load(data: bytes) -> Dict[str, Array]:
"""
Loads a safetensors file into flax format from pure bytes.
Args:
data (`bytes`):
The content of a safetensors file
Returns:
`Dict[str, Array]`: dictionary that contains name as key, value as `Array` on cpu
Example:
```python
from safetensors.flax import load
file_path = "./my_folder/bert.safetensors"
with open(file_path, "rb") as f:
data = f.read()
loaded = load(data)
```
"""
flat = numpy.load(data)
return _np2jnp(flat)
def load_file(filename: Union[str, os.PathLike]) -> Dict[str, Array]:
"""
Loads a safetensors file into flax format.
Args:
filename (`str`, or `os.PathLike`)):
The name of the file which contains the tensors
Returns:
`Dict[str, Array]`: dictionary that contains name as key, value as `Array`
Example:
```python
from safetensors.flax import load_file
file_path = "./my_folder/bert.safetensors"
loaded = load_file(file_path)
```
"""
result = {}
with safe_open(filename, framework="flax") as f:
for k in f.keys():
result[k] = f.get_tensor(k)
return result
def _np2jnp(numpy_dict: Dict[str, np.ndarray]) -> Dict[str, Array]:
for k, v in numpy_dict.items():
numpy_dict[k] = jnp.array(v)
return numpy_dict
def _jnp2np(jnp_dict: Dict[str, Array]) -> Dict[str, np.array]:
for k, v in jnp_dict.items():
jnp_dict[k] = np.asarray(v)
return jnp_dict
|