Preprocessing#

The preprocess module provides scanner-specific preprocessing and more general preprocessing to compute and correct the sinogram data. See demo_nsi.py in the mbirjax_applications repo for example uses.

NorthStar Instrument (NSI) reader#

mbirjax.preprocess.nsi.compute_sino_and_params(dataset_dir, downsample_factor=(1, 1), subsample_view_factor=1, crop_pixels_sides=None, crop_pixels_top=None, crop_pixels_bottom=None, verbose=1, offset_correction=True)[source]#

Load NSI sinogram data and prepare arrays and parameters for ConeBeamModel reconstruction.

This function computes the sinogram and geometry parameters from an NSI scan directory. It performs the following:

  1. Loads object, blank, and dark scans, and geometry parameters from the dataset.

  2. Computes the sinogram from the scan images.

  3. Replaces defective pixels with interpolated values.

  4. Corrects for detector rotation.

  5. Applies background offset correction.

Parameters:
  • dataset_dir (str) – Path to the NSI scan directory. Expected structure: - *.nsipro (NSI config file) - Geometry*.rtf (geometry report) - Radiographs*/ (radiograph images) - **/gain0.tif (blank scan) - **/offset.tif (dark scan) - **/*.defect (defective pixel info)

  • downsample_factor (Tuple[int, int], optional) – Downsample factors for detector rows and channels. Defaults to (1, 1).

  • subsample_view_factor (int, optional) – Factor by which to subsample views. Defaults to 1.

  • crop_pixels_sides (int, optional) – Pixels to crop from each side of the sinogram. If None, uses NSI config file.

  • crop_pixels_top (int, optional) – Pixels to crop from the top. If None, uses NSI config file.

  • crop_pixels_bottom (int, optional) – Pixels to crop from the bottom. If None, uses NSI config file.

  • verbose (int, optional) – Verbosity level. Defaults to 1.

  • offset_correction (bool) – Whether to apply detector offset correction using values from the Geometry Report. Defaults to True.

Returns:

tuple

(sino, cone_beam_params, optional_params)
  • sino (jax.numpy.ndarray): Sinogram of shape (num_views, num_det_rows, num_det_channels).

  • cone_beam_params (dict): Parameters for initializing ConeBeamModel.

  • optional_params (dict): Parameters to be passed via set_params().

Example

# Get data and reconstruction parameters
sino, cone_beam_params, optional_params = mbirjax.preprocess.NSI.compute_sino_and_params(
    dataset_dir, downsample_factor=downsample_factor, subsample_view_factor=subsample_view_factor)

# Create the model and set parameters
ct_model = mbirjax.ConeBeamModel(**cone_beam_params)
ct_model.set_params(**optional_params)
ct_model.set_params(sharpness=sharpness, verbose=1)

# Generate weights and run reconstruction
weights = mj.gen_weights(sino, weight_type='transmission_root')
recon, recon_dict = ct_model.recon(sino, weights=weights)
mbirjax.preprocess.nsi.load_scans_and_params(dataset_dir, view_id_start=0, view_id_end=None, subsample_view_factor=1, verbose=1, offset_correction=True)[source]#

Load the object scan, blank scan, dark scan, view angles, defective pixel information, and geometry parameters from an NSI scan directory.

Parameters:
  • dataset_dir (string) – Path to an NSI scan directory. The directory is assumed to have the following structure:

    • *.nsipro (NSI config file)

    • Geometry*.rtf (geometry report)

    • Radiographs*/ (directory containing all radiograph images)

    • **/gain0.tif (blank scan image)

    • **/offset.tif (dark scan image)

    • **/*.defect (defective pixel information)

  • view_id_start (int, optional) – view index corresponding to the first view.

  • view_id_end (int, optional) – view index corresponding to the last view. If None, this will be equal to the total number of object scan images in obj_scan_dir.

  • subsample_view_factor (int, optional) – view subsample factor.

  • verbose (int, optional) – Verbosity level. Defaults to 1.

  • offset_correction (bool) – Whether to apply detector offset correction using values from the Geometry Report. Defaults to True.

Returns:

tuple

(obj_scan, blank_scan, dark_scan, nsi_params, defective_pixel_array)

  • obj_scan (numpy.ndarray): 3D object scan with shape (num_views, num_det_rows, num_det_channels).

  • blank_scan (numpy.ndarray): 3D blank scan with shape (1, num_det_rows, num_det_channels).

  • dark_scan (numpy.ndarray): 3D dark scan with shape (1, num_det_rows, num_det_channels).

  • nsi_params (dict): Required parameters needed for convert_nsi_to_mbirjax_params() (e.g., geometry vectors, spacings, and angles).

  • defective_pixel_array (numpy.ndarray | tuple): If a defective-pixel file is present, an (N, 2) integer array of (detector_row_idx, detector_channel_idx) pairs; otherwise an empty tuple ().

Zeiss Versa cone beam reader#

mbirjax.preprocess.zeiss_cb.compute_sino_and_params(dataset_dir, downsample_factor=(1, 1), subsample_view_factor=1, crop_pixels_sides=0, crop_pixels_top=0, crop_pixels_bottom=0, alu_unit='mm', bg_option='global', verbose=1)[source]#

Compute sinogram and parameters from txrm file generated by a Zeiss Versa scanner.

Notes

Thanks to contributions of Amir Koushyar Ziabari of Oak Ridge National Laboratory (ORNL). Portions of this code are adapted from the DXchange library: data-exchange/dxchange

Steps:
  1. Load object, blank, and dark scans and geometry.

  2. Compute the sinogram from the scans.

  3. Apply background offset correction.

  4. Apply sinogram offset correction.

Parameters:
  • dataset_dir (str) – Path to the Zeiss dataset. Accepts a .txrm file with: - ImageData*/Image* (scan data) - Zeiss OLE metadata streams

  • downsample_factor (Tuple[int, int], optional) – Downsample factors for detector rows and channels. Defaults to (1, 1).

  • subsample_view_factor (int, optional) – Factor by which to subsample views. Defaults to 1.

  • crop_pixels_sides (int, optional) – Pixels to crop from each lateral side of the detector. Defaults to 0.

  • crop_pixels_top (int, optional) – Pixels to crop from the top of the detector. Defaults to 0.

  • crop_pixels_bottom (int, optional) – Pixels to crop from the bottom of the detector. Defaults to 0.

  • alu_unit (str, optional) – The physical unit used to define 1 ALU (Arbitrary Length Unit). Defaults to ‘mm’. Supported units input: ‘um’, ‘mm’, ‘cm’, ‘m’.

  • bg_option (str or None) – Option for background offset correction. Defaults to ‘global’. Supported options: - None: No correction; return the input sinogram unchanged. - “global”: Estimate one scalar offset from edge regions across all views. - “per_view”: Estimate one offset per view from edge regions.

  • verbose (int, optional) – Verbosity level. Defaults to 1.

