Source code for tensorwaves.data.rng

# pylint:disable=import-outside-toplevel
"""Implementations of `.RealNumberGenerator`."""

from typing import TYPE_CHECKING, Optional, Union

import numpy as np

from tensorwaves.function._backend import raise_missing_module_error
from tensorwaves.interface import RealNumberGenerator

if TYPE_CHECKING:  # pragma: no cover
    import tensorflow as tf

    SeedLike = Optional[Union[int, tf.random.Generator]]


[docs]class NumpyUniformRNG(RealNumberGenerator): """Implements a uniform real random number generator using `numpy`.""" def __init__(self, seed: Optional[float] = None): self.seed = seed
[docs] def __call__( self, size: int, min_value: float = 0.0, max_value: float = 1.0 ) -> np.ndarray: return self.generator.uniform(size=size, low=min_value, high=max_value)
@property def seed(self) -> Optional[float]: return self.__seed @seed.setter def seed(self, value: Optional[float]) -> None: self.__seed = value generator_seed: Optional[Union[float, int]] = self.seed if generator_seed is not None: if not float(generator_seed).is_integer(): raise ValueError("NumPy generator seed has to be integer") generator_seed = int(generator_seed) self.generator: np.random.Generator = np.random.default_rng( seed=generator_seed )
[docs]class TFUniformRealNumberGenerator(RealNumberGenerator): """Implements a uniform real random number generator using tensorflow.""" def __init__(self, seed: Optional[float] = None): try: from tensorflow import float64 except ImportError: # pragma: no cover raise_missing_module_error("tensorflow", extras_require="tf") self.seed = seed self.dtype = float64
[docs] def __call__( self, size: int, min_value: float = 0.0, max_value: float = 1.0 ) -> np.ndarray: return self.generator.uniform( shape=[size], minval=min_value, maxval=max_value, dtype=self.dtype, ).numpy()
@property def seed(self) -> Optional[float]: return self.__seed @seed.setter def seed(self, value: Optional[float]) -> None: self.__seed = value self.generator = _get_tensorflow_rng(self.seed)
def _get_tensorflow_rng(seed: "SeedLike" = None) -> "tf.random.Generator": """Get or create a `tf.random.Generator`. https://github.com/zfit/phasespace/blob/5998e2b/phasespace/random.py#L15-L41 """ try: import tensorflow as tf except ImportError: # pragma: no cover raise_missing_module_error("tensorflow", extras_require="tf") if seed is None: return tf.random.get_global_generator() if isinstance(seed, int): return tf.random.Generator.from_seed(seed=seed) if isinstance(seed, tf.random.Generator): return seed raise TypeError( f"Cannot create a tf.random.Generator from a {type(seed).__name__}" )