In [None]:
%%capture
%config Completer.use_jedi = False
%config InlineBackend.figure_formats = ['svg']
import os

STATIC_WEB_PAGE = {"EXECUTE_NB", "READTHEDOCS"}.intersection(os.environ)

# Install on Google Colab
import subprocess
import sys

from IPython import get_ipython

install_packages = "google.colab" in str(get_ipython())
if install_packages:
    for package in ["tensorwaves[doc]", "graphviz"]:
        subprocess.check_call(
            [sys.executable, "-m", "pip", "install", package]
        )

# Step 3: Perform fit

As explained in the {doc}`previous step <step2>`, a {class}`.Model` object and its lambdified {class}`.Function` form behave just like a mathematical function that takes a set of data points as an argument and returns a list of intensities (real numbers). At this stage, we want to optimize the parameters of the intensity model, so that it matches the distribution of our data sample. This is what we call 'performing a fit'.

First, load the relevant data from the previous steps.

In [None]:
import pickle

import qrules as q

from tensorwaves.model import SympyModel

result = q.io.load("transitions.json")
with open("helicity_model.pickle", "rb") as stream:
    model = pickle.load(stream)
with open("data_set.pickle", "rb") as stream:
    data_set = pickle.load(stream)
with open("phsp_set.pickle", "rb") as stream:
    phsp_set = pickle.load(stream)
sympy_model = SympyModel(
    expression=model.expression,
    parameters=model.parameter_defaults,
)

## 3.1 Define estimator

To perform a fit, you need to define an {class}`.Estimator`. An estimator is a measure for the discrepancy between the intensity model and the data distribution to which you fit. In PWA, we usually use an *unbinned negative log likelihood estimator*.

Generally, the intensity model is not normalized, but a log likelihood estimator requires a normalized function. This is where the {ref}`phase space dataset <usage/step2:2.1 Generate phase space sample>` comes into play again: the intensity is evaluated separately with the phase space dataset so that the output of the {class}`.Function` can be normalized with it. The phase space sample is therefore a required argument!

```{margin}
If you want to correct for the efficiency of the detector, you should insert a *detector-reconstructed* phase space sample here.
```

In [None]:
from tensorwaves.estimator import UnbinnedNLL

estimator = UnbinnedNLL(
    sympy_model,
    data_set,
    phsp_set,
    backend="jax",
)

Note that the {class}`.UnbinnedNLL` can be expressed with different backends (it creates a {class}`.LambdifiedFunction` internally from the template {class}`.Model`). Here, we use {func}`jax <jax.jit>`, which is turns out to be the fastest backend for this model.

## 3.2 Optimize intensity model

### Specify fit parameters

