TomographyModel Developer Docs#

The TomographyModel provides the basic interface for all specific geometries for tomographic projection and reconstruction.

class mbirjax.TomographyModel(sinogram_shape, **kwargs)[source]#

Bases: object

Represents a general model for tomographic reconstruction using MBIRJAX. This class encapsulates the parameters and methods for the forward and back projection processes required in tomographic imaging.

Note that this class is a template for specific subclasses. TomographyModel by itself does not implement projectors or recon. Use self.print_params() to print the parameters of the model after initialization.

Parameters:
  • sinogram_shape (tuple) – The shape of the sinogram array expected (num_views, num_det_rows, num_det_channels).

  • **kwargs (dict) – Arbitrary keyword arguments for setting model parameters dynamically.

Sets up the reconstruction size and parameters.

__init__(sinogram_shape, **kwargs)[source]#
forward_project(recon)[source]#

Perform a full forward projection at all voxels in the field-of-view.

Parameters:

recon (jnp array) – The 3D reconstruction array.

Returns:

jnp array – The resulting 3D sinogram after projection.

back_project(sinogram)[source]#

Perform a full back projection at all voxels in the field-of-view.

Parameters:

sinogram (jnp array) – 3D jax array containing sinogram.

Returns:

jnp array – The resulting 3D sinogram after projection.

sparse_forward_project(voxel_values, indices)[source]#

Forward project the given voxel values to a sinogram. The indices are into a flattened 2D array of shape (recon_rows, recon_cols), and the projection is done using all voxels with those indices across all the slices.

Parameters:
  • voxel_values (jax.numpy.DeviceArray) – 2D array of voxel values to project, size (len(voxel_indices), num_recon_slices).

  • indices (numpy.ndarray) – Array of indices specifying which voxels to project.

Returns:

jnp array – The resulting 3D sinogram after projection.

sparse_back_project(sinogram, indices)[source]#

Back project the given sinogram to the voxels given by the indices. The indices are into a flattened 2D array of shape (recon_rows, recon_cols), and the projection is done using all voxels with those indices across all the slices.

Parameters:
  • sinogram (jnp array) – 3D jax array containing sinogram.

  • indices (jnp array) – Array of indices specifying which voxels to back project.

Returns:

A jax array of shape (len(indices), num_slices)

compute_hessian_diagonal(weights=None)[source]#

Computes the diagonal elements of the Hessian matrix for given weights.

Parameters:

weights (jnp array) – Sinogram Weights for the Hessian computation.

Returns:

jnp array – Diagonal of the Hessian matrix with same shape as recon.

auto_set_regularization_params(sinogram, weights=1)[source]#

Automatically sets the regularization parameters (self.sigma_y, self.sigma_x, and self.sigma_p) used in MBIR reconstruction based on the provided sinogram and optional weights.

Parameters:
  • sinogram (jnp.array) – 3D jax array containing the sinogram with shape (num_views, num_det_rows, num_det_channels).

  • weights (scalar or jnp.array, optional) – Scalar value or 3D weights array with the same shape as the sinogram. Defaults to 1.

The method adjusts the regularization parameters only if auto_regularize_flag is set to True within the model’s parameters.

auto_set_sigma_y(sinogram, weights=1)[source]#

Sets the value of the parameter sigma_y used for use in MBIR reconstruction.

Parameters:
  • sinogram (jax array) – 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels).

  • weights (scalar or 3D jax array) – scalar value or 3D weights array with the same shape as sinogram.

auto_set_sigma_x(sinogram)[source]#

Compute the automatic value of sigma_x for use in MBIR reconstruction with qGGMRF prior.

Parameters:

sinogram (jax array) – 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels).

auto_set_sigma_p(sinogram)[source]#

Compute the automatic value of sigma_p for use in MBIR reconstruction with proximal map prior.

Parameters:

sinogram (jax array) – 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels).

auto_set_recon_size(sinogram_shape, magnification=1.0, no_compile=True)[source]#

Compute the default recon size using the internal parameters delta_channel and delta_pixel plus the number of channels from the sinogram

print_params()[source]#

Prints out the parameters of the model.

