Core ideas illustrated#

At core, tensorwaves is a package that can ‘fit’ arbitrary mathematical expressions to a data set using different computational back-ends. It can also use those expressions to describe a distribution over which to generate data samples.

This page illustrate what’s going on behind the scenes with some simple 1-dimensional and 2-dimensional expressions. The main steps are:

  1. Formulate a mathematical expression with sympy.

  2. Generate a distribution data sample for that expression.

  3. Express the expression as a function in some computational back-end.

  4. Tweak the parameters and fit the ParametrizedFunction to the generated distribution.

Hide code cell content

Formulate model#

First of all, we’ll formulate some expression with sympy. In this example, we take a sum of Gaussians plus some Poisson distribution.

def gaussian(x: sp.Symbol, mu: sp.Symbol, sigma: sp.Symbol) -> sp.Expr:
    return sp.exp(-(((x - mu) / sigma) ** 2) / 2)


def poisson(x: sp.Symbol, k: sp.Symbol) -> sp.Expr:
    return x**k * sp.exp(-x) / sp.factorial(k)
Hide code cell source
x, mu, sigma = sp.symbols("x mu sigma")
k = sp.Symbol("k", integer=True)
lam = sp.Symbol("lambda", positive=True)
style = "<style>#output-body{display:flex; flex-direction: row;}</style>"
display(
    gaussian(x, mu, sigma),
    poisson(lam, k),
)
\[\displaystyle e^{- \frac{\left(- \mu + x\right)^{2}}{2 \sigma^{2}}}\]
\[\displaystyle \frac{\lambda^{k} e^{- \lambda}}{k!}\]
x, a, b, c, mu1, mu2, sigma1, sigma2 = sp.symbols("x (a:c) mu_(:2) sigma_(:2)")
expression_1d = (
    a * gaussian(x, mu1, sigma1) + b * gaussian(x, mu2, sigma2) + c * poisson(x, k=2)
)
expression_1d
\[\displaystyle a e^{- \frac{\left(- \mu_{0} + x\right)^{2}}{2 \sigma_{0}^{2}}} + b e^{- \frac{\left(- \mu_{1} + x\right)^{2}}{2 \sigma_{1}^{2}}} + \frac{c x^{2} e^{- x}}{2}\]

The expression above consists of a number of Symbols that we want to identify as parameters (that we want to optimize with regard to a certain data sample) and variables (in which the data sample is expressed). Let’s say \(x\) is the variable and that the rest of the Symbols are the parameters.

Here, we’ll pick some default values for the parameter and use them to plot the model with regard to the variable \(x\) (see subs()). The default values are used later on as well when we generate data.

parameter_defaults = {
    a: 0.15,
    b: 0.05,
    c: 0.3,
    mu1: 1.0,
    sigma1: 0.3,
    mu2: 2.7,
    sigma2: 0.5,
}
x_range = (x, 0, 5)
substituted_expr_1d = expression_1d.subs(parameter_defaults)
p1 = sp.plot(substituted_expr_1d, x_range, show=False, line_color="red")
p2 = sp.plot(*substituted_expr_1d.args, x_range, show=False, line_color="gray")
p2.append(p1[0])
p2.show()
../_images/9323a2767773175589a6382b0c8d77fea259c93761965a30a629b7ee57735314.svg

Convert to backend#

So far, all we did was using sympy to symbolically formulate a mathematical expression. We now need to lambdify() that expression to some computational backend, so that we can efficiently generate data and/or optimize the parameters in the function to ‘fit’ the model to some data sample. TensorWaves can do this with the create_parametrized_function() function:

function_1d = create_parametrized_function(
    expression=expression_1d,
    parameters=parameter_defaults,
    backend="jax",
    use_cse=False,
)

Tip

Here, we used use_cse=False in create_parametrized_function(). Setting this argument to True (the default) causes sympy to search for common sub-expressions, which speeds up lambdification in large expressions and makes the lambdified source code more efficient. See also cse().

The resulting ParametrizedBackendFunction internally carries some source code that numpy understands. With get_source_code(), we can see that it indeed looks similar to the expression that we formulated in Formulate model:

Hide code cell source
from black import FileMode, format_str

from tensorwaves.function import get_source_code

src = get_source_code(function_1d)
src = format_str(src, mode=FileMode())
print(src)
def _lambdifygenerated(a, b, c, mu_0, mu_1, sigma_0, sigma_1, x):
    return (
        a * exp(-1 / 2 * (-mu_0 + x) ** 2 / sigma_0**2)
        + b * exp(-1 / 2 * (-mu_1 + x) ** 2 / sigma_1**2)
        + (1 / 2) * c * x**2 * exp(-x)
    )

