Tomography Model#
The TomographyModel provides the basic interface for all specific geometries for tomographic projection
and reconstruction.
Constructor#
- class mbirjax.TomographyModel(sinogram_shape, **kwargs)[source]#
Bases:
ParameterHandlerRepresents a general model for tomographic reconstruction using MBIRJAX. This class encapsulates the parameters and methods for the forward and back projection processes required in tomographic imaging.
Note that this class is a template for specific subclasses. TomographyModel by itself does not implement projectors or recon. Use self.print_params() to print the parameters of the model after initialization.
- Parameters:
sinogram_shape (tuple) – The shape of the sinogram array expected (num_views, num_det_rows, num_det_channels).
recon_shape (tuple) – The shape of the reconstruction array (num_rows, num_cols, num_slices).
**kwargs (dict) – Arbitrary keyword arguments for setting model parameters dynamically. See the full list of parameters and their descriptions at Parameter Documentation.
Sets up the reconstruction size and parameters.
Reconstruction and Projection#
- TomographyModel.recon(sinogram, weights=None, init_recon=None, max_iterations=15, stop_threshold_change_pct=0.2, first_iteration=0, compute_prior_loss=False, logfile_path='./logs/recon.log', print_logs=True)[source]#
Perform MBIR reconstruction using the Multi-Granular Vector Coordinate Descent algorithm. This function takes care of generating its own partitions and partition sequence. TO restart a recon using the same partition sequence, set first_iteration to be the number of iterations completed so far, and set init_recon to be the output of the previous recon. This will continue using the same partition sequence from where the previous recon left off.
- Parameters:
sinogram (ndarray or jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).
weights (ndarray or jax array, optional) – 3D positive weights with same shape as error_sinogram. Defaults to None, in which case the weights are implicitly all 1.
init_recon (jax array or None or 0, optional) – Initial reconstruction to use in reconstruction. If None, then direct_recon is called with default arguments. Defaults to None.
max_iterations (int, optional) – maximum number of iterations of the VCD algorithm to perform.
stop_threshold_change_pct (float, optional) – Stop reconstruction when 100 * ||delta_recon||_1 / ||recon||_1 change from one iteration to the next is below stop_threshold_change_pct. Defaults to 0.2. Set this to 0 to guarantee exactly max_iterations.
first_iteration (int, optional) – Set this to be the number of iterations previously completed when restarting a recon using init_recon. This defines the first index in the partition sequence. Defaults to 0.
compute_prior_loss (bool, optional) – Set true to calculate and return the prior model loss. This will lead to slower reconstructions and is meant only for small recons.
logfile_path (str, optional) – Path to the output log file. Defaults to ‘./logs/recon.log’.
print_logs (bool, optional) – If true then print logs to console. Defaults to True.
- Returns:
(recon, recon_dict) –
- reconstruction array and a dict containing the recon parameters.
recon (jax array): the reconstruction volume
- recon_dict (dict): A dict obtained from
get_recon_dict()with entries ’recon_params’
’notes’
’recon_logs’
’model_params’
- recon_dict (dict): A dict obtained from
- TomographyModel.direct_recon(sinogram, filter_name=None, view_batch_size=100)[source]#
Do a direct (non-iterative) reconstruction, typically using a form of filtered backprojection. The implementation details are geometry specific, and direct_recon may not be available for all geometries.
- Parameters:
sinogram (ndarray or jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).
filter_name (string or None, optional) – The name of the filter to use, defaults to None, in which case the geometry specific method chooses a default, typically ‘ramp’.
view_batch_size (int, optional) – An integer specifying the size of a view batch to limit memory use. Defaults to 100.
- Returns:
recon (jax array) – The reconstructed volume after direct reconstruction.
- TomographyModel.prox_map(prox_input, sinogram, sigma_prox=None, weights=None, init_recon=None, do_initialization=True, stop_threshold_change_pct=0.2, max_iterations=3, first_iteration=0, logfile_path='./logs/prox.log', print_logs=True)[source]#
Proximal Map function for use in Plug-and-Play applications. This function is similar to recon, but it essentially uses a prior with a mean of prox_input and a standard deviation of sigma_prox.
- Parameters:
prox_input (jax array) – proximal map input with same shape as reconstruction.
sinogram (jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).
sigma_prox (None or float, optional) – The standard deviation of the proximal map prior term. If None, then set automatically from the sinogram. Defaults to None.
weights (jax array, optional) – 3D positive weights with same shape as sinogram. Defaults to None, in which case the weights are implicitly all 1s.
init_recon (jax array, optional) – optional reconstruction to be used for initialization. Defaults to None, in which case the initial recon is determined by vcd_recon.
do_initialization (bool, optional) – If True, then initialize parameters and place arrays on appropriate devices. Defaults to True. Set to False if initialization has already been performed on this sinogram, and prox_input and init_recon are on main_device and sinogram and weights are on sinogram_device.
stop_threshold_change_pct (float, optional) – Stop reconstruction when NMAE percent change from one iteration to the next is below stop_threshold_change_pct. Defaults to 0.2.
max_iterations (int, optional) – maximum number of iterations of the VCD algorithm to perform.
first_iteration (int, optional) – Set this to be the number of iterations previously completed when restarting a recon using init_recon. This defines the first index in the partition sequence. Defaults to 0.
logfile_path (str, optional) – Path to the output log file. Defaults to ‘./logs/recon.log’.
print_logs (bool, optional) – If true then print logs to console. Defaults to True.
- Returns:
(recon, recon_dict) –
- reconstruction array and a dict containing the recon parameters.
recon (jax array): the reconstruction volume
- recon_dict (dict): A dict obtained from
get_recon_dict()with entries ’recon_params’
’notes’
’recon_logs’
’model_params’
- recon_dict (dict): A dict obtained from
- TomographyModel.forward_project(recon)[source]#
Perform a full forward projection at all voxels in the field-of-view.
Note
This method should generally not be used directly for iterative reconstruction. For iterative reconstruction, use
recon().- Parameters:
recon (jnp array) – The 3D reconstruction array.
- Returns:
jnp array – The resulting 3D sinogram after projection.
- TomographyModel.back_project(sinogram)[source]#
Perform a full back projection at all voxels in the field-of-view.
Note
This method should generally not be used directly for iterative reconstruction. For iterative reconstruction, use
recon().- Parameters:
sinogram (jnp array) – 3D jax array containing sinogram.
- Returns:
jnp array – The resulting 3D sinogram after projection.
Parameter Handling#
- TomographyModel.set_params(no_warning=False, no_compile=False, **kwargs)[source]#
Update parameters using keyword arguments.
This method updates internal model parameters. If any key geometry-related parameters are modified, it triggers recompilation of the projector system unless suppressed via the no_compile flag.
- Parameters:
no_warning (bool, optional) – If True, disables validity checking and warning messages. Defaults to False.
no_compile (bool, optional) – If True, suppresses projector recompilation after updates. Defaults to False.
**kwargs – Arbitrary keyword arguments specifying parameter names and values to update.
- Returns:
bool – True if projector recompilation is required and not suppressed by no_compile, otherwise False.
Example
>>> import mbirjax as mj >>> ct_model = mj.ParallelBeamModel(sinogram_shape, angles) >>> ct_model.set_params(recon_shape=(128, 128, 128), sharpness=0.7)
- ParameterHandler.get_params(parameter_names: Literal['geometry_type', 'file_format', 'sinogram_shape', 'delta_det_channel', 'delta_det_row', 'det_row_offset', 'det_channel_offset', 'sigma_y', 'alu_unit', 'alu_value', 'recon_shape', 'delta_voxel', 'sigma_x', 'sigma_prox', 'p', 'q', 'T', 'qggmrf_nbr_wts', 'auto_regularize_flag', 'positivity_flag', 'snr_db', 'sharpness', 'granularity', 'partition_sequence', 'verbose', 'use_gpu', 'max_overrelaxation'] | list[Literal['geometry_type', 'file_format', 'sinogram_shape', 'delta_det_channel', 'delta_det_row', 'det_row_offset', 'det_channel_offset', 'sigma_y', 'alu_unit', 'alu_value', 'recon_shape', 'delta_voxel', 'sigma_x', 'sigma_prox', 'p', 'q', 'T', 'qggmrf_nbr_wts', 'auto_regularize_flag', 'positivity_flag', 'snr_db', 'sharpness', 'granularity', 'partition_sequence', 'verbose', 'use_gpu', 'max_overrelaxation']]) Any[source]#
Get the values of the listed parameter names from the internal parameter dictionary.
This method retrieves the current values of one or more parameters managed by the model.
- Parameters:
parameter_names (str or list of str) – Name of a parameter, or a list of parameter names.
- Returns:
Any or list – Single parameter value if a string is passed, or a list of values if a list is passed.
- Raises:
NameError – If any of the provided parameter names are not recognized.
Example
>>> sharpness = model.get_params('sharpness') >>> recon_shape, sharpness = model.get_params(['recon_shape', 'sharpness'])
- ParameterHandler.print_params()[source]#
Print the current parameter values in the model.
This method prints all parameters stored in the model’s internal dictionary. If the model’s verbosity level is less than 3, it prints only the parameter names and values. If verbosity is 3 or higher, it also includes the recompile_flag status for each parameter.
Example
>>> ct_model = mj.ParallelBeamModel(sinogram_shape, angles) >>> ct_model.set_params(sharpness=0.7, recon_shape=(128, 128, 128)) >>> ct_model.print_params()
- TomographyModel.get_recon_dict(recon_params=None, notes=None, save_log=True, save_model=True, str_format=False)[source]#
Encapsulate the recon parameters, logs, notes, and optionally all model parameters to a text-based dict with entries ‘recon_params’, ‘recon_log’, ‘notes’, and optionally ‘model_params’. This dict can be used with
mbirjax.viewer.slice_viewer()andTomographyModel.save_recon_hdf5().This dict from this function is returned by
TomographyModel.recon().- Parameters:
recon_params (dict, optional) – dict of reconstruction parameters. Defaults to None.
notes (str, optional) – User-supplied notes to attach to the dataset. Defaults to None.
save_log (bool, optional) – If True, saves the internal log buffer (if available). Defaults to True.
save_model (bool, optional) – If True, saves the model parameters as a YAML string. Defaults to True.
str_format (bool, optional) – If True, then each top level entry is a string, which is a yaml string when the entries could be saved as a dict.
- Returns:
dict –
- A dict with entries
’recon_params’
’notes’
’recon_log’
’model_params’.
Example
>>> recon, recon_dict = ct_model.recon(sinogram) >>> print(recon_dict['recon_log'])
Recon Shape and Voxel Spacing#
- TomographyModel.auto_set_recon_geometry(no_compile=True, no_warning=False)[source]#
Set the automatic value of the recon shape using the geometry parameters and sinogram shape.
- TomographyModel.scale_recon_shape(row_scale=1.0, col_scale=1.0, slice_scale=1.0)[source]#
Scale the reconstruction shape by the given scale factors.
This can be used before starting a reconstruction to improve results when part of the object projects outside the detector. The method updates the internal recon_shape parameter.
- Parameters:
row_scale (float) – Scale factor for the number of rows in the reconstruction.
col_scale (float) – Scale factor for the number of columns in the reconstruction.
slice_scale (float) – Scale factor for the number of slices in the reconstruction.
- Returns:
tuple[int, int, int] – A 3-tuple representing the number of pixels added to the (rows, columns, slices) dimensions due to scaling.
Example
>>> old_shape = model.get_params('recon_shape') >>> added_padding = model.scale_recon_shape(row_scale=1.2, col_scale=1.1) >>> new_shape = model.get_params('recon_shape') >>> print(f"Shape increased by: {added_padding}")
- TomographyModel.get_magnification()#
Compute the scale factor from a voxel at iso (at the origin on the center of rotation) to its projection on the detector. For parallel beam, this is 1, but it may be parameter-dependent for other geometries.
- Returns:
(float) – magnification
Saving and Loading#
- TomographyModel.save_recon_hdf5(filepath, recon, recon_dict=None)[source]#
Save the reconstruction array and optionally the recon_dict from
recon().This method creates a file that contains a single dataset named ‘recon’, with the entries in recon_dict serialized to strings and saved as hdf5 dataset attributes.
The resulting file can be loaded with
load_recon_hdf5()ormbirjax.viewer.slice_viewer().- Parameters:
filepath (str or Path) – Path to the output HDF5 file. Should typically end with a .h5 extension.
recon (array-like) – The reconstruction volume as a NumPy or JAX array.
recon_dict (dict or None, optional) – The dictionary of recon attributes from
get_recon_dict()
- Raises:
Exception – If saving the file or directory creation fails.
Example
>>> recon, recon_dict = ct_model.recon(sinogram) >>> recon_dict['notes'] += 'Test scan' >>> ct_model.save_recon_hdf5("output/my_recon.h5", recon, recon_dict=recon_dict)
- static TomographyModel.load_recon_hdf5(filepath, recreate_model=False)[source]#
This function loads a numpy array stored in an HDF5 file created by
save_recon_hdf5(). It also loads any associated attribute dict and can use the model parameters in that dict to create a new model.- Parameters:
filepath (str) – Path to the HDF5 file containing the reconstructed volume.
recreate_model (bool, optional) – Deprecated. Will raise a ValueError if set to True.
- Returns:
- (recon, recon_dict)
recon (ndarray): The tensor saved by save_data_hdf5()
recon_dict (dict): A dict with the attributes for the data array as in
get_recon_dict()
- Raises:
FileNotFoundError – If the file does not exist.
ValueError – If more than one dataset is not found in the file or if recreate_model is set to True.
Example
>>> recon, recon_dict = ct_model.load_recon_hdf5("output/recon_volume.h5") >>> recon.shape (64, 256, 256)
Data Generation#
- TomographyModel.gen_modified_3d_sl_phantom()[source]#
DEPRECATED: This method has been deprecated and will be removed in a future release. Instead, use
mbirjax.generate_3d_shepp_logan_low_dynamic_range()Generates a simplified, low-dynamic range version of the 3D Shepp-Logan phantom.
- Returns:
ndarray – A 3D numpy array of shape specified by TomographyModel class parameters.
Parameter Documentation#
See the Primary Parameters page.