Source code for mbirjax.parameter_handler

import warnings
import copy
import logging
import io
from collections.abc import Iterable, Sized
from typing import Literal, Union, Any, TextIO, overload
import os

import jax.numpy as jnp
import numpy as np
from ruamel.yaml import YAML

from mbirjax._utils import Param
import mbirjax as mj

# NOTE:  additions/deletions here should also be made to _utils.py
ParamNames=Literal['geometry_type','file_format','sinogram_shape','delta_det_channel','delta_det_row','det_row_offset','det_channel_offset','sigma_y','alu_unit','alu_value','recon_shape','delta_voxel','sigma_x','sigma_prox','p','q','T','qggmrf_nbr_wts','auto_regularize_flag','positivity_flag','snr_db','sharpness','granularity','partition_sequence','verbose','use_gpu','max_overrelaxation',]


class ParameterHandler:
    array_prefix = ':ARRAY:'

    def __init__(self):

        self.params = mj._utils.get_default_params()
        self.logger = None
        self.log_buffer = None

    def setup_logger(self, *, logfile_path: str = "./logs/recon.log", print_logs: bool = True):
        """
        Initialize self.logger and self.log_buffer.

        Args:
            logfile_path: Path to the log file. If None or empty, file logging is skipped.
            verbosity: 0 -> WARNING, 1 -> INFO, 2+ -> DEBUG
            print_logs: If True, emit logs to console.

        Raises:
            Exception: If logfile_path directory cannot be created.
        """
        # Map verbosity to logging level
        verbose = self.get_params('verbose')
        if verbose < 1:
            level = logging.WARNING
        elif verbose < 2:
            level = logging.INFO
        else:
            level = logging.DEBUG

        # Configure logger
        logger = logging.getLogger(self.__class__.__name__)
        logger.setLevel(level)
        # Close and remove any existing handlers to prevent leaked file descriptors
        for h in list(logger.handlers):
            try:
                h.flush()
            finally:
                h.close()
                logger.removeHandler(h)

        # In-memory buffer handler (always enabled)
        self.log_buffer = io.StringIO()
        buffer_handler = logging.StreamHandler(self.log_buffer)
        buffer_handler.setLevel(level)
        buffer_formatter = logging.Formatter('%(message)s')
        buffer_handler.setFormatter(buffer_formatter)
        logger.addHandler(buffer_handler)

        # Console handler
        if print_logs:
            console_handler = logging.StreamHandler()
            console_handler.setLevel(level)
            console_formatter = logging.Formatter('%(message)s')
            console_handler.setFormatter(console_formatter)
            logger.addHandler(console_handler)

        # File handler (optional)
        if logfile_path:
            mj.makedirs(logfile_path)
            file_handler = logging.FileHandler(logfile_path, mode='w')
            file_handler.setLevel(level)
            file_formatter = logging.Formatter('%(message)s')
            file_handler.setFormatter(file_formatter)
            logger.addHandler(file_handler)

        self.logger = logger

