# Constant sub-expressions#

As mentioned in 3.1 Prepare parametrized function, once we know which parameters in a `ParametrizedFunction`

we want to optimize, we can apply several optimizations before running `optimize()`

. The most important of these is to identify sub-expressions that are unaffected by a change to one of the `parameters`

(constant sub-expressions). Itās smart to compute these sub-expressions beforehand (caching), so that only the top-expression has to be recomputed for each iteration of the `Optimizer`

.

If we are creating the `ParametrizedFunction`

from a `sympy.Expr`

, the strategy is as follows:

Create a top-expression where the constant sub-expressions are collapsed into constant nodes (represented by

`Symbol`

s) and a mapping of those`Symbol`

s to the substituted sub-expressions. This can be done with`extract_constant_sub_expressions()`

.Create a new

`ParametrizedFunction`

for this top-expression and a`SympyDataTransformer`

for the sub-expressions.Transform the original

`DataSample`

s with that`SympyDataTransformer`

(this is where the caching takes place).

This procedure is facilitated with the function `create_cached_function()`

.

## Determine free parameters#

Letās have a look how this works for a simple expression. Caching makes more sense in complicated expressions like the ones in Amplitude analysis, but this simple expression illustrates the idea.

```
import sympy as sp
a, b, c, d, x, y = sp.symbols("a b c d x y")
expression = a * x + b * (c * x**2 + d * y**2)
expression
```

Now, imagine that we have a data distribution over \(x\) and that we *only* want to optimize the **free parameters** \(a\) and \(d\).

```
free_symbols = {a, d}
```

Normally, we would just use `create_parametrized_function()`

over the entire expression without thinking about which `Symbol`

s other than **variables** \(x\) and \(y\) are to be optimizable parameters:

```
from tensorwaves.function.sympy import create_parametrized_function
parameter_defaults = {a: -2.5, b: 1, c: 0.0, d: 3.7}
original_func = create_parametrized_function(
expression,
parameter_defaults,
backend="numpy",
)
```

Note, however, that resulting `ParametrizedFunction`

will have to compute the entire expression tree on each iteration, even though we only want to optimize the blue parameters:

## Show code cell source

```
import graphviz
from IPython.display import display
class SymbolIdentifiable(sp.Symbol):
# SymbolIdentifiable because of alphabetical sorting in dotprint
@classmethod
def from_symbol(cls, symbol):
return SymbolIdentifiable(symbol.name, **symbol.assumptions0)
dot_style = (
(sp.Basic, {"color": "blue", "shape": "ellipse"}),
(sp.Expr, {"color": "black"}),
(sp.Atom, {"color": "gray"}),
(SymbolIdentifiable, {"color": "blue"}),
)
def visualize_free_symbols(expression, free_symbols):
def substitute_identifiable_symbols(expression, symbols):
substitutions = {s: SymbolIdentifiable.from_symbol(s) for s in symbols}
return expression.xreplace(substitutions)
dot = sp.dotprint(
substitute_identifiable_symbols(expression, symbols=free_symbols),
styles=dot_style,
bgcolor="none",
)
graph = graphviz.Source(dot)
display(graph)
visualize_free_symbols(expression, free_symbols)
```

## Extract constant sub-expressions#

The function `extract_constant_sub_expressions()`

helps us to extract sub-expressions that remain constant with regard to some of its `Symbol`

s. It returns a new top-expression where the sub-expressions are substituted by symbols \(f_0, f_1, \dots\), as well as a mapping with sub-expression definitions for these symbols.

```
from tensorwaves.function.sympy import extract_constant_sub_expressions
top_expression, sub_expressions = extract_constant_sub_expressions(
expression, free_symbols
)
```

## Show code cell source

```
from IPython.display import Math
display(top_expression)
for symbol, expr in sub_expressions.items():
latex = sp.multiline_latex(symbol, expr, environment="eqnarray")
display(Math(latex))
```