The ParametrizedBackendFunction also carries the original default values for the parameters that we defined earlier on.

{'a': 0.15,
 'b': 0.05,
 'c': 0.3,
 'mu_0': 1.0,
 'sigma_0': 0.3,
 'mu_1': 2.7,
 'sigma_1': 0.5}

The ParametrizedFunction.__call__() takes a dict of variable names (here, "x" only) to the value(s) that should be used in their place.

function_1d({"x": 0})
Array(0.00057991, dtype=float64, weak_type=True)

This is where we move to the data generation ― the input values are usually a list of values (expressed in the backend):

rng = np.random.default_rng()
x_values = np.linspace(0, 5, num=20)
y_values = function_1d({"x": x_values})
y_values
Array([0.00057991, 0.01533186, 0.06767626, 0.1597476 , 0.20593748,
       0.15694528, 0.10445861, 0.09506037, 0.10579904, 0.1189151 ,
       0.12428972, 0.11587303, 0.09647026, 0.07504325, 0.05834281,
       0.0473477 , 0.03998121, 0.03433191, 0.02951671, 0.02526857],      dtype=float64)
Hide code cell source
plt.scatter(x_values, y_values)
plt.gca().set_xlabel("$x$")
plt.gca().set_ylabel("$f(x)$");
../_images/74f87aa237c0413275a5e9d3fed8c3bc8e069e05c142bebff895f7649ae4db98.svg

Generate data#

So, we now have a function \(f\) of \(x\) expressed in some computational backend. This function is to describe a distribution over \(x\). In the real world, \(x\) is an observable from a process you measure. But sometimes, it’s useful to generate a ‘toy’ data sample for your model function as well, to try it out.

Hit & miss#

The challenge is to generate values of \(x\) with a density that is proportional to the value of the function evaluated at that point. To do this, we use a hit & miss approach:

  1. Generate a random value for \(x\) within the domain \((x_\mathrm{min}, x_\mathrm{max})\) on which you want to generate the data sample.

  2. Generate a random value \(y\) between \(0\) and the maximum value \(y_\mathrm{max}\) of the function over the domain of \(x\).

  3. Check if \(y\) lies below \(f(x)\) (“hit”) or above (“miss”).

  4. If there is a “hit”, accept this value of \(x\) and add it to the data sample.

We keep performing this until the sample of \(x\) values contains the desired number of events.

Hide code cell source
x_domain = np.linspace(0, 5, num=200)
y_values = function_1d({"x": x_domain})
fig = plt.figure(figsize=(8, 5))
plt.plot(x_domain, y_values)
plt.gca().set_xlabel("$x$")
plt.gca().set_ylabel("$f(x)$")

x_min = x_range[1]
x_max = x_range[2]
y_max = 0.21
x_value = 1.5

line_segment = [[0, 0], [0, y_max]]
plt.plot(*line_segment, color="black")
plt.text(
    -0.22,
    y_max / 2 * 0.5,
    "uniform sample $(0, y_{max})$",
    rotation="vertical",
)
plt.axhline(y=y_max, linestyle="dotted", color="black")
plt.text(
    x_min + 0.1,
    y_max - 0.01,
    "$y_{max}$",
)

line_segment = [[x_min, x_max], [0, 0]]
plt.plot(*line_segment, color="black")
plt.text(
    (x_max - x_min) / 2 - 0.22,
    0.005,
    R"uniform sample $(x_\mathrm{min}, x_\mathrm{max})$",
)
plt.scatter(x_value, function_1d({"x": x_value}))
plt.axvline(x=x_value, linestyle="dotted")


def draw_y_hit(x_random, y_random):
    y_value = function_1d({"x": x_random})
    color = "green" if y_random < y_value else "red"
    text = "hit" if y_random < y_value else "miss"
    plt.scatter(0, y_random, color=color)
    plt.arrow(
        x=0,
        y=y_random,
        dx=x_random,
        dy=0,
        head_length=0.15,
        length_includes_head=True,
        color=color,
        linestyle="dotted",
    )
    plt.text(x_value + 0.05, y_random, text)


draw_y_hit(x_random=x_value, y_random=0.05)
draw_y_hit(x_random=x_value, y_random=0.17)
../_images/0d2e3876a8d171c59c12ee35f02c25df545abb536a12b79b55ccee1d1f51dae9.svg

