warpfield.register

  1import warnings
  2import gc
  3import pathlib
  4import os
  5from typing import List, Union, Callable
  6
  7import numpy as np
  8import scipy.signal
  9import cupy as cp
 10import cupyx
 11import cupyx.scipy.ndimage
 12from pydantic import BaseModel, ValidationError
 13from tqdm.auto import tqdm
 14
 15from .warp import warp_volume
 16from .utils import create_rgb_video, mips_callback
 17from .ndimage import (
 18    accumarray,
 19    dogfilter,
 20    gausskernel_sheared,
 21    infill_nans,
 22    ndwindow,
 23    periodic_smooth_decomposition_nd_rfft,
 24    sliding_block,
 25    upsampled_dft_rfftn,
 26    soften_edges,
 27)
 28
 29_ArrayType = Union[np.ndarray, cp.ndarray]
 30
 31class WarpMap:
 32    """Represents a 3D displacement field
 33
 34    Args:
 35        warp_field (numpy.array): the displacement field data (3-x-y-z)
 36        block_size (3-element list or numpy.array):
 37        block_stride (3-element list or numpy.array):
 38        ref_shape (tuple): shape of the reference volume
 39        mov_shape (tuple): shape of the moving volume
 40    """
 41
 42    def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape):
 43        self.warp_field = cp.array(warp_field, dtype="float32")
 44        self.block_size = cp.array(block_size, dtype="float32")
 45        self.block_stride = cp.array(block_stride, dtype="float32")
 46        self.ref_shape = ref_shape
 47        self.mov_shape = mov_shape
 48
 49    def warp(self, vol, out=None):
 50        """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.
 51
 52        Args:
 53            vol (cupy.array): the volume to be warped
 54
 55        Returns:
 56            cupy.array: warped volume
 57        """
 58        if np.any(vol.shape != np.array(self.mov_shape)):
 59            warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.")
 60        if out is None:
 61            out = cp.zeros(self.ref_shape, dtype="float32", order="C")
 62        vol_out = warp_volume(
 63            vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out
 64        )
 65        return vol_out
 66
 67    def apply(self, *args, **kwargs):
 68        """Alias of warp method"""
 69        return self.warp(*args, **kwargs)
 70
 71    def fit_affine(self, target=None):
 72        """Fit affine transformation and return new fitted WarpMap
 73
 74        Args:
 75            target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
 76
 77        Returns:
 78            WarpMap:
 79            numpy.array: affine tranformation coefficients
 80        """
 81        if target is None:
 82            warp_field_shape = self.warp_field.shape
 83            block_size = self.block_size
 84            block_stride = self.block_stride
 85        else:
 86            warp_field_shape = target["warp_field_shape"]
 87            block_size = cp.array(target["block_size"]).astype("float32")
 88            block_stride = cp.array(target["block_stride"]).astype("float32")
 89
 90        ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T
 91        ix = ix * self.block_stride + self.block_size / 2
 92        M = cp.zeros(self.warp_field.shape[1:])
 93        #M[1:-1, 1:-1, 1:-1] = 1
 94        M[:,:,:] = 1
 95        ixg = cp.where(M.flatten() > 0)[0]
 96        a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))])
 97        b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg]
 98        coeff = cp.linalg.lstsq(a, b, rcond=None)[0]
 99        ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2
