Source code for tensorwaves.model

"""Evaluateable physics models for amplitude analysis.

The `.model` module takes care of lambdifying mathematical expressions to
computational backends.
"""

# pyright: reportUnusedImport=false
from typing import Dict, Mapping, Union

import numpy as np

from tensorwaves.interface import DataSample, Function, Model, ParameterValue

from .sympy import SympyModel  # noqa: F401


[docs]class LambdifiedFunction(Function): """Implements `.Function` based on a `.Model` using {meth}`~.Model.lambdify`.""" def __init__( self, model: Model, backend: Union[str, tuple, dict] = "numpy", ) -> None: self.__lambdified_model = model.lambdify(backend=backend) self.__parameters = model.parameters self.__ordered_args = model.argument_order
[docs] def __call__(self, dataset: DataSample) -> np.ndarray: return self.__lambdified_model( *[ dataset[var_name] if var_name in dataset else self.__parameters[var_name] for var_name in self.__ordered_args ], )
@property def parameters(self) -> Dict[str, ParameterValue]: return self.__parameters
[docs] def update_parameters( self, new_parameters: Mapping[str, ParameterValue] ) -> None: over_defined = set(new_parameters) - set(self.__parameters) if over_defined: sep = "\n " parameter_listing = f"{sep}".join(sorted(self.__parameters)) raise ValueError( f"Parameters {over_defined} do not exist in function" f" arguments. Expecting one of:{sep}{parameter_listing}" ) self.__parameters.update(new_parameters)