Source code for mbirjax.viewer

import os
import warnings
import matplotlib
import easygui

import mbirjax as mj

# Set backend
if os.environ.get("READTHEDOCS") == "True":
    matplotlib.use('Agg')
else:
    try:
        matplotlib.use('TkAgg')
    except ImportError:
        # matplotlib.use('Agg')  # Fallback
        warnings.warn("TkAgg not available for matplotlib. slice_viewer menus may not display correctly.")

# Now it's safe to import pyplot
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import gridspec
from matplotlib.widgets import RangeSlider, Slider, RadioButtons, CheckButtons
import time
import h5py

# === CONSTANTS ===
TOOLTIP_FONT_SIZE = 9
TOOLTIP_BOX_ALPHA = 0.9
TOOLTIP_OFFSET = (10, 10)
TOOLTIP_TEXT = (
        "Click and drag to move" + chr(10) +
        "Click edge to resize" + chr(10) +
        "Press Esc to remove"
)

CIRCLE_COLOR = 'red'
CIRCLE_LINEWIDTH = 2
CIRCLE_ALPHA = 1.0
CIRCLE_FILL = False

SLICE_AXIS_FONT_SIZE = 9
SLICE_AXIS_LABEL_FONT_SIZE = 8
SLICE_AXIS_RADIO_SIZE = 30

VALMAX_EPS = 1e-6

TKAGG = True
Y_SKIP = 28
LOAD_LABEL = 'Load'
SAVE_LABEL = 'Save data to h5'
if matplotlib.get_backend() != 'TkAgg':
    TKAGG = False
    Y_SKIP = 50
    LOAD_LABEL = 'Load disabled - requires TkAgg'
    SAVE_LABEL = 'Save disabled - requires TkAgg'


def multiline(*lines):
    return chr(10).join(lines)