100        linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape)
101        return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff
102
103    def median_filter(self):
104        """Apply median filter to the displacement field
105
106        Returns:
107            WarpMap: new WarpMap with median filtered displacement field
108        """
109        warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest")
110        return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape)
111
112    def resize_to(self, target):
113        """Resize to target WarpMap, using linear interpolation
114
115        Args:
116            target (WarpMap or WarpMapper): target to resize to
117                or a dict with keys "shape", "block_size", and "block_stride"
118
119        Returns:
120            WarpMap: resized WarpMap
121        """
122        if isinstance(target, WarpMap):
123            t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride
124        elif isinstance(target, WarpMapper):
125            t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride)
126        elif isinstance(target, dict):
127            t_sh, t_bsz, t_bst = (
128                target["warp_field_shape"][1:],
129                cp.array(target["block_size"]),
130                cp.array(target["block_stride"]),
131            )
132        else:
133            raise ValueError("target must be a WarpMap, WarpMapper, or dict")
134        ix = cp.array(cp.indices(t_sh).reshape(3, -1))
135        # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5
136        ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None]
137        dm_r = cp.array(
138            [
139                cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape(
140                    t_sh
141                )
142                for i in range(3)
143            ]
144        )
145        return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape)
146
147    def chain(self, target):
148        """Chain displacement maps
149
150        Args:
151            target (WarpMap): WarpMap to be added to existing map
152
153        Returns:
154            WarpMap: new WarpMap with chained displacement field
155        """
156        indices = cp.indices(target.warp_field.shape[1:])
157        warp_field = self.warp_field.copy()
158        warp_field += target.warp_field
159        return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape)
160
161    def invert(self, **kwargs):
162        """alias for invert_fast method"""
163        return self.invert_fast(**kwargs)
164
165    def invert_fast(self, sigma=0.5, truncate=20):
166        """Invert the displacement field using accumulation and Gaussian basis interpolation.
167
168        Args:
169            sigma (float): standard deviation for Gaussian basis interpolation
170            truncate (float): truncate parameter for Gaussian basis interpolation
171
172        Returns:
173            WarpMap: inverted WarpMap
174        """
175        warp_field = self.warp_field.get()
176        target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get()
177        wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int")
178        num_coords = accumarray(target_coords, wf_shape)
179        inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype)
180        for i in range(3):
181            inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel())
182            with np.errstate(invalid="ignore"):
183                inv_field[i] /= num_coords
184            inv_field[i][num_coords == 0] = np.nan
185            inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate)
186        return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape)
187
188    def push_coordinates(self, coords, negative_shifts=False):
189        """Push voxel coordinates from fixed to moving space.
190
191        Args:
192            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
193
194        Returns:
195            numpy.array: transformed voxel coordinates
196        """
197        assert coords.shape[0] == 3
198        coords = cp.array(coords, dtype="float32")
199        # coords_blocked = coords / self.block_size[:, None] - 0.5
200        coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None]
201        warp_field = self.warp_field.copy()
202        shifts = cp.zeros_like(coords)
203        for idim in range(3):
204            shifts[idim] = cupyx.scipy.ndimage.map_coordinates(
205                warp_field[idim], coords_blocked, order=1, mode="nearest"
206            )
207        if negative_shifts:
208            shifts = -shifts
209        return coords + shifts
210
211    def pull_coordinates(self, coords):
212        """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.
213
214        Args:
215            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
216
217        Returns:
218            numpy.array: transformed voxel coordinates
219        """
220        return self.invert().push_coordinates(coords, negative_shifts=True)
221
222    def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1):
223        """
224        Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.
225
226        Args:
227            edge_order : passed to np.gradient (1 or 2)
228
229        Returns:
230            detJ: cp.ndarray of shape spatial
231        """
232        scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride
233        coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None]
234        phi = coords + self.warp_field
235        J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32")
236        for i in range(3):
237            grads = cp.gradient(phi[i], *scaling, edge_order=edge_order)
238            for j in range(3):
239                J[..., i, j] = grads[j]
240        return cp.linalg.det(J)
241
242    def as_ants_image(self, voxel_size_um=1):
243        """Convert to ANTsImage
244
245        Args:
246            voxel_size_um (scalar or array): voxel size (default is 1)
247
248        Returns:
249            ants.core.ants_image.ANTsImage:
250        """
251        try:
252            import ants
253        except ImportError:
254            raise ImportError("ANTs is not installed. Please install it using 'pip install ants'")
255
256        ants_image = ants.from_numpy(
257            self.warp_field.get().transpose(1, 2, 3, 0),
258            origin=list((self.block_size.get() - 1) / 2 * voxel_size_um),
259            spacing=list(self.block_stride.get() * voxel_size_um),
260            has_components=True,
261        )
262        return ants_image
263
264    def __repr__(self):
265        """String representation of the WarpMap object."""
266        info = (
267            f"WarpMap("
268            f"warp_field_shape={self.warp_field.shape}, "
269            f"block_size={self.block_size.get()}, "
270            f"block_stride={self.block_stride.get()}, "
271            f"transformation: {str(self.mov_shape)} --> {str(self.ref_shape)}"
272        )
273        return info
274
275
276class WarpMapper:
277    """Class that estimates warp field using cross-correlation, based on a piece-wise rigid model.
278
279    Args:
280        ref_vol (numpy.array): The reference volume
281        block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated
282        block_stride (3-element list or numpy.array): stride (usually identical to block_size)
283        proj_method (str or callable): Projection method
284    """
285
286    def __init__(
287        self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5
288    ):
289        self.proj_method = proj_method
290        self.plan_rev = [None, None, None]
291        self.subpixel = subpixel
292        self.epsilon = epsilon
293        self.tukey_alpha = tukey_alpha
294        self.update_reference(ref_vol, block_size, block_stride)
295        self.ref_shape = np.array(ref_vol.shape)
296        if np.any(block_size > np.array(ref_vol.shape)):
297            raise ValueError(f"Block size ({block_size}) must be smaller than the volume shape ({ref_vol.shape})")
298
299    def update_reference(self, ref_vol, block_size, block_stride=None):
300        ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1))
301        block_size = np.array(block_size)
302        block_stride = block_size if block_stride is None else np.array(block_stride)
303        ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride)
304        self.blocks_shape = ref_blocks.shape
305        ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]]
306        if self.tukey_alpha < 1:
307            ref_blocks_proj = [
308                ref_blocks_proj[i]
309                * cp.array(
310                    ndwindow(
311                        [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5)
312                    )
313                ).astype("float32")
314                for i in range(3)
315            ]
316        self.plan_fwd = [
317            cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3)
318        ]
319        self.ref_blocks_proj_ft_conj = [
320            cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3)
321        ]
322        self.block_size = block_size
323        self.block_stride = block_stride
324
325    def get_displacement(self, vol, smooth_func=None):
326        """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.
327
328        Args:
329            vol (numpy.array): Input volume
330            smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
331
332        Returns:
333            WarpMap
334        """
335        vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride)
336        vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]]
337        del vol_blocks
338
339        disp_field = []
340        for i in range(3):
341            R = (
342                cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i])
343                * self.ref_blocks_proj_ft_conj[i]
344            )
345            if self.plan_rev[i] is None:
346                self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R")
347            xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1))
348            if smooth_func is not None:
349                xcorr_proj = smooth_func(xcorr_proj, self.block_size)
350            xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon
351
352            max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:]))
353            max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2
354            del xcorr_proj
355            i0, j0 = max_ix.reshape(2, -1)
356            shifts = upsampled_dft_rfftn(
357                R.reshape(-1, *R.shape[-2:]),
358                upsampled_region_size=int(self.subpixel * 2 + 1),
359                upsample_factor=self.subpixel,
360                axis_offsets=(i0, j0),
361            )
362            del R
363            max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:]))
364            max_sub = (
365                max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2
366            ) / self.subpixel
367            del shifts
368            disp_field.append(max_ix + max_sub)
369
370        disp_field = cp.array(disp_field)
371        disp_field = (
372            cp.array(
373                [
374                    disp_field[1, 0] + disp_field[2, 0],
375                    disp_field[0, 0] + disp_field[2, 1],
376                    disp_field[0, 1] + disp_field[1, 1],
377                ]
378            ).astype("float32")
379            / 2
380        )
381        return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)
382
383
384class RegistrationPyramid:
385    """A class for performing multi-resolution registration.
386
387    Args:
388        ref_vol (numpy.array): Reference volume
389        settings (pandas.DataFrame): Settings for each level of the pyramid.
390            IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level.
391        reg_mask (numpy.array): Mask for registration
392        clip_thresh (float): Threshold for clipping the reference volume
393    """
394
395    def __init__(self, ref_vol, recipe, reg_mask=1):
396        recipe.model_validate(recipe.model_dump())
397        self.recipe = recipe
398        self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C")
399        self.mappers = []
400        ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C")
401        self.ref_shape = ref_vol.shape
402        if self.recipe.pre_filter is not None:
403            ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask)
404        self.mapper_ix = []
405        for i in range(len(recipe.levels)):
406            if recipe.levels[i].repeats < 1:
407                continue
408            block_size = np.array(recipe.levels[i].block_size)
409            tmp = np.r_[ref_vol.shape] // -block_size
410            block_size[block_size < 0] = tmp[block_size < 0]
411            if isinstance(recipe.levels[i].block_stride, (int, float)):
412                block_stride = np.round(block_size * recipe.levels[i].block_stride).astype("int")
413            else:
414                block_stride = np.array(recipe.levels[i].block_stride)
415            self.mappers.append(
416                WarpMapper(
417                    ref_vol,
418                    block_size,
419                    block_stride=block_stride,
420                    proj_method=recipe.levels[i].project,
421                    tukey_alpha=recipe.levels[i].tukey_ref,
422                )
423            )
424            self.mapper_ix.append(i)
425        assert len(self.mappers) > 0, "At least one level of registration is required"
426
427    def register_single(self, vol, callback=None, verbose=False):
428        """Register a single volume to the reference volume.
429
430        Args:
431            vol (array_like): Volume to be registered (numpy or cupy array)
432            callback (function): Callback function to be called after each level of registration
433
434        Returns:
435            - vol (array_like): Registered volume (numpy or cupy array, depending on input)
436            - warp_map (WarpMap): Displacement field
437            - callback_output (list): List of outputs from the callback function
438        """
439        was_numpy = isinstance(vol, np.ndarray)
440        vol = cp.array(vol, "float32", copy=False, order="C")
441        offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2
442        warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape)
443        warp_map = warp_map.resize_to(self.mappers[-1])
444        callback_output = []
445        vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol
446        vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C")
447        warp_map.warp(vol_tmp0, out=vol_tmp)
448        min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0)
449        if callback is not None:
450            callback_output.append(callback(vol_tmp))
451
452        if np.any(self.mappers[-1].block_stride > min_block_stride[0]):
453            warnings.warn(
454                "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)."
455            )
456        for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)):
457            for _ in tqdm(
458                range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose
459            ):
460                wm = mapper.get_displacement(
461                    vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth  # * self.reg_mask,
462                )
463                wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate
464                if self.recipe.levels[self.mapper_ix[k]].median_filter:
465                    wm = wm.median_filter()
466                if self.recipe.levels[self.mapper_ix[k]].affine:
467                    if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1:
468                        raise ValueError(
469                            f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}"
470                        )
471                    wm, _ = wm.fit_affine(
472                        target=dict(
473                            warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]),
474                            block_size=self.mappers[-1].block_size,
475                            block_stride=self.mappers[-1].block_stride,
476                        )
477                    )
478                else:
479                    wm = wm.resize_to(self.mappers[-1])
480
481                warp_map = warp_map.chain(wm)
482                warp_map.warp(vol_tmp0, out=vol_tmp)
483                if callback is not None:
484                    # callback_output.append(callback(warp_map.unwarp(vol)))
485                    callback_output.append(callback(vol_tmp))
486        warp_map.warp(vol, out=vol_tmp)
487        if was_numpy:
488            vol_tmp = vol_tmp.get()
489        return vol_tmp, warp_map, callback_output
490
491
492def register_volumes(ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None):
493    """Register a volume to a reference volume using a registration pyramid.
494
495    Args:
496        ref (numpy.array or cupy.array): Reference volume
497        vol (numpy.array or cupy.array): Volume to be registered
498        recipe (Recipe): Registration recipe
499        reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask)
500        callback (function): Callback function to be called on the volume after each iteration. Default is None.
501            Can be used to monitor and optimize registration. Example: `callback = lambda vol: vol.mean(1).get()`
502            (note that `vol` is a 3D cupy array. Use `.get()` to turn the output into a numpy array and save GPU memory).
503            Callback outputs for each registration step will be returned as a list.
504        verbose (bool): If True, show progress bars. Default is True
505        video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None.
506        vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values.
507
508    Returns:
509        - numpy.array or cupy.array (depending on vol input): Registered volume
510        - WarpMap: Displacement field
511        - list: List of outputs from the callback function
512    """
513    recipe.model_validate(recipe.model_dump())
514    reg = RegistrationPyramid(ref, recipe, reg_mask=reg_mask)
515    registered_vol, warp_map, cbout = reg.register_single(vol, callback=callback, verbose=verbose)
516    del reg
517    gc.collect()
518    cp.fft.config.get_plan_cache().clear()
519
520    if video_path is not None:
521        try:
522            assert cbout[0].ndim == 2, "Callback output must be a 2D array"
523            ref = callback(recipe.pre_filter(ref))
524            vmax = np.percentile(ref, 99.9).item() if vmax is None else vmax
525            create_rgb_video(video_path, ref / vmax, np.array(cbout) / vmax, fps=10)
526        except (ValueError, AssertionError) as e:
527            warnings.warn(f"Video generation failed with error: {e}")
528    return registered_vol, warp_map, cbout
529
530
531class Projector(BaseModel):
532    """A class to apply a 2D projection and filters to a volume block
533
534    Parameters:
535        max: if True, apply a max filter to the volume block. Default is True
536        normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False
537        dog: if True, apply a DoG filter to the volume block. Default is True
538        low: the lower sigma value for the DoG filter. Default is 0.5
539        high: the higher sigma value for the DoG filter. Default is 10.0
540        tukey_env: if True, apply a Tukey window to the output. Default is False
541        gauss_env: if True, apply a Gaussian window to the output. Default is False
542    """
543
544    max: bool = True
545    normalize: Union[bool, float] = False
546    dog: bool = True
547    low: Union[Union[int, float], List[Union[int, float]]] = 0.5
548    high: Union[Union[int, float], List[Union[int, float]]] = 10.0
549    periodic_smooth: bool = False
550
551    def __call__(self, vol_blocks, axis):
552        """Apply a 2D projection and filters to a volume block
553        Args:
554            vol_blocks (cupy.array): Blocked volume to be projected (6D dataset, with the first 3 dimensions being blocks and the last 3 dimensions being voxels)
555            axis (int): Axis along which to project
556        Returns:
557            cupy.array: Projected volume block (5D dataset, with the first 3 dimensions being blocks and the last 2 dimensions being 2D projections)
558        """
559        if self.max:
560            out = vol_blocks.max(axis)
561        else:
562            out = vol_blocks.mean(axis)
563        if self.periodic_smooth:
564            out = periodic_smooth_decomposition_nd_rfft(out)
565        low = np.delete(np.r_[1,1,1] * self.low, axis)
566        high = np.delete(np.r_[1,1,1] * self.high, axis)
567        if self.dog:
568            out = dogfilter(out, [0, 0, 0, *low], [0, 0, 0, *high], mode="reflect")
569        elif not np.all(np.array(self.low) == 0):
570            out = cupyx.scipy.ndimage.gaussian_filter(out, [0, 0, 0, *low], mode="reflect", truncate=5.0)
571        if self.normalize > 0:
572            out /= cp.sqrt(cp.sum(out**2, axis=(-2, -1), keepdims=True)) ** self.normalize + 1e-9
573        return out
574
575
576class Smoother(BaseModel):
577    """Smooth blocks with a Gaussian kernel
578    Args:
579        sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied.
580        truncate (float): truncate parameter for gaussian kernel. Default is 5.
581        shear (float): shear parameter for gaussian kernel. Default is None.
582        long_range_ratio (float): long range ratio for double gaussian kernel. Default is None.
583    """
584
585    sigmas: Union[float, List[float]] = [1.0, 1.0, 1.0]
586    shear: Union[float, None] = None
587    long_range_ratio: Union[float, None] = 0.05
588
589    def __call__(self, xcorr_proj, block_size=None):
590        """Apply a Gaussian filter to the cross-correlation data
591        Args:
592            xcorr_proj (cupy.array): cross-correlation data (5D array, with the first 3 dimensions being the blocks and the last 2 dimensions being the 2D projection)
593            block_size (list): shape of blocks, whose rigid displacement is estimated
594        Returns:
595            cupy.array: smoothed cross-correlation volume
596        """
597        truncate = 4.0
598        if self.sigmas is None:
599            return xcorr_proj
600        if self.shear is not None:
601            shear_blocks = self.shear * (block_size[1] / block_size[0])
602            gw = gausskernel_sheared(self.sigma[:2], shear_blocks, truncate=truncate)
603            gw = cp.array(gw[:, :, None, None, None])
604            xcorr_proj = cupyx.scipy.ndimage.convolve(xcorr_proj, gw, mode="constant")
605            xcorr_proj = cupyx.scipy.ndimage.gaussian_filter1d(
606                xcorr_proj, self.sigmas[2], axis=2, mode="constant", truncate=truncate
607            )
608        else:  # shear is None:
609            xcorr_proj = cupyx.scipy.ndimage.gaussian_filter(
610                xcorr_proj, [*self.sigmas, 0, 0], mode="constant", truncate=truncate
611            )
612        if self.long_range_ratio is not None:
613            xcorr_proj *= 1 - self.long_range_ratio
614            xcorr_proj += (
615                cupyx.scipy.ndimage.gaussian_filter(
616                    xcorr_proj, [*np.array(self.sigmas) * 5, 0, 0], mode="constant", truncate=truncate
617                )
618                * self.long_range_ratio
619            )
620        return xcorr_proj
621
622
623class RegFilter(BaseModel):
624    """A class to apply a filter to the volume before registration
625
626    Parameters:
627        clip_thresh: threshold for clipping the reference volume. Default is 0
628        dog: if True, apply a DoG filter to the volume. Default is True
629        low: the lower sigma value for the DoG filter. Default is 0.5
630        high: the higher sigma value for the DoG filter. Default is 10.0
631    """
632
633    clip_thresh: float = 0
634    dog: bool = True
635    low: float = 0.5
636    high: float = 10.0
637    soft_edge: Union[Union[int, float], List[Union[int, float]]] = 0.0
638
639    def __call__(self, vol, reg_mask=None):
640        """Apply the filter to the volume
641        Args:
642            vol (cupy or numpy array): 3D volume to be filtered
643            reg_mask (array): Mask for registration
644        Returns:
645            cupy.ndarray: Filtered volume
646        """
647        vol = cp.clip(cp.array(vol, "float32", copy=False) - self.clip_thresh, 0, None)
648        if np.any(np.array(self.soft_edge) > 0):
649            vol = soften_edges(vol, soft_edge=self.soft_edge, copy=False)
650        if reg_mask is not None:
651            vol *= cp.array(reg_mask, dtype="float32", copy=False)
652        if self.dog:
653            vol = dogfilter(vol, self.low, self.high, mode="reflect")
654        return vol
655
656
657class LevelConfig(BaseModel):
658    """Configuration for each level of the registration pyramid
659
660    Args:
661        block_size (list): shape of blocks, whose rigid displacement is estimated
662        block_stride (list): stride (usually identical to block_size)
663        repeats (int): number of iterations for this level (deisable level by setting repeats to 0)
664        smooth (Smoother or None): Smoother object
665        project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block.
666        tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5
667        affine (bool): if True, apply affine transformation to the displacement field
668        median_filter (bool): if True, apply median filter to the displacement field
669        update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations.
670    """
671
672    block_size: Union[List[int]]
673    block_stride: Union[List[int], float] = 1.0
674    project: Union[Projector, Callable[[_ArrayType, int], _ArrayType]] = Projector()
675    tukey_ref: Union[float, None] = 0.5
676    smooth: Union[Smoother, None] = Smoother()
677    affine: bool = False
678    median_filter: bool = True
679    update_rate: float = 1.0
680    repeats: int = 5
681
682
683class Recipe(BaseModel):
684    """Configuration for the registration recipe. Recipe is initialized with a single affine level.
685
686    Args:
687        reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume
688        levels (list): List of LevelConfig objects
689    """
690
691    pre_filter: Union[RegFilter, Callable[[_ArrayType], _ArrayType], None] = RegFilter()
692    levels: List[LevelConfig] = [
693        LevelConfig(block_size=[-1, -1, -1], repeats=1),  # translation level
694        LevelConfig(  # affine level
695            block_size=[-4, -4, -4],
696            repeats=10,
697            affine=True,
698            median_filter=False,
699            smooth=Smoother(sigmas=[0.5, 0.5, 0.5]),
700        ),
701    ]
702
703    def add_level(self, block_size, **kwargs):
704        """Add a level to the registration recipe
705
706        Args:
707            block_size (list): shape of blocks, whose rigid displacement is estimated
708            **kwargs: additional arguments for LevelConfig
709        """
710        if isinstance(block_size, (int, float)):
711            block_size = [block_size] * 3
712        if len(block_size) != 3:
713            raise ValueError("block_size must be a list of 3 integers")
714        self.levels.append(LevelConfig(block_size=block_size, **kwargs))
715
716    def insert_level(self, index, block_size, **kwargs):
717        """Insert a level to the registration recipe
718
719        Args:
720            index (int): A number specifying in which position to insert the level
721            block_size (list): shape of blocks, whose rigid displacement is estimated
722            **kwargs: additional arguments for LevelConfig
723        """
724        if isinstance(block_size, (int, float)):
725            block_size = [block_size] * 3
726        if len(block_size) != 3:
727            raise ValueError("block_size must be a list of 3 integers")
728        self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs))
729
730    @classmethod
731    def from_yaml(cls, yaml_path):
732        """Load a recipe from a YAML file
733
734        Args:
735            yaml_path (str): path to the YAML file
736
737        Returns:
738            Recipe: Recipe object
739        """
740        import yaml
741
742        this_file_dir = pathlib.Path(__file__).resolve().parent
743        if os.path.isfile(yaml_path):
744            yaml_path = yaml_path
745        else:
746            yaml_path = os.path.join(this_file_dir, "recipes", yaml_path)
747
748        with open(yaml_path, "r") as f:
749            data = yaml.safe_load(f)
750
751        return cls.model_validate(data)
752
753    def to_yaml(self, yaml_path):
754        """Save the recipe to a YAML file
755
756        Args:
757            yaml_path (str): path to the YAML file
758        """
759        import yaml
760
761        with open(yaml_path, "w") as f:
762            yaml.dump(self.model_dump(), f)
763        print(f"Recipe saved to {yaml_path}")
class WarpMap:
 32class WarpMap:
 33    """Represents a 3D displacement field
 34
 35    Args:
 36        warp_field (numpy.array): the displacement field data (3-x-y-z)
 37        block_size (3-element list or numpy.array):
 38        block_stride (3-element list or numpy.array):
 39        ref_shape (tuple): shape of the reference volume
 40        mov_shape (tuple): shape of the moving volume
 41    """
 42
 43    def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape):
 44        self.warp_field = cp.array(warp_field, dtype="float32")
 45        self.block_size = cp.array(block_size, dtype="float32")
 46        self.block_stride = cp.array(block_stride, dtype="float32")
 47        self.ref_shape = ref_shape
 48        self.mov_shape = mov_shape
 49
 50    def warp(self, vol, out=None):
 51        """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.
 52
 53        Args:
 54            vol (cupy.array): the volume to be warped
 55
 56        Returns:
 57            cupy.array: warped volume
 58        """
 59        if np.any(vol.shape != np.array(self.mov_shape)):
 60            warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.")
 61        if out is None:
 62            out = cp.zeros(self.ref_shape, dtype="float32", order="C")
 63        vol_out = warp_volume(
 64            vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out
 65        )
 66        return vol_out
 67
 68    def apply(self, *args, **kwargs):
 69        """Alias of warp method"""
 70        return self.warp(*args, **kwargs)
 71
 72    def fit_affine(self, target=None):
 73        """Fit affine transformation and return new fitted WarpMap
 74
 75        Args:
 76            target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
 77
 78        Returns:
 79            WarpMap:
 80            numpy.array: affine tranformation coefficients
 81        """
 82        if target is None:
 83            warp_field_shape = self.warp_field.shape
 84            block_size = self.block_size
 85            block_stride = self.block_stride
 86        else:
 87            warp_field_shape = target["warp_field_shape"]
 88            block_size = cp.array(target["block_size"]).astype("float32")
 89            block_stride = cp.array(target["block_stride"]).astype("float32")
 90
 91        ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T
 92        ix = ix * self.block_stride + self.block_size / 2
 93        M = cp.zeros(self.warp_field.shape[1:])
 94        #M[1:-1, 1:-1, 1:-1] = 1
 95        M[:,:,:] = 1
 96        ixg = cp.where(M.flatten() > 0)[0]
 97        a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))])
 98        b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg]
 99        coeff = cp.linalg.lstsq(a, b, rcond=None)[0]