Now, notice how we have split up the original expression tree into a top tree with parameters that are to be optimized and sub-trees that remain constant:

## Show code cell source

```
visualize_free_symbols(top_expression, free_symbols)
for expr in sub_expressions.values():
dot = sp.dotprint(expr, styles=dot_style, bgcolor="none")
display(graphviz.Source(dot))
```

As an additional optimization, we could further substitute the non-optimized parameters with the values to which they are fixed. This can be done with `prepare_caching()`

. Notice how one of the sub-expression trees disappears altogether, because we decided to choose \(c=0\) in the `parameter_defaults`

and how the top tree has been simplified since \(b=1\)!

```
from tensorwaves.function.sympy import prepare_caching
cache_expression, transformer_expressions = prepare_caching(
expression, parameter_defaults, free_symbols
)
```

## Show code cell source

```
visualize_free_symbols(cache_expression.evalf(2), free_symbols)
for symbol, expr in transformer_expressions.items():
if expr is symbol:
continue
dot = sp.dotprint(expr, styles=dot_style, bgcolor="none")
display(graphviz.Source(dot))
```

## Caching#

All of the above is mainly useful when optimizing a `ParametrizedFunction`

with regard to some `Estimator`

. For this reason, the `estimator`

module brings this all together with the function `create_cached_function()`

. This function prepares the expression trees just like we see above and creates a ācachedā `ParametrizedFunction`

from the top-expression, as well as a `DataTransformer`

to create a ācachedā `DataSample`

as input for that cached function.

```
from tensorwaves.estimator import create_cached_function
cached_func, cache_transformer = create_cached_function(
expression,
parameter_defaults,
free_parameters=free_symbols,
backend="numpy",
)
```

Notice that only the free parameters appear as `parameters`

in the ācachedā function, how the `DataTransformer`

defines the remaining symbols, and how variables \(x, y\) are the only required arguments to the functions in the `DataTransformer`

:

```
cached_func.parameters
```

```
{'a': -2.5, 'd': 3.7}
```

```
cache_transformer.functions
```

```
{'f0': PositionalArgumentFunction(function=<function _lambdifygenerated at 0x7f0d821bbbe0>, argument_order=('x', 'y')),
'x': PositionalArgumentFunction(function=<function _lambdifygenerated at 0x7f0d821badd0>, argument_order=('x', 'y'))}
```

## Performance check#

How to use this ācachedā `ParametrizedFunction`

and `DataTransformer`

? And is the output of that function the same as the normal functions created with `create_parametrized_function()`

? Letās generate generate a small `DataSample`

for the domain \(x, y\):

```
from tensorwaves.data import NumpyDomainGenerator, NumpyUniformRNG
boundaries = {
"x": (-1, +1),
"y": (-1, +1),
}
domain_generator = NumpyDomainGenerator(boundaries)
rng = NumpyUniformRNG()
domain = domain_generator.generate(10_000, rng)
```

The domain `DataSample`

can be given directly to the original function:

```
intensities = original_func(domain)
intensities
```

```
array([ 1.21693661, 0.23233727, 2.10437094, ..., 0.04068075,
-0.85892522, 0.15663904])
```

For the ācachedā function, we first need to transform the domain. **This is where the caching takes place!**

```
cached_domain = cache_transformer(domain)
intensities_from_cache = cached_func(cached_domain)
intensities_from_cache
```

```
array([ 1.21693661, 0.23233727, 2.10437094, ..., 0.04068075,
-0.85892522, 0.15663904])
```

The results are indeed the same and the cached function is faster as well!

```
%timeit -n100 original_func(domain)
%timeit -n100 cached_func(cached_domain)
```

```
75.4 Āµs Ā± 712 ns per loop (mean Ā± std. dev. of 7 runs, 100 loops each)
26.3 Āµs Ā± 287 ns per loop (mean Ā± std. dev. of 7 runs, 100 loops each)
```