Usage

TensorWaves is a package for fitting general mathematical expressions to data distributions. The fundamentals behind the package are illustrated in Core ideas illustrated.

While general in design, the package is intended for doing Partial Wave Analysis. First, the ampform package determines which transitions are allowed from some initial state to a final state. It then formulates those transitions mathematically as an amplitude model. TensorWaves can then lambdify() this expression to some computational backend. Finally, TensorWaves ‘fits’ this model to some data sample. Optionally, a data sample can be generated from the model.

This page shows a brief overview of the complete workflow. More info about each step can be found under Step-by-step workflow.

Overview

import logging
import warnings

import ampform
import graphviz
import matplotlib.pyplot as plt
import pandas as pd
import qrules
import sympy as sp
from ampform.dynamics.builder import create_relativistic_breit_wigner_with_ff

from tensorwaves.data import generate_data, generate_phsp
from tensorwaves.data.transform import HelicityTransformer
from tensorwaves.estimator import UnbinnedNLL
from tensorwaves.model import LambdifiedFunction, SympyModel
from tensorwaves.optimizer.callbacks import CSVSummary
from tensorwaves.optimizer.minuit import Minuit2

logger = logging.getLogger()
logger.setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

Construct a model

reaction = qrules.generate_transitions(
    initial_state=("J/psi(1S)", [-1, +1]),
    final_state=["gamma", "pi0", "pi0"],
    allowed_intermediate_particles=[
        "f(0)(980)",
        "f(0)(1500)",
        "f(0)(1710)",
    ],
    allowed_interaction_types=["strong", "EM"],
    formalism="canonical-helicity",
)
dot = qrules.io.asdot(reaction, collapse_graphs=True)
graphviz.Source(dot)
_images/usage_8_0.svg
model_builder = ampform.get_builder(reaction)
for name in reaction.get_intermediate_particles().names:
    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)
model = model_builder.formulate()
model.components[
    R"A_{J/\psi(1S)_{-1} \xrightarrow[S=1]{L=0} f_{0}(980)_{0} \gamma_{-1};"
    R" f_{0}(980)_{0} \xrightarrow[S=0]{L=0} \pi^{0}_{0} \pi^{0}_{0}}"
].doit()
\[\displaystyle \frac{C_{J/\psi(1S) \xrightarrow[S=1]{L=0} f_{0}(980) \gamma; f_{0}(980) \xrightarrow[S=0]{L=0} \pi^{0} \pi^{0}} \Gamma_{f(0)(980)} m_{f(0)(980)} \left(\frac{1}{2} - \frac{\cos{\left(\theta_{1+2} \right)}}{2}\right) e^{- i \phi_{1+2}}}{- \frac{i \Gamma_{f(0)(980)} m_{f(0)(980)} \sqrt{\frac{\left(m_{12}^{2} - \left(m_{1} - m_{2}\right)^{2}\right) \left(m_{12}^{2} - \left(m_{1} + m_{2}\right)^{2}\right)}{m_{12}^{2}}} \sqrt{m_{f(0)(980)}^{2}}}{\sqrt{\frac{\left(m_{f(0)(980)}^{2} - \left(m_{1} - m_{2}\right)^{2}\right) \left(m_{f(0)(980)}^{2} - \left(m_{1} + m_{2}\right)^{2}\right)}{m_{f(0)(980)}^{2}}} \left|{m_{12}}\right|} - m_{12}^{2} + m_{f(0)(980)}^{2}}\]

Generate data sample

sympy_model = SympyModel(
    expression=model.expression.doit(),
    parameters=model.parameter_defaults,
)
intensity = LambdifiedFunction(sympy_model, backend="jax")
data_converter = HelicityTransformer(model.adapter)
reaction_info = model.adapter.reaction_info
initial_state_mass = reaction_info.initial_state[-1].mass
final_state_masses = {i: p.mass for i, p in reaction_info.final_state.items()}
phsp_sample = generate_phsp(100_000, initial_state_mass, final_state_masses)
data_sample = generate_data(
    10_000,
    initial_state_mass,
    final_state_masses,
    data_converter,
    intensity,
)
import numpy as np
from matplotlib import cm

reaction_info = model.adapter.reaction_info
intermediate_states = sorted(
    (
        p
        for p in model.particles
        if p not in reaction_info.final_state.values()
        and p not in reaction_info.initial_state.values()
    ),
    key=lambda p: p.mass,
)