100        ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2
101        linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape)
102        return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff
103
104    def median_filter(self):
105        """Apply median filter to the displacement field
106
107        Returns:
108            WarpMap: new WarpMap with median filtered displacement field
109        """
110        warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest")
111        return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape)
112
113    def resize_to(self, target):
114        """Resize to target WarpMap, using linear interpolation
115
116        Args:
117            target (WarpMap or WarpMapper): target to resize to
118                or a dict with keys "shape", "block_size", and "block_stride"
119
120        Returns:
121            WarpMap: resized WarpMap
122        """
123        if isinstance(target, WarpMap):
124            t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride
125        elif isinstance(target, WarpMapper):
126            t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride)
127        elif isinstance(target, dict):
128            t_sh, t_bsz, t_bst = (
129                target["warp_field_shape"][1:],
130                cp.array(target["block_size"]),
131                cp.array(target["block_stride"]),
132            )
133        else:
134            raise ValueError("target must be a WarpMap, WarpMapper, or dict")
135        ix = cp.array(cp.indices(t_sh).reshape(3, -1))
136        # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5
137        ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None]
138        dm_r = cp.array(
139            [
140                cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape(
141                    t_sh
142                )
143                for i in range(3)
144            ]
145        )
146        return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape)
147
148    def chain(self, target):
149        """Chain displacement maps
150
151        Args:
152            target (WarpMap): WarpMap to be added to existing map
153
154        Returns:
155            WarpMap: new WarpMap with chained displacement field
156        """
157        indices = cp.indices(target.warp_field.shape[1:])
158        warp_field = self.warp_field.copy()
159        warp_field += target.warp_field
160        return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape)
161
162    def invert(self, **kwargs):
163        """alias for invert_fast method"""
164        return self.invert_fast(**kwargs)
165
166    def invert_fast(self, sigma=0.5, truncate=20):
167        """Invert the displacement field using accumulation and Gaussian basis interpolation.
168
169        Args:
170            sigma (float): standard deviation for Gaussian basis interpolation
171            truncate (float): truncate parameter for Gaussian basis interpolation
172
173        Returns:
174            WarpMap: inverted WarpMap
175        """
176        warp_field = self.warp_field.get()
177        target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get()
178        wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int")
179        num_coords = accumarray(target_coords, wf_shape)
180        inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype)
181        for i in range(3):
182            inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel())
183            with np.errstate(invalid="ignore"):
184                inv_field[i] /= num_coords
185            inv_field[i][num_coords == 0] = np.nan
186            inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate)
187        return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape)
188
189    def push_coordinates(self, coords, negative_shifts=False):
190        """Push voxel coordinates from fixed to moving space.
191
192        Args:
193            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
194
195        Returns:
196            numpy.array: transformed voxel coordinates
197        """
198        assert coords.shape[0] == 3
199        coords = cp.array(coords, dtype="float32")
200        # coords_blocked = coords / self.block_size[:, None] - 0.5
201        coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None]
202        warp_field = self.warp_field.copy()
203        shifts = cp.zeros_like(coords)
204        for idim in range(3):
205            shifts[idim] = cupyx.scipy.ndimage.map_coordinates(
206                warp_field[idim], coords_blocked, order=1, mode="nearest"
207            )
208        if negative_shifts:
209            shifts = -shifts
210        return coords + shifts
211
212    def pull_coordinates(self, coords):
213        """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.
214
215        Args:
216            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
217
218        Returns:
219            numpy.array: transformed voxel coordinates
220        """
221        return self.invert().push_coordinates(coords, negative_shifts=True)
222
223    def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1):
224        """
225        Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.
226
227        Args:
228            edge_order : passed to np.gradient (1 or 2)
229
230        Returns:
231            detJ: cp.ndarray of shape spatial
232        """
233        scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride
234        coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None]
235        phi = coords + self.warp_field
236        J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32")
237        for i in range(3):
238            grads = cp.gradient(phi[i], *scaling, edge_order=edge_order)
239            for j in range(3):
240                J[..., i, j] = grads[j]
241        return cp.linalg.det(J)
242
243    def as_ants_image(self, voxel_size_um=1):
244        """Convert to ANTsImage
245
246        Args:
247            voxel_size_um (scalar or array): voxel size (default is 1)
248
249        Returns:
250            ants.core.ants_image.ANTsImage:
251        """
252        try:
253            import ants
254        except ImportError:
255            raise ImportError("ANTs is not installed. Please install it using 'pip install ants'")
256
257        ants_image = ants.from_numpy(
258            self.warp_field.get().transpose(1, 2, 3, 0),
259            origin=list((self.block_size.get() - 1) / 2 * voxel_size_um),
260            spacing=list(self.block_stride.get() * voxel_size_um),
261            has_components=True,
262        )
263        return ants_image
264
265    def __repr__(self):
266        """String representation of the WarpMap object."""
267        info = (
268            f"WarpMap("
269            f"warp_field_shape={self.warp_field.shape}, "
270            f"block_size={self.block_size.get()}, "
271            f"block_stride={self.block_stride.get()}, "
272            f"transformation: {str(self.mov_shape)} --> {str(self.ref_shape)}"
273        )
274        return info