There is one problem though: how to determine \(y_\mathrm{max}\)? In this example, we can just read off the value of \(y_\mathrm{max}\), or even compute it analytically from the original sympy.Expr. This is not the case generally though, so we need to apply a trick.

Since we are generating uniformly distributed random values values of \(x\) and computing their \(f(x)\) values, we can keep track of which values of \(f(x)\) is the highest. Starting with \(y_\mathrm{max} = 0\) we just set \(y_\mathrm{max} = f(x)\) once \(f(x) > y_\mathrm{max}\) and completely restart the generate loop. Eventually, some value of \(x\) will lie near the absolute maximum of \(f\) and the data generation will happily continue until the requested number of events has been reached.

Warning

There are two caveats:

  1. The domain sample (here: the uniformly distributed values of \(x\)) has to be large in order for the data sample to accurately describe the original function.

  2. The the function displays narrow structures, like some sharp global maximum containing \(y_\mathrm{max}\), changes are smaller that the value of \(x\) will lie within this peak. The domain sample will therefore have to be even larger. It will also take longer before \(y_\mathrm{max}\) is found.

Domain distribution#

First of all, we need to randomly generate values of \(x\). In this simple, 1-dimensional example, we could just use a random generator like numpy.random.Generator feed its output to the ParametrizedFunction.__call__(). Generally, though, we want to cover \(n\)-dimensional cases. The class NumpyDomainGenerator allows us to generate such a uniform distribution for each variable within a certain range. It requires a RealNumberGenerator (here we use NumpyUniformRNG) and it also requires us to define boundaries for each variable in the resulting DataSample.

Tip

Set a seed in the RealNumberGenerator if you want to generate deterministic data sample. If you leave it unspecified, you get an indeterministic data sample.

Tip

You can disable the progress bar through the logging module:

import logging

logging.getLogger("tensorwaves.data").setLevel(logging.ERROR)

Use "tensorwaves" to regulate all tensorwaves logging.

When we feed the sample generated domain sample to the ParametrizedBackendFunction and use it its output values as weights to the histogram of the uniform domain sample, we see that the domain nicely produces a distribution as expected from the model we defined:

Hide code cell source
plt.hist(
    domain["x"],
    bins=200,
    density=True,
    alpha=0.5,
    label="uniform",
)
plt.hist(
    domain["x"],
    weights=np.array(function_1d(domain)),
    bins=200,
    alpha=0.5,
    density=True,
    label="weighted with $f$",
)
plt.legend();
../_images/6b724ae567635dab4f9fa73427630eeca97e20b5c705739ea13472feb47bcb2b.svg

Note

In PWA, the sample on which we perform hit-and-miss is not uniform, because the available space is limited by the masses of the initial and final state (phase space). See TFPhaseSpaceGenerator and Step 2: Generate data.

Intensity distribution#

With a Domain distribution in hand, we can work out an implementation for the Hit & miss approach. The IntensityDistributionGenerator class helps us to do this:

And indeed, it results in the correct distribution!

Hide code cell source
plt.hist(data["x"], bins=200);
../_images/baec0b8d26916adb788210b28ed0fedaa01056276f2b3cd74971989a18b92bad.svg

Optimize the model#

For the rest, the procedure is really just the same as that sketched in Step 3: Perform fit.

We tweak the parameters a bit, then use ParametrizedBackendFunction.update_parameters() to change the function…

initial_parameters = {
    "a": 0.2,
    "b": 0.1,
    "c": 0.2,
    "mu_0": 0.9,
    "sigma_0": 0.4,
    "sigma_1": 0.4,
}
function_1d.update_parameters(initial_parameters)

…compare what this looks like compared to the data…

Hide code cell source
plt.hist(data["x"], bins=200, density=True)
plt.hist(
    domain["x"],
    weights=np.array(function_1d(domain)),
    bins=200,
    histtype="step",
    color="red",
    density=True,
);
../_images/caf5d17f0a765f3afa3dcf5dcd906aab6fb97ae6a8b6199004d532a2ee61d8d5.svg

…define an Estimator and choose jax as backend…

estimator = UnbinnedNLL(function_1d, data, domain, backend="jax")

…optimize with Minuit2 (the callback argument is optional—see Callbacks).

