Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	File size: 4,050 Bytes
			
			| 2f38c9c 976f8d8 41e5fd5 9bfcbfa 41e5fd5 976f8d8 a2fd8f3 7d4300a 2f38c9c 51a6b05 2f38c9c a2fd8f3 7d4300a 2f38c9c 7d4300a c7187a6 a2fd8f3 c7187a6 fbb7cf7 4b56660 fbb7cf7 c7187a6 593c674 c7187a6 593c674 c7187a6 fbb7cf7 c7187a6 a15823e c7187a6 9bfcbfa a2fd8f3 b07eb2d fbb7cf7 4b56660 fbb7cf7 7d4300a b444c7e 593c674 7d4300a 9bfcbfa 593c674 7d4300a 9bfcbfa fbb7cf7 d398bf9 9bfcbfa 7d4300a f5577ea 9bfcbfa ce5b119 7cda629 beaf20b 7cda629 beaf20b ce5b119 4b56660 7cda629 4b56660 7cda629 4b56660 7cda629 beaf20b ce5b119 beaf20b ce5b119 a15823e a2fd8f3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import unittest
from functools import partial
import numpy as np
import pandas as pd
import sympy
from .. import PySRRegressor, sympy2jax
class TestJAX(unittest.TestCase):
    def setUp(self):
        np.random.seed(0)
    def test_sympy2jax(self):
        from jax import numpy as jnp
        from jax import random
        x, y, z = sympy.symbols("x y z")
        cosx = 1.0 * sympy.cos(x) + y
        key = random.PRNGKey(0)
        X = random.normal(key, (1000, 2))
        true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
        f, params = sympy2jax(cosx, [x, y, z])
        self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
    def test_pipeline_pandas(self):
        from jax import numpy as jnp
        X = pd.DataFrame(np.random.randn(100, 10))
        y = np.ones(X.shape[0])
        model = PySRRegressor(
            progress=False,
            max_evals=10000,
            output_jax_format=True,
        )
        model.fit(X, y)
        equations = pd.DataFrame(
            {
                "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
                "Loss": [1.0, 0.1, 1e-5],
                "Complexity": [1, 2, 3],
            }
        )
        equations["Complexity Loss Equation".split(" ")].to_csv(
            "equation_file.csv.bkup"
        )
        model.refresh(checkpoint_file="equation_file.csv")
        jformat = model.jax()
        np.testing.assert_almost_equal(
            np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
            np.square(np.cos(X.values[:, 1])),  # Select feature 1
            decimal=3,
        )
    def test_pipeline(self):
        from jax import numpy as jnp
        X = np.random.randn(100, 10)
        y = np.ones(X.shape[0])
        model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
        model.fit(X, y)
        equations = pd.DataFrame(
            {
                "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
                "Loss": [1.0, 0.1, 1e-5],
                "Complexity": [1, 2, 3],
            }
        )
        equations["Complexity Loss Equation".split(" ")].to_csv(
            "equation_file.csv.bkup"
        )
        model.refresh(checkpoint_file="equation_file.csv")
        jformat = model.jax()
        np.testing.assert_almost_equal(
            np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
            np.square(np.cos(X[:, 1])),  # Select feature 1
            decimal=3,
        )
    def test_feature_selection_custom_operators(self):
        rstate = np.random.RandomState(0)
        X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
        cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
        y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
        model = PySRRegressor(
            progress=False,
            unary_operators=["cos_approx(x) = 1 - x^2 / 2 + x^4 / 24 + x^6 / 720"],
            select_k_features=3,
            maxsize=10,
            early_stop_condition=1e-5,
            extra_sympy_mappings={"cos_approx": cos_approx},
            extra_jax_mappings={
                "cos_approx": "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
            },
            random_state=0,
            deterministic=True,
            procs=0,
            multithreading=False,
        )
        np.random.seed(0)
        model.fit(X.values, y.values)
        f, parameters = model.jax().values()
        np_prediction = model.predict
        jax_prediction = partial(f, parameters=parameters)
        np_output = np_prediction(X.values)
        jax_output = jax_prediction(X.values)
        np.testing.assert_almost_equal(y.values, np_output, decimal=3)
        np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
def runtests():
    """Run all tests in test_jax.py."""
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    suite.addTests(loader.loadTestsFromTestCase(TestJAX))
    runner = unittest.TextTestRunner()
    return runner.run(suite)
 |