TomographyModel#

The TomographyModel provides the basic interface for all specific geometries for tomographic projection and reconstruction.

Constructor#

class mbirjax.TomographyModel(sinogram_shape, **kwargs)[source]#

Bases: ParameterHandler

Represents a general model for tomographic reconstruction using MBIRJAX. This class encapsulates the parameters and methods for the forward and back projection processes required in tomographic imaging.

Note that this class is a template for specific subclasses. TomographyModel by itself does not implement projectors or recon. Use self.print_params() to print the parameters of the model after initialization.

Parameters:
  • sinogram_shape (tuple) – The shape of the sinogram array expected (num_views, num_det_rows, num_det_channels).

  • recon_shape (tuple) – The shape of the reconstruction array (num_rows, num_cols, num_slices).

  • **kwargs (dict) – Arbitrary keyword arguments for setting model parameters dynamically. See the full list of parameters and their descriptions at Parameter Documentation.

Sets up the reconstruction size and parameters.

Reconstruction and Projection#

TomographyModel.recon(sinogram, weights=None, init_recon=None, max_iterations=15, stop_threshold_change_pct=0.2, first_iteration=0, compute_prior_loss=False, num_iterations=None)[source]#

Perform MBIR reconstruction using the Multi-Granular Vector Coordinate Descent algorithm. This function takes care of generating its own partitions and partition sequence. TO restart a recon using the same partition sequence, set first_iteration to be the number of iterations completed so far, and set init_recon to be the output of the previous recon. This will continue using the same partition sequence from where the previous recon left off.

