Source code for tensorwaves.optimizer.minuit

# cspell: ignore nfcn

"""Minuit2 adapter to the `iminuit.Minuit` package."""

import time
from copy import deepcopy
from datetime import datetime
from typing import Dict, Optional

from iminuit import Minuit
from tqdm.auto import tqdm

from tensorwaves.interfaces import Estimator, Optimizer

from .callbacks import Callback, CallbackList


[docs]class Minuit2(Optimizer): """The Minuit2 adapter. Implements the `~.interfaces.Optimizer` interface. """ def __init__(self, callback: Optional[Callback] = None) -> None: self.__callback: Callback = CallbackList([]) if callback is not None: self.__callback = callback
[docs] def optimize( self, estimator: Estimator, initial_parameters: Dict[str, float] ) -> dict: parameters = deepcopy(initial_parameters) progress_bar = tqdm() n_function_calls = 0 def wrapped_function(pars: list) -> float: nonlocal n_function_calls n_function_calls += 1 for i, k in enumerate(parameters.keys()): parameters[k] = pars[i] estimator.update_parameters(parameters) estimator_value = estimator() progress_bar.set_postfix({"estimator": estimator_value}) progress_bar.update() logs = { "time": datetime.now(), "estimator": { "type": self.__class__.__name__, "value": float(estimator_value), }, "parameters": { name: float(value) for name, value in parameters.items() }, } self.__callback.on_iteration_end(n_function_calls, logs) return estimator_value minuit = Minuit( wrapped_function, tuple(parameters.values()), name=tuple(parameters), ) minuit.errors = tuple( 0.1 * x if x != 0.0 else 0.1 for x in parameters.values() ) minuit.errordef = Minuit.LIKELIHOOD start_time = time.time() minuit.migrad() end_time = time.time() self.__callback.on_function_call_end() parameter_values = dict() parameter_errors = dict() for i, name in enumerate(parameters.keys()): par_state = minuit.params[i] parameter_values[name] = par_state.value parameter_errors[name] = par_state.error return { "parameter_values": parameter_values, "parameter_errors": parameter_errors, "log_likelihood": minuit.fmin.fval, "function_calls": minuit.fmin.nfcn, "execution_time": end_time - start_time, }