# pylint: disable=too-many-arguments
"""The `.data` module takes care of data generation."""
from __future__ import annotations
import logging
import numpy as np
from tqdm.auto import tqdm
from tensorwaves.interface import (
DataGenerator,
DataSample,
DataTransformer,
Function,
RealNumberGenerator,
)
from ._data_sample import (
finalize_progress_bar,
get_number_of_events,
merge_events,
select_events,
)
# pyright: reportUnusedImport=false
from .phasespace import ( # noqa:F401
TFPhaseSpaceGenerator,
TFWeightedPhaseSpaceGenerator,
)
from .rng import NumpyUniformRNG, TFUniformRealNumberGenerator # noqa:F401
from .transform import IdentityTransformer, SympyDataTransformer # noqa:F401
_LOGGER = logging.getLogger(__name__)
[docs]class NumpyDomainGenerator(DataGenerator):
"""Generate a uniform `.DataSample` as a domain for a `.Function`.
Args:
boundaries: A mapping of the keys in the `.DataSample` that is to be
generated. The boundaries have to be a `tuple` of a minimum and a maximum
value that define the range for each key in the `.DataSample`.
"""
def __init__(self, boundaries: dict[str, tuple[float, float]]) -> None:
self.__boundaries = boundaries
[docs] def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
return {
var_name: rng(size, min_value, max_value)
for var_name, (min_value, max_value) in self.__boundaries.items()
}
[docs]class IntensityDistributionGenerator(DataGenerator):
"""Generate an hit-and-miss `.DataSample` distribution for a `.Function`.
Args:
domain_generator: A `.DataGenerator` that can be used to generate a
**domain** `.DataSample` over which to evaluate the :code:`function`.
function: An **intensity** `.Function` with which the output
distribution `.DataSample` is generated using a :ref:`hit-and-miss strategy
<usage/basics:Hit & miss>`.
domain_transformer: Optional `.DataTransformer` that can convert a generated
**domain** `.DataSample` to a `.DataSample` that the :code:`function` can
take as input.
bunch_size: Size of a bunch that is generated during a hit-and-miss
iteration.
"""
def __init__(
self,
domain_generator: DataGenerator,
function: Function,
domain_transformer: DataTransformer | None = None,
bunch_size: int = 50_000,
) -> None:
self.__domain_generator = domain_generator
if domain_transformer is not None:
self.__domain_transformer = domain_transformer
else:
self.__domain_transformer = IdentityTransformer()
self.__function = function
self.__bunch_size = bunch_size
[docs] def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
progress_bar = tqdm(
total=size,
desc="Generating intensity-based sample",
disable=_LOGGER.level > logging.WARNING,
)
returned_data: DataSample = {}
current_max_intensity = 0.0
while get_number_of_events(returned_data) < size:
data_bunch, bunch_max = self._generate_bunch(rng)
if bunch_max > current_max_intensity:
current_max_intensity = 1.05 * bunch_max
if get_number_of_events(returned_data) > 0:
_LOGGER.info(
f"Processed bunch maximum of {bunch_max} is over"
f" current maximum {current_max_intensity}. Restarting"
" generation!"
)
returned_data = {}
# reset progress bar
progress_bar.update(n=-progress_bar.n)
continue
if len(returned_data):
returned_data = merge_events(returned_data, data_bunch)
else:
returned_data = data_bunch
progress_bar.update(n=get_number_of_events(returned_data) - progress_bar.n)
finalize_progress_bar(progress_bar)
return select_events(returned_data, selector=slice(None, size))
def _generate_bunch(self, rng: RealNumberGenerator) -> tuple[DataSample, float]:
domain = _generate_without_progress_bar(
self.__domain_generator, self.__bunch_size, rng
)
transformed_domain = self.__domain_transformer(domain)
computed_intensities = self.__function(transformed_domain)
max_intensity: float = np.max(computed_intensities)
random_intensities = rng(size=self.__bunch_size, max_value=max_intensity)
weights = domain.get("weights", 1)
hit_and_miss_sample = select_events(
domain,
selector=weights * computed_intensities > random_intensities,
)
return hit_and_miss_sample, max_intensity
def _generate_without_progress_bar(
domain_generator: DataGenerator, bunch_size: int, rng: RealNumberGenerator
) -> DataSample:
# https://github.com/ComPWA/tensorwaves/issues/395
show_progress = getattr(domain_generator, "show_progress", None)
if show_progress is not None:
domain_generator.show_progress = False # type: ignore[attr-defined]
domain = domain_generator.generate(bunch_size, rng)
if show_progress is not None:
domain_generator.show_progress = show_progress # type: ignore[attr-defined]
return domain