minuit2 = Minuit2(
    callback=CallbackList(
        [
            CSVSummary("traceback-1D.csv"),
            YAMLSummary("fit-result-1D.yaml"),
            TFSummary(),
        ]
    )
)
fit_result = minuit2.optimize(estimator, initial_parameters)
fit_result
FitResult(
 minimum_valid=True,
 execution_time=13.819352626800537,
 function_calls=212,
 estimator_value=-186820.05513799438,
 parameter_values={
  'a': 0.1403761700107824,
  'b': 0.047290394143682746,
  'c': 0.27950294131845443,
  'mu_0': 1.0019881244455366,
  'sigma_0': 0.3011346771376966,
  'sigma_1': 0.4904526974298494,
 },
 parameter_errors={
  'a': 0.01896954672835483,
  'b': 0.006408559981444219,
  'c': 0.03774871918893573,
  'mu_0': 0.0009358262171881016,
  'sigma_0': 0.0007578363938386547,
  'sigma_1': 0.004097733521447734,
 },
)

Tip

For complicated expressions, the fit can be made faster with create_cached_function(). See Constant sub-expressions.

And again, we have a look at the resulting fit, as well as what happened during the optimization.

Hide code cell source
plt.hist(data["x"], bins=200, density=True)
plt.hist(
    domain["x"],
    weights=np.array(function_1d(domain)),
    bins=200,
    histtype="step",
    color="red",
    density=True,
);
../_images/e1ec9b161c3a4e293daf63d71f47c0406908449ee0bab44abf5529bb700de988.svg

Callbacks#

The Minuit2 optimizer above was constructed with callbacks. Callbacks allow us to insert behavior into the fit procedure of the optimizer. In this example, we use CallbackList to stack some Callback classes: CSVSummary, YAMLSummary, and TFSummary.

YAMLSummary writes the latest fit result to disk. It’s a Loadable callable and can be used to pick up a fit later on, for instance if it was aborted.

latest_parameters = YAMLSummary.load_latest_parameters("fit-result-1D.yaml")
latest_parameters
{'a': 0.1403761700107824,
 'b': 0.047290394143682746,
 'c': 0.27950294131845443,
 'mu_0': 1.0019881244455366,
 'sigma_0': 0.3011346771376966,
 'sigma_1': 0.4904526974298494}
FitResult(
 minimum_valid=True,
 execution_time=6.3444013595581055,
 function_calls=99,
 estimator_value=-186820.0551417682,
 parameter_values={
  'a': 0.1403765366355947,
  'b': 0.04729020830379075,
  'c': 0.2795023332563104,
  'mu_0': 1.001988144033014,
  'sigma_0': 0.3011346392962339,
  'sigma_1': 0.4904483693287645,
 },
 parameter_errors={
  'a': 0.018960537981428667,
  'b': 0.006405414972712265,
  'c': 0.03772997895924757,
  'mu_0': 0.0009358257773431991,
  'sigma_0': 0.0007578365284404684,
  'sigma_1': 0.004097692788137594,
 },
)

CSVSummary records the parameter values in each iteration and can be used to analyze the fit process:

Hide code cell source
fit_traceback = pd.read_csv("traceback-1D.csv")
fig, (ax1, ax2) = plt.subplots(
    2, figsize=(7, 9), sharex=True, gridspec_kw={"height_ratios": [1, 2]}
)
fit_traceback.plot("function_call", "estimator_value", ax=ax1)
fit_traceback.plot("function_call", sorted(initial_parameters), ax=ax2)
fig.tight_layout()
ax2.set_xlabel("function call");
../_images/2b49e23e037d86937099a10d31e03d0d7d50d04327832e89730c685d38343387.svg

TFSummary provides a nice, interactive representation of the fit process and can be viewed with TensorBoard as follows:

tensorboard --logdir logs
import tensorboard as tb

tb.notebook.list()  # View open TensorBoard instances
tb.notebook.start(args_string="--logdir logs")

See more info here

%load_ext tensorboard
%tensorboard --logdir logs

See more info here

Example in 2D#

The idea illustrated above works for any number of dimensions. Let’s create multiply the expression we had with some \(\cos\) as a function of \(y\):

y, omega = sp.symbols("y omega")
expression_2d = expression_1d * sp.cos(y * omega) ** 2
expression_2d
\[\displaystyle \left(a e^{- \frac{\left(- \mu_{0} + x\right)^{2}}{2 \sigma_{0}^{2}}} + b e^{- \frac{\left(- \mu_{1} + x\right)^{2}}{2 \sigma_{1}^{2}}} + \frac{c x^{2} e^{- x}}{2}\right) \cos^{2}{\left(\omega y \right)}\]
parameter_defaults[omega] = 0.5
y_range = (y, -sp.pi, +sp.pi)
substituted_expr_2d = expression_2d.subs(parameter_defaults)
plot3d(substituted_expr_2d, x_range, y_range)
../_images/825d6569196686c3ac3d699ca48d09a6aeca00ad77071500ec20bfc28f07cea8.svg
<sympy.plotting.plot.Plot at 0x7f015ea5bca0>
function_2d = create_parametrized_function(
    expression=expression_2d,
    parameters=parameter_defaults,
    backend="jax",
)