Returns:

tuple

(sino, cone_beam_params, optional_params, zeiss_params)

  • sino (numpy.ndarray): Sinogram of shape (num_views, num_det_rows, num_channels).

  • cone_beam_params (dict): Parameters for initializing ConeBeamModel.

  • optional_params (dict): Additional parameters to be set via ConeBeamModel.set_params.

Example

from mbirjax.preprocess.zeiss_cb import compute_sino_and_params
sino, cone_beam_params, optional_params, zeiss_params = compute_sino_and_params(
    dataset_dir, verbose=1
)
ct_model = mbirjax.ConeBeamModel(**cone_beam_params)
ct_model.set_params(**optional_params)
recon, recon_dict = ct_model.recon(sino)
mbirjax.preprocess.zeiss_cb.load_scans_and_params(dataset_dir, subsample_view_factor, verbose=1)[source]#

Load the scan data and geometry from a Zeiss scan directory.

Notes

Thanks to contributions of Amir Koushyar Ziabari of Oak Ridge National Laboratory (ORNL). Portions of this code are adapted from the DXchange library: data-exchange/dxchange

Parameters:
  • dataset_dir (str) – Path to a Zeiss scan directory (expect a .txrm file). Expected structure: - ImageData*/Image* (scan data) - **/** (Zeiss metadata/parameters)

  • subsample_view_factor (int, optional) – view subsample factor.

  • verbose (int, optional) – Verbosity level. Defaults to 1.

Returns:

tuple

(obj_scan, blank_scan, dark_scan, zeiss_params, zeiss_params)

  • obj_scan (numpy.ndarray): 3D object scan with shape (num_views, num_det_rows, num_channels).

  • blank_scan (numpy.ndarray): 3D blank scan with shape (1, num_det_rows, num_channels).

  • dark_scan (numpy.ndarray): 3D dark scan with shape (1, num_det_rows, num_channels).

    If no dark scan is available, returns a zero array of the same shape.

Zeiss translation tomography functions#

mbirjax.preprocess.zeiss_tct.compute_sino_and_params(dataset_dir, crop_pixels_sides=0, crop_pixels_top=0, crop_pixels_bottom=0, alu_unit='mm', verbose=1)[source]#

Load Zeiss sinogram data and prepare arrays ana parameters for TranslationModel reconstruction.

This function computes the sinogram and geometry parameters from a Zeiss scan directory. It performs the following:

  1. Load object, blank, and dark scans, and geometry parameters from the dataset.

  2. Computes the sinogram from the scan images.

  3. Applies background offset correction.

Parameters:
  • dataset_dir (str) – Path to the Zeiss scan directory. Expected structure. - “obj_scan” (a subfolder containing the object scan) - “blank_scan” (a subfolder containing the blank scan) - “dark_scan” (a subfolder containing the dark scan)

  • crop_pixels_sides (int, optional) – Pixels to crop from each side of the sinogram. Defaults to None.

  • crop_pixels_top (int, optional) – Pixels to crop from top of the sinogram. Defaults to None.

  • crop_pixels_bottom (int, optional) – Pixels to crop from bottom of the sinogram. Defaults to None.

  • alu_unit (str, optional) – The physical unit used to define 1 ALU (Arbitrary Length Unit). Defaults to ‘mm’. Supported units input: ‘um’, ‘mm’, ‘cm’, ‘m’.

  • verbose (int, optional) – Verbosity level. Defaults to 1.

Returns:

tuple

(sino, translation_params, optional_params)
  • sino (jax.numpy.ndarray): Sinogram of shape (num_views, num_det_rows, num_channels)

  • translation_params (dict): Parameters for initializing TranslationModel.

  • optional_params (dict): Parameters to be passed via set_params.

Example

# Get data and reconstruction parameters
sino, translation_params, optional_params = mbirjax.preprocess.zeiss.compute_sino_and_params(dataset_dir)

# Create the model and set parameters
tct_model = mbirjax.TranslationModel(**translation_params)
tct_model.set_params(**optional_params)
tct_model.set_params(sharpness=sharpness, verbose=1)

# Run reconstruction
recon, recon_dict = tct_model.recon(sino)
mbirjax.preprocess.zeiss_tct.load_scans_and_params(dataset_dir, verbose=1)[source]#

Load the object scan, blank scan, dark scan, and geometry from a Zeiss scan directory.

Parameters:
  • dataset_dir (str) – Path to a Zeiss scan directory. Expected structure:

    • obj_scan — subfolder containing the object scan

    • blank_scan — subfolder containing the blank scan

    • dark_scan — subfolder containing the dark scan

  • verbose (int, optional) – Verbosity level. Defaults to 1.

Returns:

tuple

(obj_scan, blank_scan, dark_scan, zeiss_params)

  • obj_scan (numpy.ndarray): 3D object scan with shape (num_views, num_det_rows, num_channels).

  • blank_scan (numpy.ndarray): 3D blank scan with shape (1, num_det_rows, num_channels).

  • dark_scan (numpy.ndarray): 3D dark scan with shape (1, num_det_rows, num_channels).

  • zeiss_params (dict): Required parameters for convert_zeiss_to_mbirjax_params (e.g., geometry vectors, spacings, and angles).

PYMBIR functions#

mbirjax.preprocess.pymbir.compute_sino_and_params(filename, bh_correction=True)[source]#

Load ORNL sinogram data from an HDF5 file and prepare all required parameters for cone‐beam reconstruction.

This function performs the following steps in one call:

  1. Extracts geometry parameters and defaults via create_proj_params_dict_ornl.

  2. Reads the raw sinogram data via load_projection_data_ornl.

  3. Optionally applies beam hardening correction to the sinogram via apply_bh_correction.

Parameters:
  • filename (str) – Path to the ORNL HDF5 file containing projection data and geometry attributes.

  • bh_correction (bool, optional) – If True, apply beam hardening correction using the file’s stored parameters. Defaults to True.

Returns:

tuple

sino (numpy.ndarray):

sinogram data of shape (num_views, num_det_rows, num_det_channels).

cone_beam_params (dict):
Dictionary of mandatory parameters for mbirjax.ConeBeamModel, including:
  • ”sinogram_shape”

  • ”angles”

  • ”source_detector_dist”

  • ”source_iso_dist”

optional_params (dict):
Additional model settings for set_params(), including:
  • ”delta_det_channel”

  • ”delta_det_row”

  • ”delta_voxel”

  • ”det_channel_offset”

  • ”det_row_offset”

Example

sino, cone_beam_params, optional_params = compute_sino_and_params("scan.h5", bh_correction=True)
ct_model = mbirjax.ConeBeamModel(**cone_beam_params)
ct_model.set_params(**optional_params)
recon, recon_dict = ct_model.recon(sino, weights=weights)

General preprocess functions#

mbirjax.preprocess.compute_sino_transmission(obj_scan, blank_scan, dark_scan, defective_pixel_array=(), batch_size=90)[source]#

Compute sinogram from object, blank, and dark scans.

This function computes a sinogram by taking the negative logarithm of the normalized transmission image: -log((obj - dark) / (blank - dark)). It supports correction for defective pixels.

The invalid sinogram entries are defined as: - Any values resulting in inf or NaN - Any indices listed in the defective_pixel_array (if provided)

Parameters:
  • obj_scan (ndarray) – A 3D object scan of shape (num_views, num_det_rows, num_det_channels).

  • blank_scan (ndarray, optional) – A 3D blank scan of shape (num_blank_scans, num_det_rows, num_det_channels). If num_blank_scans > 1, a pixel-wise mean will be computed.

  • dark_scan (ndarray, optional) – A 3D dark scan of shape (num_dark_scans, num_det_rows, num_det_channels). If num_dark_scans > 1, a pixel-wise mean will be computed.

  • defective_pixel_array (ndarray, optional) – An array of defective pixel indices. Format can be either (view_idx, row_idx, channel_idx) or (row_idx, channel_idx), if shared across views. If None, invalid pixels are inferred from NaN or inf values.

  • batch_size (int) – Number of views to process in each GPU batch.

Returns:

ndarray – The computed sinogram, with shape (num_views, num_det_rows, num_det_channels).

mbirjax.preprocess.auto_crop_sino_conebeam(sino, cone_beam_params, optional_params, safety_buffer=20)[source]#

Automatically crop unused sinogram margins and update cone-beam geometry parameters.

This reduces the reconstruction volume by removing blank detector margins in the sinogram and updating the corresponding geometry offsets so the physical coordinate system remains consistent.

Parameters:
  • sino (np.ndarray) – Input sinogram array with shape (num_views, num_det_rows, num_det_channels).

  • cone_beam_params (dict) – Cone-beam geometry parameters that can be passed to the model constructor.

  • optional_params (dict) – Optional geometry parameters set after the model is constructed.

  • safety_buffer (int, optional) – Safety buffer (in pixels) to keep around the detected object region. Defaults to 20.

Returns:

tuple – A 3-tuple (sino, cone_beam_params, optional_params) where:

  • sino (np.ndarray): Cropped sinogram with updated shape.

  • cone_beam_params (dict): Updated parameters with adjusted 'sinogram_shape'.

  • optional_params (dict): Updated parameters with adjusted 'det_row_offset', 'det_channel_offset', and 'recon_slice_offset'.

mbirjax.preprocess.align_sino_views(ct_model, sino, direct_recon)[source]#

Align each sinogram view using estimated per-view shifts.

This function performs sinogram alignment in two steps: 1. Estimate a 2D shift for each sinogram view. 2. Align each sinogram view using the estimated shift with the forward projected reconstruction.

The alignment helps correct small per-view misalignments between the measured sinogram and the forward projection of a preliminary reconstruction.

Parameters:
  • ct_model (mj.TomographyModel) – A CT model object that defined the CT geometry.

  • sino (numpy array or jax array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • direct_recon (numpy array or jax array) – A preliminary 3D reconstruction of the sinogram.

Returns:

jax array – Aligned sinogram with the same shape as the input sinogram (num_views, num_det_rows, num_det_channels).

mbirjax.preprocess.interpolate_defective_pixels(sino, defective_pixel_array=())[source]#

Interpolates defective sinogram entries with the mean of neighboring pixels.

Parameters:
  • sino (jax array, float) – Sinogram data with 3D shape (num_views, num_det_rows, num_det_channels).

  • defective_pixel_array (jax array) – A list of tuples containing indices of invalid sinogram pixels, with the format (detector_row_idx, detector_channel_idx) or (view_idx, detector_row_idx, detector_channel_idx).

Returns:

2-element tuple containing

  • sino (jax array, float): Corrected sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • defective_pixel_list (list(tuple)): Updated defective_pixel_list with the format (detector_row_idx, detector_channel_idx) or (view_idx, detector_row_idx, detector_channel_idx).

mbirjax.preprocess.correct_det_rotation(sino, det_rotation=0.0, batch_size=30)[source]#

Correct sinogram data to account for detector rotation, using JAX for batch processing and GPU acceleration. Weights are not modified.

Parameters:
  • sino (numpy.ndarray) – Sinogram data with 3D shape (num_views, num_det_rows, num_det_channels).

  • det_rotation (optional, float) – tilt angle between the rotation axis and the detector columns in radians.

  • batch_size (int) – Number of views to process in each batch to avoid memory overload.

Returns:

  • A numpy.ndarray containing the corrected sinogram data if weights is None.

  • A tuple (sino_corrected, weights) if weights is not None.

mbirjax.preprocess.correct_background_offset(sino, edge_width=9, option='global')[source]#

Correct background offset in a sinogram.

Parameters:
  • sino (numpy.ndarray) – Sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • edge_width (int, optional) – Width of the edge regions in pixels. Must be an integer >= 1. Defaults to 9.

  • option (str or None) – One of: - None: No correction; return the input sinogram unchanged. - “global”: Estimate one scalar offset from edge regions across all views. - “per_view”: Estimate one offset per view from edge regions. Defaults to ‘global’.

Returns:

sino_corrected (numpy.ndarray)

mbirjax.preprocess.downsample_view_data(obj_scan, blank_scan, dark_scan, downsample_factor, defective_pixel_array=(), batch_size=90)[source]#

Performs down-sampling of the scan images in the detector plane. This is done for the object, blank_scan, and dark_scan data, and the defective_pixel_array is updated to reflect the new pixel grid.

Parameters:
  • obj_scan (ndarray) – A stack of sinograms. 3D NumPy array of shape (num_views, num_det_rows, num_det_channels).

  • blank_scan (ndarray) – Blank scan(s). 3D NumPy array of shape (num_blank_views, num_det_rows, num_det_channels).

  • dark_scan (ndarray) – Dark scan(s). 3D NumPy array of shape (num_dark_views, num_det_rows, num_det_channels).

  • downsample_factor (tuple of int) – Two integers defining the down-sample factor. Must be ≥ 1 in each dimension.

  • defective_pixel_array (ndarray) – Array of shape (num_defective_pixels, 2) indicating defective pixel coordinates.

  • batch_size (int) – Number of views to include in one JAX batch. Controls memory usage.

Notes

This function supports both singleton blank/dark scans (shape (1, H, W)) and multi-view scans (shape (N, H, W), where N > 1). Downsampling is applied independently to each view.

Returns:

tuple

  • obj_scan (ndarray): Downsampled object scan. Shape (num_views, new_rows, new_cols).

  • blank_scan (ndarray): Downsampled blank scan(s). Shape (num_blank_views, new_rows, new_cols).

  • dark_scan (ndarray): Downsampled dark scan(s). Shape (num_dark_views, new_rows, new_cols).

  • defective_pixel_array (ndarray): Updated defective pixel coordinates. Shape (N_def, 2).

mbirjax.preprocess.crop_view_data(obj_scan, blank_scan, dark_scan, crop_pixels_sides=0, crop_pixels_top=0, crop_pixels_bottom=0, defective_pixel_array=())[source]#

Crop obj_scan, blank_scan, and dark_scan by the specified pixel amounts and update defective_pixel_array.

The same number of pixels is cropped from the left and right sides (via crop_pixels_sides) to preserve the detector center/rotation axis. Top and bottom cropping are controlled independently by crop_pixels_top and crop_pixels_bottom. Any defective pixels that fall outside the cropped region are removed; remaining coordinates are shifted to the new origin of the cropped images.

Parameters:
  • obj_scan (np.ndarray) – Sinogram stack of shape (num_views, num_det_rows, num_det_channels).

  • blank_scan (np.ndarray) – Blank scan(s) of shape (num_blank_views, num_det_rows, num_det_channels).

  • dark_scan (np.ndarray) – Dark scan(s) of shape (num_dark_views, num_det_rows, num_det_channels).

  • crop_pixels_sides (int, optional) – Number of pixels to remove from each side (left and right) of the detector channels. Defaults to 0.

  • crop_pixels_top (int, optional) – Number of pixels to remove from the top (small row indices). Defaults to 0.

  • crop_pixels_bottom (int, optional) – Number of pixels to remove from the bottom (large row indices). Defaults to 0.

  • defective_pixel_array (np.ndarray | tuple, optional) – Array of shape (num_defective_pixels, 2) containing (row, col) pixel coordinates that are known to be defective in detector coordinates shared across views. May be an empty tuple () if no defects are provided. Defaults to ().

Returns:

tuple – A 4-tuple (obj_scan, blank_scan, dark_scan, defective_pixel_array) where

  • obj_scan (np.ndarray): Cropped object scan of shape (num_views, new_rows, new_cols).

  • blank_scan (np.ndarray): Cropped blank scan(s) of shape (num_blank_views, new_rows, new_cols).

  • dark_scan (np.ndarray): Cropped dark scan(s) of shape (num_dark_views, new_rows, new_cols).

  • defective_pixel_array (np.ndarray | tuple): Updated defective-pixel coordinates in the cropped detector grid (shape (N_def, 2)), or () if no defects remain.

Raises:

AssertionError – If any crop amount is negative, or if crop_pixels_top + crop_pixels_bottom >= num_det_rows, or if 2 * crop_pixels_sides >= num_det_channels.

Notes

This function supports both singleton and multi-view blank_scan/dark_scan. Cropping is applied identically across all views.

mbirjax.preprocess.apply_cylindrical_mask(recon, radial_margin=0, top_margin=0, bottom_margin=0)[source]#

Applies a cylindrical mask to a 3D reconstruction volume.

This function zeros out all voxels outside a centered cylindrical region in the (row, col) plane and also zeroes a specified number of slices from the top and bottom along the Z-axis (slice axis).

This function is useful for removing flash that typically accumulates on the boundaries of an MBIR reconstruction volume.

Note

This function may need to be converted to batch over slices for very large recons.

Parameters:
  • recon (jnp.ndarray) – 3D volume with shape (num_rows, num_cols, num_slices).

  • radial_margin (int) – Margin to subtract from the cylinder radius in pixels.

  • top_margin (int) – Number of top slices to set to zero along the Z-axis.

  • bottom_margin (int) – Number of bottom slices to set to zero along the Z-axis.

Returns:

jnp.ndarray – Masked 3D volume of the same shape as recon.

Example

>>> import jax.numpy as jnp
>>> vol = jnp.ones((128, 128, 64))
>>> masked_vol = apply_cylindrical_mask(vol,radial_margin=10,top_margin=4,bottom_margin=4)
>>> masked_vol.shape
(128, 128, 64)
mbirjax.preprocess.read_tif_stack_dir(scan_dir, view_ids=None)[source]#

Reads a tif stack of scan images from a directory. This function is a subroutine to load_scans_and_params.

Parameters:
  • scan_dir (string) – Path to a ConeBeam Scan directory. Example: “<absolute_path_to_dataset>/Radiographs”

  • view_ids (ndarray of ints, optional, default=None) – List of view indices to specify which scans to read.

Returns:

ndarray (float) – 3D numpy array, (num_views, num_det_rows, num_det_channels). A stack of scan images.

mbirjax.preprocess.read_tif_img(img_path)[source]#

Reads a scan image from a TIFF file. Supports both 2D and 3D TIFFs.

This function loads a TIFF image using tifffile.imread(), then calls _normalize_to_float32() to normalizes it to float32 format if the input is of integer type. If the image has more than two dimensions (e.g., 3D volumes or RGB channels), the returned array preserves that shape.

Parameters:

img_path (str) – Path to the image file. The file must be readable by tifffile.

Returns:

np.ndarray – Image data as a float32 NumPy array. Can be 2D or higher dimensional depending on the input.

MAR utilities#

mbirjax.preprocess.gen_huber_weights(weights, sino_error, T=1.0, delta=1.0, epsilon=1e-06)[source]#

This function generates generalized Huber weights based on the method described in the referenced notes. It adds robustness by treating any element where |sino_error / weights| > T as an outlier, down-weighting it according to the generalized Huber function.

The function returns new ghuber_weights.

Typically, to obtain the final robust weights, the ghuber_weights should be multiplied by the original weights:

final_weights = weights * ghuber_weights

Parameters:
  • weights – jnp.ndarray of shape (views, rows, cols): Initial weights, typically derived from inverse variance estimates.

  • sino_error – jnp.ndarray of shape (views, rows, cols): Sinogram error array representing deviations from the model.

  • T – float, optional (default=1.0): Threshold parameter; values greater than T are treated as outliers.

  • delta – float, optional (default=1.0): Controls the strength of the generalized Huber function (delta=1 corresponds to the conventional Huber).

  • epsilon – float, optional (default=1e-6): Small number to avoid division by zero.

Returns:

huber_weights

jnp.ndarray of shape (views, rows, cols)

The computed generalized Huber weights.

Notes

The generalized Huber function used in this function is based on: Venkatakrishnan, S. V., Drummy, L. F., Jackson, M., De Graef, M., Simmons, J. P., and Bouman, C. A., “Model-Based Iterative Reconstruction for Bright-Field Electron Tomography,” IEEE Transactions on Computational Imaging, vol. 1, no. 1, pp. 1–15, 2015. DOI: 10.1109/TCI.2014.2371751

Example

>>> from mbirjax import gen_huber_weights
>>> huber_weights = gen_huber_weights(weights, sino_error)
>>> final_weights = weights * huber_weights
mbirjax.preprocess.BH_correction(sino, alpha, batch_size=64)[source]#

Apply a polynomial beam hardening correction to a sinogram.

This function applies a polynomial correction to each view of the sinogram by evaluating powers of the sinogram values and weighting them by the coefficients in alpha, while also including the original linear term (the sinogram itself).

The corrected sinogram is computed as:

corrected_sino = sino + alpha[0] * sino**2 + alpha[1] * sino**3 + …

It processes the sinogram in batches of views for memory efficiency.

Parameters:
  • sino (jnp.ndarray or np.ndarray of shape (views, rows, cols)) – Input sinogram to correct.

  • alpha (list or array of floats) – Coefficients for the polynomial correction. The k-th term corresponds to sino^(k+1).

  • batch_size (int, optional, default=16) – Number of views to process in a single batch.

Returns:

corrected_sino

jnp.ndarray of shape (views, rows, cols)

Beam hardening corrected sinogram.

Example

>>> import mbirjax.preprocess as mjp
>>> alpha = [1.0, 0.2, 0.1]  # Correction: sino + 0.2 * sino^2 + 0.1 * sino^3
>>> corrected_sino = mjp.BH_correction(sino, alpha)
mbirjax.preprocess.recon_plastic_metal(ct_model, sino, weights, num_BH_iterations=3, num_constraint_update_iter=10, stop_threshold_change_pct=0.5, num_metal=1, order=3, alpha=1, beta=0.002, gamma=0.1, verbose=0)[source]#

Perform iterative metal artifact reduction using plastic-metal beam hardening correction. If num_metal is 0, then this performs a standard MBIR recon.

This function alternates between adaptive beam hardening correction (via correct_sino_plastic_metal) and reconstruction, refining the image over several iterations to suppress metal-induced artifacts.

Parameters:
  • ct_model – MBIRJAX cone beam model instance with direct_recon and recon methods.

  • sino (jnp.ndarray) – Input sinogram data to be corrected.

  • weights (jnp.ndarray) – Transmission weights used in the reconstruction algorithm.

  • num_BH_iterations (int, optional) – Number of correction-reconstruction iterations. Defaults to 3.

  • num_constraint_update_iter (int, optional) – Number of iterations for updating constraints. At each iteration, the most violated constraints are activated and the quadratic program is re-solved via OSQP.

  • stop_threshold_change_pct (float, optional) – Relative change threshold (%) for early stopping in MBIR. Defaults to 0.5.

  • num_metal (int, optional) – Number of metal materials to segment and correct for. Defaults to 1.

  • order (int, optional) – Maximum total degree of the beam hardening correction polynomial. Defaults to 3.

  • alpha (float, optional) – Degree-dependent scaling factor for regularization weights. Higher values penalize higher-order terms more strongly. Defaults to 1.

  • beta (float, optional) – Regularization strength for ridge regression. Defaults to 0.002.

  • gamma (float, optional) – Stabilization factor used in plastic correction. Multiplies the median of s_p to set a positive floor in the denominator, preventing division by near-zero or negative values. Defaults to 0.1.

  • verbose (int, optional) – Verbosity level for printing intermediate information. Defaults to 0.

Returns:

jnp.ndarray – The final corrected reconstruction after iterative beam hardening correction.

Example

>>> recon = recon_plastic_metal(
...     ct_model, sino, weights,
...     num_BH_iterations=3,
...     stop_threshold_change_pct=0.5,
...     num_metal=1,
...     order=3,
...     alpha=1,
...     beta=0.005,
...     verbose=1
... )
>>> mj.slice_viewer(recon)

Stripe/Ring/Offset Removal#

mbirjax.preprocess.remove_all_stripe(sino, snr=3, large_filter_size=61, small_filter_size=21)[source]#

Removes all types of stripe artifacts from a sinogram using a combination of three algorithms: 1. Interpolation-based removal of unresponsive and fluctuating stripes. 2. Sorting-based removal of large partial and full stripes. 3. Sorting-based removal of small to medium partial and full stripes.

This method is adapted from tomopy.remove_all_stripes() and is based on: Vo N, Atwood RC, Drakopoulos M. “Superior techniques for eliminating ring artifacts in x-ray micro-tomography.” Optics Express, 26(22):28396–28412, 2018.

Parameters:
  • sino (jax.Array) – A 3D sinogram array with shape (num_views, num_det_rows, num_det_channels).

  • snr (float, optional) – Signal-to-noise ratio used for stripe detection. A typical value is 3.0. Defaults to 3.

  • large_filter_size (int, optional) – Median filter window size for removing large stripes. Defaults to 61.

  • small_filter_size (int, optional) – Median filter window size for removing small-to-medium stripes. Defaults to 21.

Returns:

jax.Array – Corrected 3D sinogram array after removing all stripe artifacts.

Example

>>> import jax.numpy as jnp
>>> import mbirjax.preprocess as mjp
>>> sino = jnp.ones((180, 128, 256))  # Simulated 3D sinogram
>>> cleaned_sino = mjp.remove_all_stripe(sino)
mbirjax.preprocess.remove_stripe_fw(sino, wavelet_filter_name='db5', sigma=2)[source]#

Removes vertical stripe artifacts from a 3D sinogram using a combined wavelet-Fourier filtering technique.

This method uses a 2D Discrete Wavelet Transform followed by a 2D Fourier transform to suppress vertical stripes, as described in: Beat Münch et al., “Stripe and ring artifact removal with combined wavelet—Fourier filtering”, Optics Express, 2009.

This implementation is adapted from the Tomopy library’s remove_stripe_fw(): tomopy/tomopy.git

Parameters:
  • sino (jax.Array) – 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).

  • wavelet_filter_name (str, optional) – Wavelet filter type (e.g., ‘db5’, ‘haar’). Defaults to ‘db5’.

  • sigma (float, optional) – Damping parameter in the Fourier domain. Controls the strength of stripe suppression. Defaults to 2.

Returns:

jax.Array – Corrected sinogram data with reduced vertical stripe artifacts.

Example

>>> import jax.numpy as jnp
>>> import mbirjax.preprocess as mjp
>>> sino = jnp.ones((180, 128, 256))  # Simulated sinogram
>>> cleaned_sino = mjp.remove_stripe_fw(sino)
mbirjax.preprocess.remove_sino_offset(sino)[source]#

Remove additive offsets in the sinogram caused by material outside the field of view.

This function corrects each row of the sinogram so that the sum over channels is constant across views and equal to the minimum sum observed across all views.

Parameters:

sino (jax.Array) – Sinogram with shape (num_views, num_rows, num_channels).

Returns:

jax.Array – Corrected sinogram with the same shape as the input.

Example

>>> import jax.numpy as jnp
>>> import mbirjax.preprocess as mjp
>>> sino = jnp.ones((180, 128, 256)) + jnp.linspace(0, 1, 180)[:, None, None]
>>> corrected_sino = mjp.remove_sino_offset(sino)

Segmentation functions#

mbirjax.preprocess.multi_threshold_otsu(image, classes=2, num_bins=1024)[source]#

Segment an image into multiple intensity classes using Otsu’s method.

This function computes optimal threshold values that divide an image into the specified number of classes by minimizing the intra-class variance. It returns classes - 1 thresholds that can be used to partition the image intensity range into classes distinct segments.

Parameters:
  • image (np.ndarray) – Input image as a NumPy array of floating-point values.

  • classes (int, optional) – Number of classes to divide the image into. Must be ≥ 2. Defaults to 2.

  • num_bins (int, optional) – Number of bins to use when constructing the image histogram. Defaults to 256.

Returns:

list of float – A list of classes - 1 threshold values, given in increasing order. These thresholds can be used to separate the image into classes distinct intensity regions.

Example

>>> thresholds = multi_threshold_otsu(image, classes=4)
>>> # Resulting thresholds will split image into 4 intensity regions
mbirjax.preprocess.segment_plastic_metal(recon, num_metal, radial_margin=10, top_margin=10, bottom_margin=10)[source]#

Segment a reconstruction into plastic and multiple metal masks using multi-threshold Otsu.

Parameters:
  • recon (jnp.ndarray) – Reconstructed volume array.

  • num_metal (int) – Number of metal materials to segment.

  • radial_margin (int, optional) – Margin in pixels to subtract from the cylindrical mask radius.

  • top_margin (int, optional) – Number of slices to mask out from the top of the volume.

  • bottom_margin (int, optional) – Number of slices to mask out from the bottom of the volume.

Returns:

Tuple[jnp.ndarray, List[jnp.ndarray], float, List[float]] – - plastic_mask (jnp.ndarray): Binary mask for plastic regions. - metal_masks (List[jnp.ndarray]): List of binary masks for each metal region. - plastic_scale (float): Scaling factor for plastic region. - metal_scales (List[float]): List of scaling factors for each metal region.