Represents a 3D displacement field

Arguments:
  • warp_field (numpy.array): the displacement field data (3-x-y-z)
  • block_size (3-element list or numpy.array):
  • block_stride (3-element list or numpy.array):
  • ref_shape (tuple): shape of the reference volume
  • mov_shape (tuple): shape of the moving volume
WarpMap(warp_field, block_size, block_stride, ref_shape, mov_shape)
43    def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape):
44        self.warp_field = cp.array(warp_field, dtype="float32")
45        self.block_size = cp.array(block_size, dtype="float32")
46        self.block_stride = cp.array(block_stride, dtype="float32")
47        self.ref_shape = ref_shape
48        self.mov_shape = mov_shape
warp_field
block_size
block_stride
ref_shape
mov_shape
def warp(self, vol, out=None):
50    def warp(self, vol, out=None):
51        """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.
52
53        Args:
54            vol (cupy.array): the volume to be warped
55
56        Returns:
57            cupy.array: warped volume
58        """
59        if np.any(vol.shape != np.array(self.mov_shape)):
60            warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.")
61        if out is None:
62            out = cp.zeros(self.ref_shape, dtype="float32", order="C")
63        vol_out = warp_volume(
64            vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out
65        )
66        return vol_out

Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.

Arguments:
  • vol (cupy.array): the volume to be warped
Returns:

cupy.array: warped volume

def apply(self, *args, **kwargs):
68    def apply(self, *args, **kwargs):
69        """Alias of warp method"""
70        return self.warp(*args, **kwargs)

Alias of warp method

def fit_affine(self, target=None):
 72    def fit_affine(self, target=None):
 73        """Fit affine transformation and return new fitted WarpMap
 74
 75        Args:
 76            target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
 77
 78        Returns:
 79            WarpMap:
 80            numpy.array: affine tranformation coefficients
 81        """
 82        if target is None:
 83            warp_field_shape = self.warp_field.shape
 84            block_size = self.block_size
 85            block_stride = self.block_stride
 86        else:
 87            warp_field_shape = target["warp_field_shape"]
 88            block_size = cp.array(target["block_size"]).astype("float32")
 89            block_stride = cp.array(target["block_stride"]).astype("float32")
 90
 91        ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T
 92        ix = ix * self.block_stride + self.block_size / 2
 93        M = cp.zeros(self.warp_field.shape[1:])
 94        #M[1:-1, 1:-1, 1:-1] = 1
 95        M[:,:,:] = 1
 96        ixg = cp.where(M.flatten() > 0)[0]
 97        a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))])
 98        b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg]
 99        coeff = cp.linalg.lstsq(a, b, rcond=None)[0]
100        ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2
101        linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape)
102        return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff

Fit affine transformation and return new fitted WarpMap

Arguments:
  • target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
Returns:

WarpMap: numpy.array: affine tranformation coefficients

def median_filter(self):
104    def median_filter(self):
105        """Apply median filter to the displacement field
106
107        Returns:
108            WarpMap: new WarpMap with median filtered displacement field
109        """
110        warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest")
111        return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape)

Apply median filter to the displacement field

Returns:

WarpMap: new WarpMap with median filtered displacement field

def resize_to(self, target):
113    def resize_to(self, target):
114        """Resize to target WarpMap, using linear interpolation
115
116        Args:
117            target (WarpMap or WarpMapper): target to resize to
118                or a dict with keys "shape", "block_size", and "block_stride"
119
120        Returns:
121            WarpMap: resized WarpMap
122        """
123        if isinstance(target, WarpMap):
124            t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride
125        elif isinstance(target, WarpMapper):
126            t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride)
127        elif isinstance(target, dict):
128            t_sh, t_bsz, t_bst = (
129                target["warp_field_shape"][1:],
130                cp.array(target["block_size"]),
131                cp.array(target["block_stride"]),
132            )
133        else:
134            raise ValueError("target must be a WarpMap, WarpMapper, or dict")
135        ix = cp.array(cp.indices(t_sh).reshape(3, -1))
136        # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5
137        ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None]
138        dm_r = cp.array(
139            [
140                cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape(
141                    t_sh
142                )
143                for i in range(3)
144            ]
145        )
146        return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape)