evenly_spaced_interval = np.linspace(0, 1, len(intermediate_states))
colors = [cm.rainbow(x) for x in evenly_spaced_interval]


def indicate_masses():
    plt.xlabel("$m$ [GeV]")
    for i, p in enumerate(intermediate_states):
        plt.gca().axvline(
            x=p.mass, linestyle="dotted", label=p.name, color=colors[i]
        )
phsp_set = data_converter.transform(phsp_sample)
data_set = data_converter.transform(data_sample)
data_frame = pd.DataFrame(data_set)
fig, ax = plt.subplots(figsize=(9, 4))
data_frame["m_12"].hist(bins=100, alpha=0.5, density=True, ax=ax)
indicate_masses()
plt.legend();
_images/usage_16_0.svg

Optimize the model

import matplotlib.pyplot as plt
import numpy as np


def compare_model(
    variable_name,
    data_set,
    phsp_set,
    intensity_model,
    bins=150,
):
    data = data_set[variable_name]
    phsp = phsp_set[variable_name]
    intensities = intensity_model(phsp_set)
    fig, ax = plt.subplots(figsize=(9, 4))
    plt.hist(
        data,
        bins=bins,
        alpha=0.5,
        label="data",
        density=True,
    )
    plt.hist(
        phsp,
        weights=intensities,
        bins=bins,
        histtype="step",
        color="red",
        label="initial fit model",
        density=True,
    )
    indicate_masses()
    plt.legend()
estimator = UnbinnedNLL(
    intensity,
    data_set,
    phsp_set,
    backend="jax",
)
initial_parameters = {
    "m_f(0)(980)": 0.93,
    "m_f(0)(1500)": 1.45,
    "m_f(0)(1710)": 1.8,
    "Gamma_f(0)(980)": 0.1,
    "Gamma_f(0)(1710)": 0.2,
}
intensity.update_parameters(initial_parameters)
compare_model("m_12", data_set, phsp_set, intensity)
print("Number of free parameters:", len(initial_parameters))
Number of free parameters: 5
_images/usage_21_1.svg
callback = CSVSummary("fit_traceback.csv")
minuit2 = Minuit2(callback)
fit_result = minuit2.optimize(estimator, initial_parameters)
fit_result
FitResult(
 minimum_valid=True,
 execution_time=4.752262592315674,
 function_calls=152,
 estimator_value=-7697.045979271293,
 parameter_values={
  'm_f(0)(980)': 0.9912307810562814,
  'm_f(0)(1500)': 1.5064498157282065,
  'm_f(0)(1710)': 1.7058351908207452,
  'Gamma_f(0)(980)': 0.05967904023243938,
  'Gamma_f(0)(1710)': 0.11620962577233274,
 },
 parameter_errors={
  'm_f(0)(980)': 0.0010087654000929575,
  'm_f(0)(1500)': 0.0011703223061067685,
  'm_f(0)(1710)': 0.0011682568550078468,
  'Gamma_f(0)(980)': 0.0015760682918362458,
  'Gamma_f(0)(1710)': 0.002956378707025239,
 },
)
optimized_parameters = fit_result.parameter_values
intensity.update_parameters(optimized_parameters)
compare_model("m_12", data_set, phsp_set, intensity)
_images/usage_24_0.svg
converters = {p: lambda s: complex(s).real for p in initial_parameters}
fit_traceback = pd.read_csv("fit_traceback.csv", converters=converters)
fig, (ax1, ax2) = plt.subplots(
    2, figsize=(7, 8), gridspec_kw={"height_ratios": [1, 1.8]}
)
fit_traceback.plot("function_call", "estimator_value", ax=ax1)
fit_traceback.plot("function_call", sorted(initial_parameters), ax=ax2)
ax1.set_title("Negative log likelihood")
ax2.set_title("Parameter values")
ax1.set_xlabel("function call")
ax2.set_xlabel("function call")
fig.tight_layout()
ax1.legend().remove()
legend_texts = ax2.legend().get_texts()
for text in legend_texts:
    latex = f"${sp.latex(sp.Symbol(text.get_text()))}$"
    latex = latex.replace("\\\\", "\\")
    if latex[2] == "C":
        latex = fR"\left|{latex}\right|"
    text.set_text(latex)
for line in ax2.get_lines():
    label = line.get_label()
    color = line.get_color()
    ax2.axhline(
        y=complex(sympy_model.parameters[label]).real,
        color=color,
        alpha=0.5,
        linestyle="dotted",
    )
_images/usage_25_0.svg