"""Defines top-level interfaces of tensorwaves."""
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
Dict,
FrozenSet,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
import attr
import numpy as np
from ampform.kinematics import ReactionInfo
try:
from IPython.lib.pretty import PrettyPrinter # type: ignore[import]
except ImportError:
PrettyPrinter = Any
# Data classes from ampform do not work with jax and jit
# https://github.com/google/jax/issues/3092
# https://github.com/google/jax/issues/4416
FourMomentum = Tuple[float, float, float, float]
MomentumSample = Mapping[int, Sequence[FourMomentum]]
DataSample = Mapping[str, np.ndarray]
"""Input data for a `Function`."""
ParameterValue = Union[complex, float]
[docs]class Function(ABC):
"""Interface of a callable function.
The parameters of the model are separated from the domain variables. This
follows the mathematical definition, in which a function defines its domain
and parameters. However specific points in the domain are not relevant.
Hence while the domain variables are the argument of the evaluation (see
:func:`~Function.__call__`), the parameters are controlled via a getter and
setter (see :func:`~Function.parameters`). The reason for this separation
is to facilitate the events when parameters have changed.
"""
[docs] @abstractmethod
def __call__(self, dataset: DataSample) -> np.ndarray:
"""Evaluate the function.
Args:
dataset: a `dict` with domain variable names as keys.
Return:
Result of the function evaluation. Type depends on the input type.
"""
@property
@abstractmethod
def parameters(self) -> Dict[str, ParameterValue]:
"""Get `dict` of parameters."""
[docs] @abstractmethod
def update_parameters(
self, new_parameters: Mapping[str, ParameterValue]
) -> None:
"""Update the collection of parameters."""
[docs]class Model(ABC):
"""Interface of a model which can be lambdified into a callable."""
[docs] @abstractmethod
def lambdify(self, backend: Union[str, tuple, dict]) -> Callable:
"""Lambdify the model into a Callable.
Args:
backend: Choice of backend for fast evaluations.
The arguments of the Callable are union of the variables and parameters.
The return value of the Callable is Any. In theory the return type
should be a value type depending on the model. Currently, there no
typing support is implemented for this.
"""
@property
@abstractmethod
def parameters(self) -> Dict[str, ParameterValue]:
"""Get mapping of parameters to suggested initial values."""
@property
@abstractmethod
def variables(self) -> FrozenSet[str]:
"""Expected input variable names."""
@property
def argument_order(self) -> Tuple[str, ...]:
"""Order of arguments of lambdified function signature."""
[docs]class Estimator(ABC):
"""Estimator for discrepancy model and data."""
[docs] @abstractmethod
def __call__(self, parameters: Mapping[str, ParameterValue]) -> float:
"""Evaluate discrepancy."""
[docs] @abstractmethod
def gradient(
self, parameters: Mapping[str, ParameterValue]
) -> Dict[str, ParameterValue]:
"""Calculate gradient for given parameter mapping."""
[docs]@attr.s(frozen=True, auto_attribs=True)
class FitResult: # pylint: disable=too-many-instance-attributes
minimum_valid: bool
execution_time: float
function_calls: int
estimator_value: float
parameter_values: Dict[str, ParameterValue]
parameter_errors: Optional[Dict[str, ParameterValue]] = None
iterations: Optional[int] = None
specifics: Optional[Any] = None
"""Any additional info provided by the specific optimizer."""
def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None:
class_name = type(self).__name__
if cycle:
p.text(f"{class_name}(...)")
else:
with p.group(indent=1, open=f"{class_name}("):
for field in attr.fields(type(self)):
if field.name in {"specifics"}:
continue
value = getattr(self, field.name)
if value != field.default:
p.breakable()
p.text(f"{field.name}=")
p.pretty(value)
p.text(",")
p.breakable()
p.text(")")
[docs]class Optimizer(ABC):
"""Optimize a fit model to a data set."""
[docs] @abstractmethod
def optimize(
self,
estimator: Estimator,
initial_parameters: Mapping[str, ParameterValue],
) -> FitResult:
"""Execute optimization."""
[docs]class PhaseSpaceGenerator(ABC):
"""Abstract class for generating phase space samples."""
[docs] @abstractmethod
def setup(self, reaction_info: ReactionInfo) -> None:
"""Hook for initialization of the PhaseSpaceGenerator.
Called before any generate calls.
"""
[docs] @abstractmethod
def generate(
self, size: int, rng: UniformRealNumberGenerator
) -> Tuple[MomentumSample, np.ndarray]:
"""Generate phase space sample.
Returns a `tuple` of a mapping of final state IDs to `numpy.array` s
with four-momentum tuples.
"""