import os, sys
from operator import itemgetter
import numpy as np
import jax
import jax.numpy as jnp
import warnings
import mbirjax.preprocess as mjp
import pprint
import logging
import olefile
import struct
from pathlib import Path
pp = pprint.PrettyPrinter(indent=4)
logger = logging.getLogger(__name__)
[docs]
def compute_sino_and_params(dataset_dir, downsample_factor=(1, 1), subsample_view_factor=1, crop_pixels_sides=0, crop_pixels_top=0, crop_pixels_bottom=0, alu_unit='mm', bg_option="global", verbose=1):
"""
Compute sinogram and parameters from txrm file generated by a Zeiss Versa scanner.
Notes:
Thanks to contributions of Amir Koushyar Ziabari of Oak Ridge National Laboratory (ORNL).
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Steps:
1. Load object, blank, and dark scans and geometry.
2. Compute the sinogram from the scans.
3. Apply background offset correction.
4. Apply sinogram offset correction.
Args:
dataset_dir (str): Path to the Zeiss dataset. Accepts a ``.txrm`` file with:
- ``ImageData*/Image*`` (scan data)
- Zeiss OLE metadata streams
downsample_factor (Tuple[int, int], optional): Downsample factors for detector rows and channels. Defaults to (1, 1).
subsample_view_factor (int, optional): Factor by which to subsample views. Defaults to 1.
crop_pixels_sides (int, optional): Pixels to crop from each lateral side of the detector. Defaults to ``0``.
crop_pixels_top (int, optional): Pixels to crop from the top of the detector. Defaults to ``0``.
crop_pixels_bottom (int, optional): Pixels to crop from the bottom of the detector. Defaults to ``0``.
alu_unit (str, optional): The physical unit used to define 1 ALU (Arbitrary Length Unit). Defaults to 'mm'.
Supported units input: 'um', 'mm', 'cm', 'm'.
bg_option(str or None): Option for background offset correction. Defaults to 'global'.
Supported options:
- None: No correction; return the input sinogram unchanged.
- "global": Estimate one scalar offset from edge regions across all views.
- "per_view": Estimate one offset per view from edge regions.
verbose (int, optional): Verbosity level. Defaults to ``1``.
Returns:
tuple: ``(sino, cone_beam_params, optional_params, zeiss_params)``
- ``sino`` (numpy.ndarray): Sinogram of shape ``(num_views, num_det_rows, num_channels)``.
- ``cone_beam_params`` (dict): Parameters for initializing ``ConeBeamModel``.
- ``optional_params`` (dict): Additional parameters to be set via ``ConeBeamModel.set_params``.
Example:
.. code-block:: python
from mbirjax.preprocess.zeiss_cb import compute_sino_and_params
sino, cone_beam_params, optional_params, zeiss_params = compute_sino_and_params(
dataset_dir, verbose=1
)
ct_model = mbirjax.ConeBeamModel(**cone_beam_params)
ct_model.set_params(**optional_params)
recon, recon_dict = ct_model.recon(sino)
"""
if verbose > 0:
print("\n\n########## Loading object, blank, dark scans, and geometry parameters from Zeiss dataset directory")
obj_scan, blank_scan, dark_scan, zeiss_params = load_scans_and_params(dataset_dir, subsample_view_factor)
cone_beam_params, optional_params = convert_zeiss_to_mbirjax_params(zeiss_params, downsample_factor=downsample_factor,
crop_pixels_sides=crop_pixels_sides,
crop_pixels_top=crop_pixels_top,
crop_pixels_bottom=crop_pixels_bottom,
alu_unit=alu_unit)
if verbose > 0:
print("\n\n########## Cropping and downsampling scans")
### crop the scans based on input params
obj_scan, blank_scan, dark_scan, defective_pixel_array = mjp.crop_view_data(obj_scan, blank_scan, dark_scan,
crop_pixels_sides=crop_pixels_sides,
crop_pixels_top=crop_pixels_top,
crop_pixels_bottom=crop_pixels_bottom)
### downsample the scans with block-averaging
if downsample_factor[0] * downsample_factor[1] > 1:
obj_scan, blank_scan, dark_scan, defective_pixel_array = mjp.downsample_view_data(obj_scan, blank_scan, dark_scan,
downsample_factor=downsample_factor,
defective_pixel_array=defective_pixel_array)
if verbose > 0:
print("\n\n########## Computing sinogram from object, blank, and dark scans")
sino = mjp.compute_sino_transmission(obj_scan, blank_scan, dark_scan, defective_pixel_array)
if verbose > 0:
print("\n\n########## Correcting sinogram data to account for background offset and sino offset")
sino = mjp.correct_background_offset(sino, option=bg_option)
# Correct sino offset
sino = correct_sino_shifts(sino, zeiss_params, downsample_factor, subsample_view_factor)
if verbose > 0:
print('obj_scan shape = ', obj_scan.shape)
print('blank_scan shape = ', blank_scan.shape)
print('dark_scan shape = ', dark_scan.shape)
return sino, cone_beam_params, optional_params
[docs]
def load_scans_and_params(dataset_dir, subsample_view_factor, verbose=1):
"""
Load the scan data and geometry from a Zeiss scan directory.
Notes:
Thanks to contributions of Amir Koushyar Ziabari of Oak Ridge National Laboratory (ORNL).
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
dataset_dir (str): Path to a Zeiss scan directory (expect a `.txrm` file). Expected structure:
- ``ImageData*/Image*`` (scan data)
- ``**/**`` (Zeiss metadata/parameters)
subsample_view_factor (int, optional): view subsample factor.
verbose (int, optional): Verbosity level. Defaults to 1.
Returns:
tuple: (obj_scan, blank_scan, dark_scan, zeiss_params, zeiss_params)
- obj_scan (numpy.ndarray): 3D object scan with shape ``(num_views, num_det_rows, num_channels)``.
- blank_scan (numpy.ndarray): 3D blank scan with shape ``(1, num_det_rows, num_channels)``.
- dark_scan (numpy.ndarray): 3D dark scan with shape ``(1, num_det_rows, num_channels)``.
If no dark scan is available, returns a zero array of the same shape.
"""
### automatically parse the paths to Zeiss scans from dataset_dir
data_dir = _parse_filenames_from_dataset_dir(dataset_dir)
if verbose > 0:
print("The following files will be used to compute the Zeiss reconstruction:\n",
f" - txrm file: {data_dir}\n")
# Read object scans and metadata
file_name = _check_read(data_dir)
try:
ole = olefile.OleFileIO(file_name)
except IOError as e:
print('No such file or directory: %s', file_name)
raise e
# Read metadata from txrm file
zeiss_params = read_metadata(ole)
# Create an empty array to store the scan data
obj_scan = np.empty(
(
zeiss_params["num_views"],
zeiss_params["num_det_rows"],
zeiss_params["num_det_channels"],
),
dtype=_get_ole_data_type(zeiss_params)
)
# Read the (subsampled) scan data from txrm file
view_indices = np.arange(zeiss_params["num_views"], step=subsample_view_factor)
obj_scan = np.zeros((len(view_indices), zeiss_params["num_det_rows"], zeiss_params["num_det_channels"]),
dtype=_get_ole_data_type(zeiss_params))
for i, idx in enumerate(view_indices):
img_string = "ImageData{}/Image{}".format(
int(np.ceil((idx + 1) / 100.0)), int(idx + 1))
obj_scan[i] = _read_ole_image(ole, img_string, zeiss_params)
# Read blank scans
blank_scan = zeiss_params["reference"]
# Read dark scans
# TODO: Currently we assume that there is no dark scan for txrm file
dark_scan = np.zeros(obj_scan.shape, dtype=obj_scan.dtype)
try:
# Get the list of measurement axis names and units
axis_names = zeiss_params.get("axis_names") # This is a list of names: ['Sample X', 'Sample Y', ..., 'CCD_X', ...]
axis_units = zeiss_params.get("axis_units") # This is a list of units: ['um', 'um', ..., ]
# Get source to iso distance
source_iso_dist = zeiss_params["source_iso_dist"]
source_iso_dist = float(np.abs(source_iso_dist))
source_iso_dist_index = get_index_in_list(axis_names, 'Source Z')
source_iso_dist_unit = axis_units[source_iso_dist_index] if source_iso_dist_index > -1 else 'mm'
# Get iso to detector distance
iso_det_dist = zeiss_params["iso_det_dist"]
iso_det_dist = float(np.abs(iso_det_dist)) if iso_det_dist is not None else 0.0
iso_det_dist_unit = source_iso_dist_unit
# Get detector pixel pitch
det_pixel_pitch = zeiss_params["det_pixel_pitch"]
det_pixel_pitch = float(np.abs(det_pixel_pitch))
# Zeiss detector pixel has equal width and height
delta_det_row = det_pixel_pitch
delta_det_channel = det_pixel_pitch
delta_det_index = get_index_in_list(axis_names, 'CCD_X')
delta_det_row_unit = axis_units[delta_det_index] if delta_det_index > -1 else 'um'
delta_det_channel_unit = delta_det_row_unit
# Get pixel pitch at iso
iso_pixel_pitch = zeiss_params["iso_pixel_pitch"]
iso_pixel_pitch = float(np.abs(iso_pixel_pitch))
iso_pixel_pitch_index = get_index_in_list(axis_names, 'Sample X')
iso_pixel_pitch_unit = axis_units[iso_pixel_pitch_index] if iso_pixel_pitch_index > -1 else 'um'
angle_index = get_index_in_list(axis_names, 'Sample Theta')
angle_unit = axis_units[angle_index] if angle_index > -1 else 'deg'
x_shifts = zeiss_params["x_shifts"]
y_shifts = zeiss_params["y_shifts"]
except ValueError as e:
print("Unable to determine units for geometry parameters; cannot safely convert to mbirjax format.")
raise e
# Get optical Magnification
opt_mag = zeiss_params["opt_mag"]
opt_mag = 1 if opt_mag is None else opt_mag
# Get dimensions of radiograph
num_views = zeiss_params["num_views"]
num_det_channels = zeiss_params["num_det_channels"]
num_det_rows = zeiss_params["num_det_rows"]
# Rotation angles
angles = -np.array(zeiss_params['thetas'], dtype=float).ravel()
angles = angles[view_indices]
# Detector offset parameters
# MBIRJAX has the reverse convention for the channel shift
det_channel_offset = -zeiss_params["center_shift"]
det_row_offset = 0.0 # There doesn't appear to be a Zeiss parameter for detector row offset
zeiss_params = {
'source_iso_dist': source_iso_dist,
'iso_det_dist': iso_det_dist,
'delta_det_channel': delta_det_channel,
'delta_det_row': delta_det_row,
'iso_pixel_pitch': iso_pixel_pitch,
'opt_mag': opt_mag,
'num_views': num_views,
'num_det_channels': num_det_channels,
'num_det_rows': num_det_rows,
'angles': angles,
'det_row_offset': det_row_offset,
'det_channel_offset': det_channel_offset,
'source_iso_dist_unit': source_iso_dist_unit,
'iso_det_dist_unit': iso_det_dist_unit,
'delta_det_row_unit': delta_det_row_unit,
'delta_det_channel_unit': delta_det_channel_unit,
'iso_pixel_pitch_unit': iso_pixel_pitch_unit,
'angle_unit': angle_unit,
'x_shifts': x_shifts,
'y_shifts': y_shifts,
}
if verbose > 0:
print("############ Zeiss geometry parameters ############")
print(f"Source to iso distance: {source_iso_dist} [{source_iso_dist_unit}]")
print(f"Iso to detector distance: {iso_det_dist} [{iso_det_dist_unit}]")
print(f"Detector pixel pitch: {det_pixel_pitch:.3f} [{delta_det_row_unit}]")
print(f"Source to iso distance: {iso_pixel_pitch} [{iso_pixel_pitch_unit}]")
print(f"Optical magnification: {opt_mag}")
print(f"Number of views: {num_views}")
print(f"Detector size: (num_det_rows, num_det_channels) = ({num_det_rows}, {num_det_channels})")
print("############ End Zeiss geometry parameters ############")
### END load Zeiss parameters from scan data
return obj_scan, blank_scan, dark_scan, zeiss_params
def convert_zeiss_to_mbirjax_params(zeiss_params, downsample_factor=(1, 1), crop_pixels_sides=0, crop_pixels_top=0, crop_pixels_bottom=0, alu_unit='mm'):
"""
Convert geometry parameters from zeiss into mbirjax format, including modifications to reflect crop.
Notes:
Thanks to contributions of Amir Koushyar Ziabari of Oak Ridge National Laboratory (ORNL).
Args:
zeiss_params (dict): Required Zeiss geometry parameters for reconstruction.
downsample_factor ((int, int), optional) - Down-sample factors along the detector rows and channels respectively.
If scan size is not divisible by `downsample_factor`, the scans will be first truncated to a size that is divisible by `downsample_factor`.
crop_pixels_sides (int, optional): The number of pixels to crop from each side of the sinogram. Defaults to 0.
crop_pixels_top (int, optional): The number of pixels to crop from top of the sinogram. Defaults to 0.
crop_pixels_bottom (int, optional): The number of pixels to crop from bottom of the sinogram. Defaults to 0.
alu_unit (str, optional): The physical unit used to define 1 ALU (Arbitrary Length Unit). Defaults to 'mm'.
Supported units input: 'um', 'mm', 'cm', 'm'.
Returns:
cone_beam_params (dict): Required parameters for the ConeBeamModel constructor.
optional_params (dict): Additional ConeBeamModel parameters to be set using set_params()
metadata (dict): metadata stored in Zeiss txrm file.
"""
# Get zeiss parameters
source_iso_dist, iso_det_dist, source_iso_dist_unit, iso_det_dist_unit = itemgetter('source_iso_dist', 'iso_det_dist', 'source_iso_dist_unit', 'iso_det_dist_unit')(zeiss_params)
delta_det_channel, delta_det_row, delta_det_channel_unit, delta_det_row_unit = itemgetter('delta_det_channel', 'delta_det_row', 'delta_det_channel_unit', 'delta_det_row_unit')(zeiss_params)
iso_pixel_pitch, iso_pixel_pitch_unit = itemgetter('iso_pixel_pitch', 'iso_pixel_pitch_unit')(zeiss_params)
opt_mag = itemgetter('opt_mag')(zeiss_params)
num_det_rows, num_det_channels = itemgetter('num_det_rows', 'num_det_channels')(zeiss_params)
angles, angle_unit = itemgetter('angles', 'angle_unit')(zeiss_params)
det_row_offset, det_channel_offset = itemgetter('det_row_offset', 'det_channel_offset')(zeiss_params)
# Create unit conversion table for all units used in the txrm files
unit_conversion = {'um': 1.0, 'mm': 1000.0, 'cm': 1e4, 'm': 1e6}
# Define 1 ALU as 1 unit of alu_unit
alu_value = 1
# Convert physical units to ALU
source_iso_dist *= unit_conversion[source_iso_dist_unit] / unit_conversion[alu_unit]
iso_det_dist *= unit_conversion[iso_det_dist_unit] / unit_conversion[alu_unit]
delta_det_channel *= unit_conversion[delta_det_channel_unit] / unit_conversion[alu_unit]
delta_det_row *= unit_conversion[delta_det_row_unit] / unit_conversion[alu_unit]
iso_pixel_pitch *= unit_conversion[iso_pixel_pitch_unit] / unit_conversion[alu_unit]
# Compute default value of source to detector distance
source_detector_dist = source_iso_dist + iso_det_dist
# Convert angles to radians
if angle_unit == 'deg':
angles = np.deg2rad(angles)
else:
pass
# Make conversions for optical magnification
# In this case, the "detector" is actually a scintillator
if opt_mag is not None:
# Compute total magnification = (optical magnification) * (magnification to scintillator)
scintillator_mag = source_detector_dist/source_iso_dist
magnification = opt_mag * scintillator_mag
else:
magnification = 1.0
# Compute source to equivalent quantities accounting for total magnification
source_detector_dist = magnification * source_iso_dist
delta_det_channel = magnification * iso_pixel_pitch
delta_det_row = magnification * iso_pixel_pitch
# Adjust detector size params w.r.t. cropping arguments
num_det_rows = num_det_rows - (crop_pixels_top + crop_pixels_bottom)
num_det_channels = num_det_channels - 2 * crop_pixels_sides
# Adjust detector size and pixel pitch params w.r.t. downsampling arguments
num_det_rows = num_det_rows // downsample_factor[0]
num_det_channels = num_det_channels // downsample_factor[1]
delta_det_row *= downsample_factor[0]
delta_det_channel *= downsample_factor[1]
iso_pixel_pitch *= downsample_factor[0]
# Convert to ALU: This assumes that the det_channel_offset and det_row_offset have units of pixels.
det_channel_offset *= delta_det_channel
det_row_offset *= delta_det_row
# Create a dictionary to store MBIR parameters
num_views = len(angles)
cone_beam_params = dict()
cone_beam_params['sinogram_shape'] = (num_views, num_det_rows, num_det_channels)
cone_beam_params["angles"] = angles
cone_beam_params['source_detector_dist'] = source_detector_dist
cone_beam_params['source_iso_dist'] = source_iso_dist
optional_params = dict()
optional_params['delta_det_channel'] = delta_det_channel
optional_params['delta_det_row'] = delta_det_row
optional_params['delta_voxel'] = iso_pixel_pitch
optional_params['det_row_offset'] = det_row_offset
optional_params['det_channel_offset'] = det_channel_offset
optional_params['alu_unit'] = alu_unit
optional_params['alu_value'] = alu_value
return cone_beam_params, optional_params
######## subroutines for parsing Zeiss object scan, blank scan, and dark scan
def _parse_filenames_from_dataset_dir(dataset_dir):
"""
Given a path to a Zeiss scan directory, automatically parse the paths to the following files and directories:
- the txrm file store the projection data
Args:
dataset_dir (string): Path to the directory containing the Zeiss scan files.
Returns:
Path to the txrm file storing the projection data
"""
if os.path.isfile(dataset_dir):
if dataset_dir.endswith(".txrm"):
return dataset_dir
else:
raise ValueError(f"Unsupported file type {dataset_dir}; only .txrm files are supported.")
else:
raise FileNotFoundError(dataset_dir)
def _check_read(fname):
"""
Validate the file path and ensure it has a recognized extension.
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
fname (str) : Path to the file to be read. Must be a string and have one of the recognized file extensions:
['.edf', '.tiff', '.tif', '.h5', '.hdf', '.npy', '.nc', '.xrm', '.txrm', '.txm', '.xmt', '.nxs'].
Returns:
str: Absolute path to the file.
"""
known_extensions = {
'.edf', '.tiff', '.tif', '.h5', '.hdf', '.npy', '.nc',
'.xrm', '.txrm', '.txm', '.xmt', '.nxs'
}
if not isinstance(fname, str):
logger.error('File name must be a string')
else:
_, ext = os.path.splitext(fname)
ext = ext.lower()
if ext not in known_extensions:
logger.error('Unknown file extension')
return os.path.abspath(fname)
def read_xrm(fname):
"""
Read data from xrm file.
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
fname (str): String defining the path of file or file name.
Returns:
np.ndarray: Output 2D image with shape (num_det_rows, num_det_channels).
dict: Output metadata.
"""
fname = _check_read(fname)
try:
ole = olefile.OleFileIO(fname)
except IOError:
print('No such file or directory: %s', fname)
return False
# Read metadata from xrm file
metadata = read_metadata(ole)
# Read scan data from xrm file
stream = ole.openstream("ImageData1/Image1")
data = stream.read()
# Get the data type of scan data
data_type = _get_ole_data_type(metadata)
data_type = data_type.newbyteorder('<')
# Reshape the scan data into 2D array
arr = np.reshape(
np.frombuffer(data, data_type),
(
metadata["num_det_rows"],
metadata["num_det_channels"]
)
)
_log_imported_data(fname, arr)
# Normalize the scan data
arr = mjp.utilities._normalize_to_float32(arr)
ole.close()
return arr, metadata
def read_xrm_dir(dir_path):
"""
Read all .xrm files in a directory (filesystem order), stack into (num_views, num_det_rows, num_det_cols),
and concatenate selected metadata.
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
dir_path (str) : Path to the directory to be read.
Returns:
np.ndarray: Output 3D image with shape (num_views, num_det_rows, num_det_channels).
dict: Output metadata
"""
dir_path = Path(dir_path)
files = [p for p in dir_path.iterdir() if p.is_file()]
# Load the scan data and metadata from first file
proj0, md0 = read_xrm(str(files[0]))
num_views = len(files)
num_det_rows, num_det_channels = proj0.shape
arr = np.empty((num_views, num_det_rows, num_det_channels), dtype=proj0.dtype)
arr[0] = proj0
# Load the x, y, z object positions of the first file
x0 = md0['x_positions'][0]
y0 = md0['y_positions'][0]
z0 = md0['z_positions'][0]
# Load the rotation angle of the object of the first file
angle0 = md0['thetas'][0]
metadata = dict(md0)
metadata['num_views'] = num_views
metadata['x_positions'] = [x0]
metadata['y_positions'] = [y0]
metadata['z_positions'] = [z0]
metadata['thetas'] = [angle0]
# Load the remaining files and stack them together
for i, p in enumerate(files[1:], start=1):
proj, md = read_xrm(str(p))
arr[i] = proj
metadata['x_positions'].append(md['x_positions'][0])
metadata['y_positions'].append(md['y_positions'][0])
metadata['z_positions'].append(md['z_positions'][0])
metadata['thetas'].append(md['thetas'][0])
_log_imported_data(str(dir_path), arr)
# Normalize the scan data
arr = mjp.utilities._normalize_to_float32(arr)
return arr, metadata
def read_metadata(ole):
"""
Read metadata from an xradia OLE file (.xrm, .txrm, .txm).
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
ole (OleFileIO instance) : An ole file to read from.
Returns:
dict: A dictionary of image metadata.
"""
number_of_images = _read_ole_value(ole, "ImageInfo/NoOfImages", "<I")
number_of_reference = _read_ole_reference(
ole, ["ReferenceData/ImageInfo/NoOfImages", "MultiReferenceData/ImageInfo/NoOfImages"], "<I")
metadata = {
'num_det_channels': _read_ole_value(ole, 'ImageInfo/ImageWidth', '<I'),
'num_det_rows': _read_ole_value(ole, 'ImageInfo/ImageHeight', '<I'),
'data_type': _read_ole_value(ole, 'ImageInfo/DataType', '<1I'),
'reference_data_type': _read_ole_reference(
ole, ['ReferenceData/DataType', 'MultiReferenceData/DataType'], '<1I'),
'num_views': number_of_images,
'num_reference': number_of_reference,
'iso_pixel_pitch': _read_ole_value(ole, 'ImageInfo/PixelSize', '<f'),
'det_pixel_pitch': _read_ole_value(ole, 'ImageInfo/CamPixelSize', '<f'),
'iso_det_dist': _read_ole_value(ole, 'ImageInfo/D2RADistance', '<f'),
'source_iso_dist': _read_ole_value(ole,'ImageInfo/StoRADistance', "<{0}f".format(number_of_images)),
'thetas': _read_ole_arr(
ole, 'ImageInfo/Angles', "<{0}f".format(number_of_images)),
'x_positions': _read_ole_arr(
ole, 'ImageInfo/XPosition', "<{0}f".format(number_of_images)),
'y_positions': _read_ole_arr(
ole, 'ImageInfo/YPosition', "<{0}f".format(number_of_images)),
'z_positions': _read_ole_arr(
ole, 'ImageInfo/ZPosition', "<{0}f".format(number_of_images)),
'x_shifts':_read_ole_arr(
ole, 'alignment/x-shifts', "<{0}f".format(number_of_images)),
'y_shifts': _read_ole_arr(
ole, 'alignment/y-shifts', "<{0}f".format(number_of_images)),
'current': _read_ole_value(
ole, "ImageInfo/XrayCurrent", "<{0}f".format(number_of_images)),
'voltage': _read_ole_value(
ole, "ImageInfo/XrayVoltage", "<{0}f".format(number_of_images)),
'ExpTimes': _read_ole_value(
ole, "ImageInfo/ExpTimes", "<{0}f".format(number_of_images)),
'center_shift': _read_ole_value(ole, "ReconSettings/CenterShift", '<f'),
'opt_mag': _read_ole_reference(
ole, ['ReferenceData/ImageInfo/OpticalMagnification', 'MultiReferenceData/ImageInfo/OpticalMagnification'], '<f'),
'fan_angle': _read_ole_reference(
ole, ['ReferenceData/ImageInfo/FanAngle', 'MultiReferenceData/ImageInfo/FanAngle'], '<{0}f'.format(number_of_reference)),
'cone_angle': _read_ole_reference(
ole,['ReferenceData/ImageInfo/ConeAngle', 'MultiReferenceData/ImageInfo/ConeAngle'], '<{0}f'.format(number_of_reference)),
'axis_names': _read_ole_str(ole, 'PositionInfo/AxisNames'),
'axis_units': _read_ole_str(ole, 'PositionInfo/AxisUnits')
}
reference = None
if ole.exists('ReferenceData'):
reference = _read_ole_image(ole, 'ReferenceData/Image', metadata, metadata['reference_data_type'])
elif ole.exists('MultiReferenceData'):
array_of_reference = []
for idx in range(metadata['num_reference']):
img_string = f"MultiReferenceData/Image{idx + 1}"
if ole.exists(img_string):
array_of_reference.append(
_read_ole_image(ole, img_string, metadata, metadata['reference_data_type'])
)
if len(array_of_reference) > 0:
reference = np.stack(array_of_reference, axis=0)
else:
warnings.warn('No reference data available. Using an array of all 1s.')
reference = np.ones((metadata['num_det_rows'], metadata['num_det_channels']))
if reference.ndim == 2:
reference = reference[None, :, :]
metadata['reference'] = reference
return metadata
def _log_imported_data(fname, arr):
"""
Log information about imported data.
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
fname (str) : Path of the file from which data was imported.
arr (np.ndarray) : Array containing the image data.
"""
logger.debug('Data shape & type: %s %s', arr.shape, arr.dtype)
logger.info('Data successfully imported: %s', fname)
def get_index_in_list(input_list, target):
"""
Find the index of target in the given list.
Return -1 if not present.
"""
if target in input_list:
idx = input_list.index(target)
else:
idx = -1 # or None
return idx
def _get_ole_data_type(metadata, datatype=None):
"""
Determine the Numpy data type for image data stored in a Zeiss OLE (.xrm, .txrm, .txm) file.
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
metadata (dict) : Dictionary containing metadata extracted from the OLE file.
Must include the key "data_type" which is an integer code indicating the pixel data format.
datatype (int, optional): Integer code for the data type. If None, the function uses `metadata["data_type"]`.
Returns:
np.dtype: The data type of the image data.
"""
# 10 float; 5 uint16 (unsigned 16-bit (2-byte) integers)
if datatype is None:
datatype = metadata["data_type"]
if datatype == 10:
return np.dtype(np.float32)
elif datatype == 5:
return np.dtype(np.uint16)
else:
raise Exception("Unsupport XRM datatype: %s" % str(datatype))
def _read_ole_struct(ole, label, struct_fmt):
"""
Reads the struct associated with label in an ole file
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
ole (OleFileIO) : An ole file to read from.
label (str) : Label associated with the OLE file.
struct_fmt (str) : Format of the OLE file.
Returns:
tuple or None: A tuple of unpacked values from the binary stream if the label exists.
"""
value = None
if ole.exists(label):
stream = ole.openstream(label)
data = stream.read()
value = struct.unpack(struct_fmt, data)
return value
def _read_ole_value(ole, label, struct_fmt):
"""
Reads the value associated with label in an ole file
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
ole (OleFileIO) : An ole file to read from.
label (str) : Label associated with the OLE file.
struct_fmt (str) : Format of the OLE file.
Returns:
int or float : The unpacked scalar value from the binary stream if the label exists,
"""
value = _read_ole_struct(ole, label, struct_fmt)
if value is not None:
value = value[0]
return value
def _read_ole_arr(ole, label, struct_fmt):
"""
Reads the numpy array associated with label in an ole file
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
ole (OleFileIO) : An ole file to read from.
label (str) : Label associated with the OLE file.
struct_fmt (str) : Format of the OLE file.
Returns:
np.ndarray: The unpacked numpy array from the binary stream if the label exists.
"""
arr = _read_ole_struct(ole, label, struct_fmt)
if arr is not None:
arr = np.array(arr)
return arr
def _read_ole_image(ole, label, metadata, datatype=None):
"""
Reads the image data associated with label in an ole file
Notes:
Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange
Args:
ole (OleFileIO) : An ole file to read from.
label (str) : Label associated with the OLE file.
metadata (dict) : Dictionary containing metadata extracted from the OLE file.
datatype: Data type of the image data. Defaults to None.
Returns:
np.ndarray: Output 2D image with shape (num_det_rows, num_det_channels).
"""
stream = ole.openstream(label)
data = stream.read()
data_type = _get_ole_data_type(metadata, datatype)
data_type = data_type.newbyteorder('<')
image = np.reshape(
np.frombuffer(data, data_type),
(metadata["num_det_rows"], metadata["num_det_channels"], )
)
return image
def _read_ole_str(ole, label):
"""
Reads the string associated with label in an ole file
Args:
ole (OleFileIO) : An ole file to read from.
label (str) : Label associated with the OLE file.
Returns:
list: A list contain all the strings from the binary stream if the label exists
"""
str = None
if ole.exists(label):
stream = ole.openstream(label)
data = stream.read()
str = [name.decode('utf-8') for name in data.split(b'\x00') if name]
return str
def _read_ole_reference(ole, labels, struct_fmt):
"""
Reads the reference-related value from any matching label in an ole file
First tries to unpack the data as a scalar value. If that does not work,
it attempts to unpack it as a NumPy array.
Args:
ole (OleFileIO): An ole file to read from.
labels (list of str): List of possible labels associated with the reference data parameters.
struct_fmt (str): Format of the OLE file.
Returns:
int, float, or ndarray: The unpacked value from a matching label. Return None if no matching label exists.
"""
for label in labels:
if ole.exists(label):
try:
return _read_ole_value(ole, label, struct_fmt)
except Exception:
try:
return _read_ole_arr(ole, label, struct_fmt)
except Exception:
pass
return None
######## END subroutines for parsing Zeiss object scan, blank scan, and dark scan
def correct_sino_shifts(sino, zeiss_params, downsample_factor, subsample_view_factor):
"""
Align each sinogram view based on the per-view projection offset.
The txrm file stores the horizontal (x-shift) and vertical (y-shift) offsets for each projection.
This function compensates the object's vibration by shifting each view of the sinogram accordingly.
Coordinate convention (view from source):
• x-shift: shift should be applied in the horizontal direction. Positive x-shift means the view should be shifted to the right
• y-shift: shift should be applied in the vertical direction. Positive y-shift means the view should be shifted down
For each view, the function:
1. Reads the corresponding offset (x-shift, y-shift)
2. Translate the view based on the (x-shift, y-shift)
Padding is added before shifting to handle image boundary,
and the padding is removed afterward.
Args:
sino (numpy array or jax array): 3D sinogram data with shape (num_views, num_det_rows, num_det_channels).
zeiss_params (dict): parameters stored in Zeiss txrm file.
downsample_factor (Tuple[int, int], optional): Downsample factors for detector rows and channels. Defaults to (1, 1).
subsample_view_factor (int, optional): Factor by which to subsample views. Defaults to 1.
Returns:
corrected_sino (numpy array or jax array): 3D sinogram data after alignment
"""
# Get sinogram view offset
# TODO: Currrently I assume that the view offset has units of pixels
# I test it and think that this assumption is correct
sino_x_offset = zeiss_params["x_shifts"][::subsample_view_factor]
sino_y_offset = zeiss_params["y_shifts"][::subsample_view_factor]
sino_x_offset /= downsample_factor[1]
sino_y_offset /= downsample_factor[0]
### Pad the sinogram to handle boundaries
# Set pad size as the largest shift in pixels across views
max_x_offset = np.max(sino_x_offset) - np.min(sino_x_offset)
max_y_offset = np.max(sino_y_offset) - np.min(sino_y_offset)
pad_size = int(np.ceil(np.maximum(max_x_offset, max_y_offset)))
if pad_size > 0:
sino_pad = np.pad(sino, ((0, 0), (pad_size, pad_size), (pad_size, pad_size)), mode='edge')
else:
sino_pad = sino
# Apply per-view translation
corrected_sino = np.zeros_like(sino_pad)
for view in range(sino.shape[0]):
corrected_sino[view] = jax.image.scale_and_translate(sino_pad[view],
shape=sino_pad[view].shape,
spatial_dims=(0, 1),
scale=jnp.array([1.0, 1.0]),
translation=jnp.array([sino_y_offset[view], sino_x_offset[view]]),
method='linear',
antialias=False)
# Remove padding
if pad_size > 0:
corrected_sino = corrected_sino[:, pad_size:-pad_size, pad_size:-pad_size]
return corrected_sino