Denoising#
MBIRJAX includes a Bayesian MAP denoising using the qGGMRF prior and a 3D median filter.
QGGMRFDenoiser#
The QGGMRFDenoiser class implements a 3D volume denoiser for additive white Gaussian denoising.
More specifically, it computes the MAP assuming additive white Gaussian noise and a qGGMRF prior distribution.
Using \(H(x)\) to denote the denoising function, the denoiser is
\[H(x) = \arg\min_v \left\{ \frac{1}{2 \sigma_{noise}^2}\|x - v\|^{2} + h(v) \right\}.\]
The denoiser will automatically estimate the noise level in the image, or the can directly set the value of noise standard deviation through the parameter sigma_noise. Larger values of sigma_noise lead to smoother images. Alternatively, the amount of denoising can be adjusted using parameter sharpness (default=0.0). This class inherits many of the behaviors and attributes of the Tomography Model.
Constructor#
- class mbirjax.QGGMRFDenoiser(image_shape)[source]#
Bases:
TomographyModelThe QGGMRFDenoiser uses the MBIRJAX recon framework to implement a qggmrf proximal map denoiser. The primary interface is through
denoise().
Denoise#
- QGGMRFDenoiser.denoise(image, sigma_noise=None, use_ror_mask=False, init_image=None, max_iterations=15, stop_threshold_change_pct=0.2, first_iteration=0, logfile_path='./logs/recon.log', print_logs=True)[source]#
Compute the MAP denoiser assuming AWGN and the 3D qGGMRF prior.
With default settings, and with X a clean image and W equal to AWGN of standard deviation sigma_noise, the result of
denoise()applied to X+W is the MAP estimate of the denoised image using the qGGMRF prior function.The amount of denoising can be changed by changing sigma_noise. If sigma_noise is None, then sigma_noise is estimated from a set of samples from the image.
Denoising strength can also be adjusted using the parameter sharpness (default=0.0).
- Parameters:
image (numpy or jax array) – The 3D volume to be denoised.
sigma_noise (float, optional) – The estimated noise variance in the noisy image. If None, then the noise level is estimated from the image.
use_ror_mask (bool, optional) – Set true to restrict denoising to an inscribed circle in the image. Defaults to False.
init_image (numpy or jax array, optional) – An initial image for the minimization. Defaults to image.
max_iterations (int, optional) – maximum number of iterations of the VCD algorithm to perform.
stop_threshold_change_pct (float, optional) – Stop reconstruction when 100 * ||delta_recon||_1 / ||recon||_1 change from one iteration to the next is below stop_threshold_change_pct. Defaults to 0.2. Set this to 0 to guarantee exactly max_iterations.
first_iteration (int, optional) – Set this to be the number of iterations previously completed when restarting a recon using init_recon. This defines the first index in the partition sequence. Defaults to 0.
logfile_path (str, optional) – Path to the output log file. Defaults to ‘./logs/recon.log’.
print_logs (bool, optional) – If true then print logs to console. Defaults to True.
- Returns:
tuple –
- (denoised_image, denoiser_dict)
denoised_image (jax array): A denoised image of the same shape as image
- denoiser_dict (dict): A dict obtained from
get_recon_dict()with entries ’recon_params’
’notes’
’recon_logs’
’model_params’
- denoiser_dict (dict): A dict obtained from
Example
>>> denoiser = mj.QGGMRFDenoiser(noisy_image.shape) >>> denoiser.set_params(sharpness=0.5) # Increase sharpness a little over the default of 0.0 >>> denoised_image, denoised_dict = denoiser.denoise(noisy_image) # Estimate the noise level from the image >>> mj.slice_viewer(noisy_image, denoised_image, data_dicts=[None, denoised_dict], title='Noisy and denoised images')
See also
TomographyModelThe base class from which this class inherits.
Median Filter#
MBIRJAX also includes a 3x3x3 median filter, which can be used as a simple denoiser. The median filter can optionally also return the min and max in 3x3x3 neighborhoods.
- mbirjax.median_filter3d(x, max_block_gb=4.0, return_min_max=False) Array | tuple[source]#
Apply a 27‑point (3x3x3) median filter to a 3‑D JAX array using replicated (edge) boundary conditions. Optionally return the min and max in each 27 point neighborhood.
The volume is processed in d0‑blocks so that the kernel can be jax.jit‑compiled while limiting peak device memory. Each block is padded with a one‑voxel halo; halos duplicate the nearest edge voxel so that the result matches NumPy’s “edge” mode.
- Parameters:
x (jax array or ndarray) – Input array. Any numeric dtype supported by JAX is allowed.
max_block_gb (float. optional) – A rough upper bound on the amount of memory in GB to use for the filtering. Defaults to 4.0.
return_min_max (bool, optional) – If true, the output is a tuple of median, min, max.
- Returns:
jax.numpy.ndarray or tuple – An array (or tuple of 3 arrays) of the same shape and dtype as x containing the median‑filtered result.
Notes
The function automatically splits the 0‑dimension into blocks so that at most roughly
max_block_gbof temporary data are materialised. If the array is large and the 0 dimension is small relative to another dimension, it may be more memory efficient to apply jnp.swapaxes(x, 0, long_dim) before applying median_filter3d, although swapaxes will make a copy of x.Within each block the filter is computed by rolling the data in all 26 neighbour directions, stacking the 27 volumes, and taking
jnp.median()along the new axis.
Examples
>>> import jax.numpy as jnp >>> import mbirjax as mj >>> vol = jnp.arange(27.).reshape(3, 3, 3) >>> mj.median_filter3d(vol) Array([[[3., 3., 4.], [6., 6., 7.], [6., 7., 8.]], ... dtype=float32)