Source code for mbirjax.preprocess.stripe

import jax
import jax.numpy as jnp


def generate_column_index_matrix(num_rows, num_cols):
    """
    Create a 2D array of indexes used for the sorting technique.

    This code is based on the method described in:
    [Nghia T. Vo et al., 2018] - "Superior techniques for eliminating ring artifacts in x-ray micro-tomography"

    This code is adapted from the Tomopy library:
    https://github.com/tomopy/tomopy.git

    References:
    [1] Vo N, and Atwood RC, and Drakopoulos M. Superior techniques for eliminating ring artifacts in x-ray micro-tomography. Optics Express, 26(22):28396–28412, 2018.
    [2] Tomopy library https://github.com/tomopy/tomopy.git

    Args:
        num_rows (int): number of detector rows in the sinogram
        num_cols (int): number of detector channels in the sinogram

    Returns:
        index_matrix(jax array): a 2D jax array of indexes
    """
    list_index = jnp.arange(0.0, num_cols, 1.0)
    index_matrix = jnp.tile(list_index, (num_rows, 1))
    return index_matrix


def remove_small_stripes_sorting(sino, filter_size, index_matrix):
    """
    Remove small-to-medium partial and fulll stripes using the sorting technique.

    This code is based on the method described in:
    [Nghia T. Vo et al., 2018] - "Superior techniques for eliminating ring artifacts in x-ray micro-tomography"

    This code is adapted from the Tomopy library:
    https://github.com/tomopy/tomopy.git

    References:
    [1] Vo N, and Atwood RC, and Drakopoulos M. Superior techniques for eliminating ring artifacts in x-ray micro-tomography. Optics Express, 26(22):28396–28412, 2018.
    [2] Tomopy library https://github.com/tomopy/tomopy.git

    Args:
        sino (jax array): a 2D slice of the sinogram data with shape (num_views, num_det_channels)
        filter_size (int): window size of the median filter
        index_matrix (jax array): a 2D array of indexes used for the sorting technique

    Return:
        corrected_sino (jax array): corrected 2D slice of the sinogram data after stripes removal
    """

    from scipy.ndimage import median_filter
    # Sort each column of the sinogram by its grayscale values
    sino = jnp.transpose(sino)
    stacked_matrix = jnp.stack([index_matrix, sino], axis=2)
    sorted_indices = jnp.argsort(stacked_matrix[:, :, 1], axis=1)
    sorted_indices_expanded = sorted_indices[:, :, None]
    sorted_stacked_matrix = jnp.take_along_axis(stacked_matrix, sorted_indices_expanded, axis=1)

    # Apply the median filter on teh sorted sinogram along each row
    sorted_stacked_matrix = sorted_stacked_matrix.at[:, :, 1].set(median_filter(sorted_stacked_matrix[:, :, 1], (filter_size, 1)))

    # Re-sort the smoothed image columns to the original rows to get the corrected sinogram
    sorted_indices = jnp.argsort(sorted_stacked_matrix[:, :, 0], axis=1)
    sorted_indices_expanded = sorted_indices[:, :, None]
    sort_back_matrix = jnp.take_along_axis(sorted_stacked_matrix, sorted_indices_expanded, axis=1)
    corrected_sino = sort_back_matrix[:, :, 1]
    corrected_sino = jnp.transpose(corrected_sino)

    return corrected_sino

def detect_stripe(list_data, snr):
    """
    Used to locate stripes.
    A segmentation algorithm to separate the extremely positive and negative defects from the normal values in the
    sinogram.

    This code is based on the method described in:
    [Nghia T. Vo et al., 2018] - "Superior techniques for eliminating ring artifacts in x-ray micro-tomography"

    This code is adapted from the Tomopy library:
    https://github.com/tomopy/tomopy.git

    Args:
        list_data (jax array): a normalized 1D array
        snr (float): a ratio between the defective value and the background value. a reasonable choice of snr should be around 3.0 or above

    Returns:
        list_mask (jax array): a 2D binary array denoting the stripes detected
    """

    num_data = list_data.shape[0]

    # Sort the 1D array
    sorted_list = jnp.sort(list_data)[::-1]
    x_list = jnp.arange(0, num_data, 1.0)

    # Apply a linear fit to values around the middle of the sorted array
    # Calculating the noise level to avoid false positives caused by minor background variations
    num_data_drop = jnp.int16(0.25 * num_data)
    (_slope, _intercept) = jnp.polyfit(x_list[num_data_drop:-num_data_drop - 1], sorted_list[num_data_drop:-num_data_drop - 1], 1)
    fitted_value_last_index = _intercept + _slope * x_list[-1]
    noise_level = jnp.abs(fitted_value_last_index - _intercept)
    noise_level = jnp.clip(noise_level, 1e-6, None)
    val1 = jnp.abs(sorted_list[0] - _intercept) / noise_level
    val2 = jnp.abs(sorted_list[-1] - fitted_value_last_index) / noise_level

    # Calculate the upper threshold and the lower threshold
    # Binarize the array by replacing all values between lower and upper threshold with 0 and others with 1.
    list_mask = jnp.zeros_like(list_data)
    if (val1 >= snr):
        upper_threshold = _intercept + noise_level * snr * 0.5
        list_mask = jnp.where(list_data > upper_threshold, 1.0, list_mask)
    if (val2 >= snr):
        lower_threshold = fitted_value_last_index - noise_level * snr * 0.5
        list_mask = jnp.where(list_data <= lower_threshold, 1.0, list_mask)

    return list_mask


