Spaces:
Sleeping
Sleeping
Commit
·
68b3673
1
Parent(s):
a06bfc4
Add torch format output; dont import jax/torch by default
Browse files- pysr/__init__.py +0 -2
- pysr/export_jax.py +47 -51
- pysr/export_torch.py +0 -2
- pysr/sr.py +21 -4
pysr/__init__.py
CHANGED
|
@@ -1,4 +1,2 @@
|
|
| 1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
| 2 |
from .feynman_problems import Problem, FeynmanProblem
|
| 3 |
-
from .export_jax import sympy2jax
|
| 4 |
-
from .export_torch import sympy2torch
|
|
|
|
| 1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
| 2 |
from .feynman_problems import Problem, FeynmanProblem
|
|
|
|
|
|
pysr/export_jax.py
CHANGED
|
@@ -2,60 +2,56 @@ import functools as ft
|
|
| 2 |
import sympy
|
| 3 |
import string
|
| 4 |
import random
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
from jax import numpy as jnp
|
| 9 |
-
from jax.scipy import special as jsp
|
| 10 |
|
| 11 |
# Special since need to reduce arguments.
|
| 12 |
-
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
except ImportError:
|
| 58 |
-
...
|
| 59 |
|
| 60 |
def sympy2jaxtext(expr, parameters, symbols_in):
|
| 61 |
if issubclass(expr.func, sympy.Float):
|
|
|
|
| 2 |
import sympy
|
| 3 |
import string
|
| 4 |
import random
|
| 5 |
+
import jax
|
| 6 |
+
from jax import numpy as jnp
|
| 7 |
+
from jax.scipy import special as jsp
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# Special since need to reduce arguments.
|
| 10 |
+
MUL = 0
|
| 11 |
+
ADD = 1
|
| 12 |
|
| 13 |
+
_jnp_func_lookup = {
|
| 14 |
+
sympy.Mul: MUL,
|
| 15 |
+
sympy.Add: ADD,
|
| 16 |
+
sympy.div: "jnp.div",
|
| 17 |
+
sympy.Abs: "jnp.abs",
|
| 18 |
+
sympy.sign: "jnp.sign",
|
| 19 |
+
# Note: May raise error for ints.
|
| 20 |
+
sympy.ceiling: "jnp.ceil",
|
| 21 |
+
sympy.floor: "jnp.floor",
|
| 22 |
+
sympy.log: "jnp.log",
|
| 23 |
+
sympy.exp: "jnp.exp",
|
| 24 |
+
sympy.sqrt: "jnp.sqrt",
|
| 25 |
+
sympy.cos: "jnp.cos",
|
| 26 |
+
sympy.acos: "jnp.acos",
|
| 27 |
+
sympy.sin: "jnp.sin",
|
| 28 |
+
sympy.asin: "jnp.asin",
|
| 29 |
+
sympy.tan: "jnp.tan",
|
| 30 |
+
sympy.atan: "jnp.atan",
|
| 31 |
+
sympy.atan2: "jnp.atan2",
|
| 32 |
+
# Note: Also may give NaN for complex results.
|
| 33 |
+
sympy.cosh: "jnp.cosh",
|
| 34 |
+
sympy.acosh: "jnp.acosh",
|
| 35 |
+
sympy.sinh: "jnp.sinh",
|
| 36 |
+
sympy.asinh: "jnp.asinh",
|
| 37 |
+
sympy.tanh: "jnp.tanh",
|
| 38 |
+
sympy.atanh: "jnp.atanh",
|
| 39 |
+
sympy.Pow: "jnp.power",
|
| 40 |
+
sympy.re: "jnp.real",
|
| 41 |
+
sympy.im: "jnp.imag",
|
| 42 |
+
sympy.arg: "jnp.angle",
|
| 43 |
+
# Note: May raise error for ints and complexes
|
| 44 |
+
sympy.erf: "jsp.erf",
|
| 45 |
+
sympy.erfc: "jsp.erfc",
|
| 46 |
+
sympy.LessThan: "jnp.less",
|
| 47 |
+
sympy.GreaterThan: "jnp.greater",
|
| 48 |
+
sympy.And: "jnp.logical_and",
|
| 49 |
+
sympy.Or: "jnp.logical_or",
|
| 50 |
+
sympy.Not: "jnp.logical_not",
|
| 51 |
+
sympy.Max: "jnp.max",
|
| 52 |
+
sympy.Min: "jnp.min",
|
| 53 |
+
sympy.Mod: "jnp.mod",
|
| 54 |
+
}
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def sympy2jaxtext(expr, parameters, symbols_in):
|
| 57 |
if issubclass(expr.func, sympy.Float):
|
pysr/export_torch.py
CHANGED
|
@@ -8,7 +8,6 @@ import functools as ft
|
|
| 8 |
import sympy
|
| 9 |
import torch
|
| 10 |
|
| 11 |
-
|
| 12 |
def _reduce(fn):
|
| 13 |
def fn_(*args):
|
| 14 |
return ft.reduce(fn, args)
|
|
@@ -67,7 +66,6 @@ _global_func_lookup = {
|
|
| 67 |
sympy.Determinant: torch.det,
|
| 68 |
}
|
| 69 |
|
| 70 |
-
|
| 71 |
class _Node(torch.nn.Module):
|
| 72 |
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
|
| 73 |
super().__init__(**kwargs)
|
|
|
|
| 8 |
import sympy
|
| 9 |
import torch
|
| 10 |
|
|
|
|
| 11 |
def _reduce(fn):
|
| 12 |
def fn_(*args):
|
| 13 |
return ft.reduce(fn, args)
|
|
|
|
| 66 |
sympy.Determinant: torch.det,
|
| 67 |
}
|
| 68 |
|
|
|
|
| 69 |
class _Node(torch.nn.Module):
|
| 70 |
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
|
| 71 |
super().__init__(**kwargs)
|
pysr/sr.py
CHANGED
|
@@ -13,8 +13,6 @@ import shutil
|
|
| 13 |
from pathlib import Path
|
| 14 |
from datetime import datetime
|
| 15 |
import warnings
|
| 16 |
-
from .export_jax import sympy2jax
|
| 17 |
-
from .export_torch import sympy2torch
|
| 18 |
|
| 19 |
global_equation_file = 'hall_of_fame.csv'
|
| 20 |
global_n_features = None
|
|
@@ -125,11 +123,12 @@ def pysr(X, y, weights=None,
|
|
| 125 |
update=True,
|
| 126 |
temp_equation_file=False,
|
| 127 |
output_jax_format=False,
|
|
|
|
| 128 |
optimizer_algorithm="BFGS",
|
| 129 |
optimizer_nrestarts=3,
|
| 130 |
optimize_probability=1.0,
|
| 131 |
-
optimizer_iterations=10
|
| 132 |
-
|
| 133 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
| 134 |
Note: most default parameters have been tuned over several example
|
| 135 |
equations, but you should adjust `niterations`,
|
|
@@ -242,6 +241,8 @@ def pysr(X, y, weights=None,
|
|
| 242 |
delete_tempfiles argument.
|
| 243 |
:param output_jax_format: Whether to create a 'jax_format' column in the output,
|
| 244 |
containing jax-callable functions and the default parameters in a jax array.
|
|
|
|
|
|
|
| 245 |
:returns: pd.DataFrame or list, Results dataframe,
|
| 246 |
giving complexity, MSE, and equations (as strings), as well as functional
|
| 247 |
forms. If list, each element corresponds to a dataframe of equations
|
|
@@ -337,6 +338,7 @@ def pysr(X, y, weights=None,
|
|
| 337 |
extra_sympy_mappings=extra_sympy_mappings,
|
| 338 |
julia_project=julia_project, loss=loss,
|
| 339 |
output_jax_format=output_jax_format,
|
|
|
|
| 340 |
multioutput=multioutput, nout=nout)
|
| 341 |
|
| 342 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
|
@@ -727,6 +729,7 @@ def run_feature_selection(X, y, select_k_features):
|
|
| 727 |
|
| 728 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
| 729 |
extra_sympy_mappings=None, output_jax_format=False,
|
|
|
|
| 730 |
multioutput=None, nout=None, **kwargs):
|
| 731 |
"""Get the equations from a hall of fame file. If no arguments
|
| 732 |
entered, the ones used previously from a call to PySR will be used."""
|
|
@@ -771,6 +774,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 771 |
lambda_format = []
|
| 772 |
if output_jax_format:
|
| 773 |
jax_format = []
|
|
|
|
|
|
|
| 774 |
use_custom_variable_names = (len(variable_names) != 0)
|
| 775 |
local_sympy_mappings = {
|
| 776 |
**extra_sympy_mappings,
|
|
@@ -786,10 +791,19 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 786 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
| 787 |
sympy_format.append(eqn)
|
| 788 |
if output_jax_format:
|
|
|
|
| 789 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 790 |
jax_format.append({'callable': func, 'parameters': params})
|
|
|
|
| 791 |
|
| 792 |
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 793 |
curMSE = output.loc[i, 'MSE']
|
| 794 |
curComplexity = output.loc[i, 'Complexity']
|
| 795 |
|
|
@@ -809,6 +823,9 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
| 809 |
if output_jax_format:
|
| 810 |
output_cols += ['jax_format']
|
| 811 |
output['jax_format'] = jax_format
|
|
|
|
|
|
|
|
|
|
| 812 |
|
| 813 |
ret_outputs.append(output[output_cols])
|
| 814 |
|
|
|
|
| 13 |
from pathlib import Path
|
| 14 |
from datetime import datetime
|
| 15 |
import warnings
|
|
|
|
|
|
|
| 16 |
|
| 17 |
global_equation_file = 'hall_of_fame.csv'
|
| 18 |
global_n_features = None
|
|
|
|
| 123 |
update=True,
|
| 124 |
temp_equation_file=False,
|
| 125 |
output_jax_format=False,
|
| 126 |
+
output_torch_format=False,
|
| 127 |
optimizer_algorithm="BFGS",
|
| 128 |
optimizer_nrestarts=3,
|
| 129 |
optimize_probability=1.0,
|
| 130 |
+
optimizer_iterations=10
|
| 131 |
+
):
|
| 132 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
| 133 |
Note: most default parameters have been tuned over several example
|
| 134 |
equations, but you should adjust `niterations`,
|
|
|
|
| 241 |
delete_tempfiles argument.
|
| 242 |
:param output_jax_format: Whether to create a 'jax_format' column in the output,
|
| 243 |
containing jax-callable functions and the default parameters in a jax array.
|
| 244 |
+
:param output_torch_format: Whether to create a 'torch_format' column in the output,
|
| 245 |
+
containing a torch module with trainable parameters.
|
| 246 |
:returns: pd.DataFrame or list, Results dataframe,
|
| 247 |
giving complexity, MSE, and equations (as strings), as well as functional
|
| 248 |
forms. If list, each element corresponds to a dataframe of equations
|
|
|
|
| 338 |
extra_sympy_mappings=extra_sympy_mappings,
|
| 339 |
julia_project=julia_project, loss=loss,
|
| 340 |
output_jax_format=output_jax_format,
|
| 341 |
+
output_torch_format=output_torch_format,
|
| 342 |
multioutput=multioutput, nout=nout)
|
| 343 |
|
| 344 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
|
|
|
| 729 |
|
| 730 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
| 731 |
extra_sympy_mappings=None, output_jax_format=False,
|
| 732 |
+
output_torch_format=False,
|
| 733 |
multioutput=None, nout=None, **kwargs):
|
| 734 |
"""Get the equations from a hall of fame file. If no arguments
|
| 735 |
entered, the ones used previously from a call to PySR will be used."""
|
|
|
|
| 774 |
lambda_format = []
|
| 775 |
if output_jax_format:
|
| 776 |
jax_format = []
|
| 777 |
+
if output_torch_format:
|
| 778 |
+
torch_format = []
|
| 779 |
use_custom_variable_names = (len(variable_names) != 0)
|
| 780 |
local_sympy_mappings = {
|
| 781 |
**extra_sympy_mappings,
|
|
|
|
| 791 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
| 792 |
sympy_format.append(eqn)
|
| 793 |
if output_jax_format:
|
| 794 |
+
from .export_jax import sympy2jax
|
| 795 |
func, params = sympy2jax(eqn, sympy_symbols)
|
| 796 |
jax_format.append({'callable': func, 'parameters': params})
|
| 797 |
+
<<<<<<< HEAD
|
| 798 |
|
| 799 |
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
| 800 |
+
=======
|
| 801 |
+
if output_torch_format:
|
| 802 |
+
from .export_torch import sympy2torch
|
| 803 |
+
func, params = sympy2torch(eqn, sympy_symbols)
|
| 804 |
+
torch_format.append({'callable': func, 'parameters': params})
|
| 805 |
+
lambda_format.append(lambdify(sympy_symbols, eqn))
|
| 806 |
+
>>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
|
| 807 |
curMSE = output.loc[i, 'MSE']
|
| 808 |
curComplexity = output.loc[i, 'Complexity']
|
| 809 |
|
|
|
|
| 823 |
if output_jax_format:
|
| 824 |
output_cols += ['jax_format']
|
| 825 |
output['jax_format'] = jax_format
|
| 826 |
+
if output_torch_format:
|
| 827 |
+
output_cols += ['torch_format']
|
| 828 |
+
output['torch_format'] = torch_format
|
| 829 |
|
| 830 |
ret_outputs.append(output[output_cols])
|
| 831 |
|