sympy#
import tensorwaves.function.sympy
Lambdify sympy
expression trees to a Function
.
- create_function(expression: Expr, backend: str, use_cse: bool = True, max_complexity: int | None = None) PositionalArgumentFunction [source]#
Convert a SymPy expression to a computational function.
- Parameters:
expression β The SymPy expression that you want to
lambdify
. Itsfree_symbols
become arguments to the resultingPositionalArgumentFunction
.backend β The computational backend in which to express the function.
use_cse β Identify common sub-expressions in the function. This usually makes the function faster and speeds up lambdification.
max_complexity β See Specifying complexity and Speed up lambdifying.
Example
>>> import numpy as np >>> import sympy as sp >>> from tensorwaves.function.sympy import create_function >>> x, y = sp.symbols("x y") >>> expression = x**2 + y**2 >>> function = create_function(expression, backend="jax") >>> array = np.linspace(0, 3, num=4) >>> data = {"x": array, "y": array} >>> function(data).tolist() [0.0, 2.0, 8.0, 18.0]
- create_parametrized_function(expression: Expr, parameters: Mapping[Symbol, ParameterValue], backend: str, use_cse: bool = True, max_complexity: int | None = None) ParametrizedBackendFunction [source]#
Convert a SymPy expression to a parametrized function.
This is an extended version of
create_function()
, which allows one to identify certain symbols in the expression as parameters.- Parameters:
expression β See
create_function()
.parameters β The symbols in the expression that are be identified as
parameters
in the returnedParametrizedBackendFunction
.backend β See
create_function()
.use_cse β See
create_function()
.max_complexity β See
create_function()
.
Example
>>> import numpy as np >>> import sympy as sp >>> from tensorwaves.function.sympy import create_parametrized_function >>> a, b, x, y = sp.symbols("a b x y") >>> expression = a * x**2 + b * y**2 >>> function = create_parametrized_function( ... expression, ... parameters={a: -1, b: 2.5}, ... backend="jax", ... ) >>> array = np.linspace(0, 1, num=5) >>> data = {"x": array, "y": array} >>> function.update_parameters({"b": 1}) >>> function(data).tolist() [0.0, 0.0, 0.0, 0.0, 0.0]
- lambdify(expression: Expr, symbols: Sequence[Symbol], backend: str, use_cse: bool = True) Callable [source]#
A wrapper around
lambdify()
.- Parameters:
expression β the
sympy.Expr
that you want to express as a function in a certain computation back-end.symbols β The
Symbol
instances in the expression that you want to serve as positional arguments in the lambdified function. Note that positional arguments are ordered.backend β Computational back-end in which to express the lambdified function.
use_cse β Lambdify with common sub-expressions (see
cse
argument inlambdify()
).
- fast_lambdify(expression: Expr, symbols: Sequence[Symbol], backend: str, *, min_complexity: int = 0, max_complexity: int, use_cse: bool = True) Callable [source]#
Speed up
lambdify()
withsplit_expression()
.For a simple example of the reasoning behind this, see Speed up lambdifying.
- extract_constant_sub_expressions(expression: Expr, free_symbols: Iterable[Symbol], fix_order: bool = False) tuple[Expr, dict[Symbol, Expr]] [source]#
Collapse and extract constant sub-expressions.
Along with
prepare_caching()
, this function prepares asympy.Expr
for caching during a fit procedure. The function returns a top expression where the constant sub-expressions have been substituted by new symbols \(f_i\) for each substituted sub-expression, and adict
that gives the sub-expressions that those symbols represent. The top expression can be given tocreate_parametrized_function()
, while thedict
of sub-expressions can be given to aSympyDataTransformer.from_sympy
.- Parameters:
expression β The
Expr
from which to extract constant sub-expressions.free_symbols β
Symbol
instance in the mainexpression
that are not constant.fix_order β If
False
, the generated symbols for the sub-expressions are not deterministic, because they depend on the hashes of those sub-expressions. Setting this toTrue
makes the order deterministic, but this is slower, because requires lambdifying each sub-expression tostr
first.
See also
- prepare_caching(expression: Expr, parameters: Mapping[Symbol, ParameterValue], free_parameters: Iterable[Symbol], fix_order: bool = False) tuple[Expr, dict[Symbol, Expr]] [source]#
Prepare an expression for optimizing with caching.
When fitting a
ParametrizedFunction
, only its freeParametrizedFunction.parameters
are updated on each iteration. This allows for an optimization: all sub-expressions that are unaffected by these free parameters can be cached as a constantDataSample
. The strategy here is to create a top expression that contains only the parameters that are to be optimized.Along with
extract_constant_sub_expressions()
, this function prepares asympy.Expr
for this caching procedure. The function returns a top expression where the constant sub-expressions have been substituted by new symbols \(f_i\) for each substituted sub-expression and adict
that gives the sub-expressions that those symbols represent.The top expression can be given to
create_parametrized_function()
, while thedict
of sub-expressions can be given to aSympyDataTransformer.from_sympy
.- Parameters:
expression β The
Expr
from which to extract constant sub-expressions.parameters β A mapping of values for each of the parameter symbols in the
expression
. Parameters that are notfree_parameters
are substituted in the returned expressions withxreplace()
.free_parameters β
Symbol
instances in the mainexpression
that are to be considered parameters and that will be optimized by anOptimizer
later on.fix_order β If
False
, the generated symbols for the sub-expressions are not deterministic, because they depend on the hashes of those sub-expressions. Setting this toTrue
makes the order deterministic, but this is slower, because requires lambdifying each sub-expression tostr
first.
See also
- split_expression(expression: Expr, max_complexity: int, min_complexity: int = 1) tuple[Expr, dict[Symbol, Expr]] [source]#
Split an expression into a βtop expressionβ and several sub-expressions.
Replace nodes in the expression tree of a
sympy.Expr
that lie within a certain complexity range (seecount_ops()
) with symbols and keep a mapping of each to these symbols to the sub-expressions that they replaced.See also