Spaces:
Running
Running
Commit
·
70dcb83
1
Parent(s):
aa16a1e
Add reset function for state saving.
Browse files- pysr/sr.py +11 -6
pysr/sr.py
CHANGED
|
@@ -322,7 +322,7 @@ def _write_project_file(tmp_dir):
|
|
| 322 |
SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
|
| 323 |
|
| 324 |
[compat]
|
| 325 |
-
SymbolicRegression = "0.7.
|
| 326 |
julia = "1.5"
|
| 327 |
"""
|
| 328 |
|
|
@@ -640,7 +640,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 640 |
self.equations = None
|
| 641 |
self.params_hash = None
|
| 642 |
self.raw_julia_state = None
|
| 643 |
-
self.raw_julia_hof = None
|
| 644 |
|
| 645 |
self.multioutput = None
|
| 646 |
self.equation_file = equation_file
|
|
@@ -861,6 +860,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 861 |
return [eq["torch_format"] for eq in best]
|
| 862 |
return best["torch_format"]
|
| 863 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
def _run(self, X, y, weights, variable_names):
|
| 865 |
global already_ran
|
| 866 |
global Main
|
|
@@ -1074,7 +1079,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 1074 |
"Warning: PySR options have changed since the last run. "
|
| 1075 |
"This is experimental and may not work. "
|
| 1076 |
"For example, if the operators change, or even their order,"
|
| 1077 |
-
" the saved equations will be in the wrong format."
|
|
|
|
|
|
|
| 1078 |
)
|
| 1079 |
|
| 1080 |
self.params_hash = cur_hash
|
|
@@ -1140,9 +1147,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 1140 |
|
| 1141 |
cprocs = 0 if multithreading else procs
|
| 1142 |
|
| 1143 |
-
|
| 1144 |
-
# state = (returnPops, hallOfFame)
|
| 1145 |
-
self.raw_julia_state, self.raw_julia_hof = Main.EquationSearch(
|
| 1146 |
Main.X,
|
| 1147 |
Main.y,
|
| 1148 |
weights=Main.weights,
|
|
|
|
| 322 |
SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
|
| 323 |
|
| 324 |
[compat]
|
| 325 |
+
SymbolicRegression = "0.7.2"
|
| 326 |
julia = "1.5"
|
| 327 |
"""
|
| 328 |
|
|
|
|
| 640 |
self.equations = None
|
| 641 |
self.params_hash = None
|
| 642 |
self.raw_julia_state = None
|
|
|
|
| 643 |
|
| 644 |
self.multioutput = None
|
| 645 |
self.equation_file = equation_file
|
|
|
|
| 860 |
return [eq["torch_format"] for eq in best]
|
| 861 |
return best["torch_format"]
|
| 862 |
|
| 863 |
+
def reset(self):
|
| 864 |
+
"""Reset the search state."""
|
| 865 |
+
self.equations = None
|
| 866 |
+
self.params_hash = None
|
| 867 |
+
self.raw_julia_state = None
|
| 868 |
+
|
| 869 |
def _run(self, X, y, weights, variable_names):
|
| 870 |
global already_ran
|
| 871 |
global Main
|
|
|
|
| 1079 |
"Warning: PySR options have changed since the last run. "
|
| 1080 |
"This is experimental and may not work. "
|
| 1081 |
"For example, if the operators change, or even their order,"
|
| 1082 |
+
" the saved equations will be in the wrong format."
|
| 1083 |
+
"\n\n"
|
| 1084 |
+
"To reset the search state, run `.reset()`. "
|
| 1085 |
)
|
| 1086 |
|
| 1087 |
self.params_hash = cur_hash
|
|
|
|
| 1147 |
|
| 1148 |
cprocs = 0 if multithreading else procs
|
| 1149 |
|
| 1150 |
+
self.raw_julia_state = Main.EquationSearch(
|
|
|
|
|
|
|
| 1151 |
Main.X,
|
| 1152 |
Main.y,
|
| 1153 |
weights=Main.weights,
|