sympy#

import tensorwaves.function.sympy

Lambdify sympy expression trees to a Function.

create_function(expression: Expr, backend: str, use_cse: bool = True, max_complexity: int | None = None) PositionalArgumentFunction[source]#

Convert a SymPy expression to a computational function.

Parameters:

Example

>>> import numpy as np
>>> import sympy as sp
>>> from tensorwaves.function.sympy import create_function
>>> x, y = sp.symbols("x y")
>>> expression = x**2 + y**2
>>> function = create_function(expression, backend="jax")
>>> array = np.linspace(0, 3, num=4)
>>> data = {"x": array, "y": array}
>>> function(data).tolist()
[0.0, 2.0, 8.0, 18.0]
create_parametrized_function(expression: Expr, parameters: Mapping[Symbol, ParameterValue], backend: str, use_cse: bool = True, max_complexity: int | None = None) ParametrizedBackendFunction[source]#

Convert a SymPy expression to a parametrized function.

This is an extended version of create_function(), which allows one to identify certain symbols in the expression as parameters.

Parameters:

Example

>>> import numpy as np
>>> import sympy as sp
>>> from tensorwaves.function.sympy import create_parametrized_function
>>> a, b, x, y = sp.symbols("a b x y")
>>> expression = a * x**2 + b * y**2
>>> function = create_parametrized_function(
...     expression,
...     parameters={a: -1, b: 2.5},
...     backend="jax",
... )
>>> array = np.linspace(0, 1, num=5)
>>> data = {"x": array, "y": array}
>>> function.update_parameters({"b": 1})
>>> function(data).tolist()
[0.0, 0.0, 0.0, 0.0, 0.0]
lambdify(expression: Expr, symbols: Sequence[Symbol], backend: str, use_cse: bool = True) Callable[source]#

A wrapper around lambdify().

Parameters:
  • expression – the sympy.Expr that you want to express as a function in a certain computation back-end.

  • symbols – The Symbol instances in the expression that you want to serve as positional arguments in the lambdified function. Note that positional arguments are ordered.

  • backend – Computational back-end in which to express the lambdified function.

  • use_cse – Lambdify with common sub-expressions (see cse argument in lambdify()).

fast_lambdify(expression: Expr, symbols: Sequence[Symbol], backend: str, *, min_complexity: int = 0, max_complexity: int, use_cse: bool = True) Callable[source]#

Speed up lambdify() with split_expression().

For a simple example of the reasoning behind this, see Speed up lambdifying.

extract_constant_sub_expressions(expression: Expr, free_symbols: Iterable[Symbol], fix_order: bool = False) tuple[Expr, dict[Symbol, Expr]][source]#

Collapse and extract constant sub-expressions.

Along with prepare_caching(), this function prepares a sympy.Expr for caching during a fit procedure. The function returns a top expression where the constant sub-expressions have been substituted by new symbols \(f_i\) for each substituted sub-expression, and a dict that gives the sub-expressions that those symbols represent. The top expression can be given to create_parametrized_function(), while the dict of sub-expressions can be given to a SympyDataTransformer.from_sympy.

Parameters:
  • expression – The Expr from which to extract constant sub-expressions.

  • free_symbols – Symbol instance in the main expression that are not constant.

  • fix_order – If False, the generated symbols for the sub-expressions are not deterministic, because they depend on the hashes of those sub-expressions. Setting this to True makes the order deterministic, but this is slower, because requires lambdifying each sub-expression to str first.

prepare_caching(expression: Expr, parameters: Mapping[Symbol, ParameterValue], free_parameters: Iterable[Symbol], fix_order: bool = False) tuple[Expr, dict[Symbol, Expr]][source]#

Prepare an expression for optimizing with caching.

When fitting a ParametrizedFunction, only its free ParametrizedFunction.parameters are updated on each iteration. This allows for an optimization: all sub-expressions that are unaffected by these free parameters can be cached as a constant DataSample. The strategy here is to create a top expression that contains only the parameters that are to be optimized.

Along with extract_constant_sub_expressions(), this function prepares a sympy.Expr for this caching procedure. The function returns a top expression where the constant sub-expressions have been substituted by new symbols \(f_i\) for each substituted sub-expression and a dict that gives the sub-expressions that those symbols represent.

The top expression can be given to create_parametrized_function(), while the dict of sub-expressions can be given to a SympyDataTransformer.from_sympy.

Parameters:
  • expression – The Expr from which to extract constant sub-expressions.

  • parameters – A mapping of values for each of the parameter symbols in the expression. Parameters that are not free_parameters are substituted in the returned expressions with xreplace().

  • free_parameters – Symbol instances in the main expression that are to be considered parameters and that will be optimized by an Optimizer later on.

  • fix_order – If False, the generated symbols for the sub-expressions are not deterministic, because they depend on the hashes of those sub-expressions. Setting this to True makes the order deterministic, but this is slower, because requires lambdifying each sub-expression to str first.

split_expression(expression: Expr, max_complexity: int, min_complexity: int = 1) tuple[Expr, dict[Symbol, Expr]][source]#

Split an expression into a β€˜top expression’ and several sub-expressions.

Replace nodes in the expression tree of a sympy.Expr that lie within a certain complexity range (see count_ops()) with symbols and keep a mapping of each to these symbols to the sub-expressions that they replaced.