Starting the fit itself is quite simple: just create an optimizer instance of your choice, here [Minuit2](https://root.cern.ch/doc/master/Minuit2Page.html), and call its {meth}`~.Minuit2.optimize` method to start the fitting process. The {meth}`~.Minuit2.optimize` method requires a mapping of parameter names to their initial values. **Only the parameters listed in the mapping are optimized.**

Let's first select a few of the parameters that we saw in {ref}`Step 3.1 <usage/step3:3.1 Define estimator>` and feed them to the optimizer to run the fit. Notice that we modify the parameters slightly to make the fit more interesting (we are running fitting to a data sample that was generated with this very same amplitude model after all).

In [None]:
initial_parameters = {
    "C[J/\\psi(1S) \\to f_{0}(1500)_{0} \\gamma_{+1}; f_{0}(1500) \\to \\pi^{0}_{0} \\pi^{0}_{0}]": 1.0
    + 0.0j,
    "Gamma_f(0)(980)": 0.15,
    "Gamma_f(0)(1500)": 0.2,
    "m_f(0)(1710)": 1.78,
}

Recall that a {class}`.Function` object computes the intensity for a certain dataset. This can be seen now nicely when we use these intensities as weights on the phase space sample and plot it together with the original dataset. Here, we look at the invariant mass distribution projection of the final states `1` and `2`, which, {ref}`as we saw before <usage/step2:2.3 Visualize kinematic variables>`, are the final state particles $\pi^0,\pi^0$.

Don't forget to use {meth}`~.Function.update_parameters` first!

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm

reaction_info = model.adapter.reaction_info
intermediate_states = sorted(
    (
        p
        for p in model.particles
        if p not in reaction_info.final_state.values()
        and p not in reaction_info.initial_state.values()
    ),
    key=lambda p: p.mass,
)

evenly_spaced_interval = np.linspace(0, 1, len(intermediate_states))
colors = [cm.rainbow(x) for x in evenly_spaced_interval]


def indicate_masses():
    plt.xlabel("$m$ [GeV]")
    for i, p in enumerate(intermediate_states):
        plt.gca().axvline(
            x=p.mass, linestyle="dotted", label=p.name, color=colors[i]
        )


def compare_model(
    variable_name,
    data_set,
    phsp_set,
    intensity_model,
    bins=100,
):
    data = data_set[variable_name]
    phsp = phsp_set[variable_name]
    intensities = intensity_model(phsp_set)
    axis = plt.gca()
    axis.hist(data, bins=bins, alpha=0.5, label="data", density=True)
    axis.hist(
        phsp,
        weights=intensities,
        bins=bins,
        histtype="step",
        color="red",
        label="initial fit model",
        density=True,
    )
    indicate_masses()
    plt.legend()

In [None]:
from tensorwaves.model import LambdifiedFunction

intensity = LambdifiedFunction(sympy_model, backend="numpy")

In [None]:
intensity.update_parameters(initial_parameters)
compare_model("m_12", data_set, phsp_set, intensity)

### Custom callbacks

The {class}`.Minuit2` class allows one to insert certain {mod}`~tensorwaves.optimizer.callbacks`. Callbacks are used to insert behavior into the {meth}`.Optimizer.optimize` method. The {mod}`~tensorwaves.optimizer.callbacks` module provides several handy callback classes, but it's also possible to define custom callbacks into the {class}`.Optimizer` by defining a custom {class}`.Callback` class. Here's one that live updates a plot of the latest fit model!

In [None]:
import os

from IPython.display import clear_output

from tensorwaves.optimizer.callbacks import Callback

variable_to_plot = "m_12"
is_doc = "READTHEDOCS" in os.environ or "EXECUTE_NB" in os.environ


class PyplotCallback(Callback):
    def __init__(self, step_size=10):
        self.__step_size = step_size
        self.__fig = None
        self.__ax = None
        self.__latest_parameters = {}

    def on_optimize_start(self, logs):
        self.__fig, self.__ax = plt.subplots(1, figsize=(8, 5))

    def on_optimize_end(self, logs):
        self.update_plot(nbins=200)
        self.__ax = None
        self.__fig = None

    def on_iteration_end(self, iteration, logs=None):
        pass

    def on_function_call_end(self, function_call, logs):
        self.__latest_parameters = logs["parameters"]
        if is_doc:
            return
        if function_call % self.__step_size != 0:
            return
        if function_call > 75:
            return
        self.update_plot(nbins=80)
        clear_output(wait=True)
        display(plt.gcf())

    def update_plot(self, nbins: int):
        if self.__fig is None or self.__ax is None:
            self.__fig, self.__ax = plt.subplots(1, figsize=(8, 5))
        data = data_set[variable_to_plot]
        phsp = phsp_set[variable_to_plot]
        intensity.update_parameters(self.__latest_parameters)
        intensities = intensity(phsp_set)
        self.__ax.cla()
        self.__ax.hist(data, bins=nbins, alpha=0.5, label="data", density=True)
        self.__ax.hist(
            phsp,
            weights=intensities,
            bins=nbins,
            histtype="step",
            color="red",
            label="fit model",
            density=True,
        )
        self.__ax.set_xlim((0.25, 2.5))
        self.__ax.set_ylim((0, 1.9))
        indicate_masses()
        plt.gcf().legend()

In the fit example below, we want to use several callbacks together. To do so, we 'stack' them together with {class}`.CallbackList`:

In [None]:
from tensorwaves.optimizer.callbacks import (
    CallbackList,
    CSVSummary,
    TFSummary,
    YAMLSummary,
)

callbacks = CallbackList(
    [
        CSVSummary("fit_traceback_minuit.csv"),
        PyplotCallback(),  # comment this line for raw fit performance
        TFSummary(),
        YAMLSummary("current_fit_result.yaml"),
    ]
)

### Perform fit

Finally, we create an {class}`.Optimizer` to {meth}`~.Minuit2.optimize` the model (which is embedded in the {class}`.Estimator`). Here, we choose the {class}`.Minuit2` optimizer, which is the most common optimizer in high-energy physics. Since we constructed the {class}`.UnbinnedNLL` with {obj}`jax <jax.jit>`, can an optimize the model with analytic gradient over the {class}`.LambdifiedFunction`:

```{margin}
The computation time depends on the complexity of the model, number of data events, the size of the phase space sample, and the number of free parameters. This model is rather small and has but a few free parameters, so the optimization shouldn't take more than a **few seconds**.

The custom callback example slows down the optimization, but you can remove it to get the raw fit performance.
```

In [None]:
from tensorwaves.optimizer.minuit import Minuit2

minuit2 = Minuit2(
    callback=callbacks,
    use_analytic_gradient=False,
)
fit_result = minuit2.optimize(estimator, initial_parameters)
fit_result

As can be seen, the values of the optimized parameters in the result are again comparable to the original values that are contained in the model ({attr}`.SympyModel.parameters`):

In [None]:
optimized_parameters = fit_result.parameter_values
for p in optimized_parameters:
    print(p)
    print(f"  initial:   {initial_parameters[p]:.3}")
    print(f"  optimized: {optimized_parameters[p]:.3}")
    print(f"  original:  {sympy_model.parameters[p]:.3}")

:::{seealso}

{ref}`usage/step3:SciPy optimizer`

:::

## 3.3 Export and import

In {ref}`usage/step3:3.2 Optimize intensity model`, we initialized {obj}`.Minuit2` with some callbacks that are {class}`.Loadable`. Such callback classes offer the possibility to {meth}`.Loadable.load_latest_parameters`, so you can pick up the optimize process in case it crashes or if you pause it. Loading the latest parameters goes as follows:

In [None]:
latest_parameters = YAMLSummary.load_latest_parameters(
    "current_fit_result.yaml"
)
latest_parameters

To restart the fit with the latest parameters, simply rerun as before.

In [None]:
minuit2 = Minuit2()
minuit2.optimize(estimator, latest_parameters)

Lo and behold: the parameters were already optimized, so the fit converged faster!

## 3.4 Visualize

### Plot optimized model

Using the same method as above, we renew the parameters of the {class}`.Function` and plot it again over the data distribution.

In [None]:
intensity.update_parameters(latest_parameters)
compare_model("m_12", data_set, phsp_set, intensity)

### Intensity components

Notice that the {class}`~ampform.helicity.HelicityModel` contains {attr}`~ampform.helicity.HelicityModel.components` attribute. This is {obj}`dict` of component names to {class}`sympy.Expr <sympy.core.expr.Expr>`s helps us to identify amplitudes and intensities in the total amplitude.

In [None]:
sorted(model.components)

In [None]:
model.components[
    R"A[J/\psi(1S)_{+1} \to f_{0}(1370)_{0} \gamma_{+1}; f_{0}(1370)_{0} \to \pi^{0}_{0} \pi^{0}_{0}]"
].subs(model.parameter_defaults).doit()

Just like in {ref}`usage/step2:2.2 Generate intensity-based sample`, these _intensity components_ can each be expressed in a computational backend. This can be done with the helper function {func}`.create_intensity_component`. After that we, update the parameters with the optimized parameter values we found in {ref}`usage/step3:Perform fit`. The two components in the example below should be the same:

In [None]:
from tensorwaves.physics import create_intensity_component

from_amplitudes = create_intensity_component(
    model,
    components=[
        R"A[J/\psi(1S)_{+1} \to f_{0}(500)_{0} \gamma_{+1}; f_{0}(500)_{0} \to \pi^{0}_{0} \pi^{0}_{0}]",
        R"A[J/\psi(1S)_{+1} \to f_{0}(980)_{0} \gamma_{+1}; f_{0}(980)_{0} \to \pi^{0}_{0} \pi^{0}_{0}]",
        R"A[J/\psi(1S)_{+1} \to f_{0}(1370)_{0} \gamma_{+1}; f_{0}(1370)_{0} \to \pi^{0}_{0} \pi^{0}_{0}]",
        R"A[J/\psi(1S)_{+1} \to f_{0}(1500)_{0} \gamma_{+1}; f_{0}(1500)_{0} \to \pi^{0}_{0} \pi^{0}_{0}]",
        R"A[J/\psi(1S)_{+1} \to f_{0}(1710)_{0} \gamma_{+1}; f_{0}(1710)_{0} \to \pi^{0}_{0} \pi^{0}_{0}]",
    ],
    backend="numpy",
)
from_amplitudes.update_parameters(latest_parameters)

In [None]:
from_intensity = create_intensity_component(
    model,
    components=R"I[J/\psi(1S)_{+1} \to \gamma_{+1} \pi^{0}_{0} \pi^{0}_{0}]",
    backend="numpy",
)
from_intensity.update_parameters(latest_parameters)

In [None]:
import numpy as np

difference = np.average(from_amplitudes(phsp_set) - from_intensity(phsp_set))
assert np.round(difference, decimals=15) == 0.0

The result is a {class}`.LambdifiedFunction` that can be plotted just like in {ref}`usage/step3:Plot optimized model`:

In [None]:
fig, ax = plt.subplots(1, figsize=(8, 5))
bins = 150
ax.hist(
    phsp_set["m_12"],
    weights=intensity(phsp_set),
    bins=bins,
    alpha=0.2,
    label="full intensity",
)
ax.hist(
    phsp_set["m_12"],
    weights=from_intensity(phsp_set),
    bins=bins,
    histtype="step",
    label=R"$J/\psi(1S)_{-1} \to \gamma_{-1} \pi^0 \pi^0$",
)
ax.set_xlim(0.25, 2.5)
ax.set_xlabel(R"$m_{\pi^0\pi^0}$ [GeV]")
ax.set_yticks([])
plt.legend();

Or generically, so that we can stack all the sub-intensities:

In [None]:
import logging

from tensorwaves.data import generate_data

logging.basicConfig()
logging.getLogger().setLevel(logging.ERROR)
intensity_components = [
    create_intensity_component(model, c, backend="numpy")
    for c in model.components
    if c.startswith("I")
]
sub_intensities = [
    generate_data(
        size=5_000,
        reaction_info=model.adapter.reaction_info,
        data_transformer=model.adapter,
        intensity=component,
    )
    for component in intensity_components
]

fig, ax = plt.subplots(1, figsize=(8, 5))
plt.hist(
    [model.adapter.transform(i)["m_12"] for i in sub_intensities],
    bins=100,
    stacked=True,
    alpha=0.6,
)
ax.set_xlim(0.25, 2.5)
ax.set_xlabel(R"$m_{\pi^0\pi^0}$ [GeV]")
ax.set_yticks([])
ax.legend(
    labels=[f"${c[2:-1]}$" for c in model.components if c.startswith("I")]
);

### Analyze optimization process

Note that {ref}`in Step 3.2 <usage/step3:3.2 Optimize intensity model>`, we initialized {class}`.Minuit2` with a {class}`.TFSummary` callback as well. Its output files provide a nice, interactive representation of the fit process and can be viewed with [TensorBoard](https://www.tensorflow.org/tensorboard/get_started) as follows:

````{tabbed} Terminal
```bash
tensorboard --logdir logs
```
````

````{tabbed} Python
```python
import tensorboard as tb

tb.notebook.list()  # View open TensorBoard instances
tb.notebook.start(args_string="--logdir logs")
```
See more info [here](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks#tensorboard_in_notebooks)
````

````{tabbed} Jupyter notebook
```ipython
%load_ext tensorboard
%tensorboard --logdir logs
```
See more info [here](https://www.tensorflow.org/tensorboard/tensorboard_in_notebooks#tensorboard_in_notebooks)
````

An alternative would be to use the output of the {class}`.CSVSummary` callback. Here's an example:

In [None]:
import pandas as pd
import sympy as sp

converters = {p: lambda s: complex(s).real for p in initial_parameters}
fit_traceback = pd.read_csv("fit_traceback_minuit.csv", converters=converters)
fig, (ax1, ax2) = plt.subplots(
    2, figsize=(7, 8), gridspec_kw={"height_ratios": [1, 1.8]}
)
fit_traceback.plot("function_call", "estimator_value", ax=ax1)
fit_traceback.plot("function_call", sorted(initial_parameters), ax=ax2)
fig.suptitle("SciPy optimizer process", fontsize=16)
ax1.set_title("Negative log likelihood")
ax2.set_title("Parameter values")
ax1.set_xlabel("function call")
ax2.set_xlabel("function call")
fig.tight_layout()
ax1.legend().remove()
legend_texts = ax2.legend().get_texts()
for text in legend_texts:
    latex = f"${sp.latex(sp.Symbol(text.get_text()))}$"
    latex = latex.replace("\\\\", "\\")
    if latex[2] == "C":
        latex = fR"\left|{latex}\right|"
    text.set_text(latex)
for line in ax2.get_lines():
    label = line.get_label()
    color = line.get_color()
    ax2.axhline(
        y=complex(sympy_model.parameters[label]).real,
        color=color,
        alpha=0.5,
        linestyle="dotted",
    )

## SciPy optimizer

As an alternative to the {class}`.Minuit2` optimizer, you can use the {class}`.ScipyMinimizer`. As opposed to Minuit, this optimizer does not compute errors. See {doc}`scipy:tutorial/optimize` and {mod}`scipy.optimize` for more info.

In [None]:
from tensorwaves.optimizer.scipy import ScipyMinimizer

scipy = ScipyMinimizer(
    callback=CSVSummary(
        "fit_traceback_scipy.csv",
        function_call_step_size=1,
        iteration_step_size=1,
    ),
    use_analytic_gradient=False,
)
fit_result = scipy.optimize(estimator, initial_parameters)
fit_result

In [None]:
converters = {p: lambda s: complex(s).real for p in initial_parameters}
fit_traceback = pd.read_csv("fit_traceback_scipy.csv", converters=converters)
fig, (ax1, ax2) = plt.subplots(
    2, figsize=(7, 8), gridspec_kw={"height_ratios": [1, 1.8]}
)
fit_traceback.plot("function_call", "estimator_value", ax=ax1)
fit_traceback.plot("function_call", sorted(initial_parameters), ax=ax2)
fig.suptitle("SciPy optimizer process", fontsize=16)
ax1.set_title("Negative log likelihood")
ax2.set_title("Parameter values")
ax1.set_xlabel("function call")
ax2.set_xlabel("function call")
fig.tight_layout()
ax1.legend().remove()
legend_texts = ax2.legend().get_texts()
for text in legend_texts:
    latex = f"${sp.latex(sp.Symbol(text.get_text()))}$"
    latex = latex.replace("\\\\", "\\")
    if latex[2] == "C":
        latex = fR"\left|{latex}\right|"
    text.set_text(latex)
for line in ax2.get_lines():
    label = line.get_label()
    color = line.get_color()
    ax2.axhline(
        y=complex(sympy_model.parameters[label]).real,
        color=color,
        alpha=0.5,
        linestyle="dotted",
    )