Parameters:
  • sinogram (ndarray or jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • weights (ndarray or jax array, optional) – 3D positive weights with same shape as error_sinogram. Defaults to None, in which case the weights are implicitly all 1.

  • init_recon (jax array or None, optional) – Initial reconstruction to use in reconstruction. If None, then direct_recon is called with default arguments. Defaults to None.

  • max_iterations (int, optional) – maximum number of iterations of the VCD algorithm to perform.

  • stop_threshold_change_pct (float, optional) – Stop reconstruction when 100 * ||delta_recon||_1 / ||recon||_1 change from one iteration to the next is below stop_threshold_change_pct. Defaults to 0.2. Set this to 0 to guarantee exactly max_iterations.

  • first_iteration (int, optional) – Set this to be the number of iterations previously completed when restarting a recon using init_recon. This defines the first index in the partition sequence. Defaults to 0.

  • compute_prior_loss (bool, optional) – Set true to calculate and return the prior model loss. This will lead to slower reconstructions and is meant only for small recons.

  • num_iterations (int, optional) – This option is deprecated and will be used to set max_iterations if this is not None. Defaults to None.

Returns:

[recon, recon_params] – reconstruction and a named tuple containing the recon parameters. recon_params (namedtuple): max_iterations, granularity, partition_sequence, fm_rmse, prior_loss, regularization_params

TomographyModel.direct_recon(sinogram, filter_name=None, view_batch_size=100)[source]#

Do a direct (non-iterative) reconstruction, typically using a form of filtered backprojection. The implementation details are geometry specific, and direct_recon may not be available for all geometries.

Parameters:
  • sinogram (ndarray or jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • filter_name (string or None, optional) – The name of the filter to use, defaults to None, in which case the geometry specific method chooses a default, typically ‘ramp’.

  • view_batch_size (int, optional) – An integer specifying the size of a view batch to limit memory use. Defaults to 100.

Returns:

recon (jax array) – The reconstructed volume after direct reconstruction.

TomographyModel.scale_recon_shape(row_scale=1.0, col_scale=1.0, slice_scale=1.0)[source]#

Scale the recon shape by the given factors. This can be used before starting a reconstruction to improve the reconstruction when part of the object projects outside the detector.

Parameters:
  • row_scale (float) – Scale for the recon rows.

  • col_scale (float) – Scale for the recon columns.

  • slice_scale (float) – Scale for the recon slices.

TomographyModel.prox_map(prox_input, sinogram, weights=None, init_recon=None, stop_threshold_change_pct=0.2, max_iterations=3, first_iteration=0)[source]#

Proximal Map function for use in Plug-and-Play applications. This function is similar to recon, but it essentially uses a prior with a mean of prox_input and a standard deviation of sigma_prox.

Parameters:
  • prox_input (jax array) – proximal map input with same shape as reconstruction.

  • sinogram (jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • weights (jax array, optional) – 3D positive weights with same shape as sinogram. Defaults to None, in which case the weights are implicitly all 1s.

  • init_recon (jax array, optional) – optional reconstruction to be used for initialization. Defaults to None, in which case the initial recon is determined by vcd_recon.

  • stop_threshold_change_pct (float, optional) – Stop reconstruction when NMAE percent change from one iteration to the next is below stop_threshold_change_pct. Defaults to 0.2.

  • max_iterations (int, optional) – maximum number of iterations of the VCD algorithm to perform.

  • first_iteration (int, optional) – Set this to be the number of iterations previously completed when restarting a recon using init_recon. This defines the first index in the partition sequence. Defaults to 0.

Returns:

[recon, fm_rmse] – reconstruction and array of loss for each iteration.

TomographyModel.forward_project(recon)[source]#

Perform a full forward projection at all voxels in the field-of-view.

Note

This method should generally not be used directly for iterative reconstruction. For iterative reconstruction, use recon().

Parameters:

recon (jnp array) – The 3D reconstruction array.

Returns:

jnp array – The resulting 3D sinogram after projection.

TomographyModel.back_project(sinogram)[source]#

Perform a full back projection at all voxels in the field-of-view.

Note

This method should generally not be used directly for iterative reconstruction. For iterative reconstruction, use recon().

Parameters:

sinogram (jnp array) – 3D jax array containing sinogram.

Returns:

jnp array – The resulting 3D sinogram after projection.

TomographyModel.gen_weights(sinogram, weight_type)[source]#

Compute the optional weights used in MBIR reconstruction.

Parameters:
  • sinogram (jax array) – 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels).

  • weight_type (string) – Type of noise model used for data - weight_type = ‘unweighted’ => return numpy.ones(sinogram.shape). - weight_type = ‘transmission’ => return numpy.exp(-sinogram). - weight_type = ‘transmission_root’ => return numpy.exp(-sinogram/2). - weight_type = ‘emission’ => return 1/(numpy.absolute(sinogram) + 0.1).

Returns:

(jax array) – Weights used in mbircone reconstruction, with the same array shape as sinogram.

Raises:

Exception – Raised if weight_type is not one of the above options.

TomographyModel.gen_weights_mar(sinogram, init_recon=None, metal_threshold=None, beta=1.0, gamma=3.0)[source]#

Generates the weights used for reducing metal artifacts in MBIR reconstruction.

This function computes sinogram weights that help to reduce metal artifacts. More specifically, it computes weights with the form:

weights = exp( -(sinogram/beta) * ( 1 + gamma * delta(metal) ) )

delta(metal) denotes a binary mask indicating the sino entries that contain projections of metal. Providing init_recon yields better metal artifact reduction. If not provided, the metal segmentation is generated directly from the sinogram.

Parameters:
  • sinogram (jax array) – 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels).

  • init_recon (jax array, optional) – An initial reconstruction used to identify metal voxels. If not provided, Otsu’s method is used to directly segment sinogram into metal regions.

  • metal_threshold (float, optional) – Values in init_recon above metal_threshold are classified as metal. If not provided, Otsu’s method is used to segment init_recon.

  • beta (float, optional) – Scalar value in range \(>0\). A larger beta improves the noise uniformity, but too large a value may increase the overall noise level.

  • gamma (float, optional) – Scalar value in range \(>=0\). A larger gamma reduces the weight of sinogram entries with metal, but too large a value may reduce image quality inside the metal regions.

Returns:

(jax array) – Weights used in mbircone reconstruction, with the same array shape as sinogram

Saving and Loading#

TomographyModel.to_file(filename)[source]#

Save parameters to yaml file.

Parameters:

filename (str) – Path to file to store the parameter dictionary. Must end in .yml or .yaml

Returns:

Nothing but creates or overwrites the specified file.

classmethod TomographyModel.from_file(filename)[source]#

Construct a TomographyModel (or a subclass) from parameters saved using to_file()

Parameters:

filename (str) – Name of the file containing parameters to load.

Returns:

ConeBeamModel with the specified parameters.

Parameter Handling#

TomographyModel.set_params(no_warning=False, no_compile=False, **kwargs)[source]#

Updates parameters using keyword arguments. After setting parameters, it checks if key geometry-related parameters have changed and, if so, recompiles the projectors.

Parameters:
  • no_warning (bool, optional, default=False) – This is used internally to allow for some initial parameter setting.

  • no_compile (bool, optional, default=False) – Prevent (re)compiling the projectors. Used for initialization.

  • **kwargs – Arbitrary keyword arguments where keys are parameter names and values are the new parameter values.

Raises:

NameError – If any key provided in kwargs is not a recognized parameter.

ParameterHandler.get_params(parameter_names)[source]#

Get the values of the listed parameter names from the internal parameter dict. Raises an exception if a parameter name is not defined in parameters.

Parameters:

parameter_names (str or list of str) – String or list of strings

Returns:

Single value or list of values

ParameterHandler.print_params()[source]#

Prints out the parameters of the model.

Data Generation#

TomographyModel.gen_modified_3d_sl_phantom()[source]#

Generates a simplified, low-dynamic range version of the 3D Shepp-Logan phantom.

Returns:

ndarray – A 3D numpy array of shape specified by TomographyModel class parameters.

Parameter Documentation#

See the Primary Parameters page.