Source code for tensorwaves.data.rng
# pylint:disable=import-outside-toplevel
"""Implementations of `.RealNumberGenerator`."""
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
from tensorwaves.function._backend import raise_missing_module_error
from tensorwaves.interface import RealNumberGenerator
if TYPE_CHECKING: # pragma: no cover
import tensorflow as tf
SeedLike = Optional[Union[int, tf.random.Generator]]
[docs]class NumpyUniformRNG(RealNumberGenerator):
"""Implements a uniform real random number generator using `numpy`."""
def __init__(self, seed: float | None = None):
self.seed = seed
def __call__(
self, size: int, min_value: float = 0.0, max_value: float = 1.0
) -> np.ndarray:
return self.generator.uniform(size=size, low=min_value, high=max_value)
@property
def seed(self) -> float | None:
return self.__seed
@seed.setter
def seed(self, value: float | None) -> None:
self.__seed = value
generator_seed: float | int | None = self.seed
if generator_seed is not None:
if not float(generator_seed).is_integer():
raise ValueError("NumPy generator seed has to be integer")
generator_seed = int(generator_seed)
self.generator: np.random.Generator = np.random.default_rng(seed=generator_seed)
[docs]class TFUniformRealNumberGenerator(RealNumberGenerator):
"""Implements a uniform real random number generator using tensorflow."""
def __init__(self, seed: float | None = None):
try:
from tensorflow import float64
except ImportError: # pragma: no cover
raise_missing_module_error("tensorflow", extras_require="tf")
self.seed = seed
self.dtype = float64
def __call__(
self, size: int, min_value: float = 0.0, max_value: float = 1.0
) -> np.ndarray:
return self.generator.uniform(
shape=[size],
minval=min_value,
maxval=max_value,
dtype=self.dtype,
).numpy()
@property
def seed(self) -> float | None:
return self.__seed
@seed.setter
def seed(self, value: float | None) -> None:
self.__seed = value
self.generator = _get_tensorflow_rng(self.seed)
def _get_tensorflow_rng(seed: SeedLike = None) -> tf.random.Generator:
"""Get or create a `tf.random.Generator`.
https://github.com/zfit/phasespace/blob/5998e2b/phasespace/random.py#L15-L41
"""
try:
import tensorflow as tf
except ImportError: # pragma: no cover
raise_missing_module_error("tensorflow", extras_require="tf")
if seed is None:
return tf.random.get_global_generator()
if isinstance(seed, int):
return tf.random.Generator.from_seed(seed=seed)
if isinstance(seed, tf.random.Generator):
return seed
raise TypeError(f"Cannot create a tf.random.Generator from a {type(seed).__name__}")