Resize to target WarpMap, using linear interpolation

Arguments:
  • target (WarpMap or WarpMapper): target to resize to or a dict with keys "shape", "block_size", and "block_stride"
Returns:

WarpMap: resized WarpMap

def chain(self, target):
148    def chain(self, target):
149        """Chain displacement maps
150
151        Args:
152            target (WarpMap): WarpMap to be added to existing map
153
154        Returns:
155            WarpMap: new WarpMap with chained displacement field
156        """
157        indices = cp.indices(target.warp_field.shape[1:])
158        warp_field = self.warp_field.copy()
159        warp_field += target.warp_field
160        return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape)

Chain displacement maps

Arguments:
  • target (WarpMap): WarpMap to be added to existing map
Returns:

WarpMap: new WarpMap with chained displacement field

def invert(self, **kwargs):
162    def invert(self, **kwargs):
163        """alias for invert_fast method"""
164        return self.invert_fast(**kwargs)

alias for invert_fast method

def invert_fast(self, sigma=0.5, truncate=20):
166    def invert_fast(self, sigma=0.5, truncate=20):
167        """Invert the displacement field using accumulation and Gaussian basis interpolation.
168
169        Args:
170            sigma (float): standard deviation for Gaussian basis interpolation
171            truncate (float): truncate parameter for Gaussian basis interpolation
172
173        Returns:
174            WarpMap: inverted WarpMap
175        """
176        warp_field = self.warp_field.get()
177        target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get()
178        wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int")
179        num_coords = accumarray(target_coords, wf_shape)
180        inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype)
181        for i in range(3):
182            inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel())
183            with np.errstate(invalid="ignore"):
184                inv_field[i] /= num_coords
185            inv_field[i][num_coords == 0] = np.nan
186            inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate)
187        return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape)

Invert the displacement field using accumulation and Gaussian basis interpolation.

Arguments:
  • sigma (float): standard deviation for Gaussian basis interpolation
  • truncate (float): truncate parameter for Gaussian basis interpolation
Returns:

WarpMap: inverted WarpMap

def push_coordinates(self, coords, negative_shifts=False):
189    def push_coordinates(self, coords, negative_shifts=False):
190        """Push voxel coordinates from fixed to moving space.
191
192        Args:
193            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
194
195        Returns:
196            numpy.array: transformed voxel coordinates
197        """
198        assert coords.shape[0] == 3
199        coords = cp.array(coords, dtype="float32")
200        # coords_blocked = coords / self.block_size[:, None] - 0.5
201        coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None]
202        warp_field = self.warp_field.copy()
203        shifts = cp.zeros_like(coords)
204        for idim in range(3):
205            shifts[idim] = cupyx.scipy.ndimage.map_coordinates(
206                warp_field[idim], coords_blocked, order=1, mode="nearest"
207            )
208        if negative_shifts:
209            shifts = -shifts
210        return coords + shifts

Push voxel coordinates from fixed to moving space.

Arguments:
  • coords (numpy.array): 3D voxel coordinates to be warped (3-by-n array)
Returns:

numpy.array: transformed voxel coordinates

def pull_coordinates(self, coords):
212    def pull_coordinates(self, coords):
213        """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.
214
215        Args:
216            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
217
218        Returns:
219            numpy.array: transformed voxel coordinates
220        """
221        return self.invert().push_coordinates(coords, negative_shifts=True)

Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.

Arguments:
  • coords (numpy.array): 3D voxel coordinates to be warped (3-by-n array)
Returns:

numpy.array: transformed voxel coordinates

def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1):
223    def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1):
224        """
225        Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.
226
227        Args:
228            edge_order : passed to np.gradient (1 or 2)
229
230        Returns:
231            detJ: cp.ndarray of shape spatial
232        """
233        scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride
234        coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None]
235        phi = coords + self.warp_field
236        J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32")
237        for i in range(3):
238            grads = cp.gradient(phi[i], *scaling, edge_order=edge_order)
239            for j in range(3):
240                J[..., i, j] = grads[j]
241        return cp.linalg.det(J)

Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.

Arguments:
  • edge_order : passed to np.gradient (1 or 2)
Returns:

detJ: cp.ndarray of shape spatial

def as_ants_image(self, voxel_size_um=1):
243    def as_ants_image(self, voxel_size_um=1):
244        """Convert to ANTsImage
245
246        Args:
247            voxel_size_um (scalar or array): voxel size (default is 1)
248
249        Returns:
250            ants.core.ants_image.ANTsImage:
251        """
252        try:
253            import ants
254        except ImportError:
255            raise ImportError("ANTs is not installed. Please install it using 'pip install ants'")
256
257        ants_image = ants.from_numpy(
258            self.warp_field.get().transpose(1, 2, 3, 0),
259            origin=list((self.block_size.get() - 1) / 2 * voxel_size_um),
260            spacing=list(self.block_stride.get() * voxel_size_um),
261            has_components=True,
262        )
263        return ants_image

Convert to ANTsImage

Arguments:
  • voxel_size_um (scalar or array): voxel size (default is 1)
Returns:

ants.core.ants_image.ANTsImage:

class WarpMapper:
277class WarpMapper:
278    """Class that estimates warp field using cross-correlation, based on a piece-wise rigid model.
279
280    Args:
281        ref_vol (numpy.array): The reference volume
282        block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated
283        block_stride (3-element list or numpy.array): stride (usually identical to block_size)
284        proj_method (str or callable): Projection method
285    """
286
287    def __init__(
288        self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5
289    ):
290        self.proj_method = proj_method
291        self.plan_rev = [None, None, None]
292        self.subpixel = subpixel
293        self.epsilon = epsilon
294        self.tukey_alpha = tukey_alpha
295        self.update_reference(ref_vol, block_size, block_stride)
296        self.ref_shape = np.array(ref_vol.shape)
297        if np.any(block_size > np.array(ref_vol.shape)):
298            raise ValueError(f"Block size ({block_size}) must be smaller than the volume shape ({ref_vol.shape})")
299
300    def update_reference(self, ref_vol, block_size, block_stride=None):
301        ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1))
302        block_size = np.array(block_size)
303        block_stride = block_size if block_stride is None else np.array(block_stride)
304        ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride)
305        self.blocks_shape = ref_blocks.shape
306        ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]]
307        if self.tukey_alpha < 1:
308            ref_blocks_proj = [
309                ref_blocks_proj[i]
310                * cp.array(
311                    ndwindow(
312                        [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5)
313                    )
314                ).astype("float32")
315                for i in range(3)
316            ]
317        self.plan_fwd = [
318            cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3)
319        ]
320        self.ref_blocks_proj_ft_conj = [
321            cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3)
322        ]
323        self.block_size = block_size
324        self.block_stride = block_stride
325
326    def get_displacement(self, vol, smooth_func=None):
327        """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.
328
329        Args:
330            vol (numpy.array): Input volume
331            smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
332
333        Returns:
334            WarpMap
335        """
336        vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride)
337        vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]]
338        del vol_blocks
339
340        disp_field = []
341        for i in range(3):
342            R = (
343                cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i])
344                * self.ref_blocks_proj_ft_conj[i]
345            )
346            if self.plan_rev[i] is None:
347                self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R")
348            xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1))
349            if smooth_func is not None:
350                xcorr_proj = smooth_func(xcorr_proj, self.block_size)
351            xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon
352
353            max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:]))
354            max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2
355            del xcorr_proj
356            i0, j0 = max_ix.reshape(2, -1)
357            shifts = upsampled_dft_rfftn(
358                R.reshape(-1, *R.shape[-2:]),
359                upsampled_region_size=int(self.subpixel * 2 + 1),
360                upsample_factor=self.subpixel,
361                axis_offsets=(i0, j0),
362            )
363            del R
364            max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:]))
365            max_sub = (
366                max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2
367            ) / self.subpixel
368            del shifts
369            disp_field.append(max_ix + max_sub)
370
371        disp_field = cp.array(disp_field)
372        disp_field = (
373            cp.array(
374                [
375                    disp_field[1, 0] + disp_field[2, 0],
376                    disp_field[0, 0] + disp_field[2, 1],
377                    disp_field[0, 1] + disp_field[1, 1],
378                ]
379            ).astype("float32")
380            / 2
381        )
382        return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)

Class that estimates warp field using cross-correlation, based on a piece-wise rigid model.

Arguments:
  • ref_vol (numpy.array): The reference volume
  • block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated
  • block_stride (3-element list or numpy.array): stride (usually identical to block_size)
  • proj_method (str or callable): Projection method
