Source code for mbirjax.translation_model

import warnings
import jax
import jax.numpy as jnp
from functools import partial
from collections import namedtuple

import numpy as np

import mbirjax as mj


[docs] class TranslationModel(mj.TomographyModel): """ This class implements the translation tomography geometry in which each view is a cone beam projection of a translated object. This geometry is useful for 3D imaging of thin objects. This class inherits all methods and properties from the :ref:`TomographyModelDocs` and may override some to suit the translation tomography geometrical requirements. See the documentation of the parent class for standard methods like setting parameters and performing projections and reconstructions. Args: sinogram_shape (tuple): Shape of the sinogram as (num_views, num_rows, num_channels), where 'num_views' is the number of translation steps, 'num_rows' is the number of detector rows, and 'num_channels' is the number of detector columns. translation_vectors (jax array or numpy array): A (num_views, 3) array of translations (x, y, z) in ALUs. Each vector specifies how the object is translated for each view. Positive x shifts the object left, z shifts up, and y shifts away from the source. source_detector_dist (float): Distance from the X-ray source to the detector. source_iso_dist (float): Distance from the X-ray source to the isocenter. Note: Additional parameter: **delta_recon_row** (float, default=0) - This parameter controls the row spacing in ALU, while the base parameter `delta_voxel` controls the voxel column and slice spacing. See Also: mbirjax.TomographyModel: Base class with standard methods like `set_params` and `reconstruct`. Example: .. code-block:: python import jax.numpy as jnp from mbirjax.translation_model import TranslationModel sinogram_shape = (180, 256, 256) translation_vectors = jnp.zeros((180, 3)) model = TranslationModel( sinogram_shape=sinogram_shape, translation_vectors=translation_vectors, source_detector_dist=500.0, source_iso_dist=500.0 ) model.set_params(delta_recon_row=2.0) model.auto_set_recon_geometry() """ DIRECT_RECON_VIEW_BATCH_SIZE = mj.TomographyModel.DIRECT_RECON_VIEW_BATCH_SIZE def __init__(self, sinogram_shape, translation_vectors, source_detector_dist, source_iso_dist): self.use_ror_mask = False self.entries_per_cylinder_batch = 128 translation_vectors = jnp.asarray(translation_vectors) view_params_name = 'translation_vectors' super().__init__(sinogram_shape, translation_vectors=translation_vectors, source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist, view_params_name=view_params_name, qggmrf_nbr_wts=(0.1, 1, 1,)) self.set_params(max_overrelaxation=1.3) # We override this value due to observed instabilities with larger values def get_magnification(self): """ Returns the magnification for the cone beam geometry. Returns: magnification = source_detector_dist / source_iso_dist """ source_detector_dist, source_iso_dist = self.get_params(['source_detector_dist', 'source_iso_dist']) if jnp.isinf(source_detector_dist): raise ValueError('Distance from source to detector is infinite, which means all translated projections have the same information.') else: magnification = source_detector_dist / source_iso_dist return magnification def get_geometry_parameters(self): """ Function to get a list of the primary geometry parameters for translation model projection. Returns: namedtuple of required geometry parameters. """ # First get the parameters managed by ParameterHandler geometry_param_names = \ ['delta_det_row', 'delta_det_channel', 'det_row_offset', 'det_channel_offset', 'source_detector_dist', 'delta_voxel', 'delta_recon_row'] geometry_param_values = self.get_params(geometry_param_names) # Then get additional parameters: geometry_param_names += ['magnification', 'psf_radius', 'entries_per_cylinder_batch'] geometry_param_values.append(self.get_magnification()) geometry_param_values.append(self.get_psf_radius()) geometry_param_values.append(self.entries_per_cylinder_batch) # Then create a namedtuple to access parameters by name in a way that can be jit-compiled. GeometryParams = namedtuple('GeometryParams', geometry_param_names) geometry_params = GeometryParams(*tuple(geometry_param_values)) return geometry_params def get_psf_radius(self): """ Compute the integer radius of the PSF kernel for cone beam projection. """ delta_det_row, delta_det_channel, source_iso_dist, source_detector_dist, recon_shape, delta_voxel, delta_recon_row, translation_vectors = self.get_params( ['delta_det_row', 'delta_det_channel', 'source_iso_dist', 'source_detector_dist', 'recon_shape', 'delta_voxel', 'delta_recon_row', 'translation_vectors']) magnification = self.get_magnification() # Compute minimum detector pitch delta_det = jnp.minimum(delta_det_row, delta_det_channel) # Find the maximum and minimum translation vectors max_translation = jnp.amax(translation_vectors, axis=0) min_translation = jnp.amin(translation_vectors, axis=0) # Compute maximum magnification if jnp.isinf(source_detector_dist): raise ValueError('Distance from source to detector is infinite, which means all translated projections have the same information.') else: # Determine the distance from the source to the closest voxel. source_to_closest_pixel = source_iso_dist - (0.5 * recon_shape[0] * delta_recon_row) - max_translation[1] # Determine the maximum magnification. max_magnification = source_detector_dist / source_to_closest_pixel if source_to_closest_pixel < 0: raise ValueError('Reconstruction volume extends into source - no valid projection in this case.') # Compute the maximum number of detector rows/channels on either side of the center detector hit by a voxel psf_radius = int(jnp.ceil(jnp.ceil((delta_voxel * max_magnification / delta_det)) / 2)) if psf_radius > 4: warnings.warn('A single voxel may project onto 100 or more detector elements, which may lead to artifacts. Consider using smaller voxels.') return psf_radius def auto_set_recon_geometry(self, no_compile=True, no_warning=False): """ Compute the automatic recon shape translation reconstruction. """ # Get model parameters sinogram_shape = self.get_params('sinogram_shape') source_detector_dist, source_iso_dist = self.get_params(['source_detector_dist', 'source_iso_dist']) delta_det_row, delta_det_channel = self.get_params(['delta_det_row', 'delta_det_channel']) translation_vectors = self.get_params('translation_vectors') # Calculate the reconstruction geometry parameters recon_shape, delta_voxel, delta_recon_row = mj.utilities.calc_tct_recon_params(source_detector_dist, source_iso_dist, delta_det_row, delta_det_channel, sinogram_shape, translation_vectors) self.set_params(no_compile=no_compile, no_warning=no_warning, recon_shape=recon_shape, delta_recon_row=delta_recon_row, delta_voxel=delta_voxel) @staticmethod @partial(jax.jit, static_argnames='projector_params') def forward_project_pixel_batch_to_one_view(voxel_values, pixel_indices, translation_vector, projector_params): """ Forward project a set of voxels determined by indices into the flattened array of size num_rows x num_cols. Args: voxel_values (jax array): 2D array of shape (num_indices, num_slices) of voxel values, where voxel_values[i, j] is the value of the voxel in slice j at the location determined by indices[i]. pixel_indices (jax array of int): 1D vector of indices into flattened array of size num_rows x num_cols. translation_vector (jax array of floats): 1D translation vector in ALU units for this view. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()) Returns: jax array of shape (num_det_rows, num_det_channels) """ recon_shape = projector_params.recon_shape num_recon_slices = recon_shape[2] if voxel_values.shape[0] != pixel_indices.shape[0] or len(voxel_values.shape) < 2 or \ voxel_values.shape[1] != num_recon_slices: raise ValueError('voxel_values must have shape[0:2] = (num_indices, num_slices)') vertical_fan_projector = TranslationModel.forward_vertical_fan_pixel_batch_to_one_view horizontal_fan_projector = TranslationModel.forward_horizontal_fan_pixel_batch_to_one_view new_voxel_values = vertical_fan_projector(voxel_values, pixel_indices, translation_vector, projector_params) sinogram_view = horizontal_fan_projector(new_voxel_values, pixel_indices, translation_vector, projector_params) return sinogram_view @staticmethod def forward_vertical_fan_pixel_batch_to_one_view(voxel_values, pixel_indices, translation_vector, projector_params): """ Apply a fan beam forward projection in the vertical direction separately to each voxel determined by indices into the flattened array of size num_rows x num_cols. This returns an array corresponding to the same pixel locations, but using values obtained from the projection of the original voxel cylinders onto a detector column, so the output array has size (len(pixel_indices), num_det_rows). Args: voxel_values (jax array): 2D array of shape (num_pixels, num_recon_slices) of voxel values, where voxel_values[i, j] is the value of the voxel in slice j at the location determined by indices[i]. pixel_indices (jax array of int): 1D vector of shape (len(pixel_indices), ) holding the indices into the flattened array of size num_rows x num_cols. translation_vector (jax array of floats): 1D translation vector in ALU units for this view. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()) Returns: jax array of shape (num_pixels, num_det_rows) """ pixel_map = jax.vmap(TranslationModel.forward_vertical_fan_one_pixel_to_one_view, in_axes=(0, 0, None, None)) new_pixels = pixel_map(voxel_values, pixel_indices, translation_vector, projector_params) return new_pixels @staticmethod def forward_horizontal_fan_pixel_batch_to_one_view(voxel_values, pixel_indices, translation_vector, projector_params): """ Apply a horizontal fan beam transformation to a set of voxel cylinders. These cylinders are assumed to have slices aligned with detector rows, so that a horizontal fan beam maps a cylinder slice to a detector row. This function returns the resulting sinogram view. Args: voxel_values (jax array): 2D array of shape (num_pixels, num_recon_slices) of voxel values, where voxel_values[i, j] is the value of the voxel in slice j at the location determined by indices[i]. pixel_indices (jax array of int): 1D vector of shape (len(pixel_indices), ) holding the indices into the flattened array of size num_rows x num_cols. translation_vector (jax array of floats): 1D translation vector in ALU units for this view. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()) Returns: jax array of shape (num_det_rows, num_det_channels) """ # Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access # elements using, for example, gp.delta_det_channel, so a longer name would be clumsy. gp = projector_params.geometry_params num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape # Get the data needed for horizontal projection n_p, n_p_center, W_p_c, cos_theta_p = TranslationModel.compute_horizontal_data(pixel_indices, translation_vector, projector_params) L_max = jnp.minimum(1, W_p_c) # Allocate the sinogram array sinogram_view = jnp.zeros((num_det_rows, num_det_channels)) # Do the horizontal projection for n_offset in jnp.arange(start=-gp.psf_radius, stop=gp.psf_radius + 1): n = n_p_center + n_offset abs_delta_p_c_n = jnp.abs(n_p - n) L_p_c_n = jnp.clip((W_p_c + 1) / 2 - abs_delta_p_c_n, 0, L_max) A_chan_n = gp.delta_recon_row * L_p_c_n / cos_theta_p A_chan_n *= (n >= 0) * (n < num_det_channels) sinogram_view = sinogram_view.at[:, n].add(A_chan_n.reshape((1, -1)) * voxel_values.T) return sinogram_view @staticmethod def forward_vertical_fan_one_pixel_to_one_view(voxel_cylinder, pixel_index, translation_vector, projector_params): """ Apply a fan beam forward projection in the vertical direction to the pixel determined by indices into the flattened array of size num_rows x num_cols. This returns a vector obtained from the projection of the original voxel cylinder onto a detector column, so the output vector has length num_det_rows. Args: voxel_cylinder (jax array): 1D array of shape (num_recon_slices, ) of voxel values, where voxel_cylinder[j] is the value of the voxel in slice j at the location determined by pixel_index. pixel_index (int): Index into the flattened array of size num_rows x num_cols. translation_vector (jax array of floats): 1D translation vector in ALU units for this view. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()) Returns: jax array of shape (num_det_rows,) Note: This is a helper function used in vmap in :meth:`TranslationModel.forward_vertical_fan_pixel_batch_to_one_view` This method has the same signature and output as that method, except single int pixel_index is used in place of the 1D pixel_indices, and likewise only a single voxel cylinder is returned. """ # Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access # elements using, for example, gp.delta_det_channel, so a longer name would be clumsy. gp = projector_params.geometry_params num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape recon_shape = projector_params.recon_shape num_recon_rows, num_recon_cols, num_recon_slices = recon_shape num_slices = voxel_cylinder.shape[0] # From pixel index, compute y and pixel_mag y, pixel_mag = TranslationModel.compute_y_mag_for_pixel(pixel_index, translation_vector, recon_shape, projector_params) # The code above depends only on the pixel - a single point. z is a potentially large vector # Here we compute cos_phi_p: 1 / cos_phi_p determines the projection length through a voxel # For computational efficiency, we use that to scale the voxel_cylinder values. ### possibly convert to a jitted function with donate_argnames to avoid copies for z, v, phi_p, cos_phi_p k = jnp.arange(len(voxel_cylinder)) z = gp.delta_voxel * (k - (num_recon_slices - 1) / 2.0) - translation_vector[2] # recon_ijk_to_xyz v = pixel_mag * z # geometry_xyz_to_uv_mag # Compute vertical cone angle of voxels phi_p = jnp.arctan2(v, gp.source_detector_dist) # compute_vertical_data_single_pixel cos_phi_p = jnp.cos(phi_p) # We assume the vertical angle |phi_p| < 45 degrees so cos_alpha_p_z = cos_phi_p scaled_voxel_values = voxel_cylinder / cos_phi_p # End # Get the length of projection of detector on vertical voxel profile (in fraction of voxel size) # This is also the slope of the map from voxel index to detector index W_p_r = (pixel_mag * gp.delta_voxel) / gp.delta_det_row slope_k_to_m = W_p_r L_max = jnp.minimum(1, W_p_r) # Maximum fraction of a detector that can be covered by one voxel. # Set up detector row indices array (0, 10, 20, ..., 10*num_slice_batches) det_rows_per_batch = gp.entries_per_cylinder_batch det_rows_per_batch = min(det_rows_per_batch, num_det_rows) num_det_row_batches = (num_det_rows + det_rows_per_batch - 1) // det_rows_per_batch det_row_indices = det_rows_per_batch * jnp.arange(num_det_row_batches) det_center_row = (num_det_rows - 1) / 2.0 row_batch = jnp.arange(det_rows_per_batch) # Set up a function to map over subsets of detector rows def create_det_column_rows(start_index): # We need to match the back projector, so we have to determine the fraction of each voxel that projects # to each detector. # First project the detector centers to the voxel cylinder m_center = start_index + row_batch # Center of detector elements v_m = (m_center - det_center_row) * gp.delta_det_row - gp.det_row_offset # Detector center in ALUs z_m = v_m / pixel_mag # z coordinate of the projection of the center of the first detector element in this batch # Convert to voxel fractional index and find the center of each voxel k_m = (z_m + translation_vector[2]) / gp.delta_voxel + (num_slices - 1) / 2.0 k_m_center = jnp.round(k_m).astype(int) # Center of the voxel hit by the center of the detector # Then map the center of the voxels back to the detector. m_p = slope_k_to_m * (k_m_center - k_m[0]) + m_center[0] # Projection to detector of voxel centers # Allocate space new_column_batch = jnp.zeros(det_rows_per_batch) # Do the vertical projection for k_offset in jnp.arange(start=-gp.psf_radius, stop=gp.psf_radius + 1): k_ind = k_m_center + k_offset # Indices of the current set of voxels touched by the detector elements # The projection of these centers is the projection of k_m_center (which is m_p) plus # the offset times the slope of the map from voxel index to detector index abs_delta_p_r_m = jnp.abs(m_p + slope_k_to_m * k_offset - m_center) # Distance from projection of center of voxel to center of detector A_row_k = jnp.clip((W_p_r + 1) / 2 - abs_delta_p_r_m, 0, L_max) # Fraction of the detector hit by this voxel A_row_k *= (k_ind >= 0) * (k_ind < num_slices) new_column_batch = jnp.add(new_column_batch, A_row_k * scaled_voxel_values[k_ind]) return new_column_batch, None det_column, _ = jax.lax.map(create_det_column_rows, det_row_indices) det_column = det_column.flatten() det_column = jax.lax.slice_in_dim(det_column, 0, num_det_rows) return det_column @staticmethod @partial(jax.jit, static_argnames='projector_params') def back_project_one_view_to_pixel_batch(sinogram_view, pixel_indices, single_view_params, projector_params, coeff_power=1): """ Use vmap to do a backprojection from one view to multiple pixels (voxel cylinders). Args: sinogram_view (2D jax array): one view of the sinogram to be back projected. 2D jax array of shape (num_det_rows)x(num_det_channels) pixel_indices (1D jax array of int): indices into flattened array of size num_rows x num_cols. single_view_params: These are the view dependent parameters for the view being back projected. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()). coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power). Normally 1, but should be 2 when computing Hessian diagonal. Returns: The voxel values for all slices at the input index (i.e., a voxel cylinder) obtained by backprojecting the input sinogram view. """ vertical_fan_projector = TranslationModel.back_vertical_fan_one_view_to_pixel_batch horizontal_fan_projector = TranslationModel.back_horizontal_fan_one_view_to_pixel_batch det_voxel_cylinder = horizontal_fan_projector(sinogram_view, pixel_indices, single_view_params, projector_params, coeff_power=coeff_power) back_projection = vertical_fan_projector(det_voxel_cylinder, pixel_indices, single_view_params, projector_params, coeff_power=coeff_power) return back_projection @staticmethod def back_horizontal_fan_one_view_to_pixel_batch(sinogram_view, pixel_indices, translation_vector, projector_params, coeff_power=1): """ Apply the back projection of a horizontal fan beam transformation to a single sinogram view and return the resulting voxel cylinders. Args: sinogram_view (2D jax array): one view of the sinogram to be back projected. 2D jax array of shape (num_det_rows)x(num_det_channels) pixel_indices (1D jax array of int): indices into flattened array of size num_rows x num_cols. translation_vector (jax array of floats): 1D translation vector in ALU units for this view. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()). coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power). Normally 1, but should be 2 when computing Hessian diagonal. Returns: jax array of shape (len(pixel_indices), num_det_rows) """ # Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access # elements using, for example, gp.delta_det_channel, so a longer name would be clumsy. gp = projector_params.geometry_params num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape num_pixels = pixel_indices.shape[0] # Get the data needed for horizontal projection n_p, n_p_center, W_p_c, cos_theta_p = TranslationModel.compute_horizontal_data(pixel_indices, translation_vector, projector_params) L_max = jnp.minimum(1, W_p_c) # Allocate the voxel cylinder array det_voxel_cylinder = jnp.zeros((num_pixels, num_det_rows)) # Do the horizontal projection for n_offset in jnp.arange(start=-gp.psf_radius, stop=gp.psf_radius + 1): n = n_p_center + n_offset abs_delta_p_c_n = jnp.abs(n_p - n) L_p_c_n = jnp.clip((W_p_c + 1) / 2 - abs_delta_p_c_n, 0, L_max) A_chan_n = gp.delta_recon_row * L_p_c_n / cos_theta_p A_chan_n *= (n >= 0) * (n < num_det_channels) A_chan_n = A_chan_n ** coeff_power det_voxel_cylinder = jnp.add(det_voxel_cylinder, A_chan_n.reshape((-1, 1)) * sinogram_view[:, n].T) return det_voxel_cylinder @staticmethod def back_vertical_fan_one_view_to_pixel_batch(det_voxel_cylinder, pixel_indices, single_view_params, projector_params, coeff_power=1): """ Apply a fan beam backward projection in the vertical direction to the pixel determined by indices into the flattened array of size num_rows x num_cols. This returns a vector obtained from the projection of the detector-based voxel cylinders onto voxel cylinders in recon space, so the output vector has length num_recon_slices. Args: det_voxel_cylinder (2D jax array): 2D array of shape (num_pixels, num_det_rows) of voxel values, where det_voxel_cylinder[i, j] is the value of the voxel in row j at the location determined by indices[i]. pixel_indices (1D jax array of int): indices into flattened array of size num_rows x num_cols. single_view_params: These are the view dependent parameters for the view being back projected. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()). coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power). Normally 1, but should be 2 when computing Hessian diagonal. Returns: 2D jax array of shape (num_pixels, num_recon_slices) of voxel values. """ pixel_map = jax.vmap(TranslationModel.back_vertical_fan_one_view_to_one_pixel, in_axes=(0, 0, None, None, None)) new_pixels = pixel_map(det_voxel_cylinder, pixel_indices, single_view_params, projector_params, coeff_power) return new_pixels @staticmethod def back_vertical_fan_one_view_to_one_pixel(detector_column_values, pixel_index, translation_vector, projector_params, coeff_power=1): """ Apply the back projection of a vertical fan beam transformation to a single voxel cylinder and return the column vector of the resulting values. Args: detector_column_values (1D jax array): 1D array of shape (num_det_rows,) of voxel values, where detector_column_values[i, j] is the value of the voxel in row j at the location determined by indices[i]. pixel_index (int): Index into flattened array of size num_rows x num_cols. translation_vector (jax array of floats): 1D translation vector in ALU units for this view. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()). coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power). Normally 1, but should be 2 when computing Hessian diagonal. Returns: 1D jax array of shape (num_recon_slices,) of voxel values. """ # Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access # elements using, for example, gp.delta_det_channel, so a longer name would be clumsy. gp = projector_params.geometry_params num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape recon_shape = projector_params.recon_shape num_recon_rows, num_recon_cols, num_recon_slices = recon_shape # Set up slice indices array (0, slices_per_batch, 2*slices_per_batch, ..., num_slice_batches*slices_per_batch) num_slices = num_recon_slices slices_per_batch = gp.entries_per_cylinder_batch slices_per_batch = min(slices_per_batch, num_slices) num_slice_batches = (num_slices + slices_per_batch - 1) // slices_per_batch slice_indices = slices_per_batch * jnp.arange(num_slice_batches) # Set up a function to map over the slices of the cylinder # Here we can use a map over subsections of the voxel cylinder because we are indexing by slice, # so there is no overlap from one section to the next. def create_voxel_cylinder_slices(start_index): # Allocate space new_cylinder = jnp.zeros(slices_per_batch) # Get the data needed for vertical projection cur_slice_indices = start_index + jnp.arange(slices_per_batch) m_p, m_p_center, W_p_r, cos_alpha_p_z = TranslationModel.compute_vertical_data_single_pixel(pixel_index, cur_slice_indices, translation_vector, projector_params) L_max = jnp.minimum(1, W_p_r) # Maximum fraction of a detector that can be covered by one voxel. # Do the vertical projection for m_offset in jnp.arange(start=-gp.psf_radius, stop=gp.psf_radius + 1): m = m_p_center + m_offset abs_delta_p_r_m = jnp.abs(m_p - m) # Distance from projection of center of voxel to center of detector L_p_r_m = jnp.clip((W_p_r + 1) / 2 - abs_delta_p_r_m, 0, L_max) A_row_m = L_p_r_m / cos_alpha_p_z A_row_m *= (m >= 0) * (m < num_det_rows) A_row_m = A_row_m ** coeff_power new_cylinder = jnp.add(new_cylinder, A_row_m * detector_column_values[m]) return new_cylinder, None recon_voxel_cylinder, _ = jax.lax.map(create_voxel_cylinder_slices, slice_indices) recon_voxel_cylinder = recon_voxel_cylinder.flatten() recon_voxel_cylinder = jax.lax.slice_in_dim(recon_voxel_cylinder, 0, num_recon_slices) return recon_voxel_cylinder @staticmethod def compute_vertical_data_single_pixel(pixel_index, slice_indices, translation_vector, projector_params): """ Compute the quantities m_p, m_p_center, W_p_r, cos_alpha_p_z needed for vertical projection. Args: pixel_index (int): Index into flattened array of size num_rows x num_cols. slice_indices (array of int): Indices into the recon slices. translation_vector (jax array of floats): 1D translation vector in ALU units for this view. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()). Returns: m_p, m_p_center, W_p_r, cos_alpha_p_z """ # Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access # elements using, for example, gp.delta_det_channel, so a longer name would be clumsy. gp = projector_params.geometry_params num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape recon_shape = projector_params.recon_shape num_recon_rows, num_recon_cols, num_recon_slices = recon_shape # Convert the index into (i,j,k) coordinates corresponding to the indices into the 3D voxel array row_index, col_index = jnp.unravel_index(pixel_index, recon_shape[:2]) # slice_indices = jnp.arange(num_recon_slices) x_p, y_p, z_p = TranslationModel.recon_ijk_to_xyz(row_index, col_index, slice_indices, recon_shape, gp.delta_voxel, gp.delta_recon_row, translation_vector) # Convert from xyz to coordinates on detector u_p, v_p, pixel_mag = TranslationModel.geometry_xyz_to_uv_mag(x_p, y_p, z_p, gp.source_detector_dist, gp.magnification) # Convert from uv to index coordinates in detector and get the vector of center detector rows for this cylinder m_p, _ = TranslationModel.detector_uv_to_mn(u_p, v_p, gp.delta_det_channel, gp.delta_det_row, gp.det_channel_offset, gp.det_row_offset, num_det_rows, num_det_channels) m_p_center = jnp.round(m_p).astype(int) # Compute vertical cone angle of pixel phi_p = jnp.arctan2(v_p, gp.source_detector_dist) # Compute cos alpha for row and columns cos_phi_p = jnp.cos(phi_p) # We assume the vertical angle |phi_p| < 45 degrees so cos_alpha_p_z = cos_phi_p # cos_alpha_p_z = jnp.maximum(jnp.abs(cos_phi_p), jnp.abs(jnp.sin(phi_p))) # Get the length of projection of flattened voxel on detector (in fraction of detector size) W_p_r = pixel_mag * (gp.delta_voxel / gp.delta_det_row) # * cos_alpha_p_z / cos_phi_p vertical_data = (m_p, m_p_center, W_p_r, cos_phi_p) # cos_alpha_p_z) return vertical_data @staticmethod def compute_horizontal_data(pixel_indices, translation_vector, projector_params): """ Compute the quantities n_p, n_p_center, W_p_c, cos_alpha_p_xy needed for vertical projection. Args: pixel_indices (1D jax array of int): indices into flattened array of size num_rows x num_cols. translation_vector (jax array of floats): 1D translation vector in ALU units for this view. projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()). Returns: n_p, n_p_center, W_p_c, cos_theta_p """ # Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access # elements using, for example, gp.delta_det_channel, so a longer name would be clumsy. gp = projector_params.geometry_params num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape recon_shape = projector_params.recon_shape num_recon_rows, num_recon_cols, num_recon_slices = recon_shape # Convert the index into (i,j,k) coordinates corresponding to the indices into the 3D voxel array row_index, col_index = jnp.unravel_index(pixel_indices, recon_shape[:2]) slice_index = jnp.arange(1) x_p, y_p, _ = TranslationModel.recon_ijk_to_xyz(row_index, col_index, slice_index, recon_shape, gp.delta_voxel, gp.delta_recon_row, translation_vector) # Convert from xyz to coordinates on detector # pixel_mag should be kept in terms of magnification to allow for source_detector_dist = jnp.Inf pixel_mag = 1 / (1 / gp.magnification - y_p / gp.source_detector_dist) # Compute the physical position that this voxel projects onto the detector u_p = pixel_mag * x_p det_center_channel = (num_det_channels - 1) / 2.0 # num_of_cols # Calculate indices on the detector grid n_p = (u_p + gp.det_channel_offset) / gp.delta_det_channel + det_center_channel # Sync with detector_uv_to_mn n_p_center = jnp.round(n_p).astype(int) # Compute horizontal and vertical cone angle of pixel theta_p = jnp.arctan2(u_p, gp.source_detector_dist) # Compute cosine alpha cos_theta_p = jnp.cos(theta_p) # Compute projected voxel width along columns and rows (in fraction of detector size) W_p_c = pixel_mag * (gp.delta_voxel / gp.delta_det_channel) horizontal_data = (n_p, n_p_center, W_p_c, cos_theta_p) return horizontal_data @staticmethod def recon_ijk_to_xyz(i, j, k, recon_shape, delta_voxel, delta_recon_row, translation_vector): """ Convert (i, j, k) indices into the recon volume to corresponding (x, y, z) coordinates. """ num_recon_rows, num_recon_cols, num_recon_slices = recon_shape y = delta_recon_row * (i - (num_recon_rows - 1) / 2.0) - translation_vector[1] x = delta_voxel * (j - (num_recon_cols - 1) / 2.0) - translation_vector[0] z = delta_voxel * (k - (num_recon_slices - 1) / 2.0) - translation_vector[2] return x, y, z @staticmethod def geometry_xyz_to_uv_mag(x, y, z, source_detector_dist, magnification): """ Convert (x, y, z) coordinates to to (u, v) detector coordinates plus the pixel-dependent magnification. """ # Compute the magnification at this specific voxel # The following expression is valid even when source_detector_dist = jnp.Inf pixel_mag = 1 / (1 / magnification - y / source_detector_dist) # Compute the physical position that this voxel projects onto the detector u = pixel_mag * x v = pixel_mag * z return u, v, pixel_mag @staticmethod @jax.jit def detector_uv_to_mn(u, v, delta_det_channel, delta_det_row, det_channel_offset, det_row_offset, num_det_rows, num_det_channels): """ Convert (u, v) detector coordinates to fractional indices (m, n) into the detector. Note: This version does not account for nonzero detector rotation. """ # Get the center of the detector grid for columns and rows det_center_row = (num_det_rows - 1) / 2.0 # num_of_rows det_center_channel = (num_det_channels - 1) / 2.0 # num_of_cols # Calculate indices on the detector grid m = (v + det_row_offset) / delta_det_row + det_center_row n = (u + det_channel_offset) / delta_det_channel + det_center_channel # Sync with compute_horizontal_data return m, n @staticmethod @jax.jit def compute_y_mag_for_pixel(pixel_index, translation_vector, recon_shape, projector_params): gp = projector_params.geometry_params row_index, col_index = jnp.unravel_index(pixel_index, recon_shape[:2]) num_recon_rows, num_recon_cols, num_recon_slices = recon_shape y = gp.delta_recon_row * (row_index - (num_recon_rows - 1) / 2.0) - translation_vector[1] # Convert from xyz to coordinates on detector pixel_mag = 1 / (1 / gp.magnification - y / gp.source_detector_dist) return y, pixel_mag def direct_recon(self, sinogram, filter_name="ramp", view_batch_size=DIRECT_RECON_VIEW_BATCH_SIZE): return self.fdk_recon(sinogram, filter_name=filter_name, view_batch_size=view_batch_size) def direct_filter(self, sinogram, filter_name="ramp", view_batch_size=DIRECT_RECON_VIEW_BATCH_SIZE): """ Perform filtering on the given sinogram as needed for an FBP/FDK or other direct recon. Args: sinogram (jax array): The input sinogram with shape (num_views, num_rows, num_channels). filter_name (string, optional): Name of the filter to be used. Defaults to "ramp" view_batch_size (int, optional): Size of view batches (used to limit memory use) Returns: filtered_sinogram (jax array): The sinogram after FBP filtering. """ return self.fdk_filter(sinogram, filter_name=filter_name, view_batch_size=view_batch_size) def fdk_filter(self, sinogram, filter_name="ramp", view_batch_size=DIRECT_RECON_VIEW_BATCH_SIZE): """ Perform FDK filtering on the given sinogram. Args: sinogram (jax array): The input sinogram with shape (num_views, num_rows, num_channels). filter_name (string, optional): Name of the filter to be used. Defaults to "ramp" view_batch_size (int, optional): Size of view batches (used to limit memory use) Returns: filtered_sinogram (jax array): The sinogram after FDK filtering. """ # Get parameters num_views, num_rows, num_channels = sinogram.shape source_detector_dist, source_iso_dist = self.get_params(['source_detector_dist', 'source_iso_dist']) delta_voxel, delta_det_row, delta_det_channel = self.get_params(['delta_voxel', 'delta_det_row', 'delta_det_channel']) det_row_offset, det_channel_offset = self.get_params(['det_row_offset', 'det_channel_offset']) if view_batch_size is None: view_batch_size = self.view_batch_size_for_vmap max_view_batch_size = 128 # Limit the view batch size here and ParallelBeam due to https://github.com/jax-ml/jax/issues/27591 view_batch_size = min(view_batch_size, max_view_batch_size) # Magnification factor M_0 = Source-Detector Distance / Source-Isocenter Distance M_0 = self.get_magnification() # Define the index arrays for channels and rows m = jnp.arange(num_rows) # Column vector for rows n = jnp.arange(num_channels) # Row vector for channels m_grid, n_grid = jnp.meshgrid(m, n, indexing='ij') # Coordinate transformation to physical distances: u_grid, v_grid = mj.ConeBeamModel.detector_mn_to_uv(m_grid, n_grid, delta_det_channel, delta_det_row, det_channel_offset, det_row_offset, num_rows, num_channels) # Compute the weight weight_map = source_detector_dist / jnp.sqrt(source_detector_dist ** 2 + u_grid**2 + v_grid**2) # Apply the pre-weighting factor to the sinogram weighted_sinogram = jax.device_put(sinogram * weight_map[None, :, :], self.sinogram_device) # Compute the scaled filter # Scaling factor alpha adjusts the filter to account for voxel size, ensuring consistent reconstruction. # For a detailed theoretical derivation of this scaling factor, please refer to the zip file linked at # https://mbirjax.readthedocs.io/en/latest/theory.html recon_filter = mj.tomography_utils.generate_direct_recon_filter(num_channels, filter_name=filter_name) alpha = delta_det_row / (delta_voxel**3 * M_0) recon_filter = alpha * recon_filter # Define convolution for a single row (across its channels) def convolve_row(row): return jax.scipy.signal.fftconvolve(row, recon_filter, mode="valid") # Apply above convolve func across each row of a view def apply_convolution_to_view(view): return jax.vmap(convolve_row)(view) # Apply convolution across the channels of the weighted sinogram per each fixed view & row num_views = sinogram.shape[0] filtered_sino_list = [] for i in range(0, num_views, view_batch_size): sino_batch = jax.device_put(weighted_sinogram[i:min(i + view_batch_size, num_views)], self.sinogram_device) filtered_sinogram_batch = jax.lax.map(apply_convolution_to_view, sino_batch, batch_size=view_batch_size) filtered_sinogram_batch.block_until_ready() filtered_sino_list.append(jax.device_put(filtered_sinogram_batch, self.sinogram_device)) filtered_sinogram = jnp.concatenate(filtered_sino_list, axis=0) filtered_sinogram *= jnp.pi / num_views return filtered_sinogram def fdk_recon(self, sinogram, filter_name="ramp", view_batch_size=DIRECT_RECON_VIEW_BATCH_SIZE): """ Perform FDK reconstruction on the given sinogram. Our implementation uses standard filtering of the sinogram, then uses the adjoint of the forward projector to perform the backprojection. This is different from many implementations, in which the backprojection is not exactly the adjoint of the forward projection. For a detailed theoretical derivation of this implementation, see the zip file linked at this page: https://mbirjax.readthedocs.io/en/latest/theory.html Args: sinogram (jax array): The input sinogram with shape (num_views, num_rows, num_channels). filter_name (string, optional): Name of the filter to be used. Defaults to "ramp" view_batch_size (int, optional): Size of view batches (used to limit memory use) Returns: recon (jax array): The reconstructed volume after FDK reconstruction. """ filtered_sinogram = self.fdk_filter(sinogram, filter_name=filter_name, view_batch_size=view_batch_size) # Apply backprojection recon = self.back_project(filtered_sinogram) return recon def _get_estimate_of_recon_std(self, sinogram, sino_indicator): """ Estimate the standard deviation of the reconstruction from the sinogram. This is used to scale sigma_prox and sigma_x in MBIR reconstruction. This version accounts for anisotropic row pitch in translation geometry Args: sinogram (ndarray): 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels). sino_indicator (ndarray): a binary mask that indicates the region of sinogram support; same shape as sinogram. """ # Get parameters delta_recon_row = self.get_params('delta_recon_row') recon_shape = self.get_params('recon_shape') # Compute the typical magnitude of a sinogram value typical_sinogram_value = np.average(np.abs(sinogram), weights=sino_indicator) # Compute a typical projection path length # For TCT, we will assume that the projections are along the row direction, # and we will assume that the object fills approximately half the distance along the rows. fraction_of_fill = 0.5 typical_path_length = fraction_of_fill * recon_shape[0] * delta_recon_row # Compute a typical recon value by dividing average sinogram value by a typical projection path length recon_std = typical_sinogram_value / typical_path_length return recon_std