class SliceViewer:
    """
    Interactive multi-volume slice viewer for 2D and 3D NumPy arrays using matplotlib.

    This class provides a graphical interface for exploring one or more 3D volumes or 2D slices.
    Features include synchronized slice navigation, ROI statistics, axis transposition, file loading,
    dynamic intensity range adjustment, and interactive GUI tools for zooming and panning.

    Designed primarily for inspecting CT or other volumetric reconstructions in research workflows.

    Args:
        *datasets (ndarray or None): One or more 2D or 3D NumPy arrays to display.
            - 2D arrays are automatically promoted to 3D via a singleton axis.
            - `None` values are replaced with placeholder zero arrays.

        data_dicts (None or dict or list of None or dicts, optional): Dictionary of string entries to associated with the data (e.g., from :meth:`get_recon_dict`)
        title (str, optional): Window title. Defaults to an empty string.
        vmin (float, optional): Minimum intensity value for display. Defaults to the global minimum across all datasets.
        vmax (float, optional): Maximum intensity value for display. Defaults to the global maximum across all datasets.
        slice_label (str or list of str, optional): Label(s) for the current slice. Defaults to "Slice".
        slice_axis (int or list of int, optional): Axis along which to slice (0, 1, or 2). Defaults to the last axis (2).
        cmap (str, optional): Colormap to use. Defaults to "gray".
        show_instructions (bool, optional): Whether to display usage instructions in the figure. Defaults to True.

    Notes:
        - Right-click an image to access a context menu with options such as axis transposition and file loading.
        - Right-click the intensity slider (if using TkAgg backend) to manually set display range bounds.
        - Press 'h' to show help overlay. Press 'Esc' to clear overlays or reset ROI selections.
    """

    def __init__(self, *datasets, data_dicts=None, title='', vmin=None, vmax=None, slice_label=None,
                 slice_axis=None, cmap='gray', show_instructions=True):
        self.datasets = datasets
        self.n_volumes = len(datasets)
        self.title = title
        self.vmin = vmin
        self.vmax = vmax
        self.slice_label = slice_label
        self.slice_axis = slice_axis
        self.cmap = cmap
        self.show_instructions = show_instructions

        self.original_data = []
        self.data = []
        if data_dicts is None:
            self.data_dicts = [None] * self.n_volumes
        else:
            if isinstance(data_dicts, dict):
                data_dicts = [data_dicts]
            if len(data_dicts) == self.n_volumes and all([isinstance(d, dict) or d is None for d in data_dicts]):
                self.data_dicts = [mj.TomographyModel.convert_subdicts_to_strings(d) for d in data_dicts]
            else:
                raise ValueError('data_dicts must be single dict or a list of dicts of the same length as the number of datasets')

        self.axes_perms = np.zeros(0).astype(int)
        self.labels = []
        self.cur_slices = []
        self.axes = [None] * self.n_volumes
        self.caxes = [None] * self.n_volumes
        self.images = [None] * self.n_volumes
        self.circles = []
        self.text_boxes = []
        self.axis_buttons = []

        self._normalize_inputs()
        self._prepare_data()
        self._build_figure()

        self._syncing_limits = True
        self._syncing_axes = False
        self._tk_button_pressed = False
        self._image_selection_dict = {}
        self._clear_image_selection()
        self._difference_image_dicts = [None] * self.n_volumes  # Entries:  comparison_index, use_abs, prev_label
        easygui.boxes.global_state.prop_font_line_length = 80
        self.show()

    def _clear_image_selection(self):
        self._image_selection_dict = {'selecting_image': False, 'baseline_index': -1}

    @staticmethod
    def _get_perm_from_slice_ind(s):
        return list({0, 1, 2} - {s}) + [s]

    def _normalize_inputs(self):
        # Set up the slice axes and labels
        if isinstance(self.slice_axis, int) or self.slice_axis is None:
            slice_axes = [2 if self.slice_axis is None else self.slice_axis] * self.n_volumes
        else:
            slice_axes = list(self.slice_axis)
        self._syncing_axes = all(slice_axis == slice_axes[0] for slice_axis in slice_axes)
        self.axes_perms = np.array([self._get_perm_from_slice_ind(s) for s in slice_axes]).astype(int)

        if isinstance(self.slice_label, str) or self.slice_label is None:
            self.labels = [f"Slice" if self.slice_label is None else self.slice_label
                           for i in range(self.n_volumes)]
        else:
            self.labels = list(self.slice_label)

    def _prepare_data(self):
        # Promote arrays to 3D if needed and find max and min
        self.original_data = list(self.datasets)
        for i, data in enumerate(self.datasets):
            if data is None:
                data = np.zeros((20, 20, 20))
                self.original_data[i] = data
            if data.ndim == 2:
                data = data[..., np.newaxis]
                self.original_data[i] = data
            elif data.ndim != 3:
                raise ValueError("Each input data must be a 2D or 3D array")
            data = np.transpose(data, self.axes_perms[i])
            self.data.append(data)

        self.cur_slices = [d.shape[2] // 2 for d in self.data]
        if self.vmin is None:
            self.vmin = np.min([np.min(d) for d in self.data])
        if self.vmax is None:
            self.vmax = np.max([np.max(d) for d in self.data])
        if self.vmin == self.vmax:
            eps = 1e-6
            scale = np.clip(eps * np.abs(self.vmax), a_min=eps, a_max=None)
            self.vmin -= scale
            self.vmax += scale

    def _build_figure(self):
        # Construct the figure to display the images and tools
        self._last_display_time = 0
        self._meshgrids = {}
        figwidth = 6 * self.n_volumes
        self.fig = plt.figure(figsize=(figwidth, 8))
        self.fig.suptitle(self.title)
        self.gs = gridspec.GridSpec(nrows=4, ncols=self.n_volumes, height_ratios=[15, 1, 1, 1])

        self._draw_images()
        self._add_slice_slider()
        self._add_intensity_slider()
        self._add_axis_controls()
        self._connect_events()
        self._resize_index = None
        self._resize_anchor = None

        if self.show_instructions:
            self.fig.text(0.01, 0.25, multiline('Press h', 'for help'), fontdict={'color': 'red'})

    def _draw_images(self, image_index=None):
        # Set up the components to display and some of the functionality

        def on_xlim_changed(ax):
            def callback(lim):
                if not self._syncing_limits:
                    return
                self._syncing_limits = False
                try:
                    for other_ax in self.axes:
                        if other_ax != ax:
                            other_ax.set_xlim(ax.get_xlim())
                finally:
                    self._syncing_limits = True
                self.fig.canvas.draw_idle()

            return callback

        def on_ylim_changed(ax):
            def callback(lim):
                if not self._syncing_limits:
                    return
                self._syncing_limits = False
                try:
                    for other_ax in self.axes:
                        if other_ax != ax:
                            other_ax.set_ylim(ax.get_ylim())
                finally:
                    self._syncing_limits = True
                self.fig.canvas.draw_idle()

            return callback

        self._remove_graphics()
        indices = [image_index] if image_index is not None else range(len(self.data))
        cur_data = [self.data[image_index]] if image_index is not None else self.data

        # Draw the images, titles, colorbars
        for i, d in zip(indices, cur_data):
            if len(self.axes) > i and self.axes[i]:
                self.axes[i].remove()
                self.caxes[i].remove()
            ax = self.fig.add_subplot(self.gs[0, i])
            img = ax.imshow(d[:, :, self.cur_slices[i]], cmap=self.cmap,
                            aspect='equal',
                            vmin=self.vmin, vmax=self.vmax)
            ax.set_title(multiline(f"{self.labels[i]} {self.cur_slices[i]}",
                                   f"Shape: {self.original_data[i].shape}, Axes: {self.axes_perms[i]}"))
            divider = make_axes_locatable(ax)
            cax = divider.append_axes('right', size='5%', pad=0.05)
            self.fig.colorbar(img, cax=cax, orientation='vertical')
            ax.zorder = 2
            cax.zorder = 1
            self.axes[i] = ax
            self.caxes[i] = cax
            self.images[i] = img

            # Sync zoom/pan
            ax.callbacks.connect('xlim_changed', on_xlim_changed(ax))
            ax.callbacks.connect('ylim_changed', on_ylim_changed(ax))

        # Rebuild tooltips to match updated axes
        self.tooltips = [
            ax.annotate(
                TOOLTIP_TEXT,
                xy=(0, 0), xytext=TOOLTIP_OFFSET, textcoords='offset points',
                ha='left', fontsize=TOOLTIP_FONT_SIZE,
                bbox=dict(boxstyle='round', fc='w', alpha=TOOLTIP_BOX_ALPHA),
                arrowprops=dict(arrowstyle='->'),
                visible=False
            ) for ax in self.axes
        ]

    def _add_slice_slider(self):
        ax = self.fig.add_subplot(self.gs[2, :])
        valmax = max(d.shape[2] for d in self.data) - 1
        valmax = max(valmax, VALMAX_EPS)
        self.slice_slider = Slider(ax, label="Slice", valmin=0,
                                   valmax=valmax,
                                   valinit=self.cur_slices[0], valfmt='%0.0f')

        self.slice_slider.on_changed(self._update_slice)

    def _add_intensity_slider(self):
        ax = self.fig.add_subplot(self.gs[3, :])
        log_range = np.log10(self.vmax - self.vmin)
        digits = max(-int(np.round(log_range)) + 2, 0)
        valfmt = '%0.' + str(digits) + 'f'
        self.intensity_slider = RangeSlider(ax=ax, label="Intensity range",
                                            valmin=self.vmin, valmax=self.vmax,
                                            valinit=(self.vmin, self.vmax), valfmt=valfmt)
        self.intensity_slider.on_changed(self._update_intensity)

        def on_right_click(event):
            # Handle a right-click on the intensity slider to change the upper and lower bounds.
            # This is enabled only when TkAgg is available.
            if event.button == 3 and event.inaxes == ax:
                title = "Adjust intensity slider range"
                msg = "Set intensity range:"
                field_names = ["Min", "Max"]
                field_values = [f"{self.vmin:.3g}", f"{self.vmax:.3g}"]
                if TKAGG:
                    self._tk_button_pressed = True  # This is designed to avoid interpreting closing this window for starting a new ROI
                    inputs = easygui.multenterbox(
                        multiline(msg, "  Cancel=use previous values", "  Leave blank=determine from data"),
                        title, field_names, field_values)
                else:
                    warnings.warn("Right-click on intensity slider requires TkAgg: use matplotlib.use('TkAgg')")
                    inputs = None
                if inputs is None:
                    return
                self._is_drawing = False
                # If the inputs are blank, then determine the min and max from data.
                if not inputs[0]:
                    inputs[0] = f"{np.min([np.min(d) for d in self.data]):.3g}"
                if not inputs[1]:
                    inputs[1] = f"{np.max([np.max(d) for d in self.data]):.3g}"
                try:
                    # Set the new min and max from inputs
                    new_min, new_max = map(float, inputs)
                    if new_min > new_max:
                        raise ValueError("Minimum must be less than maximum")
                    if new_min == new_max:
                        eps = 1e-6
                        scale = np.clip(eps * np.abs(new_max), a_min=eps, a_max=None)
                        new_min -= scale
                        new_max += scale
                    self.vmin, self.vmax = new_min, new_max
                    self.intensity_slider.valmin = self.vmin
                    self.intensity_slider.valmax = self.vmax
                    self.intensity_slider.ax.set_xlim(self.vmin, self.vmax)
                    self.intensity_slider.set_val((self.vmin, self.vmax))
                    self._update_intensity((self.vmin, self.vmax))
                    self.fig.canvas.draw_idle()
                except Exception as e:
                    easygui.msgbox(str(e), title="Invalid Input")

        self.fig.canvas.mpl_connect('button_press_event', on_right_click)

    def _add_axis_controls(self):
        # This is the radio button control to select the slice axis
        self.axis_buttons = []
        all_same = all(ax_ind == self.axes_perms[0, -1] for ax_ind in self.axes_perms[:, -1])

        if self.n_volumes == 1:
            self._create_axis_buttons(False)
        else:

            self._create_axis_buttons(not all_same)

    def _create_axis_buttons(self, decoupled):
        # If the image axes are coupled, then we have only one radio button, otherwise one for each.
        for b in self.axis_buttons:
            b.ax.remove()
        self.axis_buttons.clear()

        if not decoupled:
            ax = self.fig.add_subplot(self.gs[1, 0])
            ax.set_title("Slice axis", loc='left', fontsize=SLICE_AXIS_FONT_SIZE)
            btns = RadioButtons(ax, labels=["0", "1", "2"], radio_props={'s': [SLICE_AXIS_RADIO_SIZE]})
            for lbl in btns.labels:
                lbl.set_fontsize(SLICE_AXIS_LABEL_FONT_SIZE)
            btns.set_active(int(self.axes_perms[0, -1]))
            for i in range(self.n_volumes):
                self._update_axis(i, int(self.axes_perms[0, -1]))
            btns.on_clicked(lambda label: [self._update_axis(i, int(label)) for i in range(self.n_volumes)])
            self.axis_buttons.append(btns)
        else:
            for i in range(self.n_volumes):
                ax = self.fig.add_subplot(self.gs[1, i])
                ax.set_title("Slice axis", loc='left', fontsize=SLICE_AXIS_FONT_SIZE)
                btns = RadioButtons(ax, labels=["0", "1", "2"], radio_props={'s': [30]})
                for lbl in btns.labels:
                    lbl.set_fontsize(8)
                btns.set_active(int(self.axes_perms[i, -1]))
                btns.on_clicked(lambda label, index=i: self._update_axis(index, int(label)))
                self.axis_buttons.append(btns)

    def _toggle_decouple_slice_axes(self):
        self._syncing_axes = not self._syncing_axes
        self._create_axis_buttons(self._syncing_axes)
        self._update_slice_slider()
        self.fig.canvas.draw_idle()

    def _toggle_sync_limits(self):
        self._syncing_limits = not self._syncing_limits
        self.fig.canvas.draw_idle()

    def _update_slice_slider(self):
        # Update slice slider max if any volume changed depth
        new_max = max(d.shape[2] for d in self.data) - 1
        new_max = max(new_max, VALMAX_EPS)
        self.slice_slider.valmax = new_max
        self.slice_slider.ax.set_xlim(self.slice_slider.valmin, new_max)
        self.slice_slider.set_val(self.cur_slices[0])
        self.fig.canvas.draw_idle()

    def _update_axis(self, i, new_perm):
        # If the order of axes changed for this image, then update the data and redraw.
        if isinstance(new_perm, (list, tuple, np.ndarray)):
            new_perm = list(new_perm)
        else:
            new_perm = self._get_perm_from_slice_ind(new_perm)

        if new_perm != list(self.axes_perms[i]):
            new_data = self.data[i]
            prev_fraction = self.slice_slider.val / (self.slice_slider.valmax + VALMAX_EPS)
            new_slice_axis = new_perm[-1]
            inverse_perm = np.argsort(self.axes_perms[i])
            new_data = np.transpose(new_data, inverse_perm)
            new_num_slices = new_data.shape[new_slice_axis]
            self.axes_perms[i] = np.array(new_perm)
            new_data = np.transpose(new_data, new_perm)
            self.data[i] = new_data
            self.cur_slices[i] = int(np.round(prev_fraction * (new_num_slices - 1)))
            self.images[i].set_data(new_data[:, :, self.cur_slices[i]])
            self.axes[i].set_xlim(0, new_data.shape[1])
            self.axes[i].set_ylim(new_data.shape[0], 0)
            self.axes[i].set_aspect('equal')
            self.axes[i].set_title(multiline(f"{self.labels[i]} {self.cur_slices[i]}",
                                             f"Shape: {new_data.shape}, Axes: {self.axes_perms[0]}"))  # f"Shape: {orig.shape}"))
            self._draw_images()
            self._update_slice_slider()
            self._update_intensity(self.intensity_slider.val)
            plt.tight_layout()
            self.fig.canvas.draw_idle()

    def _update_slice(self, val):
        # Update the image to display the chosen slice.
        new_slice = int(round(val))
        for i, d in enumerate(self.data):
            frac = new_slice / d.shape[2]
            idx = int(round(frac * d.shape[2]))
            idx = np.clip(idx, 0, d.shape[2] - 1)
            self.cur_slices[i] = idx
            self.images[i].set_data(d[:, :, idx])
            self.axes[i].set_title(multiline(f"{self.labels[i]} {idx}",
                                             f"Shape: {self.original_data[i].shape}, Axes: {self.axes_perms[i]}"))  #
        self._display_mean()
        self.fig.canvas.draw_idle()

    def _update_intensity(self, val):
        # Update the intensity range.
        for img in self.images:
            img.set_clim(val[0], val[1])
        self.fig.canvas.draw_idle()

    def _get_mask(self, data, x, y, r):
        # Get a mask for the ROI
        shape = data.shape
        if shape not in self._meshgrids:
            ny, nx = shape[:2]
            self._meshgrids[shape] = np.meshgrid(np.arange(nx), np.arange(ny))
        xv, yv = self._meshgrids[shape]
        return (xv - x) ** 2 + (yv - y) ** 2 <= r ** 2

    def _display_mean(self):
        # Show the stats associated with an ROI
        if all(c is None for c in self.circles):
            return
        now = time.time()
        if now - self._last_display_time < 0.3:  # Delay each update to avoid lag for large volumes
            return
        self._last_display_time = now

        # Get the stats for the circle in each image.
        for i, circle in enumerate(self.circles):
            if circle is None or self._is_drawing or self._is_moving:
                self.tooltips[i].set_visible(False)
            x, y = circle.center
            r = circle.get_radius()
            mask = self._get_mask(self.data[i][:, :, self.cur_slices[i]], x, y, r)
            values = self.data[i][:, :, self.cur_slices[i]][mask]
            if values.size:
                mean, std = np.mean(values), np.std(values)
                cur_min, cur_max = np.min(values), np.max(values)
                if self.text_boxes[i]:
                    self.text_boxes[i].remove()
                self.text_boxes[i] = self.axes[i].text(0.05, 0.95,
                                                       multiline(f"(µ, σ)=({mean:.3g}, {std:.3g})",
                                                                 f"(min, max)=({cur_min:.3g}, {cur_max:.3g})"),
                                                       transform=self.axes[i].transAxes,
                                                       fontsize=12, va='top', bbox=dict(facecolor='white', alpha=1.0))
        self.fig.canvas.draw_idle()

    def _remove_graphics(self):
        # Remove the ROI, tooltips, etc.
        for item in self.circles + self.text_boxes:
            if item is not None:
                item.remove()
        for tip in getattr(self, 'tooltips', []):
            tip.set_visible(False)
        self.circles = [None] * self.n_volumes
        self.text_boxes = [None] * self.n_volumes

    def _create_circle(self, center, radius):
        return plt.Circle(center, radius, color=CIRCLE_COLOR, lw=CIRCLE_LINEWIDTH, fill=CIRCLE_FILL, alpha=CIRCLE_ALPHA)

    def _connect_events(self):
        # Set up the main event handling routines.
        self._menu_texts = []
        self._is_moving = False
        self._move_offset = None
        self.tooltips = [
            ax.annotate(
                TOOLTIP_TEXT,
                xy=(0, 0), xytext=TOOLTIP_OFFSET, textcoords='offset points',
                ha='left', fontsize=TOOLTIP_FONT_SIZE,
                bbox=dict(boxstyle='round', fc='w', alpha=TOOLTIP_BOX_ALPHA),
                arrowprops=dict(arrowstyle='->'),
                visible=False
            ) for ax in self.axes
        ]
        self._drag_start = None
        self._is_drawing = False  # Add a flag to track drawing state

        def on_press(event):
            if event.button != 1:
                return
            if event.inaxes not in self.axes:
                return
            if hasattr(event, 'menu_option_selected') and event.menu_option_selected:
                return  # This is set to true when the user selects a menu option.
            if hasattr(self.fig.canvas, 'toolbar') and getattr(self.fig.canvas.toolbar, 'mode', ''):
                return
            if self._tk_button_pressed:  # Skip a press if the user just chose a menu item
                if event.button == 1:  # Consume only a left-click
                    self._tk_button_pressed = False
                    return
            if self._image_selection_dict['selecting_image']:
                image_index = self.axes.index(event.inaxes)
                self._on_difference(image_index)
                return

            self._resize_index = None

            # Handle the case of resizing or moving an ROI circle
            for i, circle in enumerate(self.circles):
                if circle is None:
                    continue
                x, y = circle.center
                r = circle.get_radius()
                if event.xdata is not None and event.ydata is not None:
                    dx = event.xdata - x
                    dy = event.ydata - y
                    dist = np.sqrt(dx ** 2 + dy ** 2)
                    if abs(dist - r) <= 0.1 * r:
                        self._resize_index = i
                        self._resize_anchor = (x, y)
                        return
                    elif dx ** 2 + dy ** 2 <= r ** 2:
                        self._is_moving = True
                        self._move_offset = (dx, dy)
                        return

            # Otherwise start drawing an ROI circle
            self._remove_graphics()
            self._is_drawing = True
            self._drag_start = (event.xdata, event.ydata)
            for i, ax in enumerate(self.axes):
                circ = self._create_circle((event.xdata, event.ydata), 0)
                self.circles[i] = circ
                ax.add_patch(circ)
            self.fig.canvas.draw_idle()

        def on_motion(event):
            if event.inaxes not in self.axes:
                return
            if hasattr(self.fig.canvas, 'toolbar') and getattr(self.fig.canvas.toolbar, 'mode', ''):
                return

            # Check for motion in or out of an ROI and display a tooltip as needed.
            for i, circle in enumerate(self.circles):
                if circle is None:
                    continue
                x, y = circle.center
                r = circle.get_radius()
                if event.xdata is not None and event.ydata is not None:
                    dx = event.xdata - x
                    dy = event.ydata - y
                    dist2 = dx ** 2 + dy ** 2
                    if not (self._is_drawing or self._is_moving or self._resize_index is not None):
                        if dist2 <= (1.1 * r) ** 2:
                            self.tooltips[i].xy = self.circles[i].center
                            self.tooltips[i].set_visible(True)
                        else:
                            self.tooltips[i].set_visible(False)
                    else:
                        self.tooltips[i].set_visible(False)

            # Update the ROI circle and stats if needed.
            if self._is_drawing and self._drag_start is not None:
                x0, y0 = self._drag_start
                r = np.sqrt((event.xdata - x0) ** 2 + (event.ydata - y0) ** 2)
                for i, ax in enumerate(self.axes):
                    if self.circles[i]:
                        self.circles[i].remove()
                    circ = self._create_circle((x0, y0), r)
                    self.circles[i] = circ
                    ax.add_patch(circ)
                self._display_mean()

            elif getattr(self, '_is_moving', False):
                dx, dy = self._move_offset
                new_center = (event.xdata - dx, event.ydata - dy)
                for j in range(self.n_volumes):
                    self.circles[j].center = new_center
                self._display_mean()

            elif self._resize_index is not None and self._resize_anchor is not None:
                x0, y0 = self._resize_anchor
                new_radius = np.sqrt((event.xdata - x0) ** 2 + (event.ydata - y0) ** 2)
                for j in range(self.n_volumes):
                    self.circles[j].set_radius(new_radius)
                self._display_mean()

            self.fig.canvas.draw_idle()

        def on_release(event):
            # Reset any items from a mouse drag
            self._is_moving = False
            self._move_offset = None
            self._resize_index = None
            self._resize_anchor = None
            self._is_drawing = False
            self._drag_start = None
            self._display_mean()

        def on_key(event):
            # Handle any key press
            if event.key == 'h':
                self._show_message(True, message_type='help')
            elif event.key == 'escape':
                # Remove and reset as needed
                self._show_message(False)
                self._remove_menu()
                self._is_drawing = False
                if self._image_selection_dict['selecting_image']:
                    image_index = self._image_selection_dict['baseline_index']
                    self._difference_image_dicts[image_index] = None
                    self._clear_image_selection()
                for i in range(self.n_volumes):
                    if self.text_boxes[i] is not None:
                        self.text_boxes[i].remove()
                self.text_boxes = [None] * self.n_volumes
                self._remove_graphics()
                self.circles = [None] * self.n_volumes
                self._display_mean()
                self.fig.canvas.draw_idle()

        def on_context_menu(event):
            if self._image_selection_dict['selecting_image']:
                return
            # Start the menu for an image
            if event.button == 3 and event.inaxes in self.axes:
                image_index = self.axes.index(event.inaxes)
                # Remove any existing menu items
                self._remove_menu()
                self._render_menu(event, image_index)

        def on_context_select(event):
            # Handle the image menu selection
            if event.button != 1:
                return
            clicked = False
            for txt in self._menu_texts:
                bbox = txt.get_window_extent()
                if bbox.contains(event.x, event.y):
                    clicked = True
                    event.menu_option_selected = True
                    if hasattr(txt, '_viewer_callback'):
                        txt._viewer_callback()
            self._remove_menu()
            return clicked

        self.fig.canvas.mpl_connect('button_press_event', on_context_menu)
        self.fig.canvas.mpl_connect('button_press_event', on_context_select)

        self.fig.canvas.mpl_connect('button_press_event', on_press)
        self.fig.canvas.mpl_connect('motion_notify_event', on_motion)
        self.fig.canvas.mpl_connect('button_release_event', on_release)
        self.fig.canvas.mpl_connect('key_press_event', on_key)

    def _get_context_menu_options(self, image_index):
        # Define the image menu options, taking care that some are available only with TkAgg
        options = []
        if self.n_volumes > 1:
            options.append([
                "{} slice axes".format("Decouple" if all(
                    ax == self.axes_perms[0, -1] for ax in self.axes_perms[:, -1]) else "Couple"),
                self._toggle_decouple_slice_axes])
            options.append([
                "{} pan/zoom".format("Decouple" if self._syncing_limits else "Couple"),
                self._toggle_sync_limits])
            if self._difference_image_dicts[image_index] is None:
                options.append(["Replace with difference image", lambda: self._on_difference(image_index)])
                options.append(["Replace with error image", lambda: self._on_difference(image_index, use_abs=True)])
            else:
                options.append(['Restore original image', lambda: self._on_restore(image_index)])
        if TKAGG:
            options += [["Show data dict", lambda: self._on_show_data_dict(image_index)]]
        options += [
            ["Transpose image", lambda: self._on_transpose(image_index)],
            [LOAD_LABEL, lambda: self._on_load(image_index)],
            [SAVE_LABEL, lambda: self._on_save(image_index)],
            ["Reset", lambda: self._on_reset(image_index)],
            ["Cancel", self._remove_menu]
        ]
        return options

    def _on_show_data_dict(self, image_index):
        # Handle the display of image data_dict
        data_dict = self.data_dicts[image_index]
        if not data_dict:
            easygui.textbox(
                msg='No data dict available',
                title='No data dict',
                text='No data dict included with this image',
                codebox=False,
                callback=None,
                run=True
            )
            return

        names = list(data_dict.keys())
        if len(data_dict) > 1:
            choices = names + ['Cancel']
            msg = ['Choose an entry to display:', ' '] + names
            choice = easygui.buttonbox(
                msg=multiline(*msg),
                title='Choose an entry to display',
                choices=choices
            )
            if choice == 'Cancel':
                return
        else:
            choice = names[0]

        index = names.index(choice)
        dict_entry = data_dict[names[index]]
        easygui.textbox(
            msg=f'Dict entry = {choice}',
            title=choice,
            text=dict_entry,
            codebox=False,
            callback=None,
            run=True
        )

    def _on_transpose(self, image_index):
        # Transpose the image
        perm = self.axes_perms[image_index].copy()
        perm[0], perm[1] = perm[1], perm[0]
        self._update_axis(image_index, perm)
        self._update_intensity(self.intensity_slider.val)
        self._remove_menu()

    def _on_load(self, image_index):
        # Handle loading a file.  The uses a delay in launching the gui to avoid oddities with tk.
        if not hasattr(self.fig.canvas.manager, 'window'):
            warnings.warn("Load disabled: matplotlib backend is not TkAgg")
            return

        self._remove_menu()

        def deferred_load():
            file_path = easygui.fileopenbox(default='*.h5', filetypes=["*.npy", "*.npz", "*.h5", "*.hdf5"])
            if not file_path:
                return
            try:
                self._load_file(file_path, image_index)
            except Exception as e:
                self._show_message(True, message=f"Failed to load file: {e}. Press Esc to exit.")

        # Use TkAgg-safe scheduling to delay the gui launch.
        try:
            self.fig.canvas.manager.window.after(10, deferred_load)
        except Exception as e:
            warnings.warn("Unable to load file.  Use matplotlib.use('TkAgg') to enable file load.")

    def _on_save(self, image_index):
        # Handle saving the full 3D volume and optional data_dict
        if not hasattr(self.fig.canvas.manager, 'window'):
            warnings.warn("Save disabled: matplotlib backend is not TkAgg")
            return

        self._remove_menu()

        def deferred_save():
            # Prompt for filename
            file_path = easygui.filesavebox(
                msg="Choose filename (must end in .h5)",
                default="volume",
                # filetypes=["*.h5"]
            )
            if not file_path:
                return
            if not file_path.lower().endswith(".h5"):
                file_path += ".h5"

            # Gather allowable data_dict with multi-line textboxes and prepopulate existing values
            # Existing data_dict loaded earlier (if any)
            if self.data_dicts[image_index] is not None:
                data_dict = self.data_dicts[image_index]
            else:
                data_dict = {}
            for key in {"model_params", "notes", "recon_log", "recon_params"}.union(data_dict.keys()):
                default_val = data_dict.get(key, "")
                val = easygui.textbox(
                    msg=f"Enter value for {key} (leave blank for none):",
                    title=key,
                    text=default_val
                )
                # If user cancelled, keep existing text; if cleared, treat as empty
                if val is None:
                    val = default_val
                data_dict[key] = val

            self.data_dicts[image_index] = data_dict

            # Save full 3D volume
            data = self.original_data[image_index]
            mj.save_data_hdf5(file_path, data, 'volume', data_dict)

            easygui.msgbox(f"Saved to {file_path}", title="Save complete")

        # Schedule the save dialog after a short delay (TkAgg-safe)
        try:
            self.fig.canvas.manager.window.after(10, deferred_save)
        except Exception:
            warnings.warn("Unable to schedule save dialog. Requires TkAgg backend.")

    def _on_difference(self, image_index: int, use_abs: bool = False):

        self._remove_menu()
        # Start the selection process for this baseline
        if not self._image_selection_dict['selecting_image']:
            self._image_selection_dict['selecting_image'] = True
            self._image_selection_dict['baseline_index'] = image_index
            self._difference_image_dicts[image_index] = {'use_abs': use_abs}
            if self.n_volumes > 2:
                # Show instructions
                self._show_message(True, message_type='difference')
                # Return to await selection or Esc
                return
            else:
                image_index = 1 - image_index  # There are only 2 images, so choose the other, then continue below with setting the data

        # Save this index as the comparison and set up the plot
        if self._image_selection_dict['selecting_image']:
            # Check the shape and axes perm of the two images
            baseline_index = self._image_selection_dict['baseline_index']
            comparison_index = image_index
            if any(a != b for a, b in
                   zip(self.original_data[baseline_index].shape, self.original_data[comparison_index].shape)) or \
                    any(a != b for a, b in zip(self.axes_perms[baseline_index], self.axes_perms[comparison_index])) or \
                    baseline_index == comparison_index:
                # Show instructions
                self._show_message(True, message_type='difference')
                return  # The difference is not valid, so exit

            # Otherwise exit the selection process and set the current index as the comparison
            self._show_message(False)
            self._clear_image_selection()
            self._difference_image_dicts[baseline_index]['comparison_index'] = comparison_index

            # Set up the difference image
            difference = self.data[comparison_index] - self.data[baseline_index]
            if self._difference_image_dicts[baseline_index]['use_abs']:
                difference = np.abs(difference)
            self.data[baseline_index] = difference
            self.images[baseline_index].set_data(difference[:, :, self.cur_slices[baseline_index]])

            # Adjust the labels and refresh the display
            self._difference_image_dicts[baseline_index]['prev_label'] = self.labels[baseline_index]
            if self._difference_image_dicts[baseline_index]['use_abs']:
                label_prepend = 'abs(Image {} minus current): '.format(comparison_index)
            else:
                label_prepend = 'Image {} minus current: '.format(comparison_index)
            self.labels[baseline_index] = label_prepend + self.labels[baseline_index]
            self._update_slice_slider()
            self.fig.canvas.draw_idle()

    def _on_restore(self, image_index):
        # Restore the original image
        original_data = self.original_data[image_index]
        self.data[image_index] = np.transpose(original_data, self.axes_perms[image_index])
        self.images[image_index].set_data(self.data[image_index][:, :, self.cur_slices[image_index]])
        if self._difference_image_dicts[image_index] is not None:
            self.labels[image_index] = self._difference_image_dicts[image_index]['prev_label']
        self._difference_image_dicts[image_index] = None
        self._update_slice_slider()
        self.fig.canvas.draw_idle()

    def _on_reset(self, image_index):
        # Do a hard reset
        if self._syncing_limits:
            for i in range(self.n_volumes):
                self._draw_images(i)
        else:
            self._draw_images(image_index)
        self._update_slice_slider()
        self._update_intensity(self.intensity_slider.val)
        plt.tight_layout()
        self.fig.canvas.draw_idle()
        self._clear_image_selection()
        self._show_message(False)
        self._remove_menu()

    def _show_message(self, show, message_type=None, message=None):
        # Clear any existing message
        if hasattr(self, 'help_overlay') and self.help_overlay:
            self.help_overlay.remove()
            self.help_overlay = None
        if show:
            if message_type == 'help':
                message = ['Left-click and drag for ROI', 'Right-click an image for menu']
                if TKAGG:
                    message += ['Right-click intensity slider to adjust range']
                message += ['Press [esc] to remove ROI/menu/help', 'Close image to quit']
                message = multiline(*message)

            elif message_type == 'difference':
                message = ['Select another image of the same shape and axes permutation', 'or press [esc] to exit']
                message = multiline(*message)

            if message is None:
                return

            self.help_overlay = self.fig.text(0.25, 0.5, message,
                                              ha='left', va='center', fontsize=12,
                                              bbox=dict(facecolor='white', alpha=0.9))

        self.fig.canvas.draw_idle()

    def _load_file(self, file_path, image_index):
        # Do the actual file load, with different behavior for npy, npz, and h5/hdf5.
        ext = os.path.splitext(file_path)[-1].lower()
        data_dict = ()

        if ext == ".npy":
            new_array = np.load(file_path)

        elif ext == ".npz":
            # Check the available arrays and have the user choose.
            array_dict = np.load(file_path)
            array_names = array_dict.files
            shapes = [array_dict[name].shape for name in array_names]
            choice = _choose_array_name(array_names, shapes, file_path)
            if choice is None:
                return
            new_array = array_dict[choice]

        elif ext in {'.h5', '.hdf5'}:
            # Check the available arrays and have the user choose.
            with h5py.File(file_path, "r") as f:
                array_names = [key for key in f.keys()]
                shapes = [f[name].shape for name in array_names]
                choice = _choose_array_name(array_names, shapes, file_path)
                if choice is None:
                    return
                new_array = f[choice][()]
                data_dict = {name: f[choice].attrs[name] for name in f[choice].attrs.keys()}

        if new_array.ndim == 2:
            new_array = new_array[..., np.newaxis]
            self.original_data[image_index] = new_array
        elif new_array.ndim == 3:
            self.original_data[image_index] = new_array
        elif new_array.ndim == 4:
            num_volumes_to_load = min(new_array.shape[-1], self.n_volumes)
            self._show_message(True, message='Loading first {} volumes in 4D array. Press Esc to exit'.format(
                num_volumes_to_load))
            for j in range(num_volumes_to_load):
                self.original_data[j] = new_array[..., j]
                self.data[j] = np.transpose(self.original_data[j], self.axes_perms[j])
            image_index = min(image_index, num_volumes_to_load - 1)
            new_array = new_array[..., image_index]
        else:
            raise ValueError("Loaded array must be 2D, 3D, or 4D")

        self.data_dicts[image_index] = data_dict
        self.axes_perms[image_index] = self._get_perm_from_slice_ind(self.axes_perms[image_index][-1])
        transposed = np.transpose(new_array, self.axes_perms[image_index])
        self.data[image_index] = transposed
        self.cur_slices[image_index] = transposed.shape[2] // 2
        self._draw_images()
        self._on_restore(image_index)

    def _make_option(self, label, position, y_offset, callback):
        # Build a matplotlib menu one option at a time.  This is to allow for cross platform display.
        bounds = self.fig.bbox.bounds
        x_frac = (position.x - bounds[0]) / (bounds[2] - bounds[0])
        y_frac = (position.y + y_offset - bounds[1]) / (bounds[3] - bounds[1])
        txt = self.fig.text(
            x_frac, y_frac,
            label,
            ha='left', va='bottom', fontsize=10, color='white',
            bbox=dict(facecolor='black', edgecolor='white')
        )
        self._menu_texts.append(txt)
        txt._viewer_callback = callback

    def _render_menu(self, event, image_index):
        # Get the menu options, set them up, and display.
        options = self._get_context_menu_options(image_index)

        y_offset = 0
        y_skip = Y_SKIP
        for option in options:
            self._make_option(option[0], event, y_offset, option[1])
            y_offset -= y_skip

        self.fig.canvas.draw_idle()

    def _remove_menu(self):
        if self._menu_texts:
            for txt in self._menu_texts:
                txt.remove()
            self._menu_texts.clear()
            self.fig.canvas.draw_idle()

    def show(self):
        """Display the viewer window and block execution until the window is closed."""
        fignum = self.fig.number
        plt.show()
        # Open and close the figure to make sure it closes properly.
        plt.figure(fignum)
        plt.close(fignum)


def _choose_array_name(array_names, shapes, file_path):
    # Display a dialog for the user to choose an array from a list - for use with _load_file
    if len(array_names) > 1:
        choices = array_names + ['Cancel']
        msg = ['Arrays in {}'.format(os.path.basename(file_path))]
        for name, shape in zip(array_names, shapes):
            msg += ['{}: shape={}'.format(name, shape)]
        choice = easygui.buttonbox(msg=multiline(*msg), title='Choose the array to display', choices=choices)
        if choice == 'Cancel':
            choice = None
    else:
        choice = array_names[0]
    return choice


[docs] def slice_viewer(*datasets, data_dicts=None, title='', vmin=None, vmax=None, slice_label=None, slice_axis=None, cmap='gray', show_instructions=True): """ Launch an interactive viewer for inspecting one or more 2D or 3D image arrays. This function provides a graphical interface for exploring one or more 3D volumes or 2D slices. Features include synchronized slice navigation, ROI statistics, axis transposition, file loading, dynamic intensity range adjustment, and interactive GUI tools for zooming and panning. Each image can have an associated data dict, typically obtained from :meth:`TomographyModel.recon`, which can be viewed as a text file within the viewer. Designed primarily for inspecting CT or other volumetric reconstructions in research workflows. Args: *datasets (ndarray or None): One or more 2D or 3D NumPy arrays to display. - 2D arrays are automatically promoted to 3D via a singleton axis. - `None` values are replaced with placeholder zero arrays. data_dicts (None or dict or list of None or dicts, optional): Dictionary of string entries to associated with the data (e.g., from :meth:`TomographyModel.get_recon_dict`) title (str, optional): Window title. Defaults to an empty string. vmin (float, optional): Minimum intensity value for display. Defaults to the global minimum across all datasets. vmax (float, optional): Maximum intensity value for display. Defaults to the global maximum across all datasets. slice_label (str or list of str, optional): Label(s) for the current slice. Defaults to "Slice". slice_axis (int or list of int, optional): Axis along which to slice (0, 1, or 2). Defaults to the last axis (2). cmap (str, optional): Colormap to use. Defaults to "gray". show_instructions (bool, optional): Whether to display usage instructions in the figure. Defaults to True. Notes: - This function blocks execution until the viewer window is closed. - Right-click an image to access a context menu with options such as axis transposition and file loading. - Right-click the intensity slider (if using TkAgg backend) to manually set display range bounds. - Press 'h' to show help overlay. Press 'Esc' to clear overlays or reset ROI selections. Example: >>> denoiser = mj.QGGMRFDenoiser(noisy_image.shape) >>> denoised_image, denoised_dict = denoiser.denoise(noisy_image) # Estimate the noise level from the image >>> mj.slice_viewer(noisy_image, denoised_image, data_dicts=[None, denoised_dict], title='Noisy and denoised images') """ viewer = SliceViewer(*datasets, data_dicts=data_dicts, title=title, vmin=vmin, vmax=vmax, slice_label=slice_label, slice_axis=slice_axis, cmap=cmap, show_instructions=show_instructions)