WarpMapper( ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-06, tukey_alpha=0.5)
287    def __init__(
288        self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5
289    ):
290        self.proj_method = proj_method
291        self.plan_rev = [None, None, None]
292        self.subpixel = subpixel
293        self.epsilon = epsilon
294        self.tukey_alpha = tukey_alpha
295        self.update_reference(ref_vol, block_size, block_stride)
296        self.ref_shape = np.array(ref_vol.shape)
297        if np.any(block_size > np.array(ref_vol.shape)):
298            raise ValueError(f"Block size ({block_size}) must be smaller than the volume shape ({ref_vol.shape})")
proj_method
plan_rev
subpixel
epsilon
tukey_alpha
ref_shape
def update_reference(self, ref_vol, block_size, block_stride=None):
300    def update_reference(self, ref_vol, block_size, block_stride=None):
301        ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1))
302        block_size = np.array(block_size)
303        block_stride = block_size if block_stride is None else np.array(block_stride)
304        ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride)
305        self.blocks_shape = ref_blocks.shape
306        ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]]
307        if self.tukey_alpha < 1:
308            ref_blocks_proj = [
309                ref_blocks_proj[i]
310                * cp.array(
311                    ndwindow(
312                        [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5)
313                    )
314                ).astype("float32")
315                for i in range(3)
316            ]
317        self.plan_fwd = [
318            cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3)
319        ]
320        self.ref_blocks_proj_ft_conj = [
321            cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3)
322        ]
323        self.block_size = block_size
324        self.block_stride = block_stride
def get_displacement(self, vol, smooth_func=None):
326    def get_displacement(self, vol, smooth_func=None):
327        """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.
328
329        Args:
330            vol (numpy.array): Input volume
331            smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
332
333        Returns:
334            WarpMap
335        """
336        vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride)
337        vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]]
338        del vol_blocks
339
340        disp_field = []
341        for i in range(3):
342            R = (
343                cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i])
344                * self.ref_blocks_proj_ft_conj[i]
345            )
346            if self.plan_rev[i] is None:
347                self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R")
348            xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1))
349            if smooth_func is not None:
350                xcorr_proj = smooth_func(xcorr_proj, self.block_size)
351            xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon
352
353            max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:]))
354            max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2
355            del xcorr_proj
356            i0, j0 = max_ix.reshape(2, -1)
357            shifts = upsampled_dft_rfftn(
358                R.reshape(-1, *R.shape[-2:]),
359                upsampled_region_size=int(self.subpixel * 2 + 1),
360                upsample_factor=self.subpixel,
361                axis_offsets=(i0, j0),
362            )
363            del R
364            max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:]))
365            max_sub = (
366                max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2
367            ) / self.subpixel
368            del shifts
369            disp_field.append(max_ix + max_sub)
370
371        disp_field = cp.array(disp_field)
372        disp_field = (
373            cp.array(
374                [
375                    disp_field[1, 0] + disp_field[2, 0],
376                    disp_field[0, 0] + disp_field[2, 1],
377                    disp_field[0, 1] + disp_field[1, 1],
378                ]
379            ).astype("float32")
380            / 2
381        )
382        return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)

Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.

Arguments:
  • vol (numpy.array): Input volume
  • smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
Returns:

WarpMap

class RegistrationPyramid:
385class RegistrationPyramid:
386    """A class for performing multi-resolution registration.
387
388    Args:
389        ref_vol (numpy.array): Reference volume
390        settings (pandas.DataFrame): Settings for each level of the pyramid.
391            IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level.
392        reg_mask (numpy.array): Mask for registration
393        clip_thresh (float): Threshold for clipping the reference volume
394    """
395
396    def __init__(self, ref_vol, recipe, reg_mask=1):
397        recipe.model_validate(recipe.model_dump())
398        self.recipe = recipe
399        self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C")
400        self.mappers = []
401        ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C")
402        self.ref_shape = ref_vol.shape
403        if self.recipe.pre_filter is not None:
404            ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask)
405        self.mapper_ix = []
406        for i in range(len(recipe.levels)):
407            if recipe.levels[i].repeats < 1:
408                continue
409            block_size = np.array(recipe.levels[i].block_size)
410            tmp = np.r_[ref_vol.shape] // -block_size
411            block_size[block_size < 0] = tmp[block_size < 0]
412            if isinstance(recipe.levels[i].block_stride, (int, float)):
413                block_stride = np.round(block_size * recipe.levels[i].block_stride).astype("int")
414            else:
415                block_stride = np.array(recipe.levels[i].block_stride)
416            self.mappers.append(
417                WarpMapper(
418                    ref_vol,
419                    block_size,
420                    block_stride=block_stride,
421                    proj_method=recipe.levels[i].project,
422                    tukey_alpha=recipe.levels[i].tukey_ref,
423                )
424            )
425            self.mapper_ix.append(i)
426        assert len(self.mappers) > 0, "At least one level of registration is required"
427
428    def register_single(self, vol, callback=None, verbose=False):
429        """Register a single volume to the reference volume.
430
431        Args:
432            vol (array_like): Volume to be registered (numpy or cupy array)
433            callback (function): Callback function to be called after each level of registration
434
435        Returns:
436            - vol (array_like): Registered volume (numpy or cupy array, depending on input)
437            - warp_map (WarpMap): Displacement field
438            - callback_output (list): List of outputs from the callback function
439        """
440        was_numpy = isinstance(vol, np.ndarray)
441        vol = cp.array(vol, "float32", copy=False, order="C")
442        offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2
443        warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape)
444        warp_map = warp_map.resize_to(self.mappers[-1])
445        callback_output = []
446        vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol
447        vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C")
448        warp_map.warp(vol_tmp0, out=vol_tmp)
449        min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0)
450        if callback is not None:
451            callback_output.append(callback(vol_tmp))
452
453        if np.any(self.mappers[-1].block_stride > min_block_stride[0]):
454            warnings.warn(
455                "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)."
456            )
457        for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)):
458            for _ in tqdm(
459                range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose
460            ):
461                wm = mapper.get_displacement(
462                    vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth  # * self.reg_mask,
463                )
464                wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate
465                if self.recipe.levels[self.mapper_ix[k]].median_filter:
466                    wm = wm.median_filter()
467                if self.recipe.levels[self.mapper_ix[k]].affine:
468                    if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1:
469                        raise ValueError(
470                            f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}"
471                        )
472                    wm, _ = wm.fit_affine(
473                        target=dict(
474                            warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]),
475                            block_size=self.mappers[-1].block_size,
476                            block_stride=self.mappers[-1].block_stride,
477                        )
478                    )
479                else:
480                    wm = wm.resize_to(self.mappers[-1])
481
482                warp_map = warp_map.chain(wm)
483                warp_map.warp(vol_tmp0, out=vol_tmp)
484                if callback is not None:
485                    # callback_output.append(callback(warp_map.unwarp(vol)))
486                    callback_output.append(callback(vol_tmp))
487        warp_map.warp(vol, out=vol_tmp)
488        if was_numpy:
489            vol_tmp = vol_tmp.get()
490        return vol_tmp, warp_map, callback_output

A class for performing multi-resolution registration.

Arguments:
  • ref_vol (numpy.array): Reference volume
  • settings (pandas.DataFrame): Settings for each level of the pyramid. IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level.
  • reg_mask (numpy.array): Mask for registration
  • clip_thresh (float): Threshold for clipping the reference volume
RegistrationPyramid(ref_vol, recipe, reg_mask=1)
396    def __init__(self, ref_vol, recipe, reg_mask=1):
397        recipe.model_validate(recipe.model_dump())
398        self.recipe = recipe
399        self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C")
400        self.mappers = []
401        ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C")
402        self.ref_shape = ref_vol.shape
403        if self.recipe.pre_filter is not None:
404            ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask)
405        self.mapper_ix = []
406        for i in range(len(recipe.levels)):
407            if recipe.levels[i].repeats < 1:
408                continue
409            block_size = np.array(recipe.levels[i].block_size)
410            tmp = np.r_[ref_vol.shape] // -block_size
411            block_size[block_size < 0] = tmp[block_size < 0]
412            if isinstance(recipe.levels[i].block_stride, (int, float)):
413                block_stride = np.round(block_size * recipe.levels[i].block_stride).astype("int")
414            else:
415                block_stride = np.array(recipe.levels[i].block_stride)
416            self.mappers.append(
417                WarpMapper(
418                    ref_vol,
419                    block_size,
420                    block_stride=block_stride,
421                    proj_method=recipe.levels[i].project,
422                    tukey_alpha=recipe.levels[i].tukey_ref,
423                )
424            )
425            self.mapper_ix.append(i)
426        assert len(self.mappers) > 0, "At least one level of registration is required"
recipe
reg_mask
mappers
ref_shape
mapper_ix
def register_single(self, vol, callback=None, verbose=False):
428    def register_single(self, vol, callback=None, verbose=False):
429        """Register a single volume to the reference volume.
430
431        Args:
432            vol (array_like): Volume to be registered (numpy or cupy array)
433            callback (function): Callback function to be called after each level of registration
434
435        Returns:
436            - vol (array_like): Registered volume (numpy or cupy array, depending on input)
437            - warp_map (WarpMap): Displacement field
438            - callback_output (list): List of outputs from the callback function
439        """
440        was_numpy = isinstance(vol, np.ndarray)
441        vol = cp.array(vol, "float32", copy=False, order="C")
442        offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2
443        warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape)
444        warp_map = warp_map.resize_to(self.mappers[-1])
445        callback_output = []
446        vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol
447        vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C")
448        warp_map.warp(vol_tmp0, out=vol_tmp)
449        min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0)
450        if callback is not None:
451            callback_output.append(callback(vol_tmp))
452
453        if np.any(self.mappers[-1].block_stride > min_block_stride[0]):
454            warnings.warn(
455                "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)."
456            )
457        for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)):
458            for _ in tqdm(
459                range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose
460            ):
461                wm = mapper.get_displacement(
462                    vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth  # * self.reg_mask,
463                )
464                wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate
465                if self.recipe.levels[self.mapper_ix[k]].median_filter:
466                    wm = wm.median_filter()
467                if self.recipe.levels[self.mapper_ix[k]].affine:
468                    if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1:
469                        raise ValueError(
470                            f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}"
471                        )
472                    wm, _ = wm.fit_affine(
473                        target=dict(
474                            warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]),
475                            block_size=self.mappers[-1].block_size,
476                            block_stride=self.mappers[-1].block_stride,
477                        )
478                    )
479                else:
480                    wm = wm.resize_to(self.mappers[-1])
481
482                warp_map = warp_map.chain(wm)
483                warp_map.warp(vol_tmp0, out=vol_tmp)
484                if callback is not None:
485                    # callback_output.append(callback(warp_map.unwarp(vol)))
486                    callback_output.append(callback(vol_tmp))
487        warp_map.warp(vol, out=vol_tmp)
488        if was_numpy:
489            vol_tmp = vol_tmp.get()
490        return vol_tmp, warp_map, callback_output

