# Speed up lambdifying#

Hide code cell content
import logging

import ampform
import graphviz
import qrules
import sympy as sp
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff
from IPython.display import HTML, SVG

from tensorwaves.function.sympy import (
create_parametrized_function,
fast_lambdify,
split_expression,
)

logging.getLogger("tensorwaves.data").setLevel(logging.ERROR)  # hide progress bars


Note

Since #374, expressions are lambdified with common sub-expressions. This should already reduce lambdification time significantly and also results in faster computational functions.

## Split expression#

Lambdifying a SymPy expression can take rather long when an expression is complicated. Fortunately, TensorWaves offers a way to speed up the lambdify process. The idea is to split up an an expression into sub-expressions, separate those separately, and then recombining them. Let’s illustrate that idea with the following simplified example:

x, y, z = sp.symbols("x:z")
expr = x**z + 2 * y + sp.log(y * z)
expr

$\displaystyle x^{z} + 2 y + \log{\left(y z \right)}$

This expression can be represented in a tree of mathematical operations.

Hide code cell source
dot = sp.dotprint(expr, bgcolor="none")
graphviz.Source(dot) The function split_expression() can now be used to split up this expression tree into a ‘top expression’ plus definitions for each of the sub-expressions into which it was split:

top_expr, sub_expressions = split_expression(expr, max_complexity=3)
top_expr

$\displaystyle f_{0} + f_{1} + f_{2}$
sub_expressions

{f0: x**z, f1: 2*y, f2: log(y*z)}


The original expression can easily be reconstructed with subs() or xreplace():

top_expr.xreplace(sub_expressions)

$\displaystyle x^{z} + 2 y + \log{\left(y z \right)}$

Each of the expression trees are now smaller than the original:

Hide code cell source
dot = sp.dotprint(top_expr, bgcolor="none")
graphviz.Source(dot) Hide code cell source
for symbol, definition in sub_expressions.items():
dot = sp.dotprint(definition, bgcolor="none")
graph = graphviz.Source(dot)
graph.render(filename=f"sub_expr_{symbol.name}", format="svg")

html = "<table>\n"
html += "  <tr>\n"
html += "".join(
f'    <th style="text-align:center; background-color:white">{symbol.name}</th>\n'
for symbol in sub_expressions
)
html += "  </tr>\n"
html += "  <tr>\n"
for symbol in sub_expressions:
svg = SVG(f"sub_expr_{symbol.name}.svg").data
html += f'    <td style="background-color:white">{svg}</td>\n'
html += "  </tr>\n"
html += "</table>"
HTML(html)

f0 f1 f2

## Fast lambdify#

Generally, the lambdify time scales exponentially with the size of an expression tree. With larger expression trees, it’s therefore much faster to lambdify these sub-expressions separately and to recombine them. TensorWaves offers a function that does this for you: fast_lambdify(). We’ll use an HelicityModel to illustrate this:

reaction = qrules.generate_transitions(
initial_state=("J/psi(1S)", [+1]),
final_state=["gamma", "pi0", "pi0"],
allowed_intermediate_particles=["f(0)"],
)
model_builder = ampform.get_builder(reaction)
for name in reaction.get_intermediate_particles().names:
model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.formulate()

expression = model.expression.doit()
sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name)

%%time
split_function = fast_lambdify(
expression,
sorted_symbols,
max_complexity=100,
backend="numpy",
)

CPU times: user 1.05 s, sys: 3.33 ms, total: 1.05 s
Wall time: 1.13 s

%%time
normal_function = sp.lambdify(sorted_symbols, expression)

CPU times: user 7.99 s, sys: 12.3 ms, total: 8 s
Wall time: 8.02 s


## Specifying complexity#

When creating a parametrized function, we use the create_parametrized_function() function. By default, this internally calls SymPy’s own lambdify() function. But if you specify its max_complexity argument, create_parametrized_function() uses TensorWaves’s fast_lambdify().

%%time
function = create_parametrized_function(
expression=model.expression.doit(),
parameters=model.parameter_defaults,
max_complexity=100,
backend="numpy",
)

CPU times: user 3.97 s, sys: 3.95 ms, total: 3.98 s
Wall time: 4.06 s