Spaces:
Running
Running
Commit
·
42cd6af
1
Parent(s):
4c39e04
Add jax, pytorch, sympy output from Regressor
Browse files- docs/options.md +15 -11
- pysr/sklearn.py +13 -5
docs/options.md
CHANGED
|
@@ -198,17 +198,18 @@ over `X` (as a PyTorch tensor). This is differentiable, and the
|
|
| 198 |
parameters of this PyTorch module correspond to the learned parameters
|
| 199 |
in the equation, and are trainable.
|
| 200 |
```python
|
| 201 |
-
|
| 202 |
-
|
| 203 |
```
|
|
|
|
| 204 |
|
| 205 |
-
For JAX, you can equivalently
|
| 206 |
This will return a dictionary containing a `'callable'` (a JAX function),
|
| 207 |
and `'parameters'` (a list of parameters in the equation).
|
| 208 |
You can execute this function with:
|
| 209 |
```python
|
| 210 |
-
|
| 211 |
-
|
| 212 |
```
|
| 213 |
Since the parameter list is a jax array, this therefore lets you also
|
| 214 |
train the parameters within JAX (and is differentiable).
|
|
@@ -226,26 +227,29 @@ Here are some additional examples:
|
|
| 226 |
|
| 227 |
abs(x-y) loss
|
| 228 |
```python
|
| 229 |
-
|
| 230 |
```
|
| 231 |
Note that the function name doesn't matter:
|
| 232 |
```python
|
| 233 |
-
|
| 234 |
```
|
| 235 |
With weights:
|
| 236 |
```python
|
| 237 |
-
|
|
|
|
| 238 |
```
|
| 239 |
Weights can be used in arbitrary ways:
|
| 240 |
```python
|
| 241 |
-
|
|
|
|
| 242 |
```
|
| 243 |
Built-in loss (faster) (see [losses](https://astroautomata.com/SymbolicRegression.jl/dev/losses/)).
|
| 244 |
This one computes the L3 norm:
|
| 245 |
```python
|
| 246 |
-
|
| 247 |
```
|
| 248 |
Can also uses these losses for weighted (weighted-average):
|
| 249 |
```python
|
| 250 |
-
|
|
|
|
| 251 |
```
|
|
|
|
| 198 |
parameters of this PyTorch module correspond to the learned parameters
|
| 199 |
in the equation, and are trainable.
|
| 200 |
```python
|
| 201 |
+
torch_model = model.pytorch()
|
| 202 |
+
torch_model(X)
|
| 203 |
```
|
| 204 |
+
**Warning: If you are using custom operators, you must define `extra_torch_mappings` or `extra_jax_mappings` (both are `dict` of callables) to provide an equivalent definition of the functions.** (At any time you can set these parameters or any others with `model.set_params`.)
|
| 205 |
|
| 206 |
+
For JAX, you can equivalently call `model.jax()`
|
| 207 |
This will return a dictionary containing a `'callable'` (a JAX function),
|
| 208 |
and `'parameters'` (a list of parameters in the equation).
|
| 209 |
You can execute this function with:
|
| 210 |
```python
|
| 211 |
+
jax_model = model.jax()
|
| 212 |
+
jax_model['callable'](X, jax_model['parameters'])
|
| 213 |
```
|
| 214 |
Since the parameter list is a jax array, this therefore lets you also
|
| 215 |
train the parameters within JAX (and is differentiable).
|
|
|
|
| 227 |
|
| 228 |
abs(x-y) loss
|
| 229 |
```python
|
| 230 |
+
PySRRegressor(..., loss="f(x, y) = abs(x - y)^1.5")
|
| 231 |
```
|
| 232 |
Note that the function name doesn't matter:
|
| 233 |
```python
|
| 234 |
+
PySRRegressor(..., loss="loss(x, y) = abs(x * y)")
|
| 235 |
```
|
| 236 |
With weights:
|
| 237 |
```python
|
| 238 |
+
model = PySRRegressor(..., loss="myloss(x, y, w) = w * abs(x - y)")
|
| 239 |
+
model.fit(..., weights=weights)
|
| 240 |
```
|
| 241 |
Weights can be used in arbitrary ways:
|
| 242 |
```python
|
| 243 |
+
model = PySRRegressor(..., weights=weights, loss="myloss(x, y, w) = abs(x - y)^2/w^2")
|
| 244 |
+
model.fit(..., weights=weights)
|
| 245 |
```
|
| 246 |
Built-in loss (faster) (see [losses](https://astroautomata.com/SymbolicRegression.jl/dev/losses/)).
|
| 247 |
This one computes the L3 norm:
|
| 248 |
```python
|
| 249 |
+
PySRRegressor(..., loss="LPDistLoss{3}()")
|
| 250 |
```
|
| 251 |
Can also uses these losses for weighted (weighted-average):
|
| 252 |
```python
|
| 253 |
+
model = PySRRegressor(..., weights=weights, loss="LPDistLoss{3}()")
|
| 254 |
+
model.fit(..., weights=weights)
|
| 255 |
```
|
pysr/sklearn.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from pysr import pysr, best_row
|
| 2 |
from sklearn.base import BaseEstimator, RegressorMixin
|
| 3 |
import inspect
|
| 4 |
import pandas as pd
|
|
@@ -94,14 +94,22 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
| 94 |
return self
|
| 95 |
|
| 96 |
def predict(self, X):
|
| 97 |
-
|
| 98 |
-
np_format = equation_row["lambda_format"]
|
| 99 |
-
|
| 100 |
return np_format(X)
|
| 101 |
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
|
|
|
|
|
|
| 105 |
_pysr_docstring_split = []
|
| 106 |
_start_recording = False
|
| 107 |
for line in inspect.getdoc(pysr).split("\n"):
|
|
|
|
| 1 |
+
from pysr import pysr, best_row, get_hof
|
| 2 |
from sklearn.base import BaseEstimator, RegressorMixin
|
| 3 |
import inspect
|
| 4 |
import pandas as pd
|
|
|
|
| 94 |
return self
|
| 95 |
|
| 96 |
def predict(self, X):
|
| 97 |
+
np_format = self.get_best()["lambda_format"]
|
|
|
|
|
|
|
| 98 |
return np_format(X)
|
| 99 |
|
| 100 |
+
def sympy(self):
|
| 101 |
+
return self.get_best()["sympy_format"]
|
| 102 |
|
| 103 |
+
def jax(self):
|
| 104 |
+
self.equations = get_hof(output_jax_format=True)
|
| 105 |
+
return self.get_best()["jax_format"]
|
| 106 |
+
|
| 107 |
+
def pytorch(self):
|
| 108 |
+
self.equations = get_hof(output_torch_format=True)
|
| 109 |
+
return self.get_best()["torch_format"]
|
| 110 |
|
| 111 |
+
|
| 112 |
+
# Add the docs from pysr() to PySRRegressor():
|
| 113 |
_pysr_docstring_split = []
|
| 114 |
_start_recording = False
|
| 115 |
for line in inspect.getdoc(pysr).split("\n"):
|