Spaces:
Sleeping
Sleeping
Save options to PySRRegressor
Browse files- pysr/julia_helpers.py +12 -2
- pysr/sr.py +25 -15
pysr/julia_helpers.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
"""Functions for initializing the Julia environment and installing deps."""
|
| 2 |
import warnings
|
| 3 |
|
|
|
|
| 4 |
from juliacall import convert as jl_convert # type: ignore
|
| 5 |
|
| 6 |
from .julia_import import jl
|
|
@@ -8,6 +9,9 @@ from .julia_import import jl
|
|
| 8 |
jl.seval("using Serialization: Serialization")
|
| 9 |
jl.seval("using PythonCall: PythonCall")
|
| 10 |
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
def install(*args, **kwargs):
|
| 13 |
del args, kwargs
|
|
@@ -35,10 +39,16 @@ def jl_array(x):
|
|
| 35 |
return jl_convert(jl.Array, x)
|
| 36 |
|
| 37 |
|
| 38 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
if s is None:
|
| 40 |
return s
|
| 41 |
buf = jl.IOBuffer()
|
| 42 |
jl.write(buf, jl_array(s))
|
| 43 |
jl.seekstart(buf)
|
| 44 |
-
return
|
|
|
|
| 1 |
"""Functions for initializing the Julia environment and installing deps."""
|
| 2 |
import warnings
|
| 3 |
|
| 4 |
+
import numpy as np
|
| 5 |
from juliacall import convert as jl_convert # type: ignore
|
| 6 |
|
| 7 |
from .julia_import import jl
|
|
|
|
| 9 |
jl.seval("using Serialization: Serialization")
|
| 10 |
jl.seval("using PythonCall: PythonCall")
|
| 11 |
|
| 12 |
+
Serialization = jl.Serialization
|
| 13 |
+
PythonCall = jl.PythonCall
|
| 14 |
+
|
| 15 |
|
| 16 |
def install(*args, **kwargs):
|
| 17 |
del args, kwargs
|
|
|
|
| 39 |
return jl_convert(jl.Array, x)
|
| 40 |
|
| 41 |
|
| 42 |
+
def jl_serialize(obj):
|
| 43 |
+
buf = jl.IOBuffer()
|
| 44 |
+
Serialization.serialize(buf, obj)
|
| 45 |
+
return np.array(jl.take_b(buf))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def jl_deserialize(s):
|
| 49 |
if s is None:
|
| 50 |
return s
|
| 51 |
buf = jl.IOBuffer()
|
| 52 |
jl.write(buf, jl_array(s))
|
| 53 |
jl.seekstart(buf)
|
| 54 |
+
return Serialization.deserialize(buf)
|
pysr/sr.py
CHANGED
|
@@ -33,10 +33,12 @@ from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2
|
|
| 33 |
from .export_torch import sympy2torch
|
| 34 |
from .feature_selection import run_feature_selection
|
| 35 |
from .julia_helpers import (
|
|
|
|
| 36 |
_escape_filename,
|
| 37 |
_load_cluster_manager,
|
| 38 |
jl_array,
|
| 39 |
-
|
|
|
|
| 40 |
)
|
| 41 |
from .julia_import import SymbolicRegression, jl
|
| 42 |
from .utils import (
|
|
@@ -602,11 +604,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 602 |
Path to the temporary equations directory.
|
| 603 |
equation_file_ : str
|
| 604 |
Output equation file name produced by the julia backend.
|
| 605 |
-
|
| 606 |
The serialized state for the julia SymbolicRegression.jl backend (after fitting),
|
| 607 |
stored as an array of uint8, produced by Julia's Serialization.serialize function.
|
| 608 |
-
julia_state_
|
| 609 |
The deserialized state.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 610 |
equation_file_contents_ : list[pandas.DataFrame]
|
| 611 |
Contents of the equation file output by the Julia backend.
|
| 612 |
show_pickle_warnings_ : bool
|
|
@@ -1053,7 +1059,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1053 |
serialization.
|
| 1054 |
|
| 1055 |
Thus, for `PySRRegressor` to support pickle serialization, the
|
| 1056 |
-
`
|
| 1057 |
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
| 1058 |
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
| 1059 |
to be serialized. Note: Jax and Torch format equations are also removed
|
|
@@ -1121,15 +1127,19 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1121 |
)
|
| 1122 |
return self.equations_
|
| 1123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1124 |
@property
|
| 1125 |
def julia_state_(self):
|
| 1126 |
-
return
|
| 1127 |
|
| 1128 |
@property
|
| 1129 |
def raw_julia_state_(self):
|
| 1130 |
warnings.warn(
|
| 1131 |
"PySRRegressor.raw_julia_state_ is now deprecated. "
|
| 1132 |
-
"Please use PySRRegressor.julia_state_ instead, or
|
| 1133 |
"for the raw stream of bytes.",
|
| 1134 |
FutureWarning,
|
| 1135 |
)
|
|
@@ -1675,6 +1685,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1675 |
define_helper_functions=False,
|
| 1676 |
)
|
| 1677 |
|
|
|
|
|
|
|
| 1678 |
# Convert data to desired precision
|
| 1679 |
test_X = np.array(X)
|
| 1680 |
is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
|
|
@@ -1718,7 +1730,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1718 |
else:
|
| 1719 |
jl_y_variable_names = None
|
| 1720 |
|
| 1721 |
-
|
| 1722 |
out = SymbolicRegression.equation_search(
|
| 1723 |
jl_X,
|
| 1724 |
jl_y,
|
|
@@ -1741,12 +1753,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1741 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
| 1742 |
verbosity=int(self.verbosity),
|
| 1743 |
)
|
| 1744 |
-
|
| 1745 |
|
| 1746 |
-
|
| 1747 |
-
buf = jl.IOBuffer()
|
| 1748 |
-
jl.Serialization.serialize(buf, out)
|
| 1749 |
-
self.raw_julia_state_stream_ = np.array(jl.take_b(buf))
|
| 1750 |
|
| 1751 |
# Set attributes
|
| 1752 |
self.equations_ = self.get_hof()
|
|
@@ -1810,10 +1819,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1810 |
Fitted estimator.
|
| 1811 |
"""
|
| 1812 |
# Init attributes that are not specified in BaseEstimator
|
| 1813 |
-
if self.warm_start and hasattr(self, "
|
| 1814 |
pass
|
| 1815 |
else:
|
| 1816 |
-
if hasattr(self, "
|
| 1817 |
warnings.warn(
|
| 1818 |
"The discovered expressions are being reset. "
|
| 1819 |
"Please set `warm_start=True` if you wish to continue "
|
|
@@ -1823,7 +1832,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
| 1823 |
self.equations_ = None
|
| 1824 |
self.nout_ = 1
|
| 1825 |
self.selection_mask_ = None
|
| 1826 |
-
self.
|
|
|
|
| 1827 |
self.X_units_ = None
|
| 1828 |
self.y_units_ = None
|
| 1829 |
|
|
|
|
| 33 |
from .export_torch import sympy2torch
|
| 34 |
from .feature_selection import run_feature_selection
|
| 35 |
from .julia_helpers import (
|
| 36 |
+
PythonCall,
|
| 37 |
_escape_filename,
|
| 38 |
_load_cluster_manager,
|
| 39 |
jl_array,
|
| 40 |
+
jl_deserialize,
|
| 41 |
+
jl_serialize,
|
| 42 |
)
|
| 43 |
from .julia_import import SymbolicRegression, jl
|
| 44 |
from .utils import (
|
|
|
|
| 604 |
Path to the temporary equations directory.
|
| 605 |
equation_file_ : str
|
| 606 |
Output equation file name produced by the julia backend.
|
| 607 |
+
julia_state_stream_ : ndarray
|
| 608 |
The serialized state for the julia SymbolicRegression.jl backend (after fitting),
|
| 609 |
stored as an array of uint8, produced by Julia's Serialization.serialize function.
|
| 610 |
+
julia_state_
|
| 611 |
The deserialized state.
|
| 612 |
+
julia_options_stream_ : ndarray
|
| 613 |
+
The serialized julia options, stored as an array of uint8,
|
| 614 |
+
julia_options_
|
| 615 |
+
The deserialized julia options.
|
| 616 |
equation_file_contents_ : list[pandas.DataFrame]
|
| 617 |
Contents of the equation file output by the Julia backend.
|
| 618 |
show_pickle_warnings_ : bool
|
|
|
|
| 1059 |
serialization.
|
| 1060 |
|
| 1061 |
Thus, for `PySRRegressor` to support pickle serialization, the
|
| 1062 |
+
`julia_state_stream_` attribute must be hidden from pickle. This will
|
| 1063 |
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
| 1064 |
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
| 1065 |
to be serialized. Note: Jax and Torch format equations are also removed
|
|
|
|
| 1127 |
)
|
| 1128 |
return self.equations_
|
| 1129 |
|
| 1130 |
+
@property
|
| 1131 |
+
def julia_options_(self):
|
| 1132 |
+
return jl_deserialize(self.julia_options_stream_)
|
| 1133 |
+
|
| 1134 |
@property
|
| 1135 |
def julia_state_(self):
|
| 1136 |
+
return jl_deserialize(self.julia_state_stream_)
|
| 1137 |
|
| 1138 |
@property
|
| 1139 |
def raw_julia_state_(self):
|
| 1140 |
warnings.warn(
|
| 1141 |
"PySRRegressor.raw_julia_state_ is now deprecated. "
|
| 1142 |
+
"Please use PySRRegressor.julia_state_ instead, or julia_state_stream_ "
|
| 1143 |
"for the raw stream of bytes.",
|
| 1144 |
FutureWarning,
|
| 1145 |
)
|
|
|
|
| 1685 |
define_helper_functions=False,
|
| 1686 |
)
|
| 1687 |
|
| 1688 |
+
self.julia_options_stream_ = jl_serialize(options)
|
| 1689 |
+
|
| 1690 |
# Convert data to desired precision
|
| 1691 |
test_X = np.array(X)
|
| 1692 |
is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
|
|
|
|
| 1730 |
else:
|
| 1731 |
jl_y_variable_names = None
|
| 1732 |
|
| 1733 |
+
PythonCall.GC.disable()
|
| 1734 |
out = SymbolicRegression.equation_search(
|
| 1735 |
jl_X,
|
| 1736 |
jl_y,
|
|
|
|
| 1753 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
| 1754 |
verbosity=int(self.verbosity),
|
| 1755 |
)
|
| 1756 |
+
PythonCall.GC.enable()
|
| 1757 |
|
| 1758 |
+
self.julia_state_stream_ = jl_serialize(out)
|
|
|
|
|
|
|
|
|
|
| 1759 |
|
| 1760 |
# Set attributes
|
| 1761 |
self.equations_ = self.get_hof()
|
|
|
|
| 1819 |
Fitted estimator.
|
| 1820 |
"""
|
| 1821 |
# Init attributes that are not specified in BaseEstimator
|
| 1822 |
+
if self.warm_start and hasattr(self, "julia_state_stream_"):
|
| 1823 |
pass
|
| 1824 |
else:
|
| 1825 |
+
if hasattr(self, "julia_state_stream_"):
|
| 1826 |
warnings.warn(
|
| 1827 |
"The discovered expressions are being reset. "
|
| 1828 |
"Please set `warm_start=True` if you wish to continue "
|
|
|
|
| 1832 |
self.equations_ = None
|
| 1833 |
self.nout_ = 1
|
| 1834 |
self.selection_mask_ = None
|
| 1835 |
+
self.julia_state_stream_ = None
|
| 1836 |
+
self.julia_options_stream_ = None
|
| 1837 |
self.X_units_ = None
|
| 1838 |
self.y_units_ = None
|
| 1839 |
|