Spaces:
Running
Running
Commit
·
e63cf2d
1
Parent(s):
a0c6429
Fix case of no extra mappings for jax/torch
Browse files- pysr/sr.py +4 -0
pysr/sr.py
CHANGED
|
@@ -546,6 +546,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 546 |
raise NotImplementedError(
|
| 547 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 548 |
)
|
|
|
|
|
|
|
| 549 |
|
| 550 |
if extra_torch_mappings is not None:
|
| 551 |
for value in extra_jax_mappings.values():
|
|
@@ -553,6 +555,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 553 |
raise NotImplementedError(
|
| 554 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
| 555 |
)
|
|
|
|
|
|
|
| 556 |
|
| 557 |
if maxsize > 40:
|
| 558 |
warnings.warn(
|
|
|
|
| 546 |
raise NotImplementedError(
|
| 547 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
| 548 |
)
|
| 549 |
+
else:
|
| 550 |
+
extra_jax_mappings = {}
|
| 551 |
|
| 552 |
if extra_torch_mappings is not None:
|
| 553 |
for value in extra_jax_mappings.values():
|
|
|
|
| 555 |
raise NotImplementedError(
|
| 556 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
| 557 |
)
|
| 558 |
+
else:
|
| 559 |
+
extra_torch_mappings = {}
|
| 560 |
|
| 561 |
if maxsize > 40:
|
| 562 |
warnings.warn(
|