Spaces:
Sleeping
Sleeping
| import functools as ft | |
| import sympy | |
| import string | |
| import random | |
| # Special since need to reduce arguments. | |
| MUL = 0 | |
| ADD = 1 | |
| _jnp_func_lookup = { | |
| sympy.Mul: MUL, | |
| sympy.Add: ADD, | |
| sympy.div: "jnp.div", | |
| sympy.Abs: "jnp.abs", | |
| sympy.sign: "jnp.sign", | |
| # Note: May raise error for ints. | |
| sympy.ceiling: "jnp.ceil", | |
| sympy.floor: "jnp.floor", | |
| sympy.log: "jnp.log", | |
| sympy.exp: "jnp.exp", | |
| sympy.sqrt: "jnp.sqrt", | |
| sympy.cos: "jnp.cos", | |
| sympy.acos: "jnp.acos", | |
| sympy.sin: "jnp.sin", | |
| sympy.asin: "jnp.asin", | |
| sympy.tan: "jnp.tan", | |
| sympy.atan: "jnp.atan", | |
| sympy.atan2: "jnp.atan2", | |
| # Note: Also may give NaN for complex results. | |
| sympy.cosh: "jnp.cosh", | |
| sympy.acosh: "jnp.acosh", | |
| sympy.sinh: "jnp.sinh", | |
| sympy.asinh: "jnp.asinh", | |
| sympy.tanh: "jnp.tanh", | |
| sympy.atanh: "jnp.atanh", | |
| sympy.Pow: "jnp.power", | |
| sympy.re: "jnp.real", | |
| sympy.im: "jnp.imag", | |
| sympy.arg: "jnp.angle", | |
| # Note: May raise error for ints and complexes | |
| sympy.erf: "jsp.erf", | |
| sympy.erfc: "jsp.erfc", | |
| sympy.LessThan: "jnp.less", | |
| sympy.GreaterThan: "jnp.greater", | |
| sympy.And: "jnp.logical_and", | |
| sympy.Or: "jnp.logical_or", | |
| sympy.Not: "jnp.logical_not", | |
| sympy.Max: "jnp.max", | |
| sympy.Min: "jnp.min", | |
| sympy.Mod: "jnp.mod", | |
| sympy.Heaviside: "jnp.heaviside", | |
| sympy.core.numbers.Half: "(lambda: 0.5)", | |
| sympy.core.numbers.One: "(lambda: 1.0)", | |
| } | |
| def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None): | |
| if issubclass(expr.func, sympy.Float): | |
| parameters.append(float(expr)) | |
| return f"parameters[{len(parameters) - 1}]" | |
| elif issubclass(expr.func, sympy.Rational): | |
| return f"{float(expr)}" | |
| elif issubclass(expr.func, sympy.Integer): | |
| return f"{int(expr)}" | |
| elif issubclass(expr.func, sympy.Symbol): | |
| return ( | |
| f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]" | |
| ) | |
| if extra_jax_mappings is None: | |
| extra_jax_mappings = {} | |
| try: | |
| _func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func] | |
| except KeyError: | |
| raise KeyError( | |
| f"Function {expr.func} was not found in JAX function mappings." | |
| "Please add it to extra_jax_mappings in the format, e.g., " | |
| "{sympy.sqrt: 'jnp.sqrt'}." | |
| ) | |
| args = [ | |
| sympy2jaxtext( | |
| arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings | |
| ) | |
| for arg in expr.args | |
| ] | |
| if _func == MUL: | |
| return " * ".join(["(" + arg + ")" for arg in args]) | |
| if _func == ADD: | |
| return " + ".join(["(" + arg + ")" for arg in args]) | |
| return f'{_func}({", ".join(args)})' | |
| jax_initialized = False | |
| jax = None | |
| jnp = None | |
| jsp = None | |
| def _initialize_jax(): | |
| global jax_initialized | |
| global jax | |
| global jnp | |
| global jsp | |
| if not jax_initialized: | |
| import jax as _jax | |
| from jax import numpy as _jnp | |
| from jax.scipy import special as _jsp | |
| jax = _jax | |
| jnp = _jnp | |
| jsp = _jsp | |
| def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None): | |
| """Returns a function f and its parameters; | |
| the function takes an input matrix, and a list of arguments: | |
| f(X, parameters) | |
| where the parameters appear in the JAX equation. | |
| # Examples: | |
| Let's create a function in SymPy: | |
| ```python | |
| x, y = symbols('x y') | |
| cosx = 1.0 * sympy.cos(x) + 3.2 * y | |
| ``` | |
| Let's get the JAX version. We pass the equation, and | |
| the symbols required. | |
| ```python | |
| f, params = sympy2jax(cosx, [x, y]) | |
| ``` | |
| The order you supply the symbols is the same order | |
| you should supply the features when calling | |
| the function `f` (shape `[nrows, nfeatures]`). | |
| In this case, features=2 for x and y. | |
| The `params` in this case will be | |
| `jnp.array([1.0, 3.2])`. You pass these parameters | |
| when calling the function, which will let you change them | |
| and take gradients. | |
| Let's generate some JAX data to pass: | |
| ```python | |
| key = random.PRNGKey(0) | |
| X = random.normal(key, (10, 2)) | |
| ``` | |
| We can call the function with: | |
| ```python | |
| f(X, params) | |
| #> DeviceArray([-2.6080756 , 0.72633684, -6.7557726 , -0.2963162 , | |
| # 6.6014843 , 5.032483 , -0.810931 , 4.2520013 , | |
| # 3.5427954 , -2.7479894 ], dtype=float32) | |
| ``` | |
| We can take gradients with respect | |
| to the parameters for each row with JAX | |
| gradient parameters now: | |
| ```python | |
| jac_f = jax.jacobian(f, argnums=1) | |
| jac_f(X, params) | |
| #> DeviceArray([[ 0.49364874, -0.9692889 ], | |
| # [ 0.8283714 , -0.0318858 ], | |
| # [-0.7447336 , -1.8784496 ], | |
| # [ 0.70755106, -0.3137085 ], | |
| # [ 0.944834 , 1.767703 ], | |
| # [ 0.51673377, 1.4111717 ], | |
| # [ 0.87347716, -0.52637756], | |
| # [ 0.8760679 , 1.0549792 ], | |
| # [ 0.9961824 , 0.79581654], | |
| # [-0.88465923, -0.5822907 ]], dtype=float32) | |
| ``` | |
| We can also JIT-compile our function: | |
| ```python | |
| compiled_f = jax.jit(f) | |
| compiled_f(X, params) | |
| #> DeviceArray([-2.6080756 , 0.72633684, -6.7557726 , -0.2963162 , | |
| # 6.6014843 , 5.032483 , -0.810931 , 4.2520013 , | |
| # 3.5427954 , -2.7479894 ], dtype=float32) | |
| ``` | |
| """ | |
| _initialize_jax() | |
| global jax_initialized | |
| global jax | |
| global jnp | |
| global jsp | |
| parameters = [] | |
| functional_form_text = sympy2jaxtext( | |
| expression, parameters, symbols_in, extra_jax_mappings | |
| ) | |
| hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in)))) | |
| text = f"def {hash_string}(X, parameters):\n" | |
| if selection is not None: | |
| # Impose the feature selection: | |
| text += f" X = X[:, {list(selection)}]\n" | |
| text += " return " | |
| text += functional_form_text | |
| ldict = {} | |
| exec(text, globals(), ldict) | |
| return ldict[hash_string], jnp.array(parameters) | |