Register a single volume to the reference volume.

Arguments:
  • vol (array_like): Volume to be registered (numpy or cupy array)
  • callback (function): Callback function to be called after each level of registration
Returns:
  • vol (array_like): Registered volume (numpy or cupy array, depending on input)
  • warp_map (WarpMap): Displacement field
  • callback_output (list): List of outputs from the callback function
def register_volumes( ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None):
493def register_volumes(ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None):
494    """Register a volume to a reference volume using a registration pyramid.
495
496    Args:
497        ref (numpy.array or cupy.array): Reference volume
498        vol (numpy.array or cupy.array): Volume to be registered
499        recipe (Recipe): Registration recipe
500        reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask)
501        callback (function): Callback function to be called on the volume after each iteration. Default is None.
502            Can be used to monitor and optimize registration. Example: `callback = lambda vol: vol.mean(1).get()`
503            (note that `vol` is a 3D cupy array. Use `.get()` to turn the output into a numpy array and save GPU memory).
504            Callback outputs for each registration step will be returned as a list.
505        verbose (bool): If True, show progress bars. Default is True
506        video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None.
507        vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values.
508
509    Returns:
510        - numpy.array or cupy.array (depending on vol input): Registered volume
511        - WarpMap: Displacement field
512        - list: List of outputs from the callback function
513    """
514    recipe.model_validate(recipe.model_dump())
515    reg = RegistrationPyramid(ref, recipe, reg_mask=reg_mask)
516    registered_vol, warp_map, cbout = reg.register_single(vol, callback=callback, verbose=verbose)
517    del reg
518    gc.collect()
519    cp.fft.config.get_plan_cache().clear()
520
521    if video_path is not None:
522        try:
523            assert cbout[0].ndim == 2, "Callback output must be a 2D array"
524            ref = callback(recipe.pre_filter(ref))
525            vmax = np.percentile(ref, 99.9).item() if vmax is None else vmax
526            create_rgb_video(video_path, ref / vmax, np.array(cbout) / vmax, fps=10)
527        except (ValueError, AssertionError) as e:
528            warnings.warn(f"Video generation failed with error: {e}")
529    return registered_vol, warp_map, cbout

Register a volume to a reference volume using a registration pyramid.

Arguments:
  • ref (numpy.array or cupy.array): Reference volume
  • vol (numpy.array or cupy.array): Volume to be registered
  • recipe (Recipe): Registration recipe
  • reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask)
  • callback (function): Callback function to be called on the volume after each iteration. Default is None. Can be used to monitor and optimize registration. Example: callback = lambda vol: vol.mean(1).get() (note that vol is a 3D cupy array. Use .get() to turn the output into a numpy array and save GPU memory). Callback outputs for each registration step will be returned as a list.
  • verbose (bool): If True, show progress bars. Default is True
  • video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None.
  • vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values.
Returns:
  • numpy.array or cupy.array (depending on vol input): Registered volume
  • WarpMap: Displacement field
  • list: List of outputs from the callback function
class Projector(pydantic.main.BaseModel):
532class Projector(BaseModel):
533    """A class to apply a 2D projection and filters to a volume block
534
535    Parameters:
536        max: if True, apply a max filter to the volume block. Default is True
537        normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False
538        dog: if True, apply a DoG filter to the volume block. Default is True
539        low: the lower sigma value for the DoG filter. Default is 0.5
540        high: the higher sigma value for the DoG filter. Default is 10.0
541        tukey_env: if True, apply a Tukey window to the output. Default is False
542        gauss_env: if True, apply a Gaussian window to the output. Default is False
543    """
544
545    max: bool = True
546    normalize: Union[bool, float] = False
547    dog: bool = True
548    low: Union[Union[int, float], List[Union[int, float]]] = 0.5
549    high: Union[Union[int, float], List[Union[int, float]]] = 10.0
550    periodic_smooth: bool = False
551
552    def __call__(self, vol_blocks, axis):
553        """Apply a 2D projection and filters to a volume block
554        Args:
555            vol_blocks (cupy.array): Blocked volume to be projected (6D dataset, with the first 3 dimensions being blocks and the last 3 dimensions being voxels)
556            axis (int): Axis along which to project
557        Returns:
558            cupy.array: Projected volume block (5D dataset, with the first 3 dimensions being blocks and the last 2 dimensions being 2D projections)
559        """
560        if self.max:
561            out = vol_blocks.max(axis)
562        else:
563            out = vol_blocks.mean(axis)
564        if self.periodic_smooth:
565            out = periodic_smooth_decomposition_nd_rfft(out)
566        low = np.delete(np.r_[1,1,1] * self.low, axis)
567        high = np.delete(np.r_[1,1,1] * self.high, axis)
568        if self.dog:
569            out = dogfilter(out, [0, 0, 0, *low], [0, 0, 0, *high], mode="reflect")
570        elif not np.all(np.array(self.low) == 0):
571            out = cupyx.scipy.ndimage.gaussian_filter(out, [0, 0, 0, *low], mode="reflect", truncate=5.0)
572        if self.normalize > 0:
573            out /= cp.sqrt(cp.sum(out**2, axis=(-2, -1), keepdims=True)) ** self.normalize + 1e-9
574        return out

A class to apply a 2D projection and filters to a volume block

Arguments:
  • max: if True, apply a max filter to the volume block. Default is True
  • normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False
  • dog: if True, apply a DoG filter to the volume block. Default is True
  • low: the lower sigma value for the DoG filter. Default is 0.5
  • high: the higher sigma value for the DoG filter. Default is 10.0
  • tukey_env: if True, apply a Tukey window to the output. Default is False
  • gauss_env: if True, apply a Gaussian window to the output. Default is False
max: bool
normalize: Union[bool, float]
dog: bool
low: Union[int, float, List[Union[int, float]]]
high: Union[int, float, List[Union[int, float]]]
periodic_smooth: bool
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Smoother(pydantic.main.BaseModel):
577class Smoother(BaseModel):
578    """Smooth blocks with a Gaussian kernel
579    Args:
580        sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied.
581        truncate (float): truncate parameter for gaussian kernel. Default is 5.
582        shear (float): shear parameter for gaussian kernel. Default is None.
583        long_range_ratio (float): long range ratio for double gaussian kernel. Default is None.
584    """
585
586    sigmas: Union[float, List[float]] = [1.0, 1.0, 1.0]
587    shear: Union[float, None] = None
588    long_range_ratio: Union[float, None] = 0.05
589
590    def __call__(self, xcorr_proj, block_size=None):
591        """Apply a Gaussian filter to the cross-correlation data
592        Args:
593            xcorr_proj (cupy.array): cross-correlation data (5D array, with the first 3 dimensions being the blocks and the last 2 dimensions being the 2D projection)
594            block_size (list): shape of blocks, whose rigid displacement is estimated
595        Returns:
596            cupy.array: smoothed cross-correlation volume
597        """
598        truncate = 4.0
599        if self.sigmas is None:
600            return xcorr_proj
601        if self.shear is not None:
602            shear_blocks = self.shear * (block_size[1] / block_size[0])
603            gw = gausskernel_sheared(self.sigma[:2], shear_blocks, truncate=truncate)
604            gw = cp.array(gw[:, :, None, None, None])
605            xcorr_proj = cupyx.scipy.ndimage.convolve(xcorr_proj, gw, mode="constant")
606            xcorr_proj = cupyx.scipy.ndimage.gaussian_filter1d(
607                xcorr_proj, self.sigmas[2], axis=2, mode="constant", truncate=truncate
608            )
609        else:  # shear is None:
610            xcorr_proj = cupyx.scipy.ndimage.gaussian_filter(
611                xcorr_proj, [*self.sigmas, 0, 0], mode="constant", truncate=truncate
612            )
613        if self.long_range_ratio is not None:
614            xcorr_proj *= 1 - self.long_range_ratio
615            xcorr_proj += (
616                cupyx.scipy.ndimage.gaussian_filter(
617                    xcorr_proj, [*np.array(self.sigmas) * 5, 0, 0], mode="constant", truncate=truncate
618                )
619                * self.long_range_ratio
620            )
621        return xcorr_proj

Smooth blocks with a Gaussian kernel

Arguments:
  • sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied.
  • truncate (float): truncate parameter for gaussian kernel. Default is 5.
  • shear (float): shear parameter for gaussian kernel. Default is None.
  • long_range_ratio (float): long range ratio for double gaussian kernel. Default is None.