[docs] def print_params(self): """ Print the current parameter values in the model. This method prints all parameters stored in the model's internal dictionary. If the model's verbosity level is less than 3, it prints only the parameter names and values. If verbosity is 3 or higher, it also includes the `recompile_flag` status for each parameter. Example: >>> ct_model = mj.ParallelBeamModel(sinogram_shape, angles) >>> ct_model.set_params(sharpness=0.7, recon_shape=(128, 128, 128)) >>> ct_model.print_params() """ verbose, view_params_name = self.get_params(['verbose', 'view_params_name']) print("----") for key, entry in self.params.items(): if verbose < 3 and key == view_params_name: continue param_val = entry.val if verbose < 3: print("{} = {}".format(key, param_val)) else: recompile_flag = entry.recompile_flag print("{} = {}, recompile_flag = {}".format(key, param_val, recompile_flag)) print("----") self.set_params(use_gpu=self.get_params('use_gpu'))
@staticmethod def convert_arrays_to_strings(cur_params): """ Replaces any jax or numpy arrays in cur_params with a flattened string representation and the array shape. Args: cur_params (dict): Parameter dictionary Returns: dict: The same dictionary with arrays replaced by strings. """ for key, entry in cur_params.items(): param_val = entry.val if isinstance(param_val, (jnp.ndarray, np.ndarray)): # Get the array values, then flatten them and put them in a string. cur_array = np.asarray(param_val) formatted_string = " ".join(f"{x:.7f}" for x in cur_array.flatten()) # Include a prefix for identification upon reading new_val = ParameterHandler.array_prefix + formatted_string cur_params[key].val = new_val cur_params[key].shape = param_val.shape cur_params[key].val = ParameterHandler.normalize_scalar(param_val) return cur_params @staticmethod def convert_strings_to_arrays(cur_params): """ Convert the string representation of an array back to an array. Args: cur_params (dict): Parameter dictionary Returns: dict: The same dictionary with array strings replaced by arrays. """ array_prefix = ParameterHandler.array_prefix for key, entry in cur_params.items(): param_val = entry.val # CHeck for a string with the array marker as prefix. if type(param_val) is str and param_val[0:len(array_prefix)] == array_prefix: # Strip the prefix, then remove the delimiters param_str = param_val[len(array_prefix):] clean_str = param_str.replace('[', '').replace(']', '').strip() # Read to a flat array, then reshape new_val = jnp.array(np.fromstring(clean_str + ' ', sep=' ')) new_shape = cur_params[key].shape # Save the value and remove the 'shape' key, which is needed only for the yaml file. cur_params[key].val = new_val.reshape(new_shape) del cur_params[key].shape return cur_params @staticmethod def normalize_scalar(val): """ Convert numpy/jax scalar types to Python native types. Also recursively normalize lists or tuples of scalars. Leave strings, bools, None, and arrays untouched. Args: val: Any parameter value. Returns: Cleaned version suitable for serialization and comparison. """ if isinstance(val, (np.generic, jnp.generic)): return val.item() elif isinstance(val, (list, tuple)): return type(val)(ParameterHandler.normalize_scalar(v) for v in val) return val @staticmethod def serialize_parameter(param_obj): """ Convert a Param object to a YAML-safe dictionary by serializing arrays and normalizing scalars. Args: param_obj (Param): A parameter object with fields `val` and `recompile_flag`. Returns: dict: A dictionary with serialized and normalized data safe for YAML dumping. """ val = param_obj.val if isinstance(val, (jnp.ndarray, np.ndarray)): cur_array = np.asarray(val) formatted_string = " ".join(f"{x:.7f}" for x in cur_array.flatten()) serialized_val = ParameterHandler.array_prefix + formatted_string return { 'val': serialized_val, 'shape': cur_array.shape, 'recompile_flag': param_obj.recompile_flag } val = ParameterHandler.normalize_scalar(val) # Convert lists to tuples for consistency if isinstance(val, list): val = tuple(val) return { 'val': val, 'recompile_flag': param_obj.recompile_flag } @staticmethod def deserialize_parameter(entry): """ Convert a dictionary loaded from YAML into a Param object. Args: entry (dict): A dictionary with 'val' and 'recompile_flag' (and optionally 'shape'). Returns: Param: A reconstructed Param object with normalized and typed values. """ val = entry['val'] if isinstance(val, str) and val.startswith(ParameterHandler.array_prefix): # Keep val as-is for convert_strings_to_arrays param = Param(val=val, recompile_flag=entry.get('recompile_flag', True)) if 'shape' in entry: param.shape = tuple(entry['shape']) return param else: val = ParameterHandler.normalize_scalar(val) if isinstance(val, list): val = tuple(val) return Param(val=val, recompile_flag=entry.get('recompile_flag', True)) @staticmethod def is_flat_iterable(x): return isinstance(x, Iterable) and isinstance(x, Sized) and not isinstance(x, (str, bytes)) @staticmethod def compare_flat_iterables(v1, v2, atol=1e-6): """ Verify that 2 iterables (tuples or lists, etc) have the same length and entries, up to atol. Args: v1 (Sized Iterable): First iterable v2 (Sized Iterable): Second iterable atol (float, optional): Absolute floating point tolerance for equality. Defaults to 1e-6. Returns: bool: True when both iterables have matching lengths and element values within tolerance. """ if len(v1) != len(v2): return False for a, b in zip(v1, v2): if (isinstance(a, (int, float, np.generic, jnp.generic)) and isinstance(b, (int, float, np.generic, jnp.generic))): if not abs(float(a) - float(b)) <= atol: return False else: if a != b: return False return True @staticmethod def compare_parameter_handlers(ph1, ph2, atol=1e-6, verbose=False): """ Compare the parameters of two ParameterHandler instances for equality. Args: ph1 (ParameterHandler): First instance. ph2 (ParameterHandler): Second instance. atol (float): Absolute tolerance for float/array comparison. verbose (bool): If True, print mismatch details. Returns: bool: True if all parameters match within tolerances, False otherwise. """ keys1 = set(ph1.params.keys()) keys2 = set(ph2.params.keys()) if keys1 != keys2: if verbose: print("Parameter key mismatch:") print("Only in ph1:", keys1 - keys2) print("Only in ph2:", keys2 - keys1) return False for key in keys1: val1 = ph1.params[key].val val2 = ph2.params[key].val if isinstance(val1, (np.ndarray, jnp.ndarray)) and isinstance(val2, (np.ndarray, jnp.ndarray)): equal = np.allclose(np.asarray(val1), np.asarray(val2), atol=atol) elif (isinstance(val1, (int, float, np.generic, jnp.generic)) and isinstance(val2, (int, float, np.generic, jnp.generic))): equal = abs(float(val1) - float(val2)) <= atol elif ParameterHandler.is_flat_iterable(val1) and ParameterHandler.is_flat_iterable(val2): equal = ParameterHandler.compare_flat_iterables(val1, val2, atol) else: equal = val1 == val2 if not equal: if verbose: print(f"Mismatch in key '{key}': {val1} != {val2}") return False return True @staticmethod def save_params(params, filename=None): """ Serialize parameters to YAML. If filename is provided, write to file; otherwise return YAML text. Args: params (dict): The parameters dict from a TomographyModel filename (str or None): Path to save YAML file (must end in .yml/.yaml). If None, return YAML string. Returns: None if filename was provided; otherwise, YAML-formatted string of parameters. Raises: ValueError: If filename is invalid. """ # Prepare parameter dict output_params = copy.deepcopy(params) for key in output_params: output_params[key] = ParameterHandler.serialize_parameter(output_params[key]) yaml_writer = YAML() yaml_writer.default_flow_style = False if filename: if not filename.lower().endswith(('.yml', '.yaml')): raise ValueError(f"Filename must end in .yaml or .yml: {filename}") # Ensure output directory exists mj.makedirs(filename) with open(filename, 'w') as file: yaml_writer.dump(output_params, file) return None else: stream = io.StringIO() yaml_writer.dump(output_params, stream) return stream.getvalue() @staticmethod @overload def load_param_dict(source: str) -> dict: ... @staticmethod @overload def load_param_dict(source: TextIO, required_param_names=None, values_only=True) -> tuple: ... @staticmethod @overload def load_param_dict(source: io.StringIO, required_param_names=None, values_only=True) -> tuple: ... @staticmethod def load_param_dict(source: Union[str, TextIO, io.StringIO], required_param_names=None, values_only=True) -> tuple: """ Load parameters from a YAML file, a YAML string, or a file-like object. Args: source: A filename (str), a YAML string (str), or a file-like object. Filename must end in .yml or .yaml required_param_names (list of strings): List of parameter names that are required for a class. values_only (bool): If True, then extract and return the values of each entry only. Returns: required_params (dict): Dictionary of required parameter entries. params (dict): Dictionary of all other parameters. """ # Determine file type yaml_stream: TextIO if isinstance(source, str): if os.path.exists(source): if not source.lower().endswith(('.yml', '.yaml')): raise ValueError("Filename must end in .yml or .yaml") with open(source, 'r') as f: contents = f.read() yaml_stream = io.StringIO(contents) else: yaml_stream = io.StringIO(source) elif hasattr(source, 'read'): yaml_stream = source else: raise TypeError("Invalid source type for load_param_dict. Must be str or file-like object.") yaml_reader = YAML(typ="safe") param_dict = yaml_reader.load(yaml_stream) param_dict = {key: ParameterHandler.deserialize_parameter(val) for key, val in param_dict.items()} param_dict = ParameterHandler.convert_strings_to_arrays(param_dict) # Convert any lists to tuples for consistency with save for key in param_dict.keys(): if isinstance(param_dict[key].val, list): param_dict[key].val = tuple(param_dict[key].val) for key in param_dict: param_dict[key].val = ParameterHandler.normalize_scalar(param_dict[key].val) return ParameterHandler.get_required_params_from_dict(param_dict, required_param_names=required_param_names, values_only=values_only) @staticmethod def get_required_params_from_dict(param_dict, required_param_names=None, values_only=True): # Separate the required parameters into a new dict and delete those entries from the original required_params = dict() param_dict = param_dict.copy() for name in required_param_names: required_params[name] = param_dict[name] del param_dict[name] if values_only: for key in required_params.keys(): required_params[key] = required_params[key].val for key in param_dict.keys(): param_dict[key] = param_dict[key].val return required_params, param_dict def set_params(self, no_warning=False, no_compile=False, **kwargs): """ Update parameters using keyword arguments. This method updates internal model parameters. If any key geometry-related parameters are modified, it triggers recompilation of the projector system unless suppressed via the `no_compile` flag. Args: no_warning (bool, optional): If True, disables validity checking and warning messages. Defaults to False. no_compile (bool, optional): If True, suppresses projector recompilation after updates. Defaults to False. **kwargs: Arbitrary keyword arguments specifying parameter names and values to update. Returns: bool: True if projector recompilation is required and not suppressed by `no_compile`, otherwise False. Example: >>> import mbirjax as mj >>> ct_model = mj.ParallelBeamModel(sinogram_shape, angles) >>> ct_model.set_params(recon_shape=(128, 128, 128), sharpness=0.7) """ # Get initial geometry parameters recompile = False regularization_parameter_change = False meta_parameter_change = False # Set all the given parameters for key, val in kwargs.items(): # Default to forcing a recompile for new parameters recompile_flag = True if key in self.params.keys(): recompile_flag = self.params[key].recompile_flag elif not no_warning: # Check if this is a valid parameter. This is disabled for initialization. error_message = '{} is not a recognized parameter'.format(key) error_message += '\nValid parameters are: \n' for valid_key in self.params.keys(): error_message += ' {}\n'.format(valid_key) raise ValueError(error_message) clean_val = ParameterHandler.normalize_scalar(val) new_entry = Param(clean_val, recompile_flag) self.params[key] = new_entry # Handle special cases if recompile_flag: recompile = True elif key in ["sigma_y", "sigma_x", "sigma_prox"]: regularization_parameter_change = True elif key in ["sharpness", "snr_db"]: meta_parameter_change = True # Handle case if any regularization parameter changed if regularization_parameter_change: if not no_warning: self.set_params(auto_regularize_flag=False) warnings.warn('You are directly setting regularization parameters, sigma_x, sigma_y or sigma_prox. ' 'This is an advanced feature that will disable auto-regularization.') # Handle case if any meta regularization parameter changed if meta_parameter_change: if self.get_params('auto_regularize_flag') is False: self.set_params(auto_regularize_flag=True) if not no_warning: warnings.warn('You have re-enabled auto-regularization by setting sharpness or snr_db. ' 'It was previously disabled') # Return a flag to signify recompiling recompile_flag = False if recompile and not no_compile: recompile_flag = True return recompile_flag @staticmethod def get_params_from_dict(param_dict, parameter_names: Union[str, list[str]]): """ Get the values of the listed parameter names from the supplied dict. Raises an exception if a parameter name is not defined in parameters. Args: param_dict (dict): The dictionary of parameters parameter_names (str or list of str): String or list of strings Returns: Single value or list of values """ if isinstance(parameter_names, str): if parameter_names in param_dict.keys(): value = param_dict[parameter_names].val else: raise NameError('"{}" is not a recognized argument'.format(parameter_names)) return value values = [] for name in parameter_names: if name in param_dict.keys(): values.append(param_dict[name].val) else: raise NameError('"{}" is not a recognized argument'.format(name)) return values
[docs] def get_params(self, parameter_names: Union[ParamNames, list[ParamNames]]) -> Any: """ Get the values of the listed parameter names from the internal parameter dictionary. This method retrieves the current values of one or more parameters managed by the model. Args: parameter_names (str or list of str): Name of a parameter, or a list of parameter names. Returns: Any or list: Single parameter value if a string is passed, or a list of values if a list is passed. Raises: NameError: If any of the provided parameter names are not recognized. Example: >>> sharpness = model.get_params('sharpness') >>> recon_shape, sharpness = model.get_params(['recon_shape', 'sharpness']) """ param_values = ParameterHandler.get_params_from_dict(self.params, parameter_names) return param_values
def get_magnification(self): """ Compute the scale factor from a voxel at iso (at the origin on the center of rotation) to its projection on the detector. For parallel beam, this is 1, but it may be parameter-dependent for other geometries. Returns: (float): magnification """ raise NotImplementedError('get_magnification is not implemented.') def verify_valid_params(self): """ Verify any conditions that must be satisfied among parameters for correct projections. Subclasses of TomographyModel should call super().verify_valid_params() before checking any subclass-specific conditions. Note: Raises ValueError for invalid parameters. """ sinogram_shape = self.get_params('sinogram_shape') if len(sinogram_shape) != 3: error_message = "sinogram_shape must be (views, rows, channels). \n" error_message += "Got {} for sinogram shape.".format(sinogram_shape) raise ValueError(error_message) geometry_type = self.get_params('geometry_type') if geometry_type != str(type(self)): raise ValueError('Parameters are associated with {}, but the current model is {}'.format(geometry_type, str(type(self))))