def remove_large_stripes_sorting(sino, snr, filter_size, index_matrix, drop_ratio=0.1):
    """
    Remove large partial and full stripes using the sorting technique.

    This code is based on the method described in:
    [Nghia T. Vo et al., 2018] - "Superior techniques for eliminating ring artifacts in x-ray micro-tomography"

    This code is adapted from the Tomopy library:
    https://github.com/tomopy/tomopy.git

    References:
    [1] Vo N, and Atwood RC, and Drakopoulos M. Superior techniques for eliminating ring artifacts in x-ray micro-tomography. Optics Express, 26(22):28396–28412, 2018.
    [2] Tomopy library https://github.com/tomopy/tomopy.git

    Args:
        sino (jax array): a 2D slice of the sinogram data with shape (num_views, num_det_channels)
        snr (float): a ratio between the defective value and the background value
        filter_size (int): window size of the median filter
        index_matrix (jnp array): a 2D array of indexes used for the sorting technique
        drop_ratio (float, optional): ratio of pixels at the top and the bottom of the sinogram to be removed. Defaults to 0.1.

    Returns:
        sino (jax array): corrected 2D slice of the sinogram data after stripes removal
    """

    from scipy.ndimage import median_filter, binary_dilation
    drop_ratio = jnp.clip(drop_ratio, 0.0, 0.8)
    (num_rows, num_cols) = sino.shape
    num_rows_drop = jnp.int16(0.5 * drop_ratio * num_rows)

    # Sorting the columns of the sinogram and apply the median filter on the sorted image along each row
    sorted_sino = jnp.sort(sino, axis=0)
    smoothed_sino = jnp.array(median_filter(sorted_sino, (1, filter_size)))

    # Compute the column-wise average of the sorted and smoothed sinogram
    # Compute the normalized 1D array
    raw_list = jnp.mean(sorted_sino[num_rows_drop:num_rows - num_rows_drop], axis=0)
    smoothed_list = jnp.mean(smoothed_sino[num_rows_drop:num_rows - num_rows_drop], axis=0)
    normalized_list = jnp.where(
        smoothed_list != 0,
        jnp.divide(raw_list, smoothed_list),
        jnp.ones_like(raw_list)
    )

    # Locate the large stripes
    list_mask = detect_stripe(normalized_list, snr)
    list_mask = jnp.array(binary_dilation(list_mask, iterations=1).astype(list_mask.dtype))
    normalized_factor = jnp.tile(normalized_list, (num_rows, 1))

    # Apply pre-correction to the original sinogram
    sino = sino / normalized_factor

    # Apply the sorting-based algorithm again to get the corrected columns
    sino_T = jnp.transpose(sino)
    stacked_matrix = jnp.stack([index_matrix, sino_T], axis=2)

    sorted_indices = jnp.argsort(stacked_matrix[:, :, 1], axis=1)
    sorted_indices_expanded = sorted_indices[:, :, None]
    sorted_matrix = jnp.take_along_axis(stacked_matrix, sorted_indices_expanded, axis=1)

    sorted_matrix = sorted_matrix.at[:, :, 1].set(jnp.transpose(smoothed_sino))

    sorted_indices = jnp.argsort(sorted_matrix[:, :, 0], axis=1)
    sorted_indices_expanded = sorted_indices[:, :, None]
    sort_back_matrix = jnp.take_along_axis(sorted_matrix, sorted_indices_expanded, axis=1)

    corrected_sino = jnp.transpose(sort_back_matrix[:, :, 1])

    list_x_miss = jnp.where(list_mask > 0.0)[0]

    # Selective Replacement of Defective Columns with corrected columns
    sino = sino.at[:, list_x_miss].set(corrected_sino[:, list_x_miss])
    return sino