save_params(fname='param_dict.npy', binaries=False)[source]#

Save parameter dict to numpy/pickle file/yaml file

load_params(fname)[source]#

Load parameter dict from numpy/pickle file/yaml file, and merge into instance params

set_params(no_warning=False, no_compile=False, **kwargs)[source]#

Updates parameters using keyword arguments. After setting parameters, it checks if key geometry-related parameters have changed and, if so, recompiles the projectors.

Parameters:
  • no_warning (bool, optional, default=False) – This is used internally to allow for some initial parameter setting.

  • no_compile (bool, optional, default=False) – Prevent (re)compiling the projectors. Used for initialization.

  • **kwargs – Arbitrary keyword arguments where keys are parameter names and values are the new parameter values.

Raises:

NameError – If any key provided in kwargs is not a recognized parameter.

get_params(parameter_names)[source]#

Get the values of the listed parameter names. Raises an exception if a parameter name is not defined in parameters.

Parameters:

parameter_names – String or list of strings

Returns:

Single value or list of values

add_new_params(**kwargs)[source]#

Add parameters using keyword arguments.

Parameters:

**kwargs – Arbitrary keyword arguments where keys are parameter names and values are the new parameter values.

verify_valid_params()[source]#

Verify any conditions that must be satisfied among parameters for correct projections. Subclasses of TomographyModel should call super().verify_valid_params() before checking any subclass-specific conditions.

get_voxels_at_indices(recon, indices)[source]#

Retrieves voxel values from a reconstruction array at specified indices.

Parameters:
  • recon (jnp array) – The 3D reconstruction array.

  • indices (jnp array) – Indices for which voxel values are required.

Returns:

numpy.ndarray or jax.numpy.DeviceArray – Array of voxel values at the specified indices.

get_forward_model_loss(error_sinogram, weights=1.0, normalize=True)[source]#

Calculate the loss function for the forward model from the error_sinogram and weights. The error sinogram should be error_sinogram = measured_sinogram - forward_proj(recon)

Parameters:
  • error_sinogram (jax array) – 3D error sinogram with shape (num_views, num_det_rows, num_det_channels).

  • weights (jax array) – 3D weights array with same shape as sinogram

Returns:

[loss].

recon(sinogram, weights=1.0, num_iterations=13, init_recon=None)[source]#

Perform MBIR reconstruction using the Multi-Granular Vector Coordinate Descent algorithm. This function takes care of generating its own partitions and partition sequence.

