TomographyModel#
The TomographyModel
provides the basic interface for all specific geometries for tomographic projection
and reconstruction.
Constructor#
- class mbirjax.TomographyModel(sinogram_shape, **kwargs)[source]#
Bases:
ParameterHandler
Represents a general model for tomographic reconstruction using MBIRJAX. This class encapsulates the parameters and methods for the forward and back projection processes required in tomographic imaging.
Note that this class is a template for specific subclasses. TomographyModel by itself does not implement projectors or recon. Use self.print_params() to print the parameters of the model after initialization.
- Parameters:
sinogram_shape (tuple) – The shape of the sinogram array expected (num_views, num_det_rows, num_det_channels).
recon_shape (tuple) – The shape of the reconstruction array (num_rows, num_cols, num_slices).
**kwargs (dict) – Arbitrary keyword arguments for setting model parameters dynamically. See the full list of parameters and their descriptions at Parameter Documentation.
Sets up the reconstruction size and parameters.
Reconstruction and Projection#
- TomographyModel.recon(sinogram, weights=None, init_recon=None, max_iterations=15, stop_threshold_change_pct=0.2, first_iteration=0, compute_prior_loss=False, num_iterations=None)[source]#
Perform MBIR reconstruction using the Multi-Granular Vector Coordinate Descent algorithm. This function takes care of generating its own partitions and partition sequence. TO restart a recon using the same partition sequence, set first_iteration to be the number of iterations completed so far, and set init_recon to be the output of the previous recon. This will continue using the same partition sequence from where the previous recon left off.
- Parameters:
sinogram (ndarray or jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).
weights (ndarray or jax array, optional) – 3D positive weights with same shape as error_sinogram. Defaults to None, in which case the weights are implicitly all 1.
init_recon (jax array or None, optional) – Initial reconstruction to use in reconstruction. If None, then direct_recon is called with default arguments. Defaults to None.
max_iterations (int, optional) – maximum number of iterations of the VCD algorithm to perform.
stop_threshold_change_pct (float, optional) – Stop reconstruction when 100 * ||delta_recon||_1 / ||recon||_1 change from one iteration to the next is below stop_threshold_change_pct. Defaults to 0.2. Set this to 0 to guarantee exactly max_iterations.
first_iteration (int, optional) – Set this to be the number of iterations previously completed when restarting a recon using init_recon. This defines the first index in the partition sequence. Defaults to 0.
compute_prior_loss (bool, optional) – Set true to calculate and return the prior model loss. This will lead to slower reconstructions and is meant only for small recons.
num_iterations (int, optional) – This option is deprecated and will be used to set max_iterations if this is not None. Defaults to None.
- Returns:
[recon, recon_params] – reconstruction and a named tuple containing the recon parameters. recon_params (namedtuple): max_iterations, granularity, partition_sequence, fm_rmse, prior_loss, regularization_params
- TomographyModel.direct_recon(sinogram, filter_name=None, view_batch_size=100)[source]#
Do a direct (non-iterative) reconstruction, typically using a form of filtered backprojection. The implementation details are geometry specific, and direct_recon may not be available for all geometries.
- Parameters:
sinogram (ndarray or jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).
filter_name (string or None, optional) – The name of the filter to use, defaults to None, in which case the geometry specific method chooses a default, typically ‘ramp’.
view_batch_size (int, optional) – An integer specifying the size of a view batch to limit memory use. Defaults to 100.
- Returns:
recon (jax array) – The reconstructed volume after direct reconstruction.
- TomographyModel.scale_recon_shape(row_scale=1.0, col_scale=1.0, slice_scale=1.0)[source]#
Scale the recon shape by the given factors. This can be used before starting a reconstruction to improve the reconstruction when part of the object projects outside the detector.
- Parameters:
row_scale (float) – Scale for the recon rows.
col_scale (float) – Scale for the recon columns.
slice_scale (float) – Scale for the recon slices.
- TomographyModel.prox_map(prox_input, sinogram, weights=None, init_recon=None, stop_threshold_change_pct=0.2, max_iterations=3, first_iteration=0)[source]#
Proximal Map function for use in Plug-and-Play applications. This function is similar to recon, but it essentially uses a prior with a mean of prox_input and a standard deviation of sigma_prox.
- Parameters:
prox_input (jax array) – proximal map input with same shape as reconstruction.
sinogram (jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).
weights (jax array, optional) – 3D positive weights with same shape as sinogram. Defaults to None, in which case the weights are implicitly all 1s.
init_recon (jax array, optional) – optional reconstruction to be used for initialization. Defaults to None, in which case the initial recon is determined by vcd_recon.
stop_threshold_change_pct (float, optional) – Stop reconstruction when NMAE percent change from one iteration to the next is below stop_threshold_change_pct. Defaults to 0.2.
max_iterations (int, optional) – maximum number of iterations of the VCD algorithm to perform.
first_iteration (int, optional) – Set this to be the number of iterations previously completed when restarting a recon using init_recon. This defines the first index in the partition sequence. Defaults to 0.
- Returns:
[recon, fm_rmse] – reconstruction and array of loss for each iteration.
- TomographyModel.forward_project(recon)[source]#
Perform a full forward projection at all voxels in the field-of-view.
Note
This method should generally not be used directly for iterative reconstruction. For iterative reconstruction, use
recon()
.- Parameters:
recon (jnp array) – The 3D reconstruction array.
- Returns:
jnp array – The resulting 3D sinogram after projection.
- TomographyModel.back_project(sinogram)[source]#
Perform a full back projection at all voxels in the field-of-view.
Note
This method should generally not be used directly for iterative reconstruction. For iterative reconstruction, use
recon()
.- Parameters:
sinogram (jnp array) – 3D jax array containing sinogram.
- Returns:
jnp array – The resulting 3D sinogram after projection.
- TomographyModel.gen_weights(sinogram, weight_type)[source]#
Compute the optional weights used in MBIR reconstruction.
- Parameters:
sinogram (jax array) – 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels).
weight_type (string) – Type of noise model used for data - weight_type = ‘unweighted’ => return numpy.ones(sinogram.shape). - weight_type = ‘transmission’ => return numpy.exp(-sinogram). - weight_type = ‘transmission_root’ => return numpy.exp(-sinogram/2). - weight_type = ‘emission’ => return 1/(numpy.absolute(sinogram) + 0.1).
- Returns:
(jax array) – Weights used in mbircone reconstruction, with the same array shape as
sinogram
.- Raises:
Exception – Raised if
weight_type
is not one of the above options.
- TomographyModel.gen_weights_mar(sinogram, init_recon=None, metal_threshold=None, beta=1.0, gamma=3.0)[source]#
Generates the weights used for reducing metal artifacts in MBIR reconstruction.
This function computes sinogram weights that help to reduce metal artifacts. More specifically, it computes weights with the form:
weights = exp( -(sinogram/beta) * ( 1 + gamma * delta(metal) ) )
delta(metal) denotes a binary mask indicating the sino entries that contain projections of metal. Providing
init_recon
yields better metal artifact reduction. If not provided, the metal segmentation is generated directly from the sinogram.- Parameters:
sinogram (jax array) – 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels).
init_recon (jax array, optional) – An initial reconstruction used to identify metal voxels. If not provided, Otsu’s method is used to directly segment sinogram into metal regions.
metal_threshold (float, optional) – Values in
init_recon
abovemetal_threshold
are classified as metal. If not provided, Otsu’s method is used to segmentinit_recon
.beta (float, optional) – Scalar value in range \(>0\). A larger
beta
improves the noise uniformity, but too large a value may increase the overall noise level.gamma (float, optional) – Scalar value in range \(>=0\). A larger
gamma
reduces the weight of sinogram entries with metal, but too large a value may reduce image quality inside the metal regions.
- Returns:
(jax array) – Weights used in mbircone reconstruction, with the same array shape as
sinogram
Saving and Loading#
Parameter Handling#
- TomographyModel.set_params(no_warning=False, no_compile=False, **kwargs)[source]#
Updates parameters using keyword arguments. After setting parameters, it checks if key geometry-related parameters have changed and, if so, recompiles the projectors.
- Parameters:
no_warning (bool, optional, default=False) – This is used internally to allow for some initial parameter setting.
no_compile (bool, optional, default=False) – Prevent (re)compiling the projectors. Used for initialization.
**kwargs – Arbitrary keyword arguments where keys are parameter names and values are the new parameter values.
- Raises:
NameError – If any key provided in kwargs is not a recognized parameter.
- ParameterHandler.get_params(parameter_names)[source]#
Get the values of the listed parameter names from the internal parameter dict. Raises an exception if a parameter name is not defined in parameters.
- Parameters:
parameter_names (str or list of str) – String or list of strings
- Returns:
Single value or list of values
Data Generation#
Parameter Documentation#
See the Primary Parameters page.