Speed up lambdifying

import logging

import ampform
import graphviz
import qrules
import sympy as sp
from ampform.dynamics.builder import (
    create_relativistic_breit_wigner_with_ff,
)
from IPython.display import HTML, SVG

from tensorwaves.model import (
    LambdifiedFunction,
    SympyModel,
    optimized_lambdify,
    split_expression,
)

logger = logging.getLogger()
logger.setLevel(logging.ERROR)

Split expression

Lambdifying a SymPy expression can take rather long when an expression is complicated. Fortunately, TensorWaves offers a way to speed up the lambdify process. The idea is to split up an an expression into sub-expressions, separate those separately, and then recombining them. Let’s illustrate that idea with the following simplified example:

x, y, z = sp.symbols("x:z")
expr = x ** z + 2 * y + sp.log(y * z)
expr
\[\displaystyle x^{z} + 2 y + \log{\left(y z \right)}\]

This expression can be represented in a tree of mathematical operations.

dot = sp.dotprint(expr)
graphviz.Source(dot)
../_images/faster-lambdify_7_0.svg

The function split_expression() can now be used to split up this expression tree into a ‘top expression’ plus definitions for each of the sub-expressions into which it was split:

top_expr, sub_expressions = split_expression(expr, max_complexity=3)
top_expr
\[\displaystyle f_{0} + f_{1} + f_{2}\]
sub_expressions
{f0: x**z, f1: 2*y, f2: log(y*z)}

The original expression can easily be reconstructed with subs() or xreplace():

top_expr.xreplace(sub_expressions)
\[\displaystyle x^{z} + 2 y + \log{\left(y z \right)}\]

Each of the expression trees are now smaller than the original:

dot = sp.dotprint(top_expr)
graphviz.Source(dot)
../_images/faster-lambdify_15_0.svg
for symbol, definition in sub_expressions.items():
    dot = sp.dotprint(definition)
    graph = graphviz.Source(dot)
    graph.render(filename=f"sub_expr_{symbol.name}", format="svg")

html = "<table>\n"
html += "  <tr>\n"
html += "".join(
    f'    <th style="text-align:center; background-color:white">{symbol.name}</th>\n'
    for symbol in sub_expressions
)
html += "  </tr>\n"
html += "  <tr>\n"
for symbol in sub_expressions:
    svg = SVG(f"sub_expr_{symbol.name}.svg").data
    html += f'    <td style="background-color:white">{svg}</td>\n'
html += "  </tr>\n"
html += "</table>"
HTML(html)
f0 f1 f2
%3 Pow(Symbol('x'), Symbol('z'))_() Pow Symbol('x')_(0,) x Pow(Symbol('x'), Symbol('z'))_()->Symbol('x')_(0,) Symbol('z')_(1,) z Pow(Symbol('x'), Symbol('z'))_()->Symbol('z')_(1,) %3 Mul(Integer(2), Symbol('y'))_() Mul Integer(2)_(0,) 2 Mul(Integer(2), Symbol('y'))_()->Integer(2)_(0,) Symbol('y')_(1,) y Mul(Integer(2), Symbol('y'))_()->Symbol('y')_(1,) %3 log(Mul(Symbol('y'), Symbol('z')))_() log Mul(Symbol('y'), Symbol('z'))_(0,) Mul log(Mul(Symbol('y'), Symbol('z')))_()->Mul(Symbol('y'), Symbol('z'))_(0,) Symbol('y')_(0, 0) y Mul(Symbol('y'), Symbol('z'))_(0,)->Symbol('y')_(0, 0) Symbol('z')_(0, 1) z Mul(Symbol('y'), Symbol('z'))_(0,)->Symbol('z')_(0, 1)

Optimized lambdify

Generally, the lambdify time scales exponentially with the size of an expression tree. With larger expression trees, it’s therefore much faster to lambdify these sub-expressions separately and to recombine them. TensorWaves offers a function that does this for you: optimized_lambdify(). We’ll use an HelicityModel to illustrate this:

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [+1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=["f(0)"],
)
model_builder = ampform.get_builder(reaction)
for name in reaction.get_intermediate_particles().names:
    model_builder.set_dynamics(
        name, create_relativistic_breit_wigner_with_ff
    )
model = model_builder.formulate()
expression = model.expression.doit()
sorted_symbols = sorted(expression.free_symbols, key=lambda s: s.name)
%%time
lambdified_optimized = optimized_lambdify(
    sorted_symbols,
    expression,
    max_complexity=100,
)
CPU times: user 964 ms, sys: 3.75 ms, total: 967 ms
Wall time: 967 ms
%%time
sp.lambdify(sorted_symbols, expression)
CPU times: user 9.44 s, sys: 39.2 ms, total: 9.48 s
Wall time: 9.48 s
<function _lambdifygenerated(Dummy_604, Dummy_603, Dummy_602, Dummy_601, Dummy_600, Dummy_599, Dummy_598, Dummy_597, Dummy_596, Dummy_595, Dummy_594, Dummy_593, Dummy_592, Dummy_591, Dummy_590, m_1, m_12, m_2, Dummy_589, Dummy_588, Dummy_587, Dummy_586, Dummy_585, Dummy_584, Dummy_583)>

Specifying complexity

In the usually workflow (see Usage), TensorWaves uses SymPy’s own lambdify() by default. You can change this behavior with the max_complexity argument of SympyModel:

sympy_model = SympyModel(
    expression=model.expression,
    parameters=model.parameter_defaults,
    max_complexity=100,
)

If max_complexity is specified (i.e., is not None), LambdifiedFunction uses TensorWaves’s optimized_lambdify().

%%time
intensity = LambdifiedFunction(sympy_model, backend="jax")
CPU times: user 1.2 s, sys: 15.9 ms, total: 1.22 s
Wall time: 1.22 s