Source code for tensorwaves.optimizer.callbacks

# pylint: disable=consider-using-with
"""Collection of loggers that can be inserted into an optimizer as callback."""

import csv
from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from typing import IO, Any, Dict, Iterable, List, Optional, Type, Union

import numpy as np
import yaml

from tensorwaves.interface import Estimator, Optimizer, ParameterValue


[docs]class Loadable(ABC):
[docs] @staticmethod @abstractmethod def load_latest_parameters(filename: Union[Path, str]) -> dict: pass
[docs]class Callback(ABC): """Interface for callbacks such as `.CSVSummary`. .. seealso:: :ref:`usage/step3:Custom callbacks` """
[docs] @abstractmethod def on_optimize_start(self, logs: Optional[Dict[str, Any]] = None) -> None: pass
[docs] @abstractmethod def on_optimize_end(self, logs: Optional[Dict[str, Any]] = None) -> None: pass
[docs] @abstractmethod def on_iteration_end( self, iteration: int, logs: Optional[Dict[str, Any]] = None ) -> None: pass
[docs] @abstractmethod def on_function_call_end( self, function_call: int, logs: Optional[Dict[str, Any]] = None ) -> None: pass
[docs]class CallbackList(Callback): """Class for combining `Callback` s. Combine different `Callback` classes in to a chain as follows: >>> from tensorwaves.optimizer.callbacks import ( ... CallbackList, TFSummary, YAMLSummary ... ) >>> from tensorwaves.optimizer.minuit import Minuit2 >>> optimizer = Minuit2( ... callback=CallbackList([TFSummary(), YAMLSummary("fit_result.yml")]) ... ) """ def __init__(self, callbacks: Iterable[Callback]) -> None: self.__callbacks: List[Callback] = [] for callback in callbacks: self.__callbacks.append(callback) @property def callbacks(self) -> List[Callback]: return list(self.__callbacks)
[docs] def __eq__(self, other: object) -> bool: if isinstance(other, CallbackList): return self.callbacks == other.callbacks return False
[docs] def on_optimize_start(self, logs: Optional[Dict[str, Any]] = None) -> None: for callback in self.__callbacks: callback.on_optimize_start(logs)
[docs] def on_optimize_end(self, logs: Optional[Dict[str, Any]] = None) -> None: for callback in self.__callbacks: callback.on_optimize_end(logs)
[docs] def on_iteration_end( self, iteration: int, logs: Optional[Dict[str, Any]] = None ) -> None: for callback in self.__callbacks: callback.on_iteration_end(iteration, logs)
[docs] def on_function_call_end( self, function_call: int, logs: Optional[Dict[str, Any]] = None ) -> None: for callback in self.__callbacks: callback.on_function_call_end(function_call, logs)
[docs]class CSVSummary(Callback, Loadable): """Log fit parameters and the estimator value to a CSV file.""" def __init__( self, filename: Union[Path, str], function_call_step_size: int = 1, iteration_step_size: Optional[int] = None, ) -> None: if iteration_step_size is None: iteration_step_size = 0 if function_call_step_size <= 0 and iteration_step_size <= 0: raise ValueError( "either function call or interaction step size should > 0." ) self.__function_call_step_size = function_call_step_size self.__iteration_step_size = iteration_step_size self.__latest_function_call: Optional[int] = None self.__latest_iteration: Optional[int] = None self.__writer: Optional[csv.DictWriter] = None self.__filename = filename self.__stream: Optional[IO] = None def __del__(self) -> None: _close_stream(self.__stream)
[docs] def on_optimize_start(self, logs: Optional[Dict[str, Any]] = None) -> None: if logs is None: raise ValueError( f"{self.__class__.__name__} requires logs on optimize start" " to determine header names" ) if self.__function_call_step_size > 0: self.__latest_function_call = 0 if self.__iteration_step_size > 0: self.__latest_iteration = 0 _close_stream(self.__stream) self.__stream = open(self.__filename, "w", newline="") self.__writer = csv.DictWriter( self.__stream, fieldnames=list(self.__log_to_rowdict(logs)), quoting=csv.QUOTE_NONNUMERIC, ) self.__writer.writeheader()
[docs] def on_optimize_end(self, logs: Optional[Dict[str, Any]] = None) -> None: if logs is not None: self.__latest_function_call = None self.__latest_iteration = None self.__write(logs) _close_stream(self.__stream)
[docs] def on_iteration_end( self, iteration: int, logs: Optional[Dict[str, Any]] = None ) -> None: self.__latest_iteration = iteration if logs is None: return if ( self.__iteration_step_size is None or self.__latest_iteration % self.__iteration_step_size != 0 ): return self.__write(logs)
[docs] def on_function_call_end( self, function_call: int, logs: Optional[Dict[str, Any]] = None ) -> None: self.__latest_function_call = function_call if logs is None: return if ( self.__function_call_step_size is None or self.__latest_function_call % self.__function_call_step_size != 0 ): return self.__write(logs)
def __write(self, logs: Dict[str, Any]) -> None: if self.__writer is None: raise ValueError( f"{csv.DictWriter.__name__} has not been initialized" ) row_dict = self.__log_to_rowdict(logs) self.__writer.writerow(row_dict) def __log_to_rowdict(self, logs: Dict[str, Any]) -> Dict[str, Any]: output = { "time": logs["time"], "optimizer": logs["optimizer"], "estimator_type": logs["estimator"]["type"], "estimator_value": logs["estimator"]["value"], **logs["parameters"], } if self.__latest_function_call is not None: output = { "function_call": self.__latest_function_call, **output, } if self.__latest_iteration is not None: output = { "iteration": self.__latest_iteration, **output, } return output
[docs] @staticmethod def load_latest_parameters(filename: Union[Path, str]) -> dict: def cast_non_numeric(value: str) -> Union[complex, float, str]: # https://docs.python.org/3/library/csv.html#csv.QUOTE_NONNUMERIC # does not work well for complex numbers try: return complex(value) except ValueError: try: return float(value) except ValueError: return value with open(filename) as stream: reader = csv.DictReader(stream) last_line = list(reader)[-1] return { name: cast_non_numeric(value) for name, value in last_line.items() }
[docs]class TFSummary(Callback): """Log fit parameters and the estimator value to a `tf.summary`. The logs can be viewed with `TensorBoard <https://www.tensorflow.org/tensorboard>`_ via: .. code-block:: shell tensorboard --logdir logs """ def __init__( self, logdir: str = "logs", step_size: int = 10, subdir: Optional[str] = None, ) -> None: self.__logdir = logdir self.__subdir = subdir self.__step_size = step_size self.__stream: Optional[Any] = None
[docs] def on_optimize_start(self, logs: Optional[Dict[str, Any]] = None) -> None: # pylint: disable=import-outside-toplevel, no-member import tensorflow as tf output_dir = ( self.__logdir + "/" + datetime.now().strftime("%Y%m%d-%H%M%S") ) if self.__subdir is not None: output_dir += "/" + self.__subdir self.__stream = tf.summary.create_file_writer(output_dir) self.__stream.set_as_default() # type: ignore[attr-defined]
[docs] def on_optimize_end(self, logs: Optional[Dict[str, Any]] = None) -> None: if self.__stream: self.__stream.close()
[docs] def on_iteration_end( self, iteration: int, logs: Optional[Dict[str, Any]] = None ) -> None: pass
[docs] def on_function_call_end( self, function_call: int, logs: Optional[Dict[str, Any]] = None ) -> None: # pylint: disable=import-outside-toplevel, no-member import tensorflow as tf if logs is None: return if function_call % self.__step_size != 0: return parameters = logs["parameters"] for par_name, value in parameters.items(): tf.summary.scalar(par_name, value, step=function_call) estimator_value = logs.get("estimator", {}).get("value", None) if estimator_value is not None: tf.summary.scalar("estimator", estimator_value, step=function_call) if self.__stream is not None: self.__stream.flush()
[docs]class YAMLSummary(Callback, Loadable): """Log fit parameters and the estimator value to a `tf.summary`. The logs can be viewed with `TensorBoard <https://www.tensorflow.org/tensorboard>`_ via: .. code-block:: shell tensorboard --logdir logs """ def __init__( self, filename: Union[Path, str], step_size: int = 10 ) -> None: self.__step_size = step_size self.__filename = filename self.__stream: Optional[IO] = None def __del__(self) -> None: _close_stream(self.__stream)
[docs] def on_optimize_start(self, logs: Optional[Dict[str, Any]] = None) -> None: _close_stream(self.__stream) self.__stream = open(self.__filename, "w")
[docs] def on_optimize_end(self, logs: Optional[Dict[str, Any]] = None) -> None: if logs is None: return self.__dump_to_yaml(logs) _close_stream(self.__stream)
[docs] def on_iteration_end( self, iteration: int, logs: Optional[Dict[str, Any]] = None ) -> None: pass
[docs] def on_function_call_end( self, function_call: int, logs: Optional[Dict[str, Any]] = None ) -> None: if logs is None: return if function_call % self.__step_size != 0: return self.__dump_to_yaml(logs)
def __dump_to_yaml(self, logs: Dict[str, Any]) -> None: _empty_file(self.__stream) cast_logs = dict(logs) cast_logs["parameters"] = { p: _cast_value(v) for p, v in logs["parameters"].items() } yaml.dump( cast_logs, self.__stream, sort_keys=False, Dumper=_IncreasedIndent, default_flow_style=False, )
[docs] @staticmethod def load_latest_parameters(filename: Union[Path, str]) -> dict: with open(filename) as stream: fit_stats = yaml.load(stream, Loader=yaml.Loader) return fit_stats["parameters"]
def _cast_value(value: Any) -> ParameterValue: # cspell:ignore iscomplex if np.iscomplex(value) or isinstance(value, complex): return complex(value) return float(value) class _IncreasedIndent(yaml.Dumper): # pylint: disable=too-many-ancestors def increase_indent( self, flow: bool = False, indentless: bool = False ) -> None: return super().increase_indent(flow, False) def _close_stream(stream: Optional[IO]) -> None: if stream is not None: stream.close() def _empty_file(stream: Optional[IO]) -> None: if stream is None: return stream.seek(0) stream.truncate() def _create_log( # pyright: reportUnusedFunction=false optimizer: Type[Optimizer], estimator_value: float, estimator_type: Type[Estimator], parameters: Dict[str, Any], function_call: int, ) -> Dict[str, Any]: return { "time": datetime.now(), "optimizer": optimizer.__name__, "estimator": { "type": estimator_type.__name__, "value": float(estimator_value), }, "function_call": function_call, "parameters": parameters, }