Source code for tensorwaves.optimizer.minuit

# cspell: ignore nfcn
"""Minuit2 adapter to the `iminuit.Minuit` package."""
from __future__ import annotations

import logging
import time
from typing import Any, Callable, Iterable, Mapping

import iminuit
from tqdm.auto import tqdm

from tensorwaves.interface import Estimator, FitResult, Optimizer, ParameterValue

from ._parameter import ParameterFlattener
from .callbacks import Callback, _create_log  # pyright: ignore[reportPrivateUsage]

_LOGGER = logging.getLogger(__name__)


[docs]class Minuit2(Optimizer): """Adapter to `Minuit2 <https://root.cern/doc/master/md_math_minuit2_doc_Minuit2.html#Minuit2Page>`_. Implements the `~.interface.Optimizer` interface using `iminuit.Minuit`. Args: callback: Optionally insert behavior through :mod:`.callbacks` into the :meth:`optimize` method. use_analytic_gradient: Use the :meth:`.Estimator.gradient` when calling :meth:`optimize`. minuit_modifier: Modify the internal `iminuit.Minuit` optimizer that is constructed during the :meth:`optimize` call. See :ref:`usage/basics:Minuit2` for an example. migrad_args: Keyword arguments given to :meth:`iminuit.Minuit.migrad`. """ def __init__( self, callback: Callback | None = None, use_analytic_gradient: bool = False, minuit_modifier: Callable[[iminuit.Minuit], None] | None = None, migrad_args: dict[str, Any] | None = None, ) -> None: self.__callback = callback self.__use_gradient = use_analytic_gradient if minuit_modifier is not None and not callable(minuit_modifier): raise TypeError( "minuit_modifier has to be a callable that takes a" f" {iminuit.Minuit.__module__}.{iminuit.Minuit.__name__} " "instance. See constructor signature." ) self.__minuit_modifier = minuit_modifier self.__migrad_args = {} if migrad_args is None else migrad_args
[docs] def optimize( # pylint: disable=too-many-locals self, estimator: Estimator, initial_parameters: Mapping[str, ParameterValue], ) -> FitResult: parameter_handler = ParameterFlattener(initial_parameters) flattened_parameters = parameter_handler.flatten(initial_parameters) progress_bar = tqdm(disable=_LOGGER.level > logging.WARNING) n_function_calls = 0 parameters = parameter_handler.unflatten(flattened_parameters) if self.__callback is not None: self.__callback.on_optimize_start( logs=_create_log( optimizer=type(self), estimator_type=type(estimator), estimator_value=estimator(parameters), function_call=n_function_calls, parameters=parameters, ) ) def update_parameters(pars: list) -> None: for i, k in enumerate(flattened_parameters): flattened_parameters[k] = pars[i] def wrapped_function(pars: list) -> float: nonlocal n_function_calls n_function_calls += 1 update_parameters(pars) parameters = parameter_handler.unflatten(flattened_parameters) estimator_value = float(estimator(parameters)) progress_bar.set_postfix({"estimator": estimator_value}) progress_bar.update() if self.__callback is not None: self.__callback.on_function_call_end( n_function_calls, logs=_create_log( optimizer=type(self), estimator_type=type(estimator), estimator_value=estimator_value, function_call=n_function_calls, parameters=parameters, ), ) return estimator_value def wrapped_gradient(pars: list) -> Iterable[float]: update_parameters(pars) parameters = parameter_handler.unflatten(flattened_parameters) grad = estimator.gradient(parameters) return parameter_handler.flatten(grad).values() minuit = iminuit.Minuit( wrapped_function, tuple(flattened_parameters.values()), grad=wrapped_gradient if self.__use_gradient else None, name=tuple(flattened_parameters), ) minuit.errors = tuple( 0.1 * abs(x) if abs(x) != 0.0 else 0.1 for x in flattened_parameters.values() ) minuit.errordef = ( iminuit.Minuit.LIKELIHOOD ) # that error definition should be defined in the estimator if self.__minuit_modifier is not None: self.__minuit_modifier(minuit) start_time = time.time() minuit.migrad(**self.__migrad_args) end_time = time.time() parameter_values = {} parameter_errors = {} for i, name in enumerate(flattened_parameters): par_state = minuit.params[i] parameter_values[name] = par_state.value parameter_errors[name] = par_state.error assert minuit.fmin is not None fit_result = FitResult( minimum_valid=minuit.valid, execution_time=end_time - start_time, function_calls=minuit.fmin.nfcn, estimator_value=minuit.fmin.fval, parameter_values=parameter_handler.unflatten(parameter_values), parameter_errors=parameter_handler.unflatten(parameter_errors), specifics=minuit, ) if self.__callback is not None: self.__callback.on_optimize_end( logs=_create_log( optimizer=type(self), estimator_type=type(estimator), estimator_value=fit_result.estimator_value, function_call=fit_result.function_calls, parameters=fit_result.parameter_values, ) ) return fit_result