def remove_dead_fluctuating_stripes_interpolation(sino, snr, filter_size, index_matrix):
    """
    Remove unresponsive and fluctuating stripes using the interpolation technique.
    Sorting approach does not work here because the rankings of the grayscales are significantly different between
    pixels inside the stripes and outside the stripes. Instead, interpolation is an appropriate choice.

    This code is based on the method described in:
    [Nghia T. Vo et al., 2018] - "Superior techniques for eliminating ring artifacts in x-ray micro-tomography"

    This code is adapted from the Tomopy library:
    https://github.com/tomopy/tomopy.git

    References:
    [1] Vo N, and Atwood RC, and Drakopoulos M. Superior techniques for eliminating ring artifacts in x-ray micro-tomography. Optics Express, 26(22):28396–28412, 2018.
    [2] Tomopy library https://github.com/tomopy/tomopy.git

    Args:
        sino (jax array): a 2D slice of the sinogram data with shape (num_views, num_det_channels)
        snr (float): a ratio between the defective value and the background value
        filter_size (int): window size of the median filter
        index_matrix (jax array): a 2D array of indexes used for the sorting technique

    Returns:
        sino (jax array): corrected 2D slice of the sinogram data after stripes removal

    """
    from scipy import interpolate
    from scipy.ndimage import uniform_filter1d, median_filter, binary_dilation
    num_rows = sino.shape[0]

    # Compute the column-wise absolute difference between the original sinogram and the smoothed sinogram
    # Help detect the unresponsive and fluctuating stripes (large diff -> fluctuating stripes; small diff -> unresponsive stripes)
    smoothed_sino = jnp.array(uniform_filter1d(sino, 10, axis=0))
    difference_list = jnp.sum(jnp.abs(sino - smoothed_sino), axis=0)

    # Compute normalized 1D array where the large value correspond to defective pixels
    difference_list_filtered = jnp.array(median_filter(difference_list, size=filter_size))
    normalized_list = jnp.where(
        difference_list_filtered != 0,
        jnp.divide(difference_list, difference_list_filtered),
        jnp.ones_like(difference_list_filtered)
    )

    # Generate binary mask
    list_mask = detect_stripe(normalized_list, snr)
    list_mask = jnp.array(binary_dilation(list_mask, iterations=1).astype(list_mask.dtype))
    list_mask = list_mask.at[0:2].set(0.0)
    list_mask = list_mask.at[-2:].set(0.0)

    # Interpolation
    x_list = jnp.where(list_mask < 1.0)[0]
    y_list = jnp.arange(num_rows)
    z_matrix = sino[:, x_list]
    fit_function = interpolate.RectBivariateSpline(y_list, x_list, z_matrix, kx=1, ky=1)

    # Apply interpolation to defective columns
    list_x_miss = jnp.where(list_mask > 0.0)[0]
    if len(list_x_miss) > 0:
        matrix_x_miss, matrix_y = jnp.meshgrid(list_x_miss, y_list)
        estimate_output = fit_function.ev(jnp.ravel(matrix_y), jnp.ravel(matrix_x_miss))
        sino = sino.at[:, list_x_miss].set(jnp.reshape(estimate_output, matrix_x_miss.shape))

    # Remove residual large stripes
    corrected_sino = remove_large_stripes_sorting(sino, snr, filter_size, index_matrix)

    return corrected_sino