sigmas: Union[float, List[float]]
shear: Optional[float]
long_range_ratio: Optional[float]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class RegFilter(pydantic.main.BaseModel):
624class RegFilter(BaseModel):
625    """A class to apply a filter to the volume before registration
626
627    Parameters:
628        clip_thresh: threshold for clipping the reference volume. Default is 0
629        dog: if True, apply a DoG filter to the volume. Default is True
630        low: the lower sigma value for the DoG filter. Default is 0.5
631        high: the higher sigma value for the DoG filter. Default is 10.0
632    """
633
634    clip_thresh: float = 0
635    dog: bool = True
636    low: float = 0.5
637    high: float = 10.0
638    soft_edge: Union[Union[int, float], List[Union[int, float]]] = 0.0
639
640    def __call__(self, vol, reg_mask=None):
641        """Apply the filter to the volume
642        Args:
643            vol (cupy or numpy array): 3D volume to be filtered
644            reg_mask (array): Mask for registration
645        Returns:
646            cupy.ndarray: Filtered volume
647        """
648        vol = cp.clip(cp.array(vol, "float32", copy=False) - self.clip_thresh, 0, None)
649        if np.any(np.array(self.soft_edge) > 0):
650            vol = soften_edges(vol, soft_edge=self.soft_edge, copy=False)
651        if reg_mask is not None:
652            vol *= cp.array(reg_mask, dtype="float32", copy=False)
653        if self.dog:
654            vol = dogfilter(vol, self.low, self.high, mode="reflect")
655        return vol

A class to apply a filter to the volume before registration

Arguments:
  • clip_thresh: threshold for clipping the reference volume. Default is 0
  • dog: if True, apply a DoG filter to the volume. Default is True
  • low: the lower sigma value for the DoG filter. Default is 0.5
  • high: the higher sigma value for the DoG filter. Default is 10.0
clip_thresh: float
dog: bool
low: float
high: float
soft_edge: Union[int, float, List[Union[int, float]]]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class LevelConfig(pydantic.main.BaseModel):
658class LevelConfig(BaseModel):
659    """Configuration for each level of the registration pyramid
660
661    Args:
662        block_size (list): shape of blocks, whose rigid displacement is estimated
663        block_stride (list): stride (usually identical to block_size)
664        repeats (int): number of iterations for this level (deisable level by setting repeats to 0)
665        smooth (Smoother or None): Smoother object
666        project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block.
667        tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5
668        affine (bool): if True, apply affine transformation to the displacement field
669        median_filter (bool): if True, apply median filter to the displacement field
670        update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations.
671    """
672
673    block_size: Union[List[int]]
674    block_stride: Union[List[int], float] = 1.0
675    project: Union[Projector, Callable[[_ArrayType, int], _ArrayType]] = Projector()
676    tukey_ref: Union[float, None] = 0.5
677    smooth: Union[Smoother, None] = Smoother()
678    affine: bool = False
679    median_filter: bool = True
680    update_rate: float = 1.0
681    repeats: int = 5

Configuration for each level of the registration pyramid

Arguments:
  • block_size (list): shape of blocks, whose rigid displacement is estimated
  • block_stride (list): stride (usually identical to block_size)
  • repeats (int): number of iterations for this level (deisable level by setting repeats to 0)
  • smooth (Smoother or None): Smoother object
  • project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block.
  • tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5
  • affine (bool): if True, apply affine transformation to the displacement field
  • median_filter (bool): if True, apply median filter to the displacement field
  • update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations.
block_size: List[int]
block_stride: Union[List[int], float]
project: Union[Projector, Callable[[Union[numpy.ndarray, cupy.ndarray], int], Union[numpy.ndarray, cupy.ndarray]]]
tukey_ref: Optional[float]
smooth: Optional[Smoother]
affine: bool
median_filter: bool
update_rate: float
repeats: int
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class Recipe(pydantic.main.BaseModel):
684class Recipe(BaseModel):
685    """Configuration for the registration recipe. Recipe is initialized with a single affine level.
686
687    Args:
688        reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume
689        levels (list): List of LevelConfig objects
690    """
691
692    pre_filter: Union[RegFilter, Callable[[_ArrayType], _ArrayType], None] = RegFilter()
693    levels: List[LevelConfig] = [
694        LevelConfig(block_size=[-1, -1, -1], repeats=1),  # translation level
695        LevelConfig(  # affine level
696            block_size=[-4, -4, -4],
697            repeats=10,
698            affine=True,
699            median_filter=False,
700            smooth=Smoother(sigmas=[0.5, 0.5, 0.5]),
701        ),
702    ]
703
704    def add_level(self, block_size, **kwargs):
705        """Add a level to the registration recipe
706
707        Args:
708            block_size (list): shape of blocks, whose rigid displacement is estimated
709            **kwargs: additional arguments for LevelConfig
710        """
711        if isinstance(block_size, (int, float)):
712            block_size = [block_size] * 3
713        if len(block_size) != 3:
714            raise ValueError("block_size must be a list of 3 integers")
715        self.levels.append(LevelConfig(block_size=block_size, **kwargs))
716
717    def insert_level(self, index, block_size, **kwargs):
718        """Insert a level to the registration recipe
719
720        Args:
721            index (int): A number specifying in which position to insert the level
722            block_size (list): shape of blocks, whose rigid displacement is estimated
723            **kwargs: additional arguments for LevelConfig
724        """
725        if isinstance(block_size, (int, float)):
726            block_size = [block_size] * 3
727        if len(block_size) != 3:
728            raise ValueError("block_size must be a list of 3 integers")
729        self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs))
730
731    @classmethod
732    def from_yaml(cls, yaml_path):
733        """Load a recipe from a YAML file
734
735        Args:
736            yaml_path (str): path to the YAML file
737
738        Returns:
739            Recipe: Recipe object
740        """
741        import yaml
742
743        this_file_dir = pathlib.Path(__file__).resolve().parent
744        if os.path.isfile(yaml_path):
745            yaml_path = yaml_path
746        else:
747            yaml_path = os.path.join(this_file_dir, "recipes", yaml_path)
748
749        with open(yaml_path, "r") as f:
750            data = yaml.safe_load(f)
751
752        return cls.model_validate(data)
753
754    def to_yaml(self, yaml_path):
755        """Save the recipe to a YAML file
756
757        Args:
758            yaml_path (str): path to the YAML file
759        """
760        import yaml
761
762        with open(yaml_path, "w") as f:
763            yaml.dump(self.model_dump(), f)
764        print(f"Recipe saved to {yaml_path}")

Configuration for the registration recipe. Recipe is initialized with a single affine level.

Arguments:
  • reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume
  • levels (list): List of LevelConfig objects
pre_filter: Union[RegFilter, Callable[[Union[numpy.ndarray, cupy.ndarray]], Union[numpy.ndarray, cupy.ndarray]], NoneType]
levels: List[LevelConfig]
def add_level(self, block_size, **kwargs):
704    def add_level(self, block_size, **kwargs):
705        """Add a level to the registration recipe
706
707        Args:
708            block_size (list): shape of blocks, whose rigid displacement is estimated
709            **kwargs: additional arguments for LevelConfig
710        """
711        if isinstance(block_size, (int, float)):
712            block_size = [block_size] * 3
713        if len(block_size) != 3:
714            raise ValueError("block_size must be a list of 3 integers")
715        self.levels.append(LevelConfig(block_size=block_size, **kwargs))

Add a level to the registration recipe

Arguments:
  • block_size (list): shape of blocks, whose rigid displacement is estimated
  • **kwargs: additional arguments for LevelConfig
def insert_level(self, index, block_size, **kwargs):
717    def insert_level(self, index, block_size, **kwargs):
718        """Insert a level to the registration recipe
719
720        Args:
721            index (int): A number specifying in which position to insert the level
722            block_size (list): shape of blocks, whose rigid displacement is estimated
723            **kwargs: additional arguments for LevelConfig
724        """
725        if isinstance(block_size, (int, float)):
726            block_size = [block_size] * 3
727        if len(block_size) != 3:
728            raise ValueError("block_size must be a list of 3 integers")
729        self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs))

Insert a level to the registration recipe

Arguments:
  • index (int): A number specifying in which position to insert the level
  • block_size (list): shape of blocks, whose rigid displacement is estimated
  • **kwargs: additional arguments for LevelConfig
@classmethod
def from_yaml(cls, yaml_path):
731    @classmethod
732    def from_yaml(cls, yaml_path):
733        """Load a recipe from a YAML file
734
735        Args:
736            yaml_path (str): path to the YAML file
737
738        Returns:
739            Recipe: Recipe object
740        """
741        import yaml
742
743        this_file_dir = pathlib.Path(__file__).resolve().parent
744        if os.path.isfile(yaml_path):
745            yaml_path = yaml_path
746        else:
747            yaml_path = os.path.join(this_file_dir, "recipes", yaml_path)
748
749        with open(yaml_path, "r") as f:
750            data = yaml.safe_load(f)
751
752        return cls.model_validate(data)

Load a recipe from a YAML file

Arguments:
  • yaml_path (str): path to the YAML file
Returns:

Recipe: Recipe object

def to_yaml(self, yaml_path):
754    def to_yaml(self, yaml_path):
755        """Save the recipe to a YAML file
756
757        Args:
758            yaml_path (str): path to the YAML file
759        """
760        import yaml
761
762        with open(yaml_path, "w") as f:
763            yaml.dump(self.model_dump(), f)
764        print(f"Recipe saved to {yaml_path}")

Save the recipe to a YAML file

Arguments:
  • yaml_path (str): path to the YAML file
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].