Source code for tensorwaves.data.rng

# pylint:disable=import-outside-toplevel
"""Implementations of `.RealNumberGenerator`."""

from typing import Optional

import numpy as np

from tensorwaves.interface import RealNumberGenerator


[docs]class NumpyUniformRNG(RealNumberGenerator): """Implements a uniform real random number generator using `numpy`.""" def __init__(self, seed: Optional[float] = None): self.seed = seed
[docs] def __call__( self, size: int, min_value: float = 0.0, max_value: float = 1.0 ) -> np.ndarray: return self.generator.uniform(size=size, low=min_value, high=max_value)
@property def seed(self) -> Optional[float]: return self.__seed @seed.setter def seed(self, value: Optional[float]) -> None: self.__seed = value self.generator: np.random.Generator = np.random.default_rng( seed=self.seed )
[docs]class TFUniformRealNumberGenerator(RealNumberGenerator): """Implements a uniform real random number generator using tensorflow.""" def __init__(self, seed: Optional[float] = None): from tensorflow import float64 self.seed = seed self.dtype = float64
[docs] def __call__( self, size: int, min_value: float = 0.0, max_value: float = 1.0 ) -> np.ndarray: return self.generator.uniform( shape=[size], minval=min_value, maxval=max_value, dtype=self.dtype, ).numpy()
@property def seed(self) -> Optional[float]: return self.__seed @seed.setter def seed(self, value: Optional[float]) -> None: from phasespace.random import get_rng self.__seed = value self.generator = get_rng(self.seed)