Generate 2D data#

Hide code cell source
fig, axes = plt.subplots(1, 2, figsize=(8, 3))
intensities = np.array(function_2d(domain_2d))
kwargs = {
    "weights": intensities,
    "bins": 100,
    "density": True,
}
axes[0].hist(domain_2d["x"], **kwargs)
axes[1].hist(domain_2d["y"], **kwargs)
axes[0].set_xlabel("$x$")
axes[1].set_xlabel("$y$")
axes[0].set_ylabel("$f(x, y)$")
axes[0].set_yticks([])
axes[1].set_yticks([])
fig.tight_layout()
../_images/7eb96255e4dc5ba7ae880e385b528943d12bb0d5139c1f74ec6b90f60ae95645.svg

Perform fit with different optimizers#

initial_parameters = {
    "a": 0.1,
    "b": 0.1,
    "c": 0.2,
    "mu_0": 0.9,
    "omega": 0.35,
    "sigma_0": 0.4,
    "sigma_1": 0.4,
}
function_2d.update_parameters(initial_parameters)
Hide code cell source
fig, axes = plt.subplots(1, 2, figsize=(9, 4), sharey=True, tight_layout=True)
axes[0].hist2d(**data_2d, bins=50)
axes[1].hist2d(**domain_2d, weights=function_2d(domain_2d), bins=50)
axes[0].set_xlabel("$x$")
axes[0].set_ylim([-3, +3])
axes[1].set_xlabel("$x$")
axes[0].set_ylabel("$y$")
axes[0].set_title("Data sample")
axes[1].set_title("Function with optimized parameters");
../_images/78b924baf934df476877c55e6f4d23ec6198a27eb5498d7b150ebf0d95e5ca19.svg

Minuit2#

Commonly, one would construct a Minuit2 instance and call its optimize() method. For more advanced options, one could specify a small minuit_modifier protocol into the Minuit2 constructor. In this example, we set the tol attribute. For other options, see iminuit.Minuit.

def tweak_minuit(minuit) -> None:
    minuit.tol = 0.2

For the rest, the fit procedure goes just as in Optimize parameters:

estimator = UnbinnedNLL(function_2d, data_2d, domain_2d, backend="jax")
minuit2 = Minuit2(
    callback=CSVSummary("traceback.csv"),
    minuit_modifier=tweak_minuit,
)
fit_result = minuit2.optimize(estimator, initial_parameters)
fit_result
FitResult(
 minimum_valid=True,
 execution_time=2.879922389984131,
 function_calls=224,
 estimator_value=-14864.25265603604,
 parameter_values={
  'a': 0.11310361916925236,
  'b': 0.03696163962720964,
  'c': 0.218539613825823,
  'mu_0': 0.9995001064242798,
  'omega': 0.5021536330548182,
  'sigma_0': 0.29908430206770875,
  'sigma_1': 0.5359594533715704,
 },
 parameter_errors={
  'a': 0.07220286406269912,
  'b': 0.02371132957912234,
  'c': 0.13935003558435563,
  'mu_0': 0.005369575868901312,
  'omega': 0.0014545776271939564,
  'sigma_0': 0.004358983875864369,
  'sigma_1': 0.0241204911214227,
 },
)

Note that further information about the internal iminuit.Minuit optimizer is available through FitResult.specifics, e.g. computing the hesse() afterwards:

