"""Phase space generation using tensorflow."""
from typing import Optional
import numpy as np
import phasespace
import tensorflow as tf
from phasespace.random import get_rng
from tensorwaves.interfaces import (
PhaseSpaceGenerator,
UniformRealNumberGenerator,
)
from tensorwaves.physics.helicity_formalism.kinematics import (
ParticleReactionKinematicsInfo,
)
[docs]class TFPhaseSpaceGenerator(PhaseSpaceGenerator):
"""Implements a phase space generator using tensorflow."""
def __init__(
self, reaction_kinematics_info: ParticleReactionKinematicsInfo
) -> None:
self.phsp_gen = phasespace.nbody_decay(
reaction_kinematics_info.total_invariant_mass,
reaction_kinematics_info.final_state_masses,
)
[docs] def generate(
self, size: int, rng: UniformRealNumberGenerator
) -> np.ndarray:
if not isinstance(rng, TFUniformRealNumberGenerator):
raise TypeError(
f"{TFPhaseSpaceGenerator.__name__} requires a "
f"{TFUniformRealNumberGenerator.__name__}, but fed a "
f"{rng.__class__.__name__}"
)
weights, particles = self.phsp_gen.generate(
n_events=size, seed=rng.generator
)
particles = np.array(
tuple(particles[x].numpy() for x in particles.keys())
)
return particles, weights.numpy()