In [None]:
# WARNING: advised to install a specific version, e.g. tensorwaves==0.1.2
%pip install -q tensorwaves[doc,jax,pwa,viz] IPython

In [None]:
%config InlineBackend.figure_formats = ['svg']
import os

from IPython.display import display  # noqa: F401

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

```{autolink-concat}
```

# Constant sub-expressions

As mentioned in {ref}`amplitude-analysis:3.1 Prepare parametrized function`, once we know which parameters in a {class}`.ParametrizedFunction` we want to optimize, we can apply several optimizations before running {meth}`~.Optimizer.optimize`. The most important of these is to identify sub-expressions that are unaffected by a change to one of the {attr}`~.ParametrizedFunction.parameters` (constant sub-expressions). It's smart to compute these sub-expressions beforehand (caching), so that only the top-expression has to be recomputed for each iteration of the {class}`.Optimizer`.

If we are creating the {class}`.ParametrizedFunction` from a {class}`sympy.Expr <sympy.core.expr.Expr>`, the strategy is as follows:
1. Create a top-expression where the constant sub-expressions are collapsed into constant nodes (represented by {class}`~sympy.core.symbol.Symbol`s) and a mapping of those {class}`~sympy.core.symbol.Symbol`s to the substituted sub-expressions. This can be done with {func}`.extract_constant_sub_expressions`.
2. Create a new {class}`.ParametrizedFunction` for this top-expression and a {class}`.SympyDataTransformer` for the sub-expressions.
3. Transform the original {obj}`.DataSample`s with that {class}`.SympyDataTransformer` (this is where the caching takes place).

This procedure is facilitated with the function {func}`.create_cached_function`.

## Determine free parameters

Let's have a look how this works for a simple expression. Caching makes more sense in complicated expressions like the ones in {doc}`/amplitude-analysis`, but this simple expression illustrates the idea.

In [None]:
import sympy as sp

a, b, c, d, x, y = sp.symbols("a b c d x y")
expression = a * x + b * (c * x**2 + d * y**2)
expression

Now, imagine that we have a data distribution over $x$ and that we _only_ want to optimize the **free parameters** $a$ and $d$.

In [None]:
free_symbols = {a, d}

Normally, we would just use {func}`.create_parametrized_function` over the entire expression without thinking about which {class}`~sympy.core.symbol.Symbol`s other than **variables** $x$ and $y$ are to be optimizable parameters:

In [None]:
from tensorwaves.function.sympy import create_parametrized_function

parameter_defaults = {a: -2.5, b: 1, c: 0.0, d: 3.7}
original_func = create_parametrized_function(
    expression,
    parameter_defaults,
    backend="numpy",
)

Note, however, that resulting {class}`.ParametrizedFunction` will have to compute the entire expression tree on each iteration, even though we only want to optimize the blue parameters:

In [None]:
import graphviz


class SymbolIdentifiable(sp.Symbol):
    # SymbolIdentifiable because of alphabetical sorting in dotprint
    @classmethod
    def from_symbol(cls, symbol):
        return SymbolIdentifiable(symbol.name, **symbol.assumptions0)


dot_style = (
    (sp.Basic, {"color": "blue", "shape": "ellipse"}),
    (sp.Expr, {"color": "black"}),
    (sp.Atom, {"color": "gray"}),
    (SymbolIdentifiable, {"color": "blue"}),
)


def visualize_free_symbols(expression, free_symbols):
    def substitute_identifiable_symbols(expression, symbols):
        substitutions = {s: SymbolIdentifiable.from_symbol(s) for s in symbols}
        return expression.xreplace(substitutions)

    dot = sp.dotprint(
        substitute_identifiable_symbols(expression, symbols=free_symbols),
        styles=dot_style,
        bgcolor="none",
    )
    graph = graphviz.Source(dot)
    display(graph)


visualize_free_symbols(expression, free_symbols)

## Extract constant sub-expressions

The function {func}`.extract_constant_sub_expressions` helps us to extract sub-expressions that remain constant with regard to some of its {class}`~sympy.core.symbol.Symbol`s. It returns a new top-expression where the sub-expressions are substituted by symbols $f_0, f_1, \dots$, as well as a mapping with sub-expression definitions for these symbols.

