# pylint: disable=import-outside-toplevel
"""Implementations of `.DataGenerator` and `.WeightedDataGenerator`."""
import logging
from typing import Mapping, Tuple
import numpy as np
from tqdm.auto import tqdm
from tensorwaves.function._backend import raise_missing_module_error
from tensorwaves.interface import (
DataGenerator,
DataSample,
RealNumberGenerator,
WeightedDataGenerator,
)
from ._data_sample import (
finalize_progress_bar,
get_number_of_events,
merge_events,
select_events,
)
from .rng import TFUniformRealNumberGenerator
[docs]class TFPhaseSpaceGenerator(DataGenerator):
"""Implements a phase space generator using tensorflow.
Args:
initial_state_mass: Mass of the decaying state.
final_state_masses: A mapping of final state IDs to the corresponding
masses.
bunch_size: Size of a bunch that is generated during a hit-and-miss
iteration.
"""
def __init__(
self,
initial_state_mass: float,
final_state_masses: Mapping[int, float],
bunch_size: int = 50_000,
) -> None:
self.__phsp_generator = TFWeightedPhaseSpaceGenerator(
initial_state_mass, final_state_masses
)
self.__bunch_size = bunch_size
# https://github.com/ComPWA/tensorwaves/issues/395
self.show_progress = True
[docs] def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
r"""Generate a `.DataSample` of phase space four-momenta.
Returns:
A `.DataSample` of **four-momenta** arrays of shape
:math:`n \times 4`.
.. seealso:: :ref:`amplitude-analysis:2.1 Generate phase space sample`
"""
progress_bar = tqdm(
total=size,
desc="Generating phase space sample",
disable=not self.show_progress
or logging.getLogger().level > logging.WARNING,
)
momentum_pool: DataSample = {}
while get_number_of_events(momentum_pool) < size:
phsp_momenta, weights = self.__phsp_generator.generate(
self.__bunch_size, rng
)
hit_and_miss_randoms = rng(self.__bunch_size)
bunch = select_events(
phsp_momenta, selector=weights > hit_and_miss_randoms
)
momentum_pool = merge_events(momentum_pool, bunch)
progress_bar.update(n=get_number_of_events(bunch))
finalize_progress_bar(progress_bar)
return select_events(momentum_pool, selector=slice(None, size))
[docs]class TFWeightedPhaseSpaceGenerator(WeightedDataGenerator):
"""Implements a phase space generator **with weights** using tensorflow.
Args:
initial_state_mass: Mass of the decaying state.
final_state_masses: A mapping of final state IDs to the corresponding
masses.
.. seealso:: :ref:`amplitude-analysis:2.2 Generate intensity-based sample`
"""
def __init__(
self,
initial_state_mass: float,
final_state_masses: Mapping[int, float],
) -> None:
try:
import phasespace
except ImportError: # pragma: no cover
raise_missing_module_error("phasespace", extras_require="phsp")
sorted_ids = sorted(final_state_masses)
self.__phsp_gen = phasespace.nbody_decay(
mass_top=initial_state_mass,
masses=[final_state_masses[i] for i in sorted_ids],
names=list(map(str, sorted_ids)),
)
[docs] def generate(
self, size: int, rng: RealNumberGenerator
) -> Tuple[DataSample, np.ndarray]:
r"""Generate a `.DataSample` of phase space four-momenta with weights.
Returns:
A `tuple` of a `.DataSample` (**four-momenta**) with an event-wise
sequence of weights. The four-momenta are arrays of shape
:math:`n \times 4`.
"""
if not isinstance(rng, TFUniformRealNumberGenerator):
raise TypeError(
f"{type(self).__name__} requires a "
f"{TFUniformRealNumberGenerator.__name__}, but got a "
f"{type(rng).__name__}"
)
weights, particles = self.__phsp_gen.generate(
n_events=size, seed=rng.generator
)
phsp_momenta = {
f"p{label}": momenta.numpy()[:, [3, 0, 1, 2]]
for label, momenta in particles.items()
}
return phsp_momenta, weights.numpy()