import warnings
import numpy as np
import jax.numpy as jnp
import jax
import mbirjax.bn256 as bn
import mbirjax.preprocess as mjp
def get_2d_ror_mask(recon_shape, *, crop_radius_pixels=0, crop_radius_fraction=0.0):
"""
Get a binary mask for the region of reconstruction. By default, the mask is the largest possible circle
inscribed on the longest edge of the 2D recon_shape[0:2]. The radius of this circle can be reduced by
setting crop_radius_pixels or crop_radius_fraction, either of which is subtracted from the radius. Only one
of these can be nonzero. Negative values are clipped to 0.
Args:
recon_shape (tuple): Shape of recon in (rows, columns, slices)
crop_radius_pixels (int): Number of pixels to subtract from the radius before creating the mask.
crop_radius_fraction (float): Fraction to subtract from the radius before creating the mask.
Returns:
A binary mask for the region of reconstruction.
"""
# Set up a mask to zero out points outside the ROR
if crop_radius_pixels != 0 and crop_radius_fraction != 0.0:
raise ValueError('Only one of crop_radius_pixels and crop_radius_fraction can be nonzero.')
num_recon_rows, num_recon_cols = recon_shape[:2]
row_center = (num_recon_rows - 1) / 2
col_center = (num_recon_cols - 1) / 2
radius = max(row_center, col_center)
crop_radius = int(radius * crop_radius_fraction)
crop_radius = max(crop_radius, crop_radius_pixels)
crop_radius = max(crop_radius, 0)
radius -= crop_radius
col_coords = np.arange(num_recon_cols) - col_center
row_coords = np.arange(num_recon_rows) - row_center
coords = np.meshgrid(col_coords, row_coords) # Note the interchange of rows and columns for meshgrid
mask = coords[0]**2 + coords[1]**2 <= radius**2
mask = mask[:, :]
return mask
def gen_set_of_pixel_partitions(recon_shape, granularity, output_device=None, use_ror_mask=True):
"""
Generates a collection of voxel partitions for an array of specified partition sizes.
This function creates a tuple of randomly generated 2D voxel partitions.
Args:
recon_shape (tuple): Shape of recon in (rows, columns, slices)
granularity (list or tuple): List of num_subsets to use for each partition
output_device (jax device): Device on which to place the output of the partition
use_ror_mask (bool): Flag to indicate whether to mask out a circular RoR
Returns:
tuple: A tuple of 2D arrays each representing a partition of voxels into the specified number of subsets.
"""
partitions = []
for num_subsets in granularity:
partition = gen_pixel_partition(recon_shape, num_subsets, use_ror_mask=use_ror_mask)
partitions += [jax.device_put(partition, output_device),]
return partitions
def gen_pixel_partition_grid(recon_shape, num_subsets, use_ror_mask=True):
small_tile_side = np.ceil(np.sqrt(num_subsets)).astype(int)
num_subsets = small_tile_side ** 2
num_small_tiles = [np.ceil(recon_shape[k] / small_tile_side).astype(int) for k in [0, 1]]
single_subset_inds = np.random.permutation(num_subsets).reshape((small_tile_side, small_tile_side))
subset_inds = np.tile(single_subset_inds, num_small_tiles)
subset_inds = subset_inds[:recon_shape[0], :recon_shape[1]]
ror_mask = get_2d_ror_mask(recon_shape[:2]) if use_ror_mask else 1
subset_inds = (subset_inds + 1) * ror_mask - 1 # Get a - at each location outside the mask, subset_ind at other points
subset_inds = subset_inds.flatten()
num_inds = len(np.where(subset_inds > -1)[0])
if num_subsets > num_inds:
# num_subsets = len(indices)
warning = '\nThe number of partition subsets is greater than the number of pixels in the region of '
warning += 'reconstruction. \nReducing the number of subsets to equal the number of indices.'
warnings.warn(warning)
subset_inds = subset_inds[subset_inds > -1]
return jnp.array(subset_inds).reshape((-1, 1))
flat_inds = []
max_points = 0
min_points = subset_inds.size
nonempty_subsets = np.unique(subset_inds[subset_inds>=0])
for k in nonempty_subsets:
cur_inds = np.where(subset_inds == k)[0]
flat_inds.append(cur_inds) # Get all the indices for each subset
max_points = max(max_points, cur_inds.size)
min_points = min(min_points, cur_inds.size)
extra_point_inds = np.random.randint(min_points, size=(max_points - min_points + 1,))
for k in range(len(nonempty_subsets)):
cur_inds = flat_inds[k]
num_extra_points = max_points - cur_inds.size
if num_extra_points > 0:
extra_subset_inds = (k + 1 + np.arange(num_extra_points, dtype=int)) % len(nonempty_subsets)
new_point_inds = [flat_inds[extra_subset_inds[j]][extra_point_inds[j]] for j in range(num_extra_points)]
flat_inds[k] = np.concatenate((cur_inds, new_point_inds))
flat_inds = np.array(flat_inds)
# Reorganize into subsets, then sort each subset
indices = jnp.array(flat_inds)
return jnp.array(indices)
def gen_pixel_partition(recon_shape, num_subsets, use_ror_mask=True):
"""
Generates a partition of pixel indices into specified number of subsets for use in tomographic reconstruction algorithms.
The function ensures that each subset contains an equal number of pixels, suitable for VCD reconstruction.
Args:
recon_shape (tuple): Shape of recon in (rows, columns, slices)
num_subsets (int): The number of subsets to divide the pixel indices into.
use_ror_mask (bool): Flag to indicate whether to mask out a circular RoR
Raises:
ValueError: If the number of subsets specified is greater than the total number of pixels in the grid.
Returns:
jnp.array: A JAX array where each row corresponds to a subset of pixel indices, sorted within each subset.
"""
# Determine the 2D indices within the RoR
num_recon_rows, num_recon_cols = recon_shape[:2]
max_index_val = num_recon_rows * num_recon_cols
indices = np.arange(max_index_val, dtype=np.int32)
# Mask off indices that are outside the region of reconstruction
if use_ror_mask:
mask = get_2d_ror_mask(recon_shape)
mask = mask.flatten()
indices = indices[mask == 1]
if num_subsets > len(indices):
num_subsets = len(indices)
warning = '\nThe number of partition subsets is greater than the number of pixels in the region of '
warning += 'reconstruction. \nReducing the number of subsets to equal the number of indices.'
warnings.warn(warning)
# Determine the number of indices to repeat to make the total number divisible by num_subsets
num_indices_per_subset = int(np.ceil((len(indices) / num_subsets)))
array_size = num_subsets * num_indices_per_subset
num_extra_indices = array_size - len(indices) # Note that this is >=0 since otherwise the ValueError is raised
indices = np.random.permutation(indices)
# Enlarge the array to the desired length by adding random indices that are not in the final subset
num_non_final_indices = (num_subsets - 1) * num_indices_per_subset
extra_indices = np.random.choice(indices[:num_non_final_indices], size=num_extra_indices, replace=False)
indices = np.concatenate((indices, extra_indices))
# Reorganize into subsets, then sort each subset
indices = indices.reshape(num_subsets, indices.size // num_subsets)
indices = jnp.sort(indices, axis=1)
return jnp.array(indices)
def gen_pixel_partition_blue_noise(recon_shape, num_subsets, use_ror_mask=True):
"""
Generates a partition of pixel indices into specified number of subsets for use in tomographic reconstruction algorithms.
The function ensures that each subset contains an equal number of pixels, suitable for VCD reconstruction.
Args:
recon_shape (tuple): Shape of recon in (rows, columns, slices)
num_subsets (int): The number of subsets to divide the pixel indices into.
use_ror_mask (bool): Flag to indicate whether to mask out a circular RoR
Raises:
ValueError: If the number of subsets specified is greater than the total number of pixels in the grid.
Returns:
jnp.array: A JAX array where each row corresponds to a subset of pixel indices, sorted within each subset.
"""
pattern = bn.bn256
num_tiles = [np.ceil(recon_shape[k] / pattern.shape[k]).astype(int) for k in [0, 1]]
ror_mask = get_2d_ror_mask(recon_shape) if use_ror_mask else 1
single_subset_inds = np.floor(pattern / (2**16 / num_subsets)).astype(int)
# Repeat each bn subset to do the tiling
subset_inds = np.tile(single_subset_inds, num_tiles)
subset_inds = subset_inds[:recon_shape[0], :recon_shape[1]]
subset_inds = (subset_inds + 1) * ror_mask - 1 # Get a -1 at each location outside the mask, subset_ind at other points
subset_inds = subset_inds.flatten()
num_valid_inds = np.sum(subset_inds >= 0)
if num_subsets > num_valid_inds:
return gen_pixel_partition(recon_shape, num_subsets)
flat_inds = []
max_points = 0
min_points = subset_inds.size
for k in range(num_subsets):
cur_inds = np.where(subset_inds == k)[0]
flat_inds.append(cur_inds) # Get all the indices for each subset
max_points = max(max_points, cur_inds.size)
min_points = min(min_points, cur_inds.size)
if min_points == 0:
return gen_pixel_partition(recon_shape, num_subsets)
extra_point_inds = np.random.randint(low=0, high=min_points, size=(max_points - min_points + 1,))
for k in range(num_subsets):
cur_inds = flat_inds[k]
num_extra_points = max_points - cur_inds.size
if num_extra_points > 0:
extra_subset_inds = (k + 1 + np.arange(num_extra_points, dtype=int)) % num_subsets
new_point_inds = [flat_inds[extra_subset_inds[j]][extra_point_inds[j]] for j in range(num_extra_points)]
flat_inds[k] = np.concatenate((cur_inds, new_point_inds))
flat_inds = jnp.array(flat_inds)
flat_inds = jnp.sort(flat_inds, axis=1)
return flat_inds
def gen_partition_sequence(partition_sequence, max_iterations):
"""
Generates a sequence of voxel partitions of the specified length by extending the sequence
with the last element if necessary.
"""
# Get sequence from params and convert it to a np array
partition_sequence = jnp.array(partition_sequence)
# Check if the sequence needs to be extended
current_length = partition_sequence.size
if max_iterations > current_length:
# Calculate the number of additional elements needed
extra_elements_needed = max_iterations - current_length
# Get the last element of the array
last_element = partition_sequence[-1]
# Create an array of the last element repeated the necessary number of times
extension_array = np.full(extra_elements_needed, last_element)
# Concatenate the original array with the extension array
extended_partition_sequence = np.concatenate((partition_sequence, extension_array))
else:
# If no extension is needed, slice the original array to the desired length
extended_partition_sequence = partition_sequence[:max_iterations]
return extended_partition_sequence
def gen_full_indices(recon_shape, use_ror_mask=True):
"""
Generates a full array of voxels in the region of reconstruction.
This is useful for computing forward projections.
"""
partition = gen_pixel_partition(recon_shape, num_subsets=1, use_ror_mask=use_ror_mask)
full_indices = partition[0]
return full_indices
[docs]
def gen_weights_mar(ct_model, sinogram, init_recon=None, metal_threshold=None, beta=1.0, gamma=3.0):
"""
Generates the weights used for reducing metal artifacts in MBIR reconstruction.
This function computes sinogram weights that help to reduce metal artifacts.
More specifically, it computes weights with the form:
weights = exp( -(sinogram/beta) * ( 1 + gamma * delta(metal) ) )
delta(metal) denotes a binary mask indicating the sino entries that contain projections of metal.
Providing ``init_recon`` yields better metal artifact reduction.
If not provided, the metal segmentation is generated directly from the sinogram.
Args:
sinogram (jax array): 3D jax array containing sinogram with shape (num_views, num_det_rows, num_det_channels).
init_recon (jax array, optional): An initial reconstruction used to identify metal voxels. If not provided, Otsu's method is used to directly segment sinogram into metal regions.
metal_threshold (float, optional): Values in ``init_recon`` above ``metal_threshold`` are classified as metal. If not provided, Otsu's method is used to segment ``init_recon``.
beta (float, optional): Scalar value in range :math:`>0`.
A larger ``beta`` improves the noise uniformity, but too large a value may increase the overall noise level.
gamma (float, optional): Scalar value in range :math:`>=0`.
A larger ``gamma`` reduces the weight of sinogram entries with metal, but too large a value may reduce image quality inside the metal regions.
Returns:
(jax array): Weights used in mbircone reconstruction, with the same array shape as ``sinogram``
"""
# If init_recon is not provided, then identify the distorted sino entries with Otsu's thresholding method.
if init_recon is None:
print("init_recon is not provided. Automatically determine distorted sinogram entries with Otsu's method.")
# assuming three categories: metal, non_metal, and background.
[bk_thresh_sino, metal_thresh_sino] = mjp.multi_threshold_otsu(sinogram, classes=3)
print("Distorted sinogram threshold = ", metal_thresh_sino)
delta_metal = jnp.array(sinogram > metal_thresh_sino, dtype=jnp.dtype(jnp.float32), device=ct_model.main_device)
# If init_recon is provided, identify the distorted sino entries by forward projecting init_recon.
else:
if metal_threshold is None:
print("Metal_threshold calculated with Otsu's method.")
# assuming three categories: metal, non_metal, and background.
[bk_threshold, metal_threshold] = mjp.multi_threshold_otsu(init_recon, classes=3)
print("metal_threshold = ", metal_threshold)
# Identify metal voxels
metal_mask = jnp.array(init_recon > metal_threshold, dtype=jnp.dtype(jnp.float32), device=ct_model.main_device)
# Forward project metal mask to generate a sinogram mask
metal_mask_projected = ct_model.forward_project(metal_mask)
# metal mask in the sinogram domain, where 1 means a distorted sino entry, and 0 else.
delta_metal = jnp.array(metal_mask_projected > 0.0, dtype=jnp.dtype(jnp.float32), device=ct_model.main_device)
# weights for undistorted sino entries
weights = jnp.exp(-sinogram*(1+gamma*delta_metal)/beta)
return weights
[docs]
def gen_weights(sinogram, weight_type):
"""
Compute optional weights used in MBIR reconstruction based on the noise model.
The weights should be proportional to the inverse variance of the noise for each sinogram entry.
They can be used to improve reconstruction quality.
Args:
sinogram (jax.Array): A 3D JAX array of shape (num_views, num_det_rows, num_det_channels)
representing the sinogram.
weight_type (str): The type of noise model to use for weighting. Must be one of:
- 'unweighted': Use uniform weights (all ones).
- 'transmission': Use exponential decay, `exp(-sinogram)`.
- 'transmission_root': Use square-root decay, `exp(-sinogram / 2)`.
- 'emission': Use reciprocal decay, `1 / (abs(sinogram) + 0.1)`.
Returns:
jax.Array: A 3D array of weights with the same shape as the input sinogram.
Raises:
Exception: If `weight_type` is not one of the supported options.
Note:
For transmission noise models, sinogram values should not be excessively large (e.g., > 5),
as this corresponds to near-zero transmission, which is not physically meaningful in typical X-ray imaging.
Example:
>>> sinogram = jnp.ones((180, 64, 128))
>>> weights = gen_weights(sinogram, weight_type='transmission_root')
>>> weights.shape
(180, 64, 128)
"""
weight_list = []
num_views = sinogram.shape[0]
batch_size = 128
main_device = jax.devices('cpu')[0]
try:
gpus = jax.devices('gpu')
worker_device = gpus[0]
except RuntimeError:
worker_device = main_device
for i in range(0, num_views, batch_size):
sino_batch = jax.device_put(sinogram[i:min(i + batch_size, num_views)], worker_device)
if weight_type == 'unweighted':
weights = jnp.ones(sino_batch.shape)
elif weight_type == 'transmission':
weights = jnp.exp(-sino_batch)
elif weight_type == 'transmission_root':
weights = jnp.exp(-sino_batch / 2)
elif weight_type == 'emission':
weights = 1.0 / (jnp.absolute(sino_batch) + 0.1)
else:
raise Exception("gen_weights: undefined weight_type {}".format(weight_type))
weight_list.append(jax.device_put(weights, main_device))
weights = jnp.concatenate(weight_list, axis=0)
return weights