[docs] def remove_all_stripe(sino, snr=3, large_filter_size=61, small_filter_size=21): """ Removes all types of stripe artifacts from a sinogram using a combination of three algorithms: 1. Interpolation-based removal of unresponsive and fluctuating stripes. 2. Sorting-based removal of large partial and full stripes. 3. Sorting-based removal of small to medium partial and full stripes. This method is adapted from `tomopy.remove_all_stripes()` and is based on: Vo N, Atwood RC, Drakopoulos M. "Superior techniques for eliminating ring artifacts in x-ray micro-tomography." Optics Express, 26(22):28396–28412, 2018. Args: sino (jax.Array): A 3D sinogram array with shape (num_views, num_det_rows, num_det_channels). snr (float, optional): Signal-to-noise ratio used for stripe detection. A typical value is 3.0. Defaults to 3. large_filter_size (int, optional): Median filter window size for removing large stripes. Defaults to 61. small_filter_size (int, optional): Median filter window size for removing small-to-medium stripes. Defaults to 21. Returns: jax.Array: Corrected 3D sinogram array after removing all stripe artifacts. Example: >>> import jax.numpy as jnp >>> import mbirjax.preprocess as mjp >>> sino = jnp.ones((180, 128, 256)) # Simulated 3D sinogram >>> cleaned_sino = mjp.remove_all_stripe(sino) """ from concurrent.futures import ThreadPoolExecutor index_matrix = generate_column_index_matrix(sino.shape[2], sino.shape[0]) index_matrix_cpu = jax.device_put(index_matrix, device=jax.devices("cpu")[0]) sino_cpu = jax.device_put(sino, device=jax.devices("cpu")[0]) result = jnp.zeros_like(sino) def process_slice(m): sino_slice = sino_cpu[:, m, :] sino_slice = remove_dead_fluctuating_stripes_interpolation(sino_slice, snr, large_filter_size, index_matrix_cpu) sino_slice = remove_small_stripes_sorting(sino_slice, small_filter_size, index_matrix_cpu) return sino_slice with ThreadPoolExecutor() as executor: processed_slices = list(executor.map(process_slice, range(sino.shape[1]))) for m, processed_slice in enumerate(processed_slices): result = result.at[:, m, :].set(processed_slice) return jax.device_put(result)
[docs] def remove_stripe_fw(sino, wavelet_filter_name="db5", sigma=2): """ Removes vertical stripe artifacts from a 3D sinogram using a combined wavelet-Fourier filtering technique. This method uses a 2D Discrete Wavelet Transform followed by a 2D Fourier transform to suppress vertical stripes, as described in: Beat Münch et al., "Stripe and ring artifact removal with combined wavelet—Fourier filtering", Optics Express, 2009. This implementation is adapted from the Tomopy library's `remove_stripe_fw()`: https://github.com/tomopy/tomopy.git Args: sino (jax.Array): 3D sinogram data with shape (num_views, num_det_rows, num_det_channels). wavelet_filter_name (str, optional): Wavelet filter type (e.g., 'db5', 'haar'). Defaults to 'db5'. sigma (float, optional): Damping parameter in the Fourier domain. Controls the strength of stripe suppression. Defaults to 2. Returns: jax.Array: Corrected sinogram data with reduced vertical stripe artifacts. Example: >>> import jax.numpy as jnp >>> import mbirjax.preprocess as mjp >>> sino = jnp.ones((180, 128, 256)) # Simulated sinogram >>> cleaned_sino = mjp.remove_stripe_fw(sino) """ # Determine decomposition level L import pywt level = int(jnp.ceil(jnp.log2(jnp.max(jnp.array(sino.shape))))) views, num_rows, num_columns = sino.shape padded_views = views + views // 8 shift_val = views // 16 for m in range(sino.shape[1]): sino_slice = jnp.zeros((padded_views, num_columns), dtype=jnp.float32) sino_slice = sino_slice.at[shift_val:views + shift_val].set(sino[:, m, :]) # 2D Discrete Wavelte Transform cH, cV, cD = {}, {}, {} for n in range(level): sino_slice, (cH[n], cV[n], cD[n]) = pywt.dwt2(sino_slice, wavelet_filter_name) # FFT transform of horizontal frequency bands for n in range(level): # FFT fcV = jnp.fft.fftshift(jnp.fft.fft(cV[n], axis=0)) my, mx = fcV.shape # Damping of vertical stripe information y_hat = (jnp.arange(-my, my, 2, dtype='float32') + 1) / 2 damp = -jnp.expm1(-jnp.square(y_hat) / (2 * jnp.square(sigma))) fcV *= jnp.transpose(jnp.tile(damp, (mx, 1))) # Inverse FFT cV[n] = jnp.real(jnp.fft.ifft(jnp.fft.ifftshift(fcV), axis=0)) # 2D inverse discrete wavelet transform for n in range(level)[::-1]: sino_slice = sino_slice[0:cH[n].shape[0], 0:cH[n].shape[1]] sino_slice = pywt.idwt2((sino_slice, (cH[n], cV[n], cD[n])), wavelet_filter_name) sino = sino.at[:, m, :].set(sino_slice[shift_val:views + shift_val, 0:num_columns]) return sino
[docs] def remove_sino_offset(sino): """ Remove additive offsets in the sinogram caused by material outside the field of view. This function corrects each row of the sinogram so that the sum over channels is constant across views and equal to the minimum sum observed across all views. Args: sino (jax.Array): Sinogram with shape (num_views, num_rows, num_channels). Returns: jax.Array: Corrected sinogram with the same shape as the input. Example: >>> import jax.numpy as jnp >>> import mbirjax.preprocess as mjp >>> sino = jnp.ones((180, 128, 256)) + jnp.linspace(0, 1, 180)[:, None, None] >>> corrected_sino = mjp.remove_sino_offset(sino) """ # Compute the average over channels: shape [view, row] sino_channel_avg = jnp.mean(sino, axis=2) # Compute the minimum of the channel average across views: shape [row] sino_min_channel_avg = jnp.min(sino_channel_avg, axis=0) # Compute corrected sinogram sino_corrected = sino - sino_channel_avg[:, :, None] + sino_min_channel_avg[None, :, None] return sino_corrected