Spaces:
Sleeping
Sleeping
Commit
·
c6c8728
1
Parent(s):
8f218cc
Test all aspects of generated LaTeX table
Browse files- test/test.py +107 -19
test/test.py
CHANGED
|
@@ -281,19 +281,38 @@ class TestPipeline(unittest.TestCase):
|
|
| 281 |
self.assertLess(np.average((model.predict(X.values) - y.values) ** 2), 1e-4)
|
| 282 |
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
class TestBest(unittest.TestCase):
|
| 285 |
def setUp(self):
|
| 286 |
self.rstate = np.random.RandomState(0)
|
| 287 |
self.X = self.rstate.randn(10, 2)
|
| 288 |
self.y = np.cos(self.X[:, 0]) ** 2
|
| 289 |
-
self.model = PySRRegressor(
|
| 290 |
-
progress=False,
|
| 291 |
-
niterations=1,
|
| 292 |
-
extra_sympy_mappings={},
|
| 293 |
-
output_jax_format=False,
|
| 294 |
-
model_selection="accuracy",
|
| 295 |
-
equation_file="equation_file.csv",
|
| 296 |
-
)
|
| 297 |
equations = pd.DataFrame(
|
| 298 |
{
|
| 299 |
"equation": ["1.0", "cos(x0)", "square(cos(x0))"],
|
|
@@ -301,17 +320,7 @@ class TestBest(unittest.TestCase):
|
|
| 301 |
"complexity": [1, 2, 3],
|
| 302 |
}
|
| 303 |
)
|
| 304 |
-
|
| 305 |
-
# Set up internal parameters as if it had been fitted:
|
| 306 |
-
self.model.equation_file_ = "equation_file.csv"
|
| 307 |
-
self.model.nout_ = 1
|
| 308 |
-
self.model.selection_mask_ = None
|
| 309 |
-
self.model.feature_names_in_ = np.array(["x0", "x1"], dtype=object)
|
| 310 |
-
equations["complexity loss equation".split(" ")].to_csv(
|
| 311 |
-
"equation_file.csv.bkup", sep="|"
|
| 312 |
-
)
|
| 313 |
-
|
| 314 |
-
self.model.refresh()
|
| 315 |
self.equations_ = self.model.equations_
|
| 316 |
|
| 317 |
def test_best(self):
|
|
@@ -485,3 +494,82 @@ class TestMiscellaneous(unittest.TestCase):
|
|
| 485 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 486 |
# If any checks failed don't let the test pass.
|
| 487 |
self.assertEqual(len(exception_messages), 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
self.assertLess(np.average((model.predict(X.values) - y.values) ** 2), 1e-4)
|
| 282 |
|
| 283 |
|
| 284 |
+
def manually_create_model(equations, feature_names=None):
|
| 285 |
+
if feature_names is None:
|
| 286 |
+
feature_names = ["x0", "x1"]
|
| 287 |
+
|
| 288 |
+
model = PySRRegressor(
|
| 289 |
+
progress=False,
|
| 290 |
+
niterations=1,
|
| 291 |
+
extra_sympy_mappings={},
|
| 292 |
+
output_jax_format=False,
|
| 293 |
+
model_selection="accuracy",
|
| 294 |
+
equation_file="equation_file.csv",
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
# Set up internal parameters as if it had been fitted:
|
| 298 |
+
model.equation_file_ = "equation_file.csv"
|
| 299 |
+
model.nout_ = 1
|
| 300 |
+
model.selection_mask_ = None
|
| 301 |
+
model.feature_names_in_ = np.array(feature_names, dtype=object)
|
| 302 |
+
equations["complexity loss equation".split(" ")].to_csv(
|
| 303 |
+
"equation_file.csv.bkup", sep="|"
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
model.refresh()
|
| 307 |
+
|
| 308 |
+
return model
|
| 309 |
+
|
| 310 |
+
|
| 311 |
class TestBest(unittest.TestCase):
|
| 312 |
def setUp(self):
|
| 313 |
self.rstate = np.random.RandomState(0)
|
| 314 |
self.X = self.rstate.randn(10, 2)
|
| 315 |
self.y = np.cos(self.X[:, 0]) ** 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
equations = pd.DataFrame(
|
| 317 |
{
|
| 318 |
"equation": ["1.0", "cos(x0)", "square(cos(x0))"],
|
|
|
|
| 320 |
"complexity": [1, 2, 3],
|
| 321 |
}
|
| 322 |
)
|
| 323 |
+
self.model = manually_create_model(equations)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
self.equations_ = self.model.equations_
|
| 325 |
|
| 326 |
def test_best(self):
|
|
|
|
| 494 |
print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
|
| 495 |
# If any checks failed don't let the test pass.
|
| 496 |
self.assertEqual(len(exception_messages), 0)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class TestLaTeXTable(unittest.TestCase):
|
| 500 |
+
def create_true_latex(self, middle_part, include_score=False):
|
| 501 |
+
if include_score:
|
| 502 |
+
true_latex_table_str = r"""
|
| 503 |
+
\begin{table}[h]
|
| 504 |
+
\begin{center}
|
| 505 |
+
\begin{tabular}{@{}clll@{}}
|
| 506 |
+
\toprule
|
| 507 |
+
Equation & Complexity & Loss & Score \\
|
| 508 |
+
\midrule"""
|
| 509 |
+
else:
|
| 510 |
+
true_latex_table_str = r"""
|
| 511 |
+
\begin{table}[h]
|
| 512 |
+
\begin{center}
|
| 513 |
+
\begin{tabular}{@{}cll@{}}
|
| 514 |
+
\toprule
|
| 515 |
+
Equation & Complexity & Loss \\
|
| 516 |
+
\midrule"""
|
| 517 |
+
true_latex_table_str += middle_part
|
| 518 |
+
true_latex_table_str += r"""\bottomrule
|
| 519 |
+
\end{tabular}
|
| 520 |
+
\end{center}
|
| 521 |
+
\end{table}
|
| 522 |
+
"""
|
| 523 |
+
# First, remove empty lines:
|
| 524 |
+
true_latex_table_str = "\n".join(
|
| 525 |
+
[line.strip() for line in true_latex_table_str.split("\n") if len(line) > 0]
|
| 526 |
+
)
|
| 527 |
+
return true_latex_table_str.strip()
|
| 528 |
+
|
| 529 |
+
def test_simple_table(self):
|
| 530 |
+
equations = pd.DataFrame(
|
| 531 |
+
dict(
|
| 532 |
+
equation=["x0", "cos(x0)", "x0 + x1 - cos(x1 * x0)"],
|
| 533 |
+
loss=[1.052, 0.02315, 1.12347e-15],
|
| 534 |
+
complexity=[1, 2, 8],
|
| 535 |
+
)
|
| 536 |
+
)
|
| 537 |
+
model = manually_create_model(equations)
|
| 538 |
+
|
| 539 |
+
# Regular table:
|
| 540 |
+
latex_table_str = model.latex_table()
|
| 541 |
+
middle_part = r"""
|
| 542 |
+
$x_{0}$ & 1 & 1.05 \\
|
| 543 |
+
$\cos{\left(x_{0} \right)}$ & 2 & 0.0232 \\
|
| 544 |
+
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.12e-15 \\
|
| 545 |
+
"""
|
| 546 |
+
true_latex_table_str = self.create_true_latex(middle_part)
|
| 547 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
| 548 |
+
|
| 549 |
+
# Different precision:
|
| 550 |
+
latex_table_str = model.latex_table(precision=5)
|
| 551 |
+
middle_part = r"""
|
| 552 |
+
$x_{0}$ & 1 & 1.052 \\
|
| 553 |
+
$\cos{\left(x_{0} \right)}$ & 2 & 0.02315 \\
|
| 554 |
+
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.1235e-15 \\
|
| 555 |
+
"""
|
| 556 |
+
true_latex_table_str = self.create_true_latex(middle_part)
|
| 557 |
+
self.assertEqual(latex_table_str, self.create_true_latex(middle_part))
|
| 558 |
+
|
| 559 |
+
# Including score:
|
| 560 |
+
latex_table_str = model.latex_table(include_score=True)
|
| 561 |
+
middle_part = r"""
|
| 562 |
+
$x_{0}$ & 1 & 1.05 & 0 \\
|
| 563 |
+
$\cos{\left(x_{0} \right)}$ & 2 & 0.0232 & 3.82 \\
|
| 564 |
+
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.12e-15 & 5.11 \\
|
| 565 |
+
"""
|
| 566 |
+
true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
|
| 567 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|
| 568 |
+
|
| 569 |
+
# Only last equation:
|
| 570 |
+
latex_table_str = model.latex_table(indices=[2])
|
| 571 |
+
middle_part = r"""
|
| 572 |
+
$x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & 8 & 1.12e-15 \\
|
| 573 |
+
"""
|
| 574 |
+
true_latex_table_str = self.create_true_latex(middle_part)
|
| 575 |
+
self.assertEqual(latex_table_str, true_latex_table_str)
|