Source code for mbirjax.preprocess.zeiss_cb

import os, sys
from operator import itemgetter
import numpy as np
import jax
import jax.numpy as jnp
import warnings
import mbirjax.preprocess as mjp
import pprint
import logging
import olefile
import struct
from pathlib import Path
pp = pprint.PrettyPrinter(indent=4)
logger = logging.getLogger(__name__)


[docs] def compute_sino_and_params(dataset_dir, downsample_factor=(1, 1), subsample_view_factor=1, crop_pixels_sides=0, crop_pixels_top=0, crop_pixels_bottom=0, alu_unit='mm', bg_option="global", verbose=1): """ Compute sinogram and parameters from txrm file generated by a Zeiss Versa scanner. Notes: Thanks to contributions of Amir Koushyar Ziabari of Oak Ridge National Laboratory (ORNL). Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Steps: 1. Load object, blank, and dark scans and geometry. 2. Compute the sinogram from the scans. 3. Apply background offset correction. 4. Apply sinogram offset correction. Args: dataset_dir (str): Path to the Zeiss dataset. Accepts a ``.txrm`` file with: - ``ImageData*/Image*`` (scan data) - Zeiss OLE metadata streams downsample_factor (Tuple[int, int], optional): Downsample factors for detector rows and channels. Defaults to (1, 1). subsample_view_factor (int, optional): Factor by which to subsample views. Defaults to 1. crop_pixels_sides (int, optional): Pixels to crop from each lateral side of the detector. Defaults to ``0``. crop_pixels_top (int, optional): Pixels to crop from the top of the detector. Defaults to ``0``. crop_pixels_bottom (int, optional): Pixels to crop from the bottom of the detector. Defaults to ``0``. alu_unit (str, optional): The physical unit used to define 1 ALU (Arbitrary Length Unit). Defaults to 'mm'. Supported units input: 'um', 'mm', 'cm', 'm'. bg_option(str or None): Option for background offset correction. Defaults to 'global'. Supported options: - None: No correction; return the input sinogram unchanged. - "global": Estimate one scalar offset from edge regions across all views. - "per_view": Estimate one offset per view from edge regions. verbose (int, optional): Verbosity level. Defaults to ``1``. Returns: tuple: ``(sino, cone_beam_params, optional_params, zeiss_params)`` - ``sino`` (numpy.ndarray): Sinogram of shape ``(num_views, num_det_rows, num_channels)``. - ``cone_beam_params`` (dict): Parameters for initializing ``ConeBeamModel``. - ``optional_params`` (dict): Additional parameters to be set via ``ConeBeamModel.set_params``. Example: .. code-block:: python from mbirjax.preprocess.zeiss_cb import compute_sino_and_params sino, cone_beam_params, optional_params, zeiss_params = compute_sino_and_params( dataset_dir, verbose=1 ) ct_model = mbirjax.ConeBeamModel(**cone_beam_params) ct_model.set_params(**optional_params) recon, recon_dict = ct_model.recon(sino) """ if verbose > 0: print("\n\n########## Loading object, blank, dark scans, and geometry parameters from Zeiss dataset directory") obj_scan, blank_scan, dark_scan, zeiss_params = load_scans_and_params(dataset_dir, subsample_view_factor) cone_beam_params, optional_params = convert_zeiss_to_mbirjax_params(zeiss_params, downsample_factor=downsample_factor, crop_pixels_sides=crop_pixels_sides, crop_pixels_top=crop_pixels_top, crop_pixels_bottom=crop_pixels_bottom, alu_unit=alu_unit) if verbose > 0: print("\n\n########## Cropping and downsampling scans") ### crop the scans based on input params obj_scan, blank_scan, dark_scan, defective_pixel_array = mjp.crop_view_data(obj_scan, blank_scan, dark_scan, crop_pixels_sides=crop_pixels_sides, crop_pixels_top=crop_pixels_top, crop_pixels_bottom=crop_pixels_bottom) ### downsample the scans with block-averaging if downsample_factor[0] * downsample_factor[1] > 1: obj_scan, blank_scan, dark_scan, defective_pixel_array = mjp.downsample_view_data(obj_scan, blank_scan, dark_scan, downsample_factor=downsample_factor, defective_pixel_array=defective_pixel_array) if verbose > 0: print("\n\n########## Computing sinogram from object, blank, and dark scans") sino = mjp.compute_sino_transmission(obj_scan, blank_scan, dark_scan, defective_pixel_array) if verbose > 0: print("\n\n########## Correcting sinogram data to account for background offset and sino offset") sino = mjp.correct_background_offset(sino, option=bg_option) # Correct sino offset sino = correct_sino_shifts(sino, zeiss_params, downsample_factor, subsample_view_factor) if verbose > 0: print('obj_scan shape = ', obj_scan.shape) print('blank_scan shape = ', blank_scan.shape) print('dark_scan shape = ', dark_scan.shape) return sino, cone_beam_params, optional_params
[docs] def load_scans_and_params(dataset_dir, subsample_view_factor, verbose=1): """ Load the scan data and geometry from a Zeiss scan directory. Notes: Thanks to contributions of Amir Koushyar Ziabari of Oak Ridge National Laboratory (ORNL). Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: dataset_dir (str): Path to a Zeiss scan directory (expect a `.txrm` file). Expected structure: - ``ImageData*/Image*`` (scan data) - ``**/**`` (Zeiss metadata/parameters) subsample_view_factor (int, optional): view subsample factor. verbose (int, optional): Verbosity level. Defaults to 1. Returns: tuple: (obj_scan, blank_scan, dark_scan, zeiss_params, zeiss_params) - obj_scan (numpy.ndarray): 3D object scan with shape ``(num_views, num_det_rows, num_channels)``. - blank_scan (numpy.ndarray): 3D blank scan with shape ``(1, num_det_rows, num_channels)``. - dark_scan (numpy.ndarray): 3D dark scan with shape ``(1, num_det_rows, num_channels)``. If no dark scan is available, returns a zero array of the same shape. """ ### automatically parse the paths to Zeiss scans from dataset_dir data_dir = _parse_filenames_from_dataset_dir(dataset_dir) if verbose > 0: print("The following files will be used to compute the Zeiss reconstruction:\n", f" - txrm file: {data_dir}\n") # Read object scans and metadata file_name = _check_read(data_dir) try: ole = olefile.OleFileIO(file_name) except IOError as e: print('No such file or directory: %s', file_name) raise e # Read metadata from txrm file zeiss_params = read_metadata(ole) # Create an empty array to store the scan data obj_scan = np.empty( ( zeiss_params["num_views"], zeiss_params["num_det_rows"], zeiss_params["num_det_channels"], ), dtype=_get_ole_data_type(zeiss_params) ) # Read the (subsampled) scan data from txrm file view_indices = np.arange(zeiss_params["num_views"], step=subsample_view_factor) obj_scan = np.zeros((len(view_indices), zeiss_params["num_det_rows"], zeiss_params["num_det_channels"]), dtype=_get_ole_data_type(zeiss_params)) for i, idx in enumerate(view_indices): img_string = "ImageData{}/Image{}".format( int(np.ceil((idx + 1) / 100.0)), int(idx + 1)) obj_scan[i] = _read_ole_image(ole, img_string, zeiss_params) # Read blank scans blank_scan = zeiss_params["reference"] # Read dark scans # TODO: Currently we assume that there is no dark scan for txrm file dark_scan = np.zeros(obj_scan.shape, dtype=obj_scan.dtype) try: # Get the list of measurement axis names and units axis_names = zeiss_params.get("axis_names") # This is a list of names: ['Sample X', 'Sample Y', ..., 'CCD_X', ...] axis_units = zeiss_params.get("axis_units") # This is a list of units: ['um', 'um', ..., ] # Get source to iso distance source_iso_dist = zeiss_params["source_iso_dist"] source_iso_dist = float(np.abs(source_iso_dist)) source_iso_dist_index = get_index_in_list(axis_names, 'Source Z') source_iso_dist_unit = axis_units[source_iso_dist_index] if source_iso_dist_index > -1 else 'mm' # Get iso to detector distance iso_det_dist = zeiss_params["iso_det_dist"] iso_det_dist = float(np.abs(iso_det_dist)) if iso_det_dist is not None else 0.0 iso_det_dist_unit = source_iso_dist_unit # Get detector pixel pitch det_pixel_pitch = zeiss_params["det_pixel_pitch"] det_pixel_pitch = float(np.abs(det_pixel_pitch)) # Zeiss detector pixel has equal width and height delta_det_row = det_pixel_pitch delta_det_channel = det_pixel_pitch delta_det_index = get_index_in_list(axis_names, 'CCD_X') delta_det_row_unit = axis_units[delta_det_index] if delta_det_index > -1 else 'um' delta_det_channel_unit = delta_det_row_unit # Get pixel pitch at iso iso_pixel_pitch = zeiss_params["iso_pixel_pitch"] iso_pixel_pitch = float(np.abs(iso_pixel_pitch)) iso_pixel_pitch_index = get_index_in_list(axis_names, 'Sample X') iso_pixel_pitch_unit = axis_units[iso_pixel_pitch_index] if iso_pixel_pitch_index > -1 else 'um' angle_index = get_index_in_list(axis_names, 'Sample Theta') angle_unit = axis_units[angle_index] if angle_index > -1 else 'deg' x_shifts = zeiss_params["x_shifts"] y_shifts = zeiss_params["y_shifts"] except ValueError as e: print("Unable to determine units for geometry parameters; cannot safely convert to mbirjax format.") raise e # Get optical Magnification opt_mag = zeiss_params["opt_mag"] opt_mag = 1 if opt_mag is None else opt_mag # Get dimensions of radiograph num_views = zeiss_params["num_views"] num_det_channels = zeiss_params["num_det_channels"] num_det_rows = zeiss_params["num_det_rows"] # Rotation angles angles = -np.array(zeiss_params['thetas'], dtype=float).ravel() angles = angles[view_indices] # Detector offset parameters # MBIRJAX has the reverse convention for the channel shift det_channel_offset = -zeiss_params["center_shift"] det_row_offset = 0.0 # There doesn't appear to be a Zeiss parameter for detector row offset zeiss_params = { 'source_iso_dist': source_iso_dist, 'iso_det_dist': iso_det_dist, 'delta_det_channel': delta_det_channel, 'delta_det_row': delta_det_row, 'iso_pixel_pitch': iso_pixel_pitch, 'opt_mag': opt_mag, 'num_views': num_views, 'num_det_channels': num_det_channels, 'num_det_rows': num_det_rows, 'angles': angles, 'det_row_offset': det_row_offset, 'det_channel_offset': det_channel_offset, 'source_iso_dist_unit': source_iso_dist_unit, 'iso_det_dist_unit': iso_det_dist_unit, 'delta_det_row_unit': delta_det_row_unit, 'delta_det_channel_unit': delta_det_channel_unit, 'iso_pixel_pitch_unit': iso_pixel_pitch_unit, 'angle_unit': angle_unit, 'x_shifts': x_shifts, 'y_shifts': y_shifts, } if verbose > 0: print("############ Zeiss geometry parameters ############") print(f"Source to iso distance: {source_iso_dist} [{source_iso_dist_unit}]") print(f"Iso to detector distance: {iso_det_dist} [{iso_det_dist_unit}]") print(f"Detector pixel pitch: {det_pixel_pitch:.3f} [{delta_det_row_unit}]") print(f"Source to iso distance: {iso_pixel_pitch} [{iso_pixel_pitch_unit}]") print(f"Optical magnification: {opt_mag}") print(f"Number of views: {num_views}") print(f"Detector size: (num_det_rows, num_det_channels) = ({num_det_rows}, {num_det_channels})") print("############ End Zeiss geometry parameters ############") ### END load Zeiss parameters from scan data return obj_scan, blank_scan, dark_scan, zeiss_params
def convert_zeiss_to_mbirjax_params(zeiss_params, downsample_factor=(1, 1), crop_pixels_sides=0, crop_pixels_top=0, crop_pixels_bottom=0, alu_unit='mm'): """ Convert geometry parameters from zeiss into mbirjax format, including modifications to reflect crop. Notes: Thanks to contributions of Amir Koushyar Ziabari of Oak Ridge National Laboratory (ORNL). Args: zeiss_params (dict): Required Zeiss geometry parameters for reconstruction. downsample_factor ((int, int), optional) - Down-sample factors along the detector rows and channels respectively. If scan size is not divisible by `downsample_factor`, the scans will be first truncated to a size that is divisible by `downsample_factor`. crop_pixels_sides (int, optional): The number of pixels to crop from each side of the sinogram. Defaults to 0. crop_pixels_top (int, optional): The number of pixels to crop from top of the sinogram. Defaults to 0. crop_pixels_bottom (int, optional): The number of pixels to crop from bottom of the sinogram. Defaults to 0. alu_unit (str, optional): The physical unit used to define 1 ALU (Arbitrary Length Unit). Defaults to 'mm'. Supported units input: 'um', 'mm', 'cm', 'm'. Returns: cone_beam_params (dict): Required parameters for the ConeBeamModel constructor. optional_params (dict): Additional ConeBeamModel parameters to be set using set_params() metadata (dict): metadata stored in Zeiss txrm file. """ # Get zeiss parameters source_iso_dist, iso_det_dist, source_iso_dist_unit, iso_det_dist_unit = itemgetter('source_iso_dist', 'iso_det_dist', 'source_iso_dist_unit', 'iso_det_dist_unit')(zeiss_params) delta_det_channel, delta_det_row, delta_det_channel_unit, delta_det_row_unit = itemgetter('delta_det_channel', 'delta_det_row', 'delta_det_channel_unit', 'delta_det_row_unit')(zeiss_params) iso_pixel_pitch, iso_pixel_pitch_unit = itemgetter('iso_pixel_pitch', 'iso_pixel_pitch_unit')(zeiss_params) opt_mag = itemgetter('opt_mag')(zeiss_params) num_det_rows, num_det_channels = itemgetter('num_det_rows', 'num_det_channels')(zeiss_params) angles, angle_unit = itemgetter('angles', 'angle_unit')(zeiss_params) det_row_offset, det_channel_offset = itemgetter('det_row_offset', 'det_channel_offset')(zeiss_params) # Create unit conversion table for all units used in the txrm files unit_conversion = {'um': 1.0, 'mm': 1000.0, 'cm': 1e4, 'm': 1e6} # Define 1 ALU as 1 unit of alu_unit alu_value = 1 # Convert physical units to ALU source_iso_dist *= unit_conversion[source_iso_dist_unit] / unit_conversion[alu_unit] iso_det_dist *= unit_conversion[iso_det_dist_unit] / unit_conversion[alu_unit] delta_det_channel *= unit_conversion[delta_det_channel_unit] / unit_conversion[alu_unit] delta_det_row *= unit_conversion[delta_det_row_unit] / unit_conversion[alu_unit] iso_pixel_pitch *= unit_conversion[iso_pixel_pitch_unit] / unit_conversion[alu_unit] # Compute default value of source to detector distance source_detector_dist = source_iso_dist + iso_det_dist # Convert angles to radians if angle_unit == 'deg': angles = np.deg2rad(angles) else: pass # Make conversions for optical magnification # In this case, the "detector" is actually a scintillator if opt_mag is not None: # Compute total magnification = (optical magnification) * (magnification to scintillator) scintillator_mag = source_detector_dist/source_iso_dist magnification = opt_mag * scintillator_mag else: magnification = 1.0 # Compute source to equivalent quantities accounting for total magnification source_detector_dist = magnification * source_iso_dist delta_det_channel = magnification * iso_pixel_pitch delta_det_row = magnification * iso_pixel_pitch # Adjust detector size params w.r.t. cropping arguments num_det_rows = num_det_rows - (crop_pixels_top + crop_pixels_bottom) num_det_channels = num_det_channels - 2 * crop_pixels_sides # Adjust detector size and pixel pitch params w.r.t. downsampling arguments num_det_rows = num_det_rows // downsample_factor[0] num_det_channels = num_det_channels // downsample_factor[1] delta_det_row *= downsample_factor[0] delta_det_channel *= downsample_factor[1] iso_pixel_pitch *= downsample_factor[0] # Convert to ALU: This assumes that the det_channel_offset and det_row_offset have units of pixels. det_channel_offset *= delta_det_channel det_row_offset *= delta_det_row # Create a dictionary to store MBIR parameters num_views = len(angles) cone_beam_params = dict() cone_beam_params['sinogram_shape'] = (num_views, num_det_rows, num_det_channels) cone_beam_params["angles"] = angles cone_beam_params['source_detector_dist'] = source_detector_dist cone_beam_params['source_iso_dist'] = source_iso_dist optional_params = dict() optional_params['delta_det_channel'] = delta_det_channel optional_params['delta_det_row'] = delta_det_row optional_params['delta_voxel'] = iso_pixel_pitch optional_params['det_row_offset'] = det_row_offset optional_params['det_channel_offset'] = det_channel_offset optional_params['alu_unit'] = alu_unit optional_params['alu_value'] = alu_value return cone_beam_params, optional_params ######## subroutines for parsing Zeiss object scan, blank scan, and dark scan def _parse_filenames_from_dataset_dir(dataset_dir): """ Given a path to a Zeiss scan directory, automatically parse the paths to the following files and directories: - the txrm file store the projection data Args: dataset_dir (string): Path to the directory containing the Zeiss scan files. Returns: Path to the txrm file storing the projection data """ if os.path.isfile(dataset_dir): if dataset_dir.endswith(".txrm"): return dataset_dir else: raise ValueError(f"Unsupported file type {dataset_dir}; only .txrm files are supported.") else: raise FileNotFoundError(dataset_dir) def _check_read(fname): """ Validate the file path and ensure it has a recognized extension. Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: fname (str) : Path to the file to be read. Must be a string and have one of the recognized file extensions: ['.edf', '.tiff', '.tif', '.h5', '.hdf', '.npy', '.nc', '.xrm', '.txrm', '.txm', '.xmt', '.nxs']. Returns: str: Absolute path to the file. """ known_extensions = { '.edf', '.tiff', '.tif', '.h5', '.hdf', '.npy', '.nc', '.xrm', '.txrm', '.txm', '.xmt', '.nxs' } if not isinstance(fname, str): logger.error('File name must be a string') else: _, ext = os.path.splitext(fname) ext = ext.lower() if ext not in known_extensions: logger.error('Unknown file extension') return os.path.abspath(fname) def read_xrm(fname): """ Read data from xrm file. Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: fname (str): String defining the path of file or file name. Returns: np.ndarray: Output 2D image with shape (num_det_rows, num_det_channels). dict: Output metadata. """ fname = _check_read(fname) try: ole = olefile.OleFileIO(fname) except IOError: print('No such file or directory: %s', fname) return False # Read metadata from xrm file metadata = read_metadata(ole) # Read scan data from xrm file stream = ole.openstream("ImageData1/Image1") data = stream.read() # Get the data type of scan data data_type = _get_ole_data_type(metadata) data_type = data_type.newbyteorder('<') # Reshape the scan data into 2D array arr = np.reshape( np.frombuffer(data, data_type), ( metadata["num_det_rows"], metadata["num_det_channels"] ) ) _log_imported_data(fname, arr) # Normalize the scan data arr = mjp.utilities._normalize_to_float32(arr) ole.close() return arr, metadata def read_xrm_dir(dir_path): """ Read all .xrm files in a directory (filesystem order), stack into (num_views, num_det_rows, num_det_cols), and concatenate selected metadata. Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: dir_path (str) : Path to the directory to be read. Returns: np.ndarray: Output 3D image with shape (num_views, num_det_rows, num_det_channels). dict: Output metadata """ dir_path = Path(dir_path) files = [p for p in dir_path.iterdir() if p.is_file()] # Load the scan data and metadata from first file proj0, md0 = read_xrm(str(files[0])) num_views = len(files) num_det_rows, num_det_channels = proj0.shape arr = np.empty((num_views, num_det_rows, num_det_channels), dtype=proj0.dtype) arr[0] = proj0 # Load the x, y, z object positions of the first file x0 = md0['x_positions'][0] y0 = md0['y_positions'][0] z0 = md0['z_positions'][0] # Load the rotation angle of the object of the first file angle0 = md0['thetas'][0] metadata = dict(md0) metadata['num_views'] = num_views metadata['x_positions'] = [x0] metadata['y_positions'] = [y0] metadata['z_positions'] = [z0] metadata['thetas'] = [angle0] # Load the remaining files and stack them together for i, p in enumerate(files[1:], start=1): proj, md = read_xrm(str(p)) arr[i] = proj metadata['x_positions'].append(md['x_positions'][0]) metadata['y_positions'].append(md['y_positions'][0]) metadata['z_positions'].append(md['z_positions'][0]) metadata['thetas'].append(md['thetas'][0]) _log_imported_data(str(dir_path), arr) # Normalize the scan data arr = mjp.utilities._normalize_to_float32(arr) return arr, metadata def read_metadata(ole): """ Read metadata from an xradia OLE file (.xrm, .txrm, .txm). Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: ole (OleFileIO instance) : An ole file to read from. Returns: dict: A dictionary of image metadata. """ number_of_images = _read_ole_value(ole, "ImageInfo/NoOfImages", "<I") number_of_reference = _read_ole_reference( ole, ["ReferenceData/ImageInfo/NoOfImages", "MultiReferenceData/ImageInfo/NoOfImages"], "<I") metadata = { 'num_det_channels': _read_ole_value(ole, 'ImageInfo/ImageWidth', '<I'), 'num_det_rows': _read_ole_value(ole, 'ImageInfo/ImageHeight', '<I'), 'data_type': _read_ole_value(ole, 'ImageInfo/DataType', '<1I'), 'reference_data_type': _read_ole_reference( ole, ['ReferenceData/DataType', 'MultiReferenceData/DataType'], '<1I'), 'num_views': number_of_images, 'num_reference': number_of_reference, 'iso_pixel_pitch': _read_ole_value(ole, 'ImageInfo/PixelSize', '<f'), 'det_pixel_pitch': _read_ole_value(ole, 'ImageInfo/CamPixelSize', '<f'), 'iso_det_dist': _read_ole_value(ole, 'ImageInfo/D2RADistance', '<f'), 'source_iso_dist': _read_ole_value(ole,'ImageInfo/StoRADistance', "<{0}f".format(number_of_images)), 'thetas': _read_ole_arr( ole, 'ImageInfo/Angles', "<{0}f".format(number_of_images)), 'x_positions': _read_ole_arr( ole, 'ImageInfo/XPosition', "<{0}f".format(number_of_images)), 'y_positions': _read_ole_arr( ole, 'ImageInfo/YPosition', "<{0}f".format(number_of_images)), 'z_positions': _read_ole_arr( ole, 'ImageInfo/ZPosition', "<{0}f".format(number_of_images)), 'x_shifts':_read_ole_arr( ole, 'alignment/x-shifts', "<{0}f".format(number_of_images)), 'y_shifts': _read_ole_arr( ole, 'alignment/y-shifts', "<{0}f".format(number_of_images)), 'current': _read_ole_value( ole, "ImageInfo/XrayCurrent", "<{0}f".format(number_of_images)), 'voltage': _read_ole_value( ole, "ImageInfo/XrayVoltage", "<{0}f".format(number_of_images)), 'ExpTimes': _read_ole_value( ole, "ImageInfo/ExpTimes", "<{0}f".format(number_of_images)), 'center_shift': _read_ole_value(ole, "ReconSettings/CenterShift", '<f'), 'opt_mag': _read_ole_reference( ole, ['ReferenceData/ImageInfo/OpticalMagnification', 'MultiReferenceData/ImageInfo/OpticalMagnification'], '<f'), 'fan_angle': _read_ole_reference( ole, ['ReferenceData/ImageInfo/FanAngle', 'MultiReferenceData/ImageInfo/FanAngle'], '<{0}f'.format(number_of_reference)), 'cone_angle': _read_ole_reference( ole,['ReferenceData/ImageInfo/ConeAngle', 'MultiReferenceData/ImageInfo/ConeAngle'], '<{0}f'.format(number_of_reference)), 'axis_names': _read_ole_str(ole, 'PositionInfo/AxisNames'), 'axis_units': _read_ole_str(ole, 'PositionInfo/AxisUnits') } reference = None if ole.exists('ReferenceData'): reference = _read_ole_image(ole, 'ReferenceData/Image', metadata, metadata['reference_data_type']) elif ole.exists('MultiReferenceData'): array_of_reference = [] for idx in range(metadata['num_reference']): img_string = f"MultiReferenceData/Image{idx + 1}" if ole.exists(img_string): array_of_reference.append( _read_ole_image(ole, img_string, metadata, metadata['reference_data_type']) ) if len(array_of_reference) > 0: reference = np.stack(array_of_reference, axis=0) else: warnings.warn('No reference data available. Using an array of all 1s.') reference = np.ones((metadata['num_det_rows'], metadata['num_det_channels'])) if reference.ndim == 2: reference = reference[None, :, :] metadata['reference'] = reference return metadata def _log_imported_data(fname, arr): """ Log information about imported data. Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: fname (str) : Path of the file from which data was imported. arr (np.ndarray) : Array containing the image data. """ logger.debug('Data shape & type: %s %s', arr.shape, arr.dtype) logger.info('Data successfully imported: %s', fname) def get_index_in_list(input_list, target): """ Find the index of target in the given list. Return -1 if not present. """ if target in input_list: idx = input_list.index(target) else: idx = -1 # or None return idx def _get_ole_data_type(metadata, datatype=None): """ Determine the Numpy data type for image data stored in a Zeiss OLE (.xrm, .txrm, .txm) file. Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: metadata (dict) : Dictionary containing metadata extracted from the OLE file. Must include the key "data_type" which is an integer code indicating the pixel data format. datatype (int, optional): Integer code for the data type. If None, the function uses `metadata["data_type"]`. Returns: np.dtype: The data type of the image data. """ # 10 float; 5 uint16 (unsigned 16-bit (2-byte) integers) if datatype is None: datatype = metadata["data_type"] if datatype == 10: return np.dtype(np.float32) elif datatype == 5: return np.dtype(np.uint16) else: raise Exception("Unsupport XRM datatype: %s" % str(datatype)) def _read_ole_struct(ole, label, struct_fmt): """ Reads the struct associated with label in an ole file Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: ole (OleFileIO) : An ole file to read from. label (str) : Label associated with the OLE file. struct_fmt (str) : Format of the OLE file. Returns: tuple or None: A tuple of unpacked values from the binary stream if the label exists. """ value = None if ole.exists(label): stream = ole.openstream(label) data = stream.read() value = struct.unpack(struct_fmt, data) return value def _read_ole_value(ole, label, struct_fmt): """ Reads the value associated with label in an ole file Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: ole (OleFileIO) : An ole file to read from. label (str) : Label associated with the OLE file. struct_fmt (str) : Format of the OLE file. Returns: int or float : The unpacked scalar value from the binary stream if the label exists, """ value = _read_ole_struct(ole, label, struct_fmt) if value is not None: value = value[0] return value def _read_ole_arr(ole, label, struct_fmt): """ Reads the numpy array associated with label in an ole file Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: ole (OleFileIO) : An ole file to read from. label (str) : Label associated with the OLE file. struct_fmt (str) : Format of the OLE file. Returns: np.ndarray: The unpacked numpy array from the binary stream if the label exists. """ arr = _read_ole_struct(ole, label, struct_fmt) if arr is not None: arr = np.array(arr) return arr def _read_ole_image(ole, label, metadata, datatype=None): """ Reads the image data associated with label in an ole file Notes: Portions of this code are adapted from the DXchange library: https://github.com/data-exchange/dxchange Args: ole (OleFileIO) : An ole file to read from. label (str) : Label associated with the OLE file. metadata (dict) : Dictionary containing metadata extracted from the OLE file. datatype: Data type of the image data. Defaults to None. Returns: np.ndarray: Output 2D image with shape (num_det_rows, num_det_channels). """ stream = ole.openstream(label) data = stream.read() data_type = _get_ole_data_type(metadata, datatype) data_type = data_type.newbyteorder('<') image = np.reshape( np.frombuffer(data, data_type), (metadata["num_det_rows"], metadata["num_det_channels"], ) ) return image def _read_ole_str(ole, label): """ Reads the string associated with label in an ole file Args: ole (OleFileIO) : An ole file to read from. label (str) : Label associated with the OLE file. Returns: list: A list contain all the strings from the binary stream if the label exists """ str = None if ole.exists(label): stream = ole.openstream(label) data = stream.read() str = [name.decode('utf-8') for name in data.split(b'\x00') if name] return str def _read_ole_reference(ole, labels, struct_fmt): """ Reads the reference-related value from any matching label in an ole file First tries to unpack the data as a scalar value. If that does not work, it attempts to unpack it as a NumPy array. Args: ole (OleFileIO): An ole file to read from. labels (list of str): List of possible labels associated with the reference data parameters. struct_fmt (str): Format of the OLE file. Returns: int, float, or ndarray: The unpacked value from a matching label. Return None if no matching label exists. """ for label in labels: if ole.exists(label): try: return _read_ole_value(ole, label, struct_fmt) except Exception: try: return _read_ole_arr(ole, label, struct_fmt) except Exception: pass return None ######## END subroutines for parsing Zeiss object scan, blank scan, and dark scan def correct_sino_shifts(sino, zeiss_params, downsample_factor, subsample_view_factor): """ Align each sinogram view based on the per-view projection offset. The txrm file stores the horizontal (x-shift) and vertical (y-shift) offsets for each projection. This function compensates the object's vibration by shifting each view of the sinogram accordingly. Coordinate convention (view from source): • x-shift: shift should be applied in the horizontal direction. Positive x-shift means the view should be shifted to the right • y-shift: shift should be applied in the vertical direction. Positive y-shift means the view should be shifted down For each view, the function: 1. Reads the corresponding offset (x-shift, y-shift) 2. Translate the view based on the (x-shift, y-shift) Padding is added before shifting to handle image boundary, and the padding is removed afterward. Args: sino (numpy array or jax array): 3D sinogram data with shape (num_views, num_det_rows, num_det_channels). zeiss_params (dict): parameters stored in Zeiss txrm file. downsample_factor (Tuple[int, int], optional): Downsample factors for detector rows and channels. Defaults to (1, 1). subsample_view_factor (int, optional): Factor by which to subsample views. Defaults to 1. Returns: corrected_sino (numpy array or jax array): 3D sinogram data after alignment """ # Get sinogram view offset # TODO: Currrently I assume that the view offset has units of pixels # I test it and think that this assumption is correct sino_x_offset = zeiss_params["x_shifts"][::subsample_view_factor] sino_y_offset = zeiss_params["y_shifts"][::subsample_view_factor] sino_x_offset /= downsample_factor[1] sino_y_offset /= downsample_factor[0] ### Pad the sinogram to handle boundaries # Set pad size as the largest shift in pixels across views max_x_offset = np.max(sino_x_offset) - np.min(sino_x_offset) max_y_offset = np.max(sino_y_offset) - np.min(sino_y_offset) pad_size = int(np.ceil(np.maximum(max_x_offset, max_y_offset))) if pad_size > 0: sino_pad = np.pad(sino, ((0, 0), (pad_size, pad_size), (pad_size, pad_size)), mode='edge') else: sino_pad = sino # Apply per-view translation corrected_sino = np.zeros_like(sino_pad) for view in range(sino.shape[0]): corrected_sino[view] = jax.image.scale_and_translate(sino_pad[view], shape=sino_pad[view].shape, spatial_dims=(0, 1), scale=jnp.array([1.0, 1.0]), translation=jnp.array([sino_y_offset[view], sino_x_offset[view]]), method='linear', antialias=False) # Remove padding if pad_size > 0: corrected_sino = corrected_sino[:, pad_size:-pad_size, pad_size:-pad_size] return corrected_sino