Spaces:
Running
Running
Commit
·
aadb328
1
Parent(s):
18afca5
Add ability to pass strings defining operators
Browse files- pysr/sr.py +29 -8
pysr/sr.py
CHANGED
|
@@ -92,6 +92,17 @@ def pysr(X=None, y=None, weights=None, threads=4,
|
|
| 92 |
|
| 93 |
"""
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
|
| 96 |
|
| 97 |
if isinstance(binary_operators, str): binary_operators = [binary_operators]
|
|
@@ -115,7 +126,24 @@ def pysr(X=None, y=None, weights=None, threads=4,
|
|
| 115 |
|
| 116 |
pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
|
| 117 |
|
| 118 |
-
def_hyperparams =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
| 120 |
const unaops = {'[' + ', '.join(unary_operators) + ']'}
|
| 121 |
const ns=10;
|
|
@@ -144,13 +172,6 @@ const mutationWeights = [
|
|
| 144 |
]
|
| 145 |
"""
|
| 146 |
|
| 147 |
-
assert len(X.shape) == 2
|
| 148 |
-
assert len(y.shape) == 1
|
| 149 |
-
assert X.shape[0] == y.shape[0]
|
| 150 |
-
if weights is not None:
|
| 151 |
-
assert len(weights.shape) == 1
|
| 152 |
-
assert X.shape[0] == weights.shape[0]
|
| 153 |
-
|
| 154 |
if X.shape[1] == 1:
|
| 155 |
X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
|
| 156 |
else:
|
|
|
|
| 92 |
|
| 93 |
"""
|
| 94 |
|
| 95 |
+
# Check for potential errors before they happen
|
| 96 |
+
assert len(binary_operators) > 0
|
| 97 |
+
assert len(unary_operators) > 0
|
| 98 |
+
assert len(X.shape) == 2
|
| 99 |
+
assert len(y.shape) == 1
|
| 100 |
+
assert X.shape[0] == y.shape[0]
|
| 101 |
+
if weights is not None:
|
| 102 |
+
assert len(weights.shape) == 1
|
| 103 |
+
assert X.shape[0] == weights.shape[0]
|
| 104 |
+
|
| 105 |
+
|
| 106 |
rand_string = f'{"".join([str(np.random.rand())[2] for i in range(20)])}'
|
| 107 |
|
| 108 |
if isinstance(binary_operators, str): binary_operators = [binary_operators]
|
|
|
|
| 126 |
|
| 127 |
pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
|
| 128 |
|
| 129 |
+
def_hyperparams = ""
|
| 130 |
+
|
| 131 |
+
# Add pre-defined functions to Julia
|
| 132 |
+
for op_list in [binary_operators, unary_operators]:
|
| 133 |
+
for i in range(len(op_list)):
|
| 134 |
+
op = op_list[i]
|
| 135 |
+
if '(' not in op:
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
def_hyperparams += op + "\n"
|
| 139 |
+
first_non_char = [
|
| 140 |
+
j for j in range(len(op))
|
| 141 |
+
if not (op[j].isalpha() or op[j].isdigit())][0]
|
| 142 |
+
function_name = op[:first_non_char]
|
| 143 |
+
op_list[i] = function_name
|
| 144 |
+
print(op_list)
|
| 145 |
+
|
| 146 |
+
def_hyperparams += f"""include("{pkg_directory}/operators.jl")
|
| 147 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
| 148 |
const unaops = {'[' + ', '.join(unary_operators) + ']'}
|
| 149 |
const ns=10;
|
|
|
|
| 172 |
]
|
| 173 |
"""
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if X.shape[1] == 1:
|
| 176 |
X_str = 'transpose([' + str(X.tolist()).replace(']', '').replace(',', '').replace('[', '') + '])'
|
| 177 |
else:
|