In [None]:
from tensorwaves.function.sympy import extract_constant_sub_expressions

top_expression, sub_expressions = extract_constant_sub_expressions(
    expression, free_symbols
)

In [None]:
from IPython.display import Math

display(top_expression)
for symbol, expr in sub_expressions.items():
    latex = sp.multiline_latex(symbol, expr, environment="eqnarray")
    display(Math(latex))

Now, notice how we have split up the original expression tree into a top tree with parameters that are to be optimized and sub-trees that remain constant:

In [None]:
visualize_free_symbols(top_expression, free_symbols)
for symbol, expr in sub_expressions.items():
    dot = sp.dotprint(expr, styles=dot_style, bgcolor="none")
    display(graphviz.Source(dot))

As an additional optimization, we could further substitute the non-optimized parameters with the values to which they are fixed. This can be done with {func}`.prepare_caching`. Notice how one of the sub-expression trees disappears altogether, because we decided to choose $c=0$ in the `parameter_defaults` and how the top tree has been simplified since $b=1$!

In [None]:
# see text in previous cell
assert parameter_defaults[c] == 0
assert parameter_defaults[b] == 1

In [None]:
from tensorwaves.function.sympy import prepare_caching

cache_expression, transformer_expressions = prepare_caching(
    expression, parameter_defaults, free_symbols
)

In [None]:
visualize_free_symbols(cache_expression.evalf(2), free_symbols)
for symbol, expr in transformer_expressions.items():
    if expr is symbol:
        continue
    dot = sp.dotprint(expr, styles=dot_style, bgcolor="none")
    display(graphviz.Source(dot))

## Caching

All of the above is mainly useful when {ref}`optimizing <usage/basics:Optimize the model>` a {class}`.ParametrizedFunction` with regard to some {class}`.Estimator`. For this reason, the {mod}`.estimator` module brings this all together with the function {func}`.create_cached_function`. This function prepares the expression trees just like we see above and creates a 'cached' {class}`.ParametrizedFunction` from the top-expression, as well as a {class}`.DataTransformer` to create a 'cached' {obj}`.DataSample` as input for that cached function.

In [None]:
from tensorwaves.estimator import create_cached_function

cached_func, cache_transformer = create_cached_function(
    expression,
    parameter_defaults,
    free_parameters=free_symbols,
    backend="numpy",
)

Notice that only the free parameters appear as {attr}`~.ParametrizedFunction.parameters` in the 'cached' function, how the {class}`.DataTransformer` defines the remaining symbols, and how variables $x, y$ are the only required arguments to the functions in the {class}`.DataTransformer`:

In [None]:
cached_func.parameters

In [None]:
cache_transformer.functions

In [None]:
free_parameter_names = set(map(str, free_symbols))
cache_variable_names = set(cached_func.argument_order) - free_parameter_names
assert set(cached_func.parameters) == free_parameter_names
assert set(cache_transformer.functions) == cache_variable_names

## Performance check

How to use this 'cached' {class}`.ParametrizedFunction` and {class}`.DataTransformer`? And is the output of that function the same as the normal functions created with {func}`.create_parametrized_function`? Let's generate generate a small {obj}`.DataSample` for the domain $x, y$:

In [None]:
from tensorwaves.data import NumpyDomainGenerator, NumpyUniformRNG

boundaries = {
    "x": (-1, +1),
    "y": (-1, +1),
}
domain_generator = NumpyDomainGenerator(boundaries)
rng = NumpyUniformRNG()
domain = domain_generator.generate(10_000, rng)

The domain {obj}`.DataSample` can be given directly to the original function:

In [None]:
intensities = original_func(domain)
intensities

For the 'cached' function, we first need to transform the domain. **This is where the caching takes place!**

In [None]:
cached_domain = cache_transformer(domain)
intensities_from_cache = cached_func(cached_domain)
intensities_from_cache

In [None]:
import numpy as np

np.testing.assert_allclose(intensities, intensities_from_cache)

The results are indeed the same and the cached function is faster as well!

```{autolink-skip}
```

In [None]:
%timeit -n100 original_func(domain)
%timeit -n100 cached_func(cached_domain)