fit_result.specifics.hesse()
Migrad
FCN = -1.486e+04 Nfcn = 274
EDM = 4.11e-05 (Goal: 0.0002)
Valid Minimum No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok Accurate Pos. def. Not forced
Name Value Hesse Error Minos Error- Minos Error+ Limit- Limit+ Fixed
0 a 0.1 0.4
1 b 0.04 0.12
2 c 0.2 0.7
3 mu_0 1.000 0.005
4 omega 0.5022 0.0014
5 sigma_0 0.299 0.004
6 sigma_1 0.536 0.024
a b c mu_0 omega sigma_0 sigma_1
a 0.131 0.0427 (1.000) 0.253 (1.000) 4.27e-06 (0.002) 2.78e-10 1.32e-06 9.58e-05 (0.011)
b 0.0427 (1.000) 0.014 0.0825 (1.000) 2.95e-06 (0.005) 7.94e-10 3.27e-06 (0.006) 3.18e-05 (0.011)
c 0.253 (1.000) 0.0825 (1.000) 0.488 3.57e-06 -1.33e-08 3.25e-06 (0.001) 7.8e-05 (0.005)
mu_0 4.27e-06 (0.002) 2.95e-06 (0.005) 3.57e-06 2.88e-05 1.47e-08 (0.002) 8.98e-06 (0.384) -7.83e-06 (-0.060)
omega 2.78e-10 7.94e-10 -1.33e-08 1.47e-08 (0.002) 2.12e-06 5.28e-09 -2.2e-10
sigma_0 1.32e-06 3.27e-06 (0.006) 3.25e-06 (0.001) 8.98e-06 (0.384) 5.28e-09 1.9e-05 3.49e-06 (0.033)
sigma_1 9.58e-05 (0.011) 3.18e-05 (0.011) 7.8e-05 (0.005) -7.83e-06 (-0.060) -2.2e-10 3.49e-06 (0.033) 0.000585

Scipy#

scipy = ScipyMinimizer(
    method="Nelder-Mead",
    callback=CSVSummary("traceback-scipy.csv"),
)
fit_result = scipy.optimize(estimator, initial_parameters)
fit_result
FitResult(
 minimum_valid=True,
 execution_time=9.696416139602661,
 function_calls=402,
 estimator_value=-14864.25268072169,
 parameter_values={
  'a': 0.11921369723669326,
  'b': 0.03896060103993386,
  'c': 0.23033372118197631,
  'mu_0': 0.9995186249006902,
  'omega': 0.502145064595906,
  'sigma_0': 0.29909655068790947,
  'sigma_1': 0.5359184972331181,
 },
 iterations=263,
)

Warning

Scipy does not provide error values for the optimized parameters.

Analyze fit process#

If we update the parameters in the ParametrizedFunction with the optimized parameter values found by the Optimizer, we can compare the data distribution with the function.

Hide code cell source
fig, axes = plt.subplots(1, 2, figsize=(9, 4), sharey=True, tight_layout=True)
fig.suptitle("Final fit result")
axes[0].hist2d(**data_2d, bins=50)
axes[1].hist2d(**domain_2d, weights=function_2d(domain_2d), bins=50)
axes[0].set_xlabel("$x$")
axes[0].set_ylim([-3, +3])
axes[1].set_xlabel("$x$")
axes[0].set_ylabel("$y$")
axes[0].set_title("Data sample")
axes[1].set_title("Function with optimized parameters");
../_images/f008f50d9401624af2e2e1ce2c9f2e04d1383a29dac32ac48e06ed4d54eb2a26.svg

In addition, the callbacks allow us to inspect how the parameter values evolved during the fit with the ScipyMinimizer and Minuit2 optimizers:

Hide code cell source
minuit_traceback = pd.read_csv("traceback.csv")
scipy_traceback = pd.read_csv("traceback-scipy.csv")
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(
    ncols=2,
    nrows=2,
    figsize=(10, 9),
    gridspec_kw={"height_ratios": [1, 3]},
)
fig.suptitle("Evolution of the parameter values during the fit")
ax1.set_title("Minuit2")
ax2.set_title("Scipy (Nelder-Mead)")
ax1.get_shared_x_axes().join(ax1, ax3)
ax2.get_shared_x_axes().join(ax2, ax4)
ax1.get_shared_y_axes().join(ax1, ax2)
ax3.get_shared_y_axes().join(ax3, ax4)
ax2.set_ylim(
    1.02 * scipy_traceback["estimator_value"].min(),
    0.98 * scipy_traceback["estimator_value"].max(),
)
minuit_traceback.plot("function_call", "estimator_value", ax=ax1, legend=False)
scipy_traceback.plot("function_call", "estimator_value", ax=ax2)
minuit_traceback.plot("function_call", initial_parameters, ax=ax3, legend=False)
scipy_traceback.plot(
    "function_call", initial_parameters, ax=ax4, legend=True
).legend(loc="upper right")
fig.tight_layout()
ax2.set_xlabel("function call");
../_images/28c8690b3fce41068341b43fa355b6a5f3724c1895ae2bd77e9be5653d6cc727.svg