import jax
import jax.numpy as jnp
from functools import partial
from collections import namedtuple
import warnings
from typing import Literal, Union, overload, Any
import numpy as np
import mbirjax as mj
from mbirjax import TomographyModel, tomography_utils, ParameterHandler
ConeBeamParamNames = mj.ParamNames | Literal['view_params_array', 'source_detector_dist', 'source_iso_dist', 'recon_slice_offset']
[docs]
class ConeBeamModel(TomographyModel):
"""
A class designed for handling forward and backward projections in a cone beam geometry, extending the
:ref:`TomographyModelDocs`. This class offers specialized methods and parameters tailored for cone beam setups.
This class inherits all methods and properties from the :ref:`TomographyModelDocs` and may override some
to suit parallel beam geometrical requirements. See the documentation of the parent class for standard methods
like setting parameters and performing projections and reconstructions.
Parameters not included in the constructor can be set using the set_params method of :ref:`TomographyModelDocs`.
Refer to :ref:`TomographyModelDocs` documentation for a detailed list of possible parameters.
Args:
sinogram_shape (tuple):
Shape of the sinogram as a tuple in the form `(views, rows, channels)`, where 'views' is the number of
different projection angles, 'rows' correspond to the number of detector rows, and 'channels' index columns of
the detector that are assumed to be aligned with the rotation axis.
angles (ndarray or jax array):
A 1D array of projection angles, in radians, specifying the angle of each projection relative to the origin.
source_detector_dist (float): Distance between the X-ray source and the detector in units of ALU.
source_iso_dist (float): Distance between the X-ray source and the center of rotation in units of ALU.
helical_z_shifts (ndarray or jax array, optional): Per-view axial shifts (ALU; same length as angles) for helical mode.
use_curved_detector (bool, optional): Detector geometry type. Either False (default) or True implies each detector row has constant distance from source.
Note:
Additional parameter:
**recon_slice_offset** (float, default=0) -
This parameter controls the vertical offset of the reconstruction in ALU. If recon_slice_offset is positive, the region below iso is reconstructed.
See Also
--------
TomographyModel : The base class from which this class inherits.
"""
DIRECT_RECON_VIEW_BATCH_SIZE = TomographyModel.DIRECT_RECON_VIEW_BATCH_SIZE
def __init__(self, sinogram_shape, angles, source_detector_dist, source_iso_dist, helical_z_shifts=None,
use_curved_detector=False):
self.bp_psf_radius = 1
self.entries_per_cylinder_batch = 128
self.slice_range_length = 0
if helical_z_shifts is None:
# If helical_z_shifts is not provided or None,
# then circular scan is helical with all-zero z shifts
helical_z_shifts = jnp.zeros_like(angles)
view_dependent_vecs = [vec.flatten() for vec in [angles, helical_z_shifts]]
try:
view_params_array = jnp.stack(view_dependent_vecs, axis=1)
except ValueError as e:
print(e)
raise ValueError("Incompatible view dependent vector lengths: all view-dependent vectors must have the "
"same length.")
view_params_name = 'view_params_array'
view_params_component_names = ['angles', 'helical_z_shifts']
super().__init__(sinogram_shape, view_params_array=view_params_array,
source_detector_dist=source_detector_dist, source_iso_dist=source_iso_dist,
view_params_name=view_params_name, view_params_component_names=view_params_component_names,
recon_slice_offset=0.0, use_curved_detector=use_curved_detector)
@overload
def get_params(self, parameter_names: Union[ConeBeamParamNames, list[ConeBeamParamNames]]) -> Any: ...
def get_params(self, parameter_names) -> Any:
return super().get_params(parameter_names)
def get_magnification(self):
"""
Returns the magnification for the cone beam geometry.
Returns:
magnification = source_detector_dist / source_iso_dist
"""
source_detector_dist, source_iso_dist = self.get_params(['source_detector_dist', 'source_iso_dist'])
if jnp.isinf(source_detector_dist):
magnification = 1
else:
magnification = source_detector_dist / source_iso_dist
return magnification
def verify_valid_params(self):
"""
Check that all parameters are compatible for a reconstruction.
Note:
Raises ValueError for invalid parameters.
"""
super().verify_valid_params()
sinogram_shape, view_params_array = self.get_params(['sinogram_shape', 'view_params_array'])
num_views, num_det_rows = sinogram_shape[:2]
if view_params_array is None:
raise ValueError("view_params_array was not set. This should be created in ConeBeamModel.__init__.")
if view_params_array.shape != (num_views, 2):
error_message = "Number view dependent parameter vectors must equal the number of views. \n"
error_message += "Got {} for length of view-dependent parameters and "
error_message += "{} for number of views.".format(view_params_array.shape[1], num_views)
raise ValueError(error_message)
# Check for cone angle > 45 degrees
source_detector_dist, delta_det_row, det_row_offset = \
self.get_params(['source_detector_dist', 'delta_det_row', 'det_row_offset'])
half_detector_height = delta_det_row * num_det_rows / 2 + jnp.abs(det_row_offset)
if half_detector_height > source_detector_dist:
warnings.warn('Cone angle is more than 45 degrees. This will likely produce recon artifacts.')
# TODO: Check for recon volume extending into the source
# # Check for a potential division by zero or very small denominator
# if (source_to_iso_dist - y) < 1e-3>:
# raise ValueError("Invalid geometry: Recon volume extends too close to source.")
def get_geometry_parameters(self):
"""
Function to get a list of the primary geometry parameters for cone beam projection.
Returns:
namedtuple of required geometry parameters.
"""
# First get the parameters managed by ParameterHandler
geometry_param_names = \
['delta_det_row', 'delta_det_channel', 'det_row_offset', 'det_channel_offset',
'source_detector_dist', 'delta_voxel', 'recon_slice_offset']
geometry_param_values = self.get_params(geometry_param_names)
# Then get additional parameters:
geometry_param_names += ['magnification', 'psf_radius', 'bp_psf_radius',
'entries_per_cylinder_batch', 'slice_range_length', 'use_curved_detector']
geometry_param_values.append(self.get_magnification())
geometry_param_values.append(self.get_psf_radius())
geometry_param_values.append(self.bp_psf_radius)
geometry_param_values.append(self.entries_per_cylinder_batch)
geometry_param_values.append(self.slice_range_length)
geometry_param_values.append(self.get_params('use_curved_detector'))
# Then create a namedtuple to access parameters by name in a way that can be jit-compiled.
GeometryParams = namedtuple('GeometryParams', geometry_param_names)
geometry_params = GeometryParams(*tuple(geometry_param_values))
return geometry_params
def get_psf_radius(self):
"""
Compute the integer radius of the PSF kernel for cone beam projection.
"""
delta_det_row, delta_det_channel, source_detector_dist, recon_shape, delta_voxel = self.get_params(
['delta_det_row', 'delta_det_channel', 'source_detector_dist', 'recon_shape', 'delta_voxel'])
magnification = self.get_magnification()
# Compute minimum detector pitch
delta_det = jnp.minimum(delta_det_row, delta_det_channel)
# Compute maximum magnification
if jnp.isinf(source_detector_dist):
max_magnification = 1
min_magnification = 1
else:
source_to_iso_dist = source_detector_dist / magnification
# This isn't exactly the closest pixel since we're not accounting for rotation but for realistic cases it shouldn't matter.
source_to_closest_pixel = source_to_iso_dist - 0.5 * jnp.maximum(recon_shape[0], recon_shape[1])*delta_voxel
max_magnification = source_detector_dist / source_to_closest_pixel
source_to_farthest_pixel = source_to_iso_dist + 0.5 * jnp.maximum(recon_shape[0], recon_shape[1])*delta_voxel
min_magnification = source_detector_dist / source_to_farthest_pixel
# Compute the maximum number of detector rows/channels on either side of the center detector hit by a voxel
psf_radius = int(jnp.ceil(jnp.ceil((delta_voxel * max_magnification / delta_det)) / 2))
# Then repeat for the back projection from detector elements to voxels.
# The voxels closest to the detector will be covered the most by a given detector element.
# With magnification=1, the number of voxels per element would be delta_det / delta_voxel
max_voxels_per_detector = delta_det / (min_magnification * delta_voxel)
self.bp_psf_radius = int(jnp.ceil(jnp.ceil(max_voxels_per_detector) / 2))
self.slice_range_length = int(1 + 2 * self.bp_psf_radius + \
jnp.ceil(self.entries_per_cylinder_batch * max_voxels_per_detector))
return psf_radius
def auto_set_recon_geometry(self, no_compile=True, no_warning=False):
""" Compute the automatic recon shape cone beam reconstruction.
"""
delta_det_row, delta_det_channel = self.get_params(['delta_det_row', 'delta_det_channel'])
# Compute delta_voxel
delta_voxel = self.get_params('delta_det_channel') / self.get_magnification()
# Compute the recon_shape
sinogram_shape = self.get_params('sinogram_shape')
num_det_rows, num_det_channels = sinogram_shape[1:3]
magnification = self.get_magnification()
num_recon_rows = int(jnp.round(num_det_channels * ((delta_det_channel / delta_voxel) / magnification)))
num_recon_cols = num_recon_rows
# z coverage for helical
z_shifts = self.get_params('view_params_array')[:,1]
z_min = jnp.min(z_shifts)
z_max = jnp.max(z_shifts)
z_travel = z_max - z_min
# Detector height mapped to iso
H_iso = num_det_rows * (delta_det_row / magnification)
# Total axial coverage: include all slices that are projected onto at least one view
num_recon_slices = int(jnp.ceil((H_iso + z_travel) / delta_voxel))
num_recon_slices = max(1, num_recon_slices)
recon_shape = (num_recon_rows, num_recon_cols, num_recon_slices)
# Center the recon about the helix travel (for circular, z_min=z_max=0 -> offset 0)
recon_slice_offset = float(0.5 * (z_min + z_max))
self.set_params(no_compile=no_compile, no_warning=no_warning, recon_shape=recon_shape, delta_voxel=delta_voxel, recon_slice_offset=recon_slice_offset)
@staticmethod
@partial(jax.jit, static_argnames='projector_params')
def forward_project_pixel_batch_to_one_view(voxel_values, pixel_indices, single_view_params, projector_params):
"""
Forward project a set of voxels determined by indices into the flattened array of size num_rows x num_cols.
Args:
voxel_values (jax array): 2D array of shape (num_indices, num_slices) of voxel values, where
voxel_values[i, j] is the value of the voxel in slice j at the location determined by indices[i].
pixel_indices (jax array of int): 1D vector of indices into flattened array of size num_rows x num_cols.
single_view_params: These are the angle and helical_z_shift for the view being forward projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params())
Returns:
jax array of shape (num_det_rows, num_det_channels)
"""
recon_shape = projector_params.recon_shape
num_recon_slices = recon_shape[2]
if voxel_values.shape[0] != pixel_indices.shape[0] or len(voxel_values.shape) < 2 or \
voxel_values.shape[1] != num_recon_slices:
raise ValueError('voxel_values must have shape[0:2] = (num_indices, num_slices)')
vertical_fan_projector = ConeBeamModel.forward_vertical_fan_pixel_batch_to_one_view
horizontal_fan_projector = ConeBeamModel.forward_horizontal_fan_pixel_batch_to_one_view
new_voxel_values = vertical_fan_projector(voxel_values, pixel_indices, single_view_params, projector_params)
sinogram_view = horizontal_fan_projector(new_voxel_values, pixel_indices, single_view_params, projector_params)
return sinogram_view
@staticmethod
def forward_vertical_fan_pixel_batch_to_one_view(voxel_values, pixel_indices, single_view_params, projector_params):
"""
Apply a fan beam forward projection in the vertical direction separately to each voxel determined by indices
into the flattened array of size num_rows x num_cols. This returns an array corresponding to the same pixel
locations, but using values obtained from the projection of the original voxel cylinders onto a detector column,
so the output array has size (len(pixel_indices), num_det_rows).
Args:
voxel_values (jax array): 2D array of shape (num_pixels, num_recon_slices) of voxel values, where
voxel_values[i, j] is the value of the voxel in slice j at the location determined by indices[i].
pixel_indices (jax array of int): 1D vector of shape (len(pixel_indices), ) holding the indices into
the flattened array of size num_rows x num_cols.
single_view_params: These are the angle and helical_z_shift for the view being forward projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params())
Returns:
jax array of shape (num_pixels, num_det_rows)
"""
pixel_map = jax.vmap(ConeBeamModel.forward_vertical_fan_one_pixel_to_one_view,
in_axes=(0, 0, None, None))
new_pixels = pixel_map(voxel_values, pixel_indices, single_view_params, projector_params)
return new_pixels
@staticmethod
def forward_horizontal_fan_pixel_batch_to_one_view(voxel_values, pixel_indices, single_view_params, projector_params):
"""
Apply a horizontal fan beam transformation to a set of voxel cylinders. These cylinders are assumed to have
slices aligned with detector rows, so that a horizontal fan beam maps a cylinder slice to a detector row.
This function returns the resulting sinogram view.
Args:
voxel_values (jax array): 2D array of shape (num_pixels, num_recon_slices) of voxel values, where
voxel_values[i, j] is the value of the voxel in slice j at the location determined by indices[i].
pixel_indices (jax array of int): 1D vector of shape (len(pixel_indices), ) holding the indices into
the flattened array of size num_rows x num_cols.
single_view_params: These are the angle and helical_z_shift for the view being forward projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params())
Returns:
jax array of shape (num_det_rows, num_det_channels)
"""
# Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access
# elements using, for example, gp.delta_det_channel, so a longer name would be clumsy.
gp = projector_params.geometry_params
num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape
# Get the data needed for horizontal projection
n_p, n_p_center, W_p_c, cos_alpha_p_xy = ConeBeamModel.compute_horizontal_data(pixel_indices, single_view_params, projector_params)
L_max = jnp.minimum(1, W_p_c)
# Allocate the sinogram array
sinogram_view = jnp.zeros((num_det_rows, num_det_channels))
# Do the horizontal projection
for n_offset in jnp.arange(start=-gp.psf_radius, stop=gp.psf_radius + 1):
n = n_p_center + n_offset
abs_delta_p_c_n = jnp.abs(n_p - n)
L_p_c_n = jnp.clip((W_p_c + 1) / 2 - abs_delta_p_c_n, 0, L_max)
A_chan_n = gp.delta_voxel * L_p_c_n / cos_alpha_p_xy
A_chan_n *= (n >= 0) * (n < num_det_channels)
sinogram_view = sinogram_view.at[:, n].add(A_chan_n.reshape((1, -1)) * voxel_values.T)
return sinogram_view
@staticmethod
def forward_vertical_fan_one_pixel_to_one_view(voxel_cylinder, pixel_index, single_view_params, projector_params):
"""
Apply a fan beam forward projection in the vertical direction to the pixel determined by indices
into the flattened array of size num_rows x num_cols. This returns a vector obtained from the projection of
the original voxel cylinder onto a detector column, so the output vector has length num_det_rows.
Args:
voxel_cylinder (jax array): 1D array of shape (num_recon_slices, ) of voxel values, where
voxel_cylinder[j] is the value of the voxel in slice j at the location determined by pixel_index.
pixel_index (int): Index into the flattened array of size num_rows x num_cols.
single_view_params: These are the angle and helical_z_shift for the view being forward projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params())
Returns:
jax array of shape (num_det_rows,)
Note:
This is a helper function used in vmap in :meth:`ConeBeamModel.forward_vertical_fan_pixel_batch_to_one_view`
This method has the same signature and output as that method, except single int pixel_index is used
in place of the 1D pixel_indices, and likewise only a single voxel cylinder is returned.
"""
angle = single_view_params[0]
helical_z_shift = single_view_params[1]
# Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access
# elements using, for example, gp.delta_det_channel, so a longer name would be clumsy.
gp = projector_params.geometry_params
num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape
recon_shape = projector_params.recon_shape
num_slices = voxel_cylinder.shape[0]
# From pixel index, compute y and pixel_mag
y, pixel_mag = ConeBeamModel.compute_y_mag_for_pixel(pixel_index, angle, recon_shape, projector_params)
# The code above depends only on the pixel - a single point. z is a potentially large vector
# Here we compute cos_phi_p: 1 / cos_phi_p determines the projection length through a voxel
# For computational efficiency, we use that to scale the voxel_cylinder values.
# TODO: possibly convert to a jitted function with donate_argnames to avoid copies for z, v, phi_p, cos_phi_p
k = jnp.arange(len(voxel_cylinder))
z = gp.delta_voxel * (k - (num_slices - 1) / 2.0) + (gp.recon_slice_offset - helical_z_shift) # recon_ijk_to_xyz
v = pixel_mag * z # geometry_xyz_to_uv_mag
# Compute vertical cone angle of voxels
phi_p = jnp.arctan2(v, gp.source_detector_dist) # compute_vertical_data_single_pixel
cos_phi_p = jnp.cos(phi_p) # We assume the vertical angle |phi_p| < 45 degrees so cos_alpha_p_z = cos_phi_p
scaled_voxel_values = voxel_cylinder / cos_phi_p
# End TODO
# Get the length of projection of detector on vertical voxel profile (in fraction of voxel size)
# This is also the slope of the map from voxel index to detector index
W_p_r = (pixel_mag * gp.delta_voxel) / gp.delta_det_row
slope_k_to_m = W_p_r
L_max = jnp.minimum(1, W_p_r) # Maximum fraction of a detector that can be covered by one voxel.
# Set up detector row indices array (0, 10, 20, ..., 10*num_slice_batches)
det_rows_per_batch = gp.entries_per_cylinder_batch
det_rows_per_batch = min(det_rows_per_batch, num_det_rows)
num_det_row_batches = (num_det_rows + det_rows_per_batch - 1) // det_rows_per_batch
det_row_indices = det_rows_per_batch * jnp.arange(num_det_row_batches)
det_center_row = (num_det_rows - 1) / 2.0
row_batch = jnp.arange(det_rows_per_batch)
# Set up a function to map over subsets of detector rows
def create_det_column_rows(start_index):
# We need to match the back projector, so we have to determine the fraction of each voxel that projects
# to each detector.
# First project the detector centers to the voxel cylinder
m_center = start_index + row_batch # Center of detector elements
v_m = (m_center - det_center_row) * gp.delta_det_row - gp.det_row_offset # Detector center in ALUs
z_m = v_m / pixel_mag # z coordinate of the projection of the center of the first detector element in this batch
# Convert to voxel fractional index and find the center of each voxel
k_m = (z_m - (gp.recon_slice_offset - helical_z_shift)) / gp.delta_voxel + (num_slices - 1) / 2.0
k_m_center = jnp.round(k_m).astype(int) # Center of the voxel hit by the center of the detector
# Then map the center of the voxels back to the detector.
m_p = slope_k_to_m * (k_m_center - k_m[0]) + m_center[0] # Projection to detector of voxel centers
# Allocate space
new_column_batch = jnp.zeros(det_rows_per_batch)
# Do the vertical projection
for k_offset in jnp.arange(start=-gp.bp_psf_radius, stop=gp.bp_psf_radius+1):
k_ind = k_m_center + k_offset # Indices of the current set of voxels touched by the detector elements
# The projection of these centers is the projection of k_m_center (which is m_p) plus
# the offset times the slope of the map from voxel index to detector index
abs_delta_p_r_m = jnp.abs(m_p + slope_k_to_m * k_offset - m_center) # Distance from projection of center of voxel to center of detector
A_row_k = jnp.clip((W_p_r + 1) / 2 - abs_delta_p_r_m, 0, L_max) # Fraction of the detector hit by this voxel
A_row_k *= (k_ind >= 0) * (k_ind < num_slices)
new_column_batch = jnp.add(new_column_batch, A_row_k * scaled_voxel_values[k_ind])
return new_column_batch, None
det_column, _ = jax.lax.map(create_det_column_rows, det_row_indices)
det_column = det_column.flatten()
det_column = jax.lax.slice_in_dim(det_column, 0, num_det_rows)
return det_column
@staticmethod
@partial(jax.jit, static_argnames='projector_params')
def back_project_one_view_to_pixel_batch(sinogram_view, pixel_indices, single_view_params, projector_params, coeff_power=1):
"""
Use vmap to do a backprojection from one view to multiple pixels (voxel cylinders).
Args:
sinogram_view (2D jax array): one view of the sinogram to be back projected.
2D jax array of shape (num_det_rows)x(num_det_channels)
pixel_indices (1D jax array of int): indices into flattened array of size num_rows x num_cols.
single_view_params: These are the view dependent parameters for the view being back projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()).
coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power).
Normally 1, but should be 2 when computing Hessian diagonal.
Returns:
The voxel values for all slices at the input index (i.e., a voxel cylinder) obtained by backprojecting
the input sinogram view.
"""
vertical_fan_projector = ConeBeamModel.back_vertical_fan_one_view_to_pixel_batch
horizontal_fan_projector = ConeBeamModel.back_horizontal_fan_one_view_to_pixel_batch
det_voxel_cylinder = horizontal_fan_projector(sinogram_view, pixel_indices, single_view_params,
projector_params, coeff_power=coeff_power)
back_projection = vertical_fan_projector(det_voxel_cylinder, pixel_indices, single_view_params,
projector_params, coeff_power=coeff_power)
return back_projection
@staticmethod
def back_horizontal_fan_one_view_to_pixel_batch(sinogram_view, pixel_indices, single_view_params,
projector_params, coeff_power=1):
"""
Apply the back projection of a horizontal fan beam transformation to a single sinogram view
and return the resulting voxel cylinders.
Args:
sinogram_view (2D jax array): one view of the sinogram to be back projected.
2D jax array of shape (num_det_rows)x(num_det_channels)
pixel_indices (1D jax array of int): indices into flattened array of size num_rows x num_cols.
single_view_params: These are the angle and helical_z_shift for the view being back projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()).
coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power).
Normally 1, but should be 2 when computing Hessian diagonal.
Returns:
jax array of shape (len(pixel_indices), num_det_rows)
"""
# Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access
# elements using, for example, gp.delta_det_channel, so a longer name would be clumsy.
gp = projector_params.geometry_params
num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape
num_pixels = pixel_indices.shape[0]
# Get the data needed for horizontal projection
n_p, n_p_center, W_p_c, cos_alpha_p_xy = ConeBeamModel.compute_horizontal_data(pixel_indices, single_view_params, projector_params)
L_max = jnp.minimum(1, W_p_c)
# Allocate the voxel cylinder array
det_voxel_cylinder = jnp.zeros((num_pixels, num_det_rows))
# Do the horizontal projection
for n_offset in jnp.arange(start=-gp.psf_radius, stop=gp.psf_radius + 1):
n = n_p_center + n_offset
abs_delta_p_c_n = jnp.abs(n_p - n)
L_p_c_n = jnp.clip((W_p_c + 1) / 2 - abs_delta_p_c_n, 0, L_max)
A_chan_n = gp.delta_voxel * L_p_c_n / cos_alpha_p_xy
A_chan_n *= (n >= 0) * (n < num_det_channels)
A_chan_n = A_chan_n ** coeff_power
det_voxel_cylinder = jnp.add(det_voxel_cylinder, A_chan_n.reshape((-1, 1)) * sinogram_view[:, n].T)
return det_voxel_cylinder
@staticmethod
def back_vertical_fan_one_view_to_pixel_batch(det_voxel_cylinder, pixel_indices, single_view_params,
projector_params, coeff_power=1):
"""
Apply a fan beam backward projection in the vertical direction to the pixel determined by indices
into the flattened array of size num_rows x num_cols. This returns a vector obtained from the projection of
the detector-based voxel cylinders onto voxel cylinders in recon space, so the output vector has length
num_recon_slices.
Args:
det_voxel_cylinder (2D jax array): 2D array of shape (num_pixels, num_det_rows) of voxel values, where
det_voxel_cylinder[i, j] is the value of the voxel in row j at the location determined by indices[i].
pixel_indices (1D jax array of int): indices into flattened array of size num_rows x num_cols.
single_view_params: These are the view dependent parameters for the view being back projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()).
coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power).
Normally 1, but should be 2 when computing Hessian diagonal.
Returns:
2D jax array of shape (num_pixels, num_recon_slices) of voxel values.
"""
pixel_map = jax.vmap(ConeBeamModel.back_vertical_fan_one_view_to_one_pixel,
in_axes=(0, 0, None, None, None))
new_pixels = pixel_map(det_voxel_cylinder, pixel_indices, single_view_params, projector_params, coeff_power)
return new_pixels
@staticmethod
def back_vertical_fan_one_view_to_one_pixel(detector_column_values, pixel_index, single_view_params, projector_params,
coeff_power=1):
"""
Apply the back projection of a vertical fan beam transformation to a single voxel cylinder and return the column
vector of the resulting values.
Args:
detector_column_values (1D jax array): 1D array of shape (num_det_rows,) of voxel values, where
detector_column_values[i, j] is the value of the voxel in row j at the location determined by indices[i].
pixel_index (int): Index into flattened array of size num_rows x num_cols.
single_view_params: These are the angle and helical_z_shift for the view being back projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()).
coeff_power (int): backproject using the coefficients of (A_ij ** coeff_power).
Normally 1, but should be 2 when computing Hessian diagonal.
Returns:
1D jax array of shape (num_recon_slices,) of voxel values.
"""
# Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access
# elements using, for example, gp.delta_det_channel, so a longer name would be clumsy.
gp = projector_params.geometry_params
num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape
recon_shape = projector_params.recon_shape
num_recon_rows, num_recon_cols, num_recon_slices = recon_shape
# Set up slice indices array (0, slices_per_batch, 2*slices_per_batch, ..., num_slice_batches*slices_per_batch)
slices_per_batch = gp.entries_per_cylinder_batch
slices_per_batch = min(slices_per_batch, num_recon_slices)
num_slice_batches = (num_recon_slices + slices_per_batch - 1) // slices_per_batch
slice_indices = slices_per_batch * jnp.arange(num_slice_batches)
# Set up a function to map over the slices of the cylinder
# Here we can use a map over subsections of the voxel cylinder because we are indexing by slice,
# so there is no overlap from one section to the next.
def create_voxel_cylinder_slices(start_index):
# Allocate space
new_cylinder = jnp.zeros(slices_per_batch)
# Get the data needed for vertical projection
cur_slice_indices = start_index + jnp.arange(slices_per_batch)
m_p, m_p_center, W_p_r, cos_alpha_p_z = ConeBeamModel.compute_vertical_data_single_pixel(pixel_index, cur_slice_indices, single_view_params,
projector_params)
L_max = jnp.minimum(1, W_p_r) # Maximum fraction of a detector that can be covered by one voxel.
# Do the vertical projection
for m_offset in jnp.arange(start=-gp.psf_radius, stop=gp.psf_radius+1):
m = m_p_center + m_offset
abs_delta_p_r_m = jnp.abs(m_p - m) # Distance from projection of center of voxel to center of detector
L_p_r_m = jnp.clip((W_p_r + 1) / 2 - abs_delta_p_r_m, 0, L_max)
A_row_m = L_p_r_m / cos_alpha_p_z
A_row_m *= (m >= 0) * (m < num_det_rows)
A_row_m = A_row_m ** coeff_power
new_cylinder = jnp.add(new_cylinder, A_row_m * detector_column_values[m])
return new_cylinder, None
recon_voxel_cylinder, _ = jax.lax.map(create_voxel_cylinder_slices, slice_indices)
recon_voxel_cylinder = recon_voxel_cylinder.flatten()
recon_voxel_cylinder = jax.lax.slice_in_dim(recon_voxel_cylinder, 0, num_recon_slices)
return recon_voxel_cylinder
@staticmethod
def compute_vertical_data_single_pixel(pixel_index, slice_indices, single_view_params, projector_params):
"""
Compute the quantities m_p, m_p_center, W_p_r, cos_alpha_p_z needed for vertical projection.
Args:
pixel_index (int): Index into flattened array of size num_rows x num_cols.
slice_indices (array of int): Indices into the recon slices.
single_view_params: These are the angle and helical_z_shift for the view being back projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()).
Returns:
m_p, m_p_center, W_p_r, cos_alpha_p_z
"""
angle = single_view_params[0]
helical_z_shift = single_view_params[1]
# Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access
# elements using, for example, gp.delta_det_channel, so a longer name would be clumsy.
gp = projector_params.geometry_params
num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape
recon_shape = projector_params.recon_shape
num_recon_rows, num_recon_cols, num_recon_slices = recon_shape
# Convert the index into (i,j,k) coordinates corresponding to the indices into the 3D voxel array
row_index, col_index = jnp.unravel_index(pixel_index, recon_shape[:2])
# slice_indices = jnp.arange(num_recon_slices)
x_p, y_p, z_p = ConeBeamModel.recon_ijk_to_xyz(row_index, col_index, slice_indices, gp.delta_voxel,
recon_shape, gp.recon_slice_offset - helical_z_shift, angle)
# Convert from xyz to coordinates on detector
u_p, v_p, pixel_mag = ConeBeamModel.geometry_xyz_to_uv_mag(x_p, y_p, z_p, gp.source_detector_dist, gp.magnification,
gp.use_curved_detector)
# Convert from uv to index coordinates in detector and get the vector of center detector rows for this cylinder
m_p, _ = ConeBeamModel.detector_uv_to_mn(u_p, v_p, gp.delta_det_channel, gp.delta_det_row, gp.det_channel_offset,
gp.det_row_offset, num_det_rows, num_det_channels)
m_p_center = jnp.round(m_p).astype(int)
# Compute vertical cone angle of pixel
phi_p = jnp.arctan2(v_p, gp.source_detector_dist)
# Compute cos alpha for row and columns
cos_phi_p = jnp.cos(phi_p) # We assume the vertical angle |phi_p| < 45 degrees so cos_alpha_p_z = cos_phi_p
# cos_alpha_p_z = jnp.maximum(jnp.abs(cos_phi_p), jnp.abs(jnp.sin(phi_p)))
# Get the length of projection of flattened voxel on detector (in fraction of detector size)
W_p_r = pixel_mag * (gp.delta_voxel / gp.delta_det_row) # * cos_alpha_p_z / cos_phi_p
vertical_data = (m_p, m_p_center, W_p_r, cos_phi_p) # cos_alpha_p_z)
return vertical_data
@staticmethod
def compute_horizontal_data(pixel_indices, single_view_params, projector_params):
"""
Compute the quantities n_p, n_p_center, W_p_c, cos_alpha_p_xy needed for horizontal projection.
Behavior differs by detector type:
- **Flat** (use_curved_detector=False): Uses theta_p = arctan2(u_p, source_detector_dist) and
includes a 1/cos(theta_p) foreshortening correction in W_p_c.
- **Curved** (use_curved_detector=True): Uses theta_p = u_p / source_detector_dist
(arc-length parameterisation) and omits the 1/cos(theta_p) term from W_p_c.
Args:
pixel_indices (1D jax array of int): indices into flattened array of size num_rows x num_cols.
single_view_params: These are the angle and helical_z_shift for the view being back projected.
projector_params (namedtuple): tuple of (sinogram_shape, recon_shape, get_geometry_params()).
Returns:
n_p, n_p_center, W_p_c, cos_alpha_p_xy
"""
angle = single_view_params[0]
helical_z_shift = single_view_params[1]
# Get all the geometry parameters - we use gp since geometry parameters is a named tuple and we'll access
# elements using, for example, gp.delta_det_channel, so a longer name would be clumsy.
gp = projector_params.geometry_params
num_views, num_det_rows, num_det_channels = projector_params.sinogram_shape
recon_shape = projector_params.recon_shape
num_recon_rows, num_recon_cols, num_recon_slices = recon_shape
# Convert the index into (i,j,k) coordinates corresponding to the indices into the 3D voxel array
row_index, col_index = jnp.unravel_index(pixel_indices, recon_shape[:2])
slice_index = jnp.arange(1)
x_p, y_p, z_p = ConeBeamModel.recon_ijk_to_xyz(row_index, col_index, slice_index, gp.delta_voxel,
recon_shape, gp.recon_slice_offset - helical_z_shift, angle)
# Convert from xyz to coordinates on detector
# pixel_mag should be kept in terms of magnification to allow for source_detector_dist = jnp.Inf
u_p, v_p, pixel_mag = ConeBeamModel.geometry_xyz_to_uv_mag(x_p, y_p, z_p, gp.source_detector_dist, gp.magnification,
gp.use_curved_detector)
det_center_channel = (num_det_channels - 1) / 2.0 # num_of_cols
# Calculate indices on the detector grid
n_p = (u_p + gp.det_channel_offset) / gp.delta_det_channel + det_center_channel # Sync with detector_uv_to_mn
n_p_center = jnp.round(n_p).astype(int)
# theta_p and W_p_c computation differ between flat and curved detectors
if not gp.use_curved_detector: # 'flat'
# Exact horizontal cone angle for flat detector
theta_p = jnp.arctan2(u_p, gp.source_detector_dist)
cos_alpha_p_xy = jnp.maximum(jnp.abs(jnp.cos(angle - theta_p)),
jnp.abs(jnp.sin(angle - theta_p)))
# Foreshortening correction: voxel footprint widens at oblique angles on a flat detector
W_p_c = pixel_mag * (gp.delta_voxel / gp.delta_det_channel) * (cos_alpha_p_xy / jnp.cos(theta_p))
else: # 'curved'
# u_p is arc-length, so effective angle is u_p / sdd
theta_p = u_p / gp.source_detector_dist
cos_alpha_p_xy = jnp.maximum(jnp.abs(jnp.cos(angle - theta_p)),
jnp.abs(jnp.sin(angle - theta_p)))
# No cos(theta_p) denominator: arc parameterisation absorbs the foreshortening
W_p_c = pixel_mag * (gp.delta_voxel / gp.delta_det_channel) * cos_alpha_p_xy
horizontal_data = (n_p, n_p_center, W_p_c, cos_alpha_p_xy)
return horizontal_data
@staticmethod
def recon_ijk_to_xyz(i, j, k, delta_voxel, recon_shape, recon_slice_offset, angle):
"""
Convert (i, j, k) indices into the recon volume to corresponding (x, y, z) coordinates.
"""
num_recon_rows, num_recon_cols, num_recon_slices = recon_shape
# Compute the un-rotated coordinates relative to iso
# Note the change in order from (i, j) to (y, x)!!
y_tilde = delta_voxel * (i - (num_recon_rows - 1) / 2.0)
x_tilde = delta_voxel * (j - (num_recon_cols - 1) / 2.0)
# Precompute cosine and sine of view angle, then do the rotation
cosine = jnp.cos(angle) # length = num_views
sine = jnp.sin(angle) # length = num_views
x = cosine * x_tilde - sine * y_tilde
y = sine * x_tilde + cosine * y_tilde
z = delta_voxel * (k - (num_recon_slices - 1) / 2.0) + recon_slice_offset
return x, y, z
@staticmethod
def geometry_xyz_to_uv_mag(x, y, z, source_detector_dist, magnification, use_curved_detector=False):
"""
Convert (x, y, z) coordinates to (u, v) detector coordinates plus the pixel-dependent magnification.
Behavior differs by detector type:
- **Flat** (use_curved_detector=False): u = pixel_mag * x (linear perspective projection).
- **Curved** (use_curved_detector=True): u = source_detector_dist * arctan2(x, source_iso_dist - y)
(arc-length on a cylinder of radius source_detector_dist).
"""
# Compute the magnification at this specific voxel
# The following expression is valid even when source_detector_dist = jnp.Inf
pixel_mag = 1 / (1 / magnification - y / source_detector_dist)
# u computation depends on use_curved_detector
if not use_curved_detector: # 'flat'
# Standard flat-panel
u = pixel_mag * x
else: # 'curved'
# u is arc-length on a cylinder of radius source_detector_dist
source_iso_dist = source_detector_dist / magnification
u = source_detector_dist * jnp.arctan2(x, (source_iso_dist - y))
v = pixel_mag * z
return u, v, pixel_mag
@staticmethod
@jax.jit
def detector_uv_to_mn(u, v, delta_det_channel, delta_det_row, det_channel_offset, det_row_offset, num_det_rows,
num_det_channels):
"""
Convert (u, v) detector coordinates to fractional indices (m, n) into the detector.
Note:
This version does not account for nonzero detector rotation.
"""
# Account for small rotation of the detector
# TODO: In addition to including the rotation, we'd need to adjust the calculation of the channel as a
# function of slice.
u_tilde = u # jnp.cos(det_rotation) * u + jnp.sin(det_rotation) * v
v_tilde = v # -jnp.sin(det_rotation) * u + jnp.cos(det_rotation) * v
# Get the center of the detector grid for columns and rows
det_center_row = (num_det_rows - 1) / 2.0 # num_of_rows
det_center_channel = (num_det_channels - 1) / 2.0 # num_of_cols
# Calculate indices on the detector grid
m = (v_tilde + det_row_offset) / delta_det_row + det_center_row
n = (u_tilde + det_channel_offset) / delta_det_channel + det_center_channel # Sync with compute_horizontal_data
return m, n
@staticmethod
@jax.jit
def detector_mn_to_uv(m, n, delta_det_channel, delta_det_row, det_channel_offset, det_row_offset, num_det_rows,
num_det_channels):
"""
Convert fractional detector grid indices (m, n) into detector coordinates (u, v).
Parameters:
m: Fractional row index on the detector grid (vertical direction).
n: Fractional channel index on the detector grid (horizontal direction).
delta_det_channel: Spacing (pitch) of the detector channels (horizontal direction).
delta_det_row: Spacing (pitch) of the detector rows (vertical direction).
det_channel_offset: Offset in the detector channel (horizontal) direction.
det_row_offset: Offset in the detector row (vertical) direction.
num_det_rows: Total number of rows in the detector.
num_det_channels: Total number of channels in the detector.
Returns:
u: Physical detector coordinate in the channel direction.
v: Physical detector coordinate in the row direction.
"""
# Calculate the center of the detector grid
det_center_row = (num_det_rows - 1) / 2.0
det_center_channel = (num_det_channels - 1) / 2.0
# Compute detector coordinates (u, v)
v = (m - det_center_row) * delta_det_row - det_row_offset
u = (n - det_center_channel) * delta_det_channel - det_channel_offset
return u, v
@staticmethod
@jax.jit
def compute_y_mag_for_pixel(pixel_index, angle, recon_shape, projector_params):
gp = projector_params.geometry_params
row_index, col_index = jnp.unravel_index(pixel_index, recon_shape[:2])
# Compute the un-rotated coordinates relative to iso
# Note the change in order from (i, j) to (y, x)!!
y_tilde = gp.delta_voxel * (row_index - (recon_shape[0] - 1) / 2.0)
x_tilde = gp.delta_voxel * (col_index - (recon_shape[1] - 1) / 2.0)
# Precompute cosine and sine of view angle, then do the rotation
cosine = jnp.cos(angle) # length = num_views
sine = jnp.sin(angle) # length = num_views
y = sine * x_tilde + cosine * y_tilde
# Convert from xyz to coordinates on detector
pixel_mag = 1 / (1 / gp.magnification - y / gp.source_detector_dist)
return y, pixel_mag
def direct_recon(self, sinogram, filter_name="ramp", view_batch_size=DIRECT_RECON_VIEW_BATCH_SIZE):
return self.fdk_recon(sinogram, filter_name=filter_name, view_batch_size=view_batch_size)
def direct_filter(self, sinogram, filter_name="ramp", view_batch_size=DIRECT_RECON_VIEW_BATCH_SIZE):
"""
Perform filtering on the given sinogram as needed for an FBP/FDK or other direct recon.
Args:
sinogram (jax array): The input sinogram with shape (num_views, num_rows, num_channels).
filter_name (string, optional): Name of the filter to be used. Defaults to "ramp"
view_batch_size (int, optional): Size of view batches (used to limit memory use)
Returns:
filtered_sinogram (jax array): The sinogram after FBP filtering.
"""
return self.fdk_filter(sinogram, filter_name=filter_name, view_batch_size=view_batch_size)
def fdk_filter(self, sinogram, filter_name="ramp", view_batch_size=DIRECT_RECON_VIEW_BATCH_SIZE):
"""
Perform FDK filtering on the given sinogram.
Args:
sinogram (jax array): The input sinogram with shape (num_views, num_rows, num_channels).
filter_name (string, optional): Name of the filter to be used. Defaults to "ramp"
view_batch_size (int, optional): Size of view batches (used to limit memory use)
Returns:
filtered_sinogram (jax array): The sinogram after FDK filtering.
"""
# Get parameters
num_views, num_rows, num_channels = sinogram.shape
source_detector_dist, source_iso_dist = self.get_params(['source_detector_dist', 'source_iso_dist'])
delta_voxel, delta_det_row, delta_det_channel = self.get_params(['delta_voxel', 'delta_det_row', 'delta_det_channel'])
det_row_offset, det_channel_offset = self.get_params(['det_row_offset', 'det_channel_offset'])
if view_batch_size is None:
view_batch_size = self.view_batch_size_for_vmap
max_view_batch_size = 128 # Limit the view batch size here and ParallelBeam due to https://github.com/jax-ml/jax/issues/27591
view_batch_size = min(view_batch_size, max_view_batch_size)
# Magnification factor M_0 = Source-Detector Distance / Source-Isocenter Distance
M_0 = self.get_magnification()
# Define the index arrays for channels and rows
m = jnp.arange(num_rows) # Column vector for rows
n = jnp.arange(num_channels) # Row vector for channels
m_grid, n_grid = jnp.meshgrid(m, n, indexing='ij')
# Coordinate transformation to physical distances:
u_grid, v_grid = self.detector_mn_to_uv(m_grid, n_grid, delta_det_channel, delta_det_row,
det_channel_offset, det_row_offset, num_rows, num_channels)
# Compute the weight
weight_map = source_detector_dist / jnp.sqrt(source_detector_dist ** 2 + u_grid**2 + v_grid**2)
# Apply the pre-weighting factor to the sinogram
weighted_sinogram = jax.device_put(sinogram * weight_map[None, :, :], self.sinogram_device)
# Compute the scaled filter
# Scaling factor alpha adjusts the filter to account for voxel size, ensuring consistent reconstruction.
# For a detailed theoretical derivation of this scaling factor, please refer to the zip file linked at
# https://mbirjax.readthedocs.io/en/latest/theory.html
recon_filter = tomography_utils.generate_direct_recon_filter(num_channels, filter_name=filter_name)
alpha = delta_det_row / (delta_voxel**3 * M_0)
recon_filter = alpha * recon_filter
# Define convolution for a single row (across its channels)
def convolve_row(row):
return jax.scipy.signal.fftconvolve(row, recon_filter, mode="valid")
# Apply above convolve func across each row of a view
def apply_convolution_to_view(view):
return jax.vmap(convolve_row)(view)
# Apply convolution across the channels of the weighted sinogram per each fixed view & row
num_views = sinogram.shape[0]
filtered_sino_list = []
for i in range(0, num_views, view_batch_size):
sino_batch = jax.device_put(weighted_sinogram[i:min(i + view_batch_size, num_views)], self.sinogram_device)
filtered_sinogram_batch = jax.lax.map(apply_convolution_to_view, sino_batch, batch_size=view_batch_size)
filtered_sinogram_batch.block_until_ready()
filtered_sino_list.append(jax.device_put(filtered_sinogram_batch, self.sinogram_device))
filtered_sinogram = jnp.concatenate(filtered_sino_list, axis=0)
filtered_sinogram *= jnp.pi / num_views
return filtered_sinogram
[docs]
def fdk_recon(self, sinogram, filter_name="ramp", view_batch_size=DIRECT_RECON_VIEW_BATCH_SIZE):
"""
Perform FDK reconstruction on the given sinogram.
Our implementation uses standard filtering of the sinogram, then uses the adjoint of the forward projector to
perform the backprojection. This is different from many implementations, in which the backprojection is not
exactly the adjoint of the forward projection. For a detailed theoretical derivation of this implementation,
see the zip file linked at this page: https://mbirjax.readthedocs.io/en/latest/theory.html
Args:
sinogram (jax array): The input sinogram with shape (num_views, num_rows, num_channels).
filter_name (string, optional): Name of the filter to be used. Defaults to "ramp"
view_batch_size (int, optional): Size of view batches (used to limit memory use)
Returns:
recon (jax array): The reconstructed volume after FDK reconstruction.
"""
filtered_sinogram = self.fdk_filter(sinogram, filter_name=filter_name, view_batch_size=view_batch_size)
# Apply backprojection
recon = self.back_project(filtered_sinogram)
return recon
[docs]
def split_sino_recon(self, sino, weights=None, half_overlap=5, init_recon=None, max_iterations=15, stop_threshold_change_pct=0.2,
first_iteration=0, compute_prior_loss=False, logfile_path='./logs/recon.log', print_logs=True):
"""
This function reduces memory usage for cone beam MBIR reconstruction by approximately a factor of 2
by splitting the detector rows into two overlapping halves, reconstructing each half separately,
and stitching the reconstructions together.
The function can be called with the same arguments as TomographyModel.recon(), and it should return a
reconstruction which is approximately equal to the reconstruction returned by TomographyModel.recon().
Args:
sino (jnp.ndarray | np.ndarray): Full sinogram of shape (num_views, num_rows, num_cols).
weights (jnp.ndarray | np.ndarray, optional): Optional sinogram weights with the same shape as `sino`.
half_overlap (int): Number of overlapping detector rows and recon slices per half. (total overlap = 2 * half_overlap)
init_recon (optional): Same as in the recon method.
max_iterations (int, optional): Same as in the recon method.
stop_threshold_change_pct (float, optional): Same as in the recon method.
first_iteration (int, optional): Same as in the TomographyModel.recon() method.
compute_prior_loss (bool, optional): Same as in the TomographyModel.recon() method.
logfile_path (str, optional): Same as in the TomographyModel.recon() method.
print_logs (bool, optional): Same as in the TomographyModel.recon() method.
Returns:
Tuple[jnp.ndarray, dict]:
- Reconstructed volume (jax array).
- Dictionary of metadata containing recon and model parameters for each half.
Raises:
ValueError: If inputs are missing or shapes are inconsistent.
AssertionError: If array dimensions are invalid.
TypeError: If half_overlap is not an integer.
Example:
>>> import jax.numpy as jnp
>>> import mbirjax as mj
>>> sino = jnp.ones((180, 64, 64)) # (views, rows, cols)
>>> model = mj.ConeBeamModel(sinogram_shape=sino.shape,
... angles=jnp.linspace(0, jnp.pi, 180),
... source_detector_dist=1000.0,
... source_iso_dist=500.0)
>>> recon, recon_info = model.split_sino_recon(sino, half_overlap=4)
>>> recon.shape # Quilted reconstruction volume
(64, 64, 64)
"""
# -------- Basic validation --------
if sino is None:
raise ValueError("sino must be provided.")
if not (hasattr(sino, "ndim") and sino.ndim == 3):
raise AssertionError("sino must be a 3D array shaped (num_views, num_rows, num_cols).")
if weights is not None and getattr(weights, "shape", None) != sino.shape:
raise AssertionError("weights, if provided, must have the same shape as sino.")
# Get parameters for later use
num_views, full_num_rows, num_cols = sino.shape
# -------- parameters needed to create top and bottom models --------
delta_det_row = self.get_params('delta_det_row')
full_det_row_offset = self.get_params('det_row_offset')
delta_voxel = self.get_params('delta_voxel')
full_recon_shape = self.get_params('recon_shape')
full_recon_slice_offset = self.get_params('recon_slice_offset')
magnification = self.get_magnification()
# Compute overlaps for sinogram and recon
delta_detector_row_at_iso = max(delta_det_row / magnification, 1e-12)
ratio_pixel_to_sino_pitch = delta_voxel / delta_detector_row_at_iso
if ratio_pixel_to_sino_pitch > 1:
half_overlap_sino = int(jnp.round(half_overlap * ratio_pixel_to_sino_pitch))
half_overlap_recon = half_overlap
else:
half_overlap_sino = half_overlap
half_overlap_recon = int(jnp.round(half_overlap * 1/ratio_pixel_to_sino_pitch))
"""
Compute detector shape parameters for top and bottom sinograms
"""
# -------- Choose the detector row nearest to iso --------
det_iso_row_float = ((full_num_rows - 1) / 2.0) + (full_det_row_offset / delta_det_row)
det_iso_row_index = int(jnp.round(det_iso_row_float))
# Validate iso-row index is inside (0, num_rows)
if not (0 < det_iso_row_index < full_num_rows):
raise ValueError(
f"Computed det_iso_row_index={det_iso_row_index} is out of valid range (0, {full_num_rows-1}). "
)
# -------- Detector row ranges for top and bottom sinogram halves --------
top_lo = 0
top_hi = min(det_iso_row_index + half_overlap_sino, full_num_rows)
bot_lo = max(det_iso_row_index - half_overlap_sino, 0)
bot_hi = full_num_rows
# -------- Split sinogram (and weights) into top and bottom halves --------
sino_top_half = sino[:, top_lo:top_hi, :]
sino_bot_half = sino[:, bot_lo:bot_hi, :]
weights_top_half = None
weights_bot_half = None
if weights is not None:
weights_top_half = weights[:, top_lo:top_hi, :]
weights_bot_half = weights[:, bot_lo:bot_hi, :]
# -------- Calculate number of rows and center location for top and bottom sinograms --------
top_num_rows = top_hi - top_lo
bot_num_rows = bot_hi - bot_lo
full_det_center = (full_num_rows - 1) / 2.0
top_det_center = (top_num_rows - 1) / 2.0
bot_det_center = (bot_num_rows - 1) / 2.0
# -------- Calculate detector row offsets required for top and bottom models --------
top_det_row_offset = full_det_row_offset + (full_det_center - (top_det_center + top_lo)) * delta_det_row
bot_det_row_offset = full_det_row_offset + (full_det_center - (bot_det_center + bot_lo)) * delta_det_row
# Set the regularization parameters from the full sinogram
self.auto_set_regularization_params(sino)
# -------- Build top-half model --------
ct_model_top_half = mj.copy_ct_model(self, new_num_det_rows=top_num_rows)
ct_model_top_half.set_params(det_row_offset=top_det_row_offset)
ct_model_top_half.set_params(auto_regularize_flag=False)
# -------- Build bottom-half model --------
ct_model_bot_half = mj.copy_ct_model(self, new_num_det_rows=bot_num_rows)
ct_model_bot_half.set_params(det_row_offset=bot_det_row_offset)
ct_model_bot_half.set_params(auto_regularize_flag=False)
"""
Compute recon shape parameters for top and bottom reconstructions
"""
# Get recon shape parameters for later use
full_recon_rows, full_recon_cols, full_recon_slices = full_recon_shape
# -------- Compute the recon slice nearest to iso --------
full_recon_iso_slice_index_float = (full_recon_slices - 1) / 2.0 - full_recon_slice_offset/delta_voxel
split_index = int(jnp.round(full_recon_iso_slice_index_float))
top_num_slices = split_index + 1
# Compute the offset of the split from iso.
# This will be used to slightly shift the slices so that they align with a standard reconstruction.
split_offset = split_index - full_recon_iso_slice_index_float
# Fallback: If split index creates an empty top or bottom half sinogram, then warn and do a normal MBIR recon.
if (split_index < 1) or (split_index > full_recon_slices - 2):
warnings.warn(
"split_index is too close to the volume boundary; falling back to standard MBIR reconstruction.",
UserWarning,
)
return self.recon(
sino,
weights=weights,
init_recon=init_recon,
max_iterations=max_iterations,
stop_threshold_change_pct=stop_threshold_change_pct,
first_iteration=first_iteration,
compute_prior_loss=compute_prior_loss,
logfile_path=logfile_path,
print_logs=print_logs,
)
# -------- Compute and set the shapes of top and bottom recons --------
top_recon_shape = (full_recon_shape[0], full_recon_shape[1], top_num_slices + half_overlap_recon)
bot_recon_shape = (full_recon_shape[0], full_recon_shape[1], (full_recon_shape[2] - top_num_slices) + half_overlap_recon)
ct_model_top_half.set_params(recon_shape=top_recon_shape)
ct_model_bot_half.set_params(recon_shape=bot_recon_shape)
# -------- Compute and set the offsets of top and bottom recons --------
top_recon_slice_offset = (+half_overlap_recon - (top_recon_shape[2]-1)/2 + 0 + split_offset) * delta_voxel
bot_recon_slice_offset = (-half_overlap_recon + (bot_recon_shape[2]-1)/2 + 1 + split_offset) * delta_voxel
ct_model_top_half.set_params(recon_slice_offset=top_recon_slice_offset)
ct_model_bot_half.set_params(recon_slice_offset=bot_recon_slice_offset)
# -------- Reconstruct halves (pass weights if provided) --------
if init_recon is not None:
top_init_recon = init_recon[:, :, :top_recon_shape[2]]
else:
top_init_recon = None
recon_top_half, recon_top_dict = ct_model_top_half.recon(sino_top_half, weights=weights_top_half,
init_recon=top_init_recon, max_iterations=max_iterations,
stop_threshold_change_pct=stop_threshold_change_pct,
first_iteration=first_iteration,
compute_prior_loss=compute_prior_loss,
logfile_path=logfile_path, print_logs=print_logs)
recon_top_half = jax.device_get(recon_top_half)
if init_recon is not None:
bot_init_recon = init_recon[:, :, -bot_recon_shape[2]:]
else:
bot_init_recon = None
recon_bot_half, recon_bot_dict = ct_model_bot_half.recon(sino_bot_half, weights=weights_bot_half,
init_recon=bot_init_recon, max_iterations=max_iterations,
stop_threshold_change_pct=stop_threshold_change_pct,
first_iteration=first_iteration,
compute_prior_loss=compute_prior_loss,
logfile_path=logfile_path, print_logs=print_logs)
recon_bot_half = jax.device_get(recon_bot_half)
# -------- Stitch together top and bottom reconstructions --------
recon_full = mj.stitch_arrays([recon_top_half, recon_bot_half], overlap=2 * half_overlap_recon, axis=2)
# -------- Construct full reconstruction dictionary --------
recon_full_dict = {'recon_params_top': recon_top_dict['recon_params'],
'recon_params_bottom': recon_bot_dict['recon_params'],
'recon_log_top': recon_top_dict['recon_log'],
'recon_log_bottom': recon_bot_dict['recon_log'],
'notes_top': recon_top_dict['notes'],
'notes_bottom': recon_bot_dict['notes'],
'model_params_top': recon_top_dict['model_params'],
'model_params_bottom': recon_bot_dict['model_params'], }
return recon_full, recon_full_dict