sympy#
import tensorwaves.function.sympy
Lambdify sympy expression trees to a Function.
- create_function(expression: Expr, backend: str, *, use_cse: bool = True, use_jit: bool | None = None, max_complexity: int | None = None) PositionalArgumentFunction[source]#
Convert a SymPy expression to a computational function.
- Parameters:
expression β The SymPy expression that you want to
lambdify. Itsfree_symbolsbecome arguments to the resultingPositionalArgumentFunction.backend β The computational backend in which to express the function.
use_cse β Identify common sub-expressions in the function. This usually makes the function faster and speeds up lambdification.
use_jit β Decorate the numerical function with a Just-in-Time decorator for the selected
backend. By default (None), functions are JIT-compiled if the backend supports JIT compilation.max_complexity β See Specifying complexity and Speed up lambdifying.
Example
>>> import numpy as np >>> import sympy as sp >>> from tensorwaves.function.sympy import create_function >>> x, y = sp.symbols("x y") >>> expression = x**2 + y**2 >>> function = create_function(expression, backend="jax") >>> array = np.linspace(0, 3, num=4) >>> data = {"x": array, "y": array} >>> function(data).tolist() [0.0, 2.0, 8.0, 18.0]
- create_parametrized_function(expression: Expr, parameters: Mapping[Basic, ParameterValue], backend: str, *, max_complexity: int | None = None, use_cse: bool = True, use_jit: bool | None = None) ParametrizedBackendFunction[source]#
Convert a SymPy expression to a parametrized function.
This is an extended version of
create_function(), which allows one to identify certain symbols in the expression as parameters.- Parameters:
expression β See
create_function().parameters β The symbols in the expression that are be identified as
parametersin the returnedParametrizedBackendFunction.backend β See
create_function().use_cse β See
create_function().use_jit β Decorate the numerical function with a Just-in-Time decorator for the selected
backend. By default (None), functions are JIT-compiled if the backend supports JIT compilation.max_complexity β See
create_function().
Example
>>> import numpy as np >>> import sympy as sp >>> from tensorwaves.function.sympy import create_parametrized_function >>> a, b, x, y = sp.symbols("a b x y") >>> expression = a * x**2 + b * y**2 >>> function = create_parametrized_function( ... expression, ... parameters={a: -1, b: 2.5}, ... backend="jax", ... ) >>> array = np.linspace(0, 1, num=5) >>> data = {"x": array, "y": array} >>> function.update_parameters({"b": 1}) >>> function(data).tolist() [0.0, 0.0, 0.0, 0.0, 0.0]
- lambdify(expression: Expr, symbols: Sequence[Basic], backend: str, *, use_cse: bool = True, use_jit: bool | None = None) Callable[source]#
A wrapper around
lambdify().- Parameters:
expression β the
sympy.Exprthat you want to express as a function in a certain computation back-end.symbols β The
Symbolinstances in the expression that you want to serve as positional arguments in the lambdified function. Note that positional arguments are ordered.backend β Computational back-end in which to express the lambdified function.
use_cse β Lambdify with common sub-expressions (see
cseargument inlambdify()).use_jit β Decorate the numerical function with a Just-in-Time decorator for the selected
backend. By default (None), functions are JIT-compiled if the backend supports JIT compilation.
- fast_lambdify(expression: Expr, symbols: Sequence[Basic], backend: str, *, use_cse: bool = True, use_jit: bool | None = None, max_complexity: int, min_complexity: int = 0) Callable[source]#
Speed up
lambdify()withsplit_expression().For a simple example of the reasoning behind this, see Speed up lambdifying.
- extract_constant_sub_expressions(expression: Expr, free_symbols: Iterable[Basic], fix_order: bool = False) tuple[Expr, dict[Symbol, Expr]][source]#
Collapse and extract constant sub-expressions.
Along with
prepare_caching(), this function prepares asympy.Exprfor caching during a fit procedure. The function returns a top expression where the constant sub-expressions have been substituted by new symbols \(f_i\) for each substituted sub-expression, and adictthat gives the sub-expressions that those symbols represent. The top expression can be given tocreate_parametrized_function(), while thedictof sub-expressions can be given to aSympyDataTransformer.from_sympy.- Parameters:
expression β The
Exprfrom which to extract constant sub-expressions.free_symbols β
Symbolinstance in the mainexpressionthat are not constant.fix_order β If
False, the generated symbols for the sub-expressions are not deterministic, because they depend on the hashes of those sub-expressions. Setting this toTruemakes the order deterministic, but this is slower, because requires lambdifying each sub-expression tostrfirst.
See also
- prepare_caching(expression: Expr, parameters: Mapping[Basic, ParameterValue], free_parameters: Iterable[Basic], fix_order: bool = False) tuple[Expr, dict[Basic, Expr]][source]#
Prepare an expression for optimizing with caching.
When fitting a
ParametrizedFunction, only its freeParametrizedFunction.parametersare updated on each iteration. This allows for an optimization: all sub-expressions that are unaffected by these free parameters can be cached as a constantDataSample. The strategy here is to create a top expression that contains only the parameters that are to be optimized.Along with
extract_constant_sub_expressions(), this function prepares asympy.Exprfor this caching procedure. The function returns a top expression where the constant sub-expressions have been substituted by new symbols \(f_i\) for each substituted sub-expression and adictthat gives the sub-expressions that those symbols represent.The top expression can be given to
create_parametrized_function(), while thedictof sub-expressions can be given to aSympyDataTransformer.from_sympy.- Parameters:
expression β The
Exprfrom which to extract constant sub-expressions.parameters β A mapping of values for each of the parameter symbols in the
expression. Parameters that are notfree_parametersare substituted in the returned expressions withxreplace().free_parameters β
Symbolinstances in the mainexpressionthat are to be considered parameters and that will be optimized by anOptimizerlater on.fix_order β If
False, the generated symbols for the sub-expressions are not deterministic, because they depend on the hashes of those sub-expressions. Setting this toTruemakes the order deterministic, but this is slower, because requires lambdifying each sub-expression tostrfirst.
See also
- split_expression(expression: Expr, max_complexity: int, min_complexity: int = 1) tuple[Expr, dict[Basic, Expr]][source]#
Split an expression into a βtop expressionβ and several sub-expressions.
Replace nodes in the expression tree of a
sympy.Exprthat lie within a certain complexity range (seecount_ops()) with symbols and keep a mapping of each to these symbols to the sub-expressions that they replaced.See also