Parameters:
  • sinogram (jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • weights (scalar or jax array) – scalar or 3D positive weights with same shape as error_sinogram.

  • num_iterations (int) – number of iterations of the VCD algorithm to perform.

  • init_recon (jax array) – optional reconstruction to be used for initialization.

Returns:

[recon, fm_rmse] – reconstruction and array of loss for each iteration.

prox_map(prox_input, sinogram, weights=1.0, num_iterations=3, init_recon=None)[source]#

Proximal Map function for use in Plug-and-Play applications. This function is similar to recon, but it essentially uses a prior with a mean of prox_input and a standard deviation of sigma_p.

Parameters:
  • prox_input (jax array) – proximal map input with same shape as reconstruction.

  • sinogram (jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • weights (scalar or jax array) – scalar or 3D positive weights with same shape as error_sinogram.

  • num_iterations (int) – number of iterations of the VCD algorithm to perform.

  • init_recon (jax array) – optional reconstruction to be used for initialization.

Returns:

[recon, fm_rmse] – reconstruction and array of loss for each iteration.

vcd_recon(sinogram, partitions, partition_sequence, weights=1.0, init_recon=None, prox_input=None)[source]#

Perform MBIR reconstruction using the Multi-Granular Vector Coordinate Descent algorithm for a given set of partitions and a prescribed partition sequence.

Parameters:
  • sinogram (jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • partitions (tuple) – A collection of K partitions, with each partition being an (N_indices) integer index array of voxels to be updated in a flattened recon.

  • partition_sequence (jax array) – A sequence of integers that specify which partition should be used at each iteration.

  • weights (scalar or jax array) – scalar or 3D positive weights with same shape as error_sinogram.

  • init_recon (jax array) – Initial reconstruction to use in reconstruction.

Returns:

[recon, fm_rmse] – reconstruction and array of loss for each iteration.

vcd_partition_iteration(error_sinogram, recon, partition, fm_hessian, weights=1.0, prox_input=None)[source]#

Calculate an iteration of the VCD algorithm for each subset of the partition Each iteration of the algorithm should return a better reconstructed recon. The error_sinogram should always be: error_sinogram = measured_sinogram - forward_proj(recon) where measured_sinogram is the measured sinogram and recon is the current reconstruction.

Parameters:
  • error_sinogram (jax array) – 3D error sinogram with shape (num_views, num_det_rows, num_det_channels).

  • partition (int array) – (K, N_indices) an integer index arrays that partitions the voxels into K arrays, each of which indexes into a flattened recon.

  • recon (jax array) – 3D array reconstruction with shape (num_recon_rows, num_recon_cols, num_recon_slices).

  • fm_hessian (jax array) – Array with same shape as recon containing diagonal of hessian for forward model loss.

  • weights (scalar or jax array) – scalar or 3D positive weights with same shape as error_sinogram.

Returns:

[error_sinogram, recon] – Both have the same shape as above, but are updated to reduce overall loss function.

vcd_subset_iteration(error_sinogram, recon, indices, fm_hessian, weights=1.0, prox_input=None)[source]#

Calculate an iteration of the VCD algorithm on a single subset of the partition Each iteration of the algorithm should return a better reconstructed recon. The combination of (error_sinogram, recon) form a overcomplete state that make computation efficient. However, it is important that at each application the state should meet the constraint that: error_sinogram = measured_sinogram - forward_proj(recon) where measured_sinogram forward_proj() is whatever forward projection is being used in reconstruction.

Parameters:
  • error_sinogram (jax array) – 3D error sinogram with shape (num_views, num_det_rows, num_det_channels).

  • indices (int array) – (N_indices) integer index array of voxels to be updated in a flattened recon.

  • recon (jax array) – 3D array reconstruction with shape (num_recon_rows, num_recon_cols, num_recon_slices).

  • fm_hessian (jax array) – Array with same shape as recon containing diagonal of hessian for forward model loss.

  • weights (scalar or jax array) – scalar or 3D positive weights with same shape as error_sinogram.

  • prox_input (jax array) – optional input for proximal map with same shape as reconstruction.

Returns:

[error_sinogram, recon] – Both have the same shape as above, but are updated to reduce overall loss function.

static gen_weights(sinogram, weight_type)[source]#

Compute the weights used in MBIR reconstruction.

Parameters:
  • sinogram (jax array) – 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels).

  • weight_type (string) – Type of noise model used for data - weight_type = ‘unweighted’ => return numpy.ones(sinogram.shape). - weight_type = ‘transmission’ => return numpy.exp(-sinogram). - weight_type = ‘transmission_root’ => return numpy.exp(-sinogram/2). - weight_type = ‘emission’ => return 1/(numpy.absolute(sinogram) + 0.1).

Returns:

(jax array) – Weights used in mbircone reconstruction, with the same array shape as sinogram.

Raises:

Exception – Raised if weight_type is not one of the above options.

gen_set_of_voxel_partitions()[source]#

Generates a collection of voxel partitions for an array of specified partition sizes. This function creates a tuple of randomly generated 2D voxel partitions.

Returns:

tuple – A tuple of 2D arrays each representing a partition of voxels into the specified number of subsets.

gen_full_indices()[source]#

Generates a full array of voxels in the region of reconstruction. This is useful for computing forward projections.

gen_partition_sequence(num_iterations)[source]#

Generates a sequence of voxel partitions of the specified length by extending the sequence with the last element if necessary.

gen_3d_sl_phantom()[source]#

Generates a 3D Shepp-Logan phantom.

Returns:

ndarray – A 3D numpy array of shape specified by TomographyModel class parameters.

reshape_recon(recon)[source]#

Reshape recon into its 3D form.

Parameters:

recon (ndarray or jnp.array) – A 3D numpy array of shape specified by (num_recon_rows, num_recon_cols, num_recon_slices)