"""Implementations of `.PhaseSpaceGenerator` and `.UniformRealNumberGenerator`."""
from typing import Optional, Tuple
import ampform as pwa
import numpy as np
import phasespace
import tensorflow as tf
from phasespace.random import get_rng
from tensorwaves.interfaces import (
MomentumSample,
PhaseSpaceGenerator,
UniformRealNumberGenerator,
)
[docs]class TFPhaseSpaceGenerator(PhaseSpaceGenerator):
"""Implements a phase space generator using tensorflow."""
def __init__(self) -> None:
self.__phsp_gen = None
[docs] def setup(self, reaction_info: pwa.kinematics.ReactionInfo) -> None:
initial_states = reaction_info.initial_state.values()
if len(initial_states) != 1:
raise ValueError("Not a 1-to-n body decay")
initial_state = next(iter(initial_states))
self.__phsp_gen = phasespace.nbody_decay(
mass_top=initial_state.mass,
masses=[p.mass for p in reaction_info.final_state.values()],
names=list(map(str, reaction_info.final_state)),
)
[docs] def generate(
self, size: int, rng: UniformRealNumberGenerator
) -> Tuple[MomentumSample, np.ndarray]:
if not isinstance(rng, TFUniformRealNumberGenerator):
raise TypeError(
f"{TFPhaseSpaceGenerator.__name__} requires a "
f"{TFUniformRealNumberGenerator.__name__}, but fed a "
f"{rng.__class__.__name__}"
)
if self.__phsp_gen is None:
raise ValueError("Phase space generator has not been set up")
weights, particles = self.__phsp_gen.generate(
n_events=size, seed=rng.generator
)
momentum_pool = {
int(label): momenta.numpy()[:, [3, 0, 1, 2]]
for label, momenta in particles.items()
}
return momentum_pool, weights.numpy()