Source code for tensorwaves.function

"""Express mathematical expressions in terms of computational functions."""
from __future__ import annotations

import inspect
from typing import Callable, Iterable, Mapping

import attrs
import numpy as np
from attrs import field, frozen

from tensorwaves.interface import (

def _all_str(
    _: PositionalArgumentFunction, __: attrs.Attribute, value: Iterable[str]
) -> None:
    if not all(isinstance(s, str) for s in value):
        raise TypeError(f"Not all arguments are of type {str.__name__}")

def _all_unique(
    _: PositionalArgumentFunction, __: attrs.Attribute, value: Iterable[str]
) -> None:
    argument_names = list(value)
    if len(set(argument_names)) != len(argument_names):
        duplicate_arguments = []
        for arg_name in argument_names:
            n_occurrences = argument_names.count(arg_name)
            if n_occurrences > 1:
        raise ValueError(
            f"There are duplicate argument names: {sorted(set(duplicate_arguments))}"

def _validate_arguments(
    instance: PositionalArgumentFunction, _: attrs.Attribute, value: Callable
) -> None:
    if not callable(value):
        raise TypeError("Function is not callable")
    n_args = len(instance.argument_order)
    signature = inspect.signature(value)
    if len(signature.parameters) != n_args:
        if len(signature.parameters) == 1:
            parameter = next(iter(signature.parameters.values()))
            if parameter.kind == parameter.VAR_POSITIONAL:
        raise ValueError(
            f"Lambdified function expects {len(signature.parameters)}"
            f" arguments, but {n_args} sorted arguments were provided."

def _to_tuple(argument_order: Iterable[str]) -> tuple[str, ...]:
    return tuple(argument_order)

[docs]@frozen class PositionalArgumentFunction(Function): """Wrapper around a function with positional arguments. This class provides a :meth:`~.Function.__call__` that can take a `.DataSample` for a function with `positional arguments <>`_. Its :attr:`argument_order` redirect the keys in the `.DataSample` to the argument positions in its underlying :attr:`function`. .. seealso:: :func:`.create_function` """ function: Callable[..., np.ndarray] = field(validator=_validate_arguments) """A function with positional arguments only.""" argument_order: tuple[str, ...] = field( converter=_to_tuple, validator=[_all_str, _all_unique] ) """Ordered labels for each positional argument.""" def __call__(self, data: DataSample) -> np.ndarray: args = [data[var_name] for var_name in self.argument_order] return self.function(*args)
[docs]class ParametrizedBackendFunction(ParametrizedFunction): """Implements `.ParametrizedFunction` for a specific computational back-end. .. seealso:: :func:`.create_parametrized_function` """ def __init__( self, function: Callable[..., np.ndarray], argument_order: Iterable[str], parameters: Mapping[str, ParameterValue], ) -> None: self.__function = PositionalArgumentFunction(function, argument_order) self.__parameters = dict(parameters) def __call__(self, data: DataSample) -> np.ndarray: extended_data = {**data, **self.__parameters} # type: ignore[arg-type] return self.__function(extended_data) @property def function(self) -> Callable[..., np.ndarray]: return self.__function.function @property def argument_order(self) -> tuple[str, ...]: return self.__function.argument_order @property def parameters(self) -> dict[str, ParameterValue]: return dict(self.__parameters)
[docs] def update_parameters(self, new_parameters: Mapping[str, ParameterValue]) -> None: over_defined = set(new_parameters) - set(self.__parameters) if over_defined: sep = "\n " parameter_listing = f"{sep}".join(sorted(self.__parameters)) raise ValueError( f"Parameters {over_defined} do not exist in function" f" arguments. Expecting one of:{sep}{parameter_listing}" ) self.__parameters.update(new_parameters)
[docs]def get_source_code(function: Function) -> str: """Get the backend source code used to compile this function. >>> import sympy as sp >>> from tensorwaves.function.sympy import create_function >>> x, y = sp.symbols("x y") >>> expr = x**2 + y**2 >>> func = create_function(expr, backend="jax", use_cse=False) >>> src = get_source_code(func) >>> print(src) def _lambdifygenerated(x, y): return x**2 + y**2 """ if isinstance(function, (PositionalArgumentFunction, ParametrizedBackendFunction)): return inspect.getsource(function.function) raise NotImplementedError( f"Cannot get source code for {Function.__name__} type {type(function).__name__}" )