Constant sub-expressions#

As mentioned in 3.1 Prepare parametrized function, once we know which parameters in a ParametrizedFunction we want to optimize, we can apply several optimizations before running optimize(). The most important of these is to identify sub-expressions that are unaffected by a change to one of the parameters (constant sub-expressions). Itā€™s smart to compute these sub-expressions beforehand (caching), so that only the top-expression has to be recomputed for each iteration of the Optimizer.

If we are creating the ParametrizedFunction from a sympy.Expr, the strategy is as follows:

  1. Create a top-expression where the constant sub-expressions are collapsed into constant nodes (represented by Symbols) and a mapping of those Symbols to the substituted sub-expressions. This can be done with extract_constant_sub_expressions().

  2. Create a new ParametrizedFunction for this top-expression and a SympyDataTransformer for the sub-expressions.

  3. Transform the original DataSamples with that SympyDataTransformer (this is where the caching takes place).

This procedure is facilitated with the function create_cached_function().

Determine free parameters#

Letā€™s have a look how this works for a simple expression. Caching makes more sense in complicated expressions like the ones in Amplitude analysis, but this simple expression illustrates the idea.

import sympy as sp

a, b, c, d, x, y = sp.symbols("a b c d x y")
expression = a * x + b * (c * x**2 + d * y**2)
expression
\[\displaystyle a x + b \left(c x^{2} + d y^{2}\right)\]

Now, imagine that we have a data distribution over \(x\) and that we only want to optimize the free parameters \(a\) and \(d\).

free_symbols = {a, d}

Normally, we would just use create_parametrized_function() over the entire expression without thinking about which Symbols other than variables \(x\) and \(y\) are to be optimizable parameters:

from tensorwaves.function.sympy import create_parametrized_function

parameter_defaults = {a: -2.5, b: 1, c: 0.0, d: 3.7}
original_func = create_parametrized_function(
    expression,
    parameter_defaults,
    backend="numpy",
)

Note, however, that resulting ParametrizedFunction will have to compute the entire expression tree on each iteration, even though we only want to optimize the blue parameters:

Hide code cell source
import graphviz
from IPython.display import display


class SymbolIdentifiable(sp.Symbol):
    # SymbolIdentifiable because of alphabetical sorting in dotprint
    @classmethod
    def from_symbol(cls, symbol):
        return SymbolIdentifiable(symbol.name, **symbol.assumptions0)


dot_style = (
    (sp.Basic, {"color": "blue", "shape": "ellipse"}),
    (sp.Expr, {"color": "black"}),
    (sp.Atom, {"color": "gray"}),
    (SymbolIdentifiable, {"color": "blue"}),
)


def visualize_free_symbols(expression, free_symbols):
    def substitute_identifiable_symbols(expression, symbols):
        substitutions = {s: SymbolIdentifiable.from_symbol(s) for s in symbols}
        return expression.xreplace(substitutions)

    dot = sp.dotprint(
        substitute_identifiable_symbols(expression, symbols=free_symbols),
        styles=dot_style,
        bgcolor="none",
    )
    graph = graphviz.Source(dot)
    display(graph)


visualize_free_symbols(expression, free_symbols)
../../_images/4f002e42a5ae0721567e8e47760246084add5727a1d4641af12bffc9c81164c9.svg

Extract constant sub-expressions#

The function extract_constant_sub_expressions() helps us to extract sub-expressions that remain constant with regard to some of its Symbols. It returns a new top-expression where the sub-expressions are substituted by symbols \(f_0, f_1, \dots\), as well as a mapping with sub-expression definitions for these symbols.

from tensorwaves.function.sympy import extract_constant_sub_expressions

top_expression, sub_expressions = extract_constant_sub_expressions(
    expression, free_symbols
)
Hide code cell source
from IPython.display import Math

display(top_expression)
for symbol, expr in sub_expressions.items():
    latex = sp.multiline_latex(symbol, expr, environment="eqnarray")
    display(Math(latex))
\[\displaystyle a x + b \left(d f_{0} + f_{1}\right)\]
\[\displaystyle \begin{eqnarray} f_{0} & = & y^{2} \end{eqnarray}\]
\[\displaystyle \begin{eqnarray} f_{1} & = & c x^{2} \end{eqnarray}\]

