# cspell: ignore nfcn
"""Minuit2 adapter to the `iminuit.Minuit` package."""
import logging
import time
from datetime import datetime
from typing import Any, Dict, Iterable, Mapping, Optional, Union
from iminuit import Minuit
from tqdm.auto import tqdm
from tensorwaves.interfaces import Estimator, FitResult, Optimizer
from ._parameter import ParameterFlattener
from .callbacks import Callback, CallbackList
[docs]class Minuit2(Optimizer):
"""The Minuit2 adapter.
Implements the `~.interfaces.Optimizer` interface.
"""
def __init__(
self,
callback: Optional[Callback] = None,
use_analytic_gradient: bool = False,
) -> None:
if callback is not None:
self.__callback = callback
else:
self.__callback = CallbackList([])
self.__use_gradient = use_analytic_gradient
[docs] def optimize( # pylint: disable=too-many-locals
self,
estimator: Estimator,
initial_parameters: Mapping[str, Union[complex, float]],
) -> FitResult:
parameter_handler = ParameterFlattener(initial_parameters)
flattened_parameters = parameter_handler.flatten(initial_parameters)
progress_bar = tqdm(
disable=logging.getLogger().level > logging.WARNING
)
n_function_calls = 0
def create_log(
estimator_value: float, parameters: Dict[str, Any]
) -> Dict[str, Any]:
return {
"time": datetime.now(),
"estimator": {
"type": self.__class__.__name__,
"value": float(estimator_value),
},
"parameters": parameters,
}
parameters = parameter_handler.unflatten(flattened_parameters)
self.__callback.on_optimize_start(
logs=create_log(float(estimator(parameters)), parameters)
)
def update_parameters(pars: list) -> None:
for i, k in enumerate(flattened_parameters):
flattened_parameters[k] = pars[i]
def wrapped_function(pars: list) -> float:
nonlocal n_function_calls
n_function_calls += 1
update_parameters(pars)
parameters = parameter_handler.unflatten(flattened_parameters)
estimator_value = estimator(parameters)
progress_bar.set_postfix({"estimator": float(estimator_value)})
progress_bar.update()
logs = create_log(estimator_value, parameters)
self.__callback.on_function_call_end(n_function_calls, logs)
return estimator_value
def wrapped_gradient(pars: list) -> Iterable[float]:
update_parameters(pars)
parameters = parameter_handler.unflatten(flattened_parameters)
grad = estimator.gradient(parameters)
return parameter_handler.flatten(grad).values()
minuit = Minuit(
wrapped_function,
tuple(flattened_parameters.values()),
grad=wrapped_gradient if self.__use_gradient else None,
name=tuple(flattened_parameters),
)
minuit.errors = tuple(
0.1 * x if x != 0.0 else 0.1 for x in flattened_parameters.values()
)
minuit.errordef = (
Minuit.LIKELIHOOD
) # that error definition should be defined in the estimator
start_time = time.time()
minuit.migrad()
end_time = time.time()
parameter_values = dict()
parameter_errors = dict()
for i, name in enumerate(flattened_parameters):
par_state = minuit.params[i]
parameter_values[name] = par_state.value
parameter_errors[name] = par_state.error
parameter_values = parameter_handler.unflatten(parameter_values)
parameter_errors = parameter_handler.unflatten(parameter_errors)
self.__callback.on_optimize_end(
logs=create_log(
estimator_value=float(estimator(parameters)),
parameters=parameter_values,
)
)
return FitResult(
minimum_valid=minuit.valid,
execution_time=end_time - start_time,
function_calls=minuit.fmin.nfcn,
estimator_value=minuit.fmin.fval,
parameter_values=parameter_values,
parameter_errors=parameter_errors,
specifics=minuit,
)