# pylint: disable=import-outside-toplevel
"""Implementations of `.PhaseSpaceGenerator` and `.UniformRealNumberGenerator`."""
from typing import Mapping, Optional, Tuple
import numpy as np
from tensorwaves.interface import (
DataSample,
PhaseSpaceGenerator,
UniformRealNumberGenerator,
)
[docs]class TFPhaseSpaceGenerator(PhaseSpaceGenerator):
"""Implements a phase space generator using tensorflow."""
def __init__(self) -> None:
self.__phsp_gen = None
[docs] def setup(
self,
initial_state_mass: float,
final_state_masses: Mapping[int, float],
) -> None:
import phasespace
sorted_ids = sorted(final_state_masses)
self.__phsp_gen = phasespace.nbody_decay(
mass_top=initial_state_mass,
masses=[final_state_masses[i] for i in sorted_ids],
names=list(map(str, sorted_ids)),
)
[docs] def generate(
self, size: int, rng: UniformRealNumberGenerator
) -> Tuple[DataSample, np.ndarray]:
if not isinstance(rng, TFUniformRealNumberGenerator):
raise TypeError(
f"{TFPhaseSpaceGenerator.__name__} requires a "
f"{TFUniformRealNumberGenerator.__name__}, but fed a "
f"{rng.__class__.__name__}"
)
if self.__phsp_gen is None:
raise ValueError("Phase space generator has not been set up")
weights, particles = self.__phsp_gen.generate(
n_events=size, seed=rng.generator
)
phsp_sample = {
int(label): momenta.numpy()[:, [3, 0, 1, 2]]
for label, momenta in particles.items()
}
return phsp_sample, weights.numpy()