Spaces:
Running
Running
Commit
·
99fff5c
1
Parent(s):
e0c68fc
Add export key error telling user to set function mappings
Browse files- pysr/export_jax.py +8 -1
- pysr/export_torch.py +8 -1
pysr/export_jax.py
CHANGED
|
@@ -63,7 +63,14 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
|
|
| 63 |
)
|
| 64 |
if extra_jax_mappings is None:
|
| 65 |
extra_jax_mappings = {}
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
args = [
|
| 68 |
sympy2jaxtext(
|
| 69 |
arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
|
|
|
|
| 63 |
)
|
| 64 |
if extra_jax_mappings is None:
|
| 65 |
extra_jax_mappings = {}
|
| 66 |
+
try:
|
| 67 |
+
_func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
|
| 68 |
+
except KeyError:
|
| 69 |
+
raise KeyError(
|
| 70 |
+
f"Function {expr.func} was not found in JAX function mappings."
|
| 71 |
+
"Please add it to extra_jax_mappings in the format, e.g., "
|
| 72 |
+
"{sympy.sqrt: 'jnp.sqrt'}."
|
| 73 |
+
)
|
| 74 |
args = [
|
| 75 |
sympy2jaxtext(
|
| 76 |
arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
|
pysr/export_torch.py
CHANGED
|
@@ -117,7 +117,14 @@ def _initialize_torch():
|
|
| 117 |
self._torch_func = lambda value: value
|
| 118 |
self._args = ((lambda memodict: memodict[expr.name]),)
|
| 119 |
else:
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
args = []
|
| 122 |
for arg in expr.args:
|
| 123 |
try:
|
|
|
|
| 117 |
self._torch_func = lambda value: value
|
| 118 |
self._args = ((lambda memodict: memodict[expr.name]),)
|
| 119 |
else:
|
| 120 |
+
try:
|
| 121 |
+
self._torch_func = _func_lookup[expr.func]
|
| 122 |
+
except KeyError:
|
| 123 |
+
raise KeyError(
|
| 124 |
+
f"Function {expr.func} was not found in Torch function mappings."
|
| 125 |
+
"Please add it to extra_torch_mappings in the format, e.g., "
|
| 126 |
+
"{sympy.sqrt: torch.sqrt}."
|
| 127 |
+
)
|
| 128 |
args = []
|
| 129 |
for arg in expr.args:
|
| 130 |
try:
|