Now, notice how we have split up the original expression tree into a top tree with parameters that are to be optimized and sub-trees that remain constant:

Hide code cell source
visualize_free_symbols(top_expression, free_symbols)
for expr in sub_expressions.values():
    dot = sp.dotprint(expr, styles=dot_style, bgcolor="none")
    display(graphviz.Source(dot))
../../_images/fe512a30a450015387ecd35e7a8c973801b1ee23eb8b0be2b5276af6101c1151.svg../../_images/6532b1b0582fd1444131ad9bc20f63b502845faa0cddc7973fdca899153fa7ec.svg../../_images/97043514bd5acc99daed7b7e1a5947160110e6ce762b939e654c842dcd9262bc.svg

As an additional optimization, we could further substitute the non-optimized parameters with the values to which they are fixed. This can be done with prepare_caching(). Notice how one of the sub-expression trees disappears altogether, because we decided to choose \(c=0\) in the parameter_defaults and how the top tree has been simplified since \(b=1\)!

from tensorwaves.function.sympy import prepare_caching

cache_expression, transformer_expressions = prepare_caching(
    expression, parameter_defaults, free_symbols
)
Hide code cell source
visualize_free_symbols(cache_expression.evalf(2), free_symbols)
for symbol, expr in transformer_expressions.items():
    if expr is symbol:
        continue
    dot = sp.dotprint(expr, styles=dot_style, bgcolor="none")
    display(graphviz.Source(dot))
../../_images/7a7cb04dbb01da59e9e41436699071d0fc76034d2c7151f9ac0c05a67056bc0e.svg../../_images/6532b1b0582fd1444131ad9bc20f63b502845faa0cddc7973fdca899153fa7ec.svg

Caching#

All of the above is mainly useful when optimizing a ParametrizedFunction with regard to some Estimator. For this reason, the estimator module brings this all together with the function create_cached_function(). This function prepares the expression trees just like we see above and creates a ā€˜cachedā€™ ParametrizedFunction from the top-expression, as well as a DataTransformer to create a ā€˜cachedā€™ DataSample as input for that cached function.

from tensorwaves.estimator import create_cached_function

cached_func, cache_transformer = create_cached_function(
    expression,
    parameter_defaults,
    free_parameters=free_symbols,
    backend="numpy",
)

Notice that only the free parameters appear as parameters in the ā€˜cachedā€™ function, how the DataTransformer defines the remaining symbols, and how variables \(x, y\) are the only required arguments to the functions in the DataTransformer:

cached_func.parameters
{'a': -2.5, 'd': 3.7}
cache_transformer.functions
{'f0': PositionalArgumentFunction(function=<function _lambdifygenerated at 0x7f0d821bbbe0>, argument_order=('x', 'y')),
 'x': PositionalArgumentFunction(function=<function _lambdifygenerated at 0x7f0d821badd0>, argument_order=('x', 'y'))}

Performance check#

How to use this ā€˜cachedā€™ ParametrizedFunction and DataTransformer? And is the output of that function the same as the normal functions created with create_parametrized_function()? Letā€™s generate generate a small DataSample for the domain \(x, y\):

from tensorwaves.data import NumpyDomainGenerator, NumpyUniformRNG

boundaries = {
    "x": (-1, +1),
    "y": (-1, +1),
}
domain_generator = NumpyDomainGenerator(boundaries)
rng = NumpyUniformRNG()
domain = domain_generator.generate(10_000, rng)

The domain DataSample can be given directly to the original function:

intensities = original_func(domain)
intensities
array([ 1.21693661,  0.23233727,  2.10437094, ...,  0.04068075,
       -0.85892522,  0.15663904])

For the ā€˜cachedā€™ function, we first need to transform the domain. This is where the caching takes place!

cached_domain = cache_transformer(domain)
intensities_from_cache = cached_func(cached_domain)
intensities_from_cache
array([ 1.21693661,  0.23233727,  2.10437094, ...,  0.04068075,
       -0.85892522,  0.15663904])

The results are indeed the same and the cached function is faster as well!

%timeit -n100 original_func(domain)

%timeit -n100 cached_func(cached_domain)
75.4 Āµs Ā± 712 ns per loop (mean Ā± std. dev. of 7 runs, 100 loops each)
26.3 Āµs Ā± 287 ns per loop (mean Ā± std. dev. of 7 runs, 100 loops each)