warpfield.register

  1import warnings
  2import gc
  3import pathlib
  4import os
  5from typing import List, Union, Callable
  6
  7import numpy as np
  8import scipy.signal
  9import cupy as cp
 10import cupyx
 11import cupyx.scipy.ndimage
 12from pydantic import BaseModel, ValidationError
 13from tqdm.auto import tqdm
 14import h5py
 15
 16from .warp import warp_volume
 17from .utils import create_rgb_video, mips_callback
 18from .ndimage import (
 19    accumarray,
 20    dogfilter,
 21    gausskernel_sheared,
 22    infill_nans,
 23    ndwindow,
 24    periodic_smooth_decomposition_nd_rfft,
 25    sliding_block,
 26    upsampled_dft_rfftn,
 27    soften_edges,
 28)
 29
 30_ArrayType = Union[np.ndarray, cp.ndarray]
 31
 32class WarpMap:
 33    """Represents a 3D displacement field
 34
 35    Args:
 36        warp_field (numpy.array): the displacement field data (3-x-y-z)
 37        block_size (3-element list or numpy.array):
 38        block_stride (3-element list or numpy.array):
 39        ref_shape (tuple): shape of the reference volume
 40        mov_shape (tuple): shape of the moving volume
 41    """
 42
 43    def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape):
 44        self.warp_field = cp.array(warp_field, dtype="float32")
 45        self.block_size = cp.array(block_size, dtype="float32")
 46        self.block_stride = cp.array(block_stride, dtype="float32")
 47        self.ref_shape = ref_shape
 48        self.mov_shape = mov_shape
 49
 50    def warp(self, vol, out=None):
 51        """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.
 52
 53        Args:
 54            vol (cupy.array): the volume to be warped
 55
 56        Returns:
 57            cupy.array: warped volume
 58        """
 59        if np.any(vol.shape != np.array(self.mov_shape)):
 60            warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.")
 61        if out is None:
 62            out = cp.zeros(self.ref_shape, dtype="float32", order="C")
 63        vol_out = warp_volume(
 64            vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out
 65        )
 66        return vol_out
 67
 68    def apply(self, *args, **kwargs):
 69        """Alias of warp method"""
 70        return self.warp(*args, **kwargs)
 71
 72    def fit_affine(self, target=None):
 73        """Fit affine transformation and return new fitted WarpMap
 74
 75        Args:
 76            target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
 77
 78        Returns:
 79            WarpMap:
 80            numpy.array: affine tranformation coefficients
 81        """
 82        if target is None:
 83            warp_field_shape = self.warp_field.shape
 84            block_size = self.block_size
 85            block_stride = self.block_stride
 86        else:
 87            warp_field_shape = target["warp_field_shape"]
 88            block_size = cp.array(target["block_size"]).astype("float32")
 89            block_stride = cp.array(target["block_stride"]).astype("float32")
 90
 91        ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T
 92        ix = ix * self.block_stride + self.block_size / 2
 93        M = cp.zeros(self.warp_field.shape[1:])
 94        #M[1:-1, 1:-1, 1:-1] = 1
 95        M[:,:,:] = 1
 96        ixg = cp.where(M.flatten() > 0)[0]
 97        a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))])
 98        b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg]
 99        coeff = cp.linalg.lstsq(a, b, rcond=None)[0]
100        ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2
101        linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape)
102        return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff
103
104    def median_filter(self):
105        """Apply median filter to the displacement field
106
107        Returns:
108            WarpMap: new WarpMap with median filtered displacement field
109        """
110        warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest")
111        return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape)
112
113    def resize_to(self, target):
114        """Resize to target WarpMap, using linear interpolation
115
116        Args:
117            target (WarpMap or WarpMapper): target to resize to
118                or a dict with keys "shape", "block_size", and "block_stride"
119
120        Returns:
121            WarpMap: resized WarpMap
122        """
123        if isinstance(target, WarpMap):
124            t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride
125        elif isinstance(target, WarpMapper):
126            t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride)
127        elif isinstance(target, dict):
128            t_sh, t_bsz, t_bst = (
129                target["warp_field_shape"][1:],
130                cp.array(target["block_size"]),
131                cp.array(target["block_stride"]),
132            )
133        else:
134            raise ValueError("target must be a WarpMap, WarpMapper, or dict")
135        ix = cp.array(cp.indices(t_sh).reshape(3, -1))
136        # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5
137        ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None]
138        dm_r = cp.array(
139            [
140                cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape(
141                    t_sh
142                )
143                for i in range(3)
144            ]
145        )
146        return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape)
147
148    def chain(self, target):
149        """Chain displacement maps
150
151        Args:
152            target (WarpMap): WarpMap to be added to existing map
153
154        Returns:
155            WarpMap: new WarpMap with chained displacement field
156        """
157        indices = cp.indices(target.warp_field.shape[1:])
158        warp_field = self.warp_field.copy()
159        warp_field += target.warp_field
160        return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape)
161
162    def invert(self, **kwargs):
163        """alias for invert_fast method"""
164        return self.invert_fast(**kwargs)
165
166    def invert_fast(self, sigma=0.5, truncate=20):
167        """Invert the displacement field using accumulation and Gaussian basis interpolation.
168
169        Args:
170            sigma (float): standard deviation for Gaussian basis interpolation
171            truncate (float): truncate parameter for Gaussian basis interpolation
172
173        Returns:
174            WarpMap: inverted WarpMap
175        """
176        warp_field = self.warp_field.get()
177        target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get()
178        wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int")
179        num_coords = accumarray(target_coords, wf_shape)
180        inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype)
181        for i in range(3):
182            inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel())
183            with np.errstate(invalid="ignore"):
184                inv_field[i] /= num_coords
185            inv_field[i][num_coords == 0] = np.nan
186            inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate)
187        return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape)
188
189    def push_coordinates(self, coords, negative_shifts=False):
190        """Push voxel coordinates from fixed to moving space.
191
192        Args:
193            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
194
195        Returns:
196            numpy.array: transformed voxel coordinates
197        """
198        assert coords.shape[0] == 3
199        coords = cp.array(coords, dtype="float32")
200        # coords_blocked = coords / self.block_size[:, None] - 0.5
201        coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None]
202        warp_field = self.warp_field.copy()
203        shifts = cp.zeros_like(coords)
204        for idim in range(3):
205            shifts[idim] = cupyx.scipy.ndimage.map_coordinates(
206                warp_field[idim], coords_blocked, order=1, mode="nearest"
207            )
208        if negative_shifts:
209            shifts = -shifts
210        return coords + shifts
211
212    def pull_coordinates(self, coords):
213        """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.
214
215        Args:
216            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
217
218        Returns:
219            numpy.array: transformed voxel coordinates
220        """
221        return self.invert().push_coordinates(coords, negative_shifts=True)
222
223    def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1):
224        """
225        Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.
226
227        Args:
228            edge_order : passed to np.gradient (1 or 2)
229
230        Returns:
231            detJ: cp.ndarray of shape spatial
232        """
233        scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride
234        coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None]
235        phi = coords + self.warp_field
236        J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32")
237        for i in range(3):
238            grads = cp.gradient(phi[i], *scaling, edge_order=edge_order)
239            for j in range(3):
240                J[..., i, j] = grads[j]
241        return cp.linalg.det(J)
242
243    def as_ants_image(self, voxel_size_um=1):
244        """Convert to ANTsImage
245
246        Args:
247            voxel_size_um (scalar or array): voxel size (default is 1)
248
249        Returns:
250            ants.core.ants_image.ANTsImage:
251        """
252        try:
253            import ants
254        except ImportError:
255            raise ImportError("ANTs is not installed. Please install it using 'pip install ants'")
256
257        ants_image = ants.from_numpy(
258            self.warp_field.get().transpose(1, 2, 3, 0),
259            origin=list((self.block_size.get() - 1) / 2 * voxel_size_um),
260            spacing=list(self.block_stride.get() * voxel_size_um),
261            has_components=True,
262        )
263        return ants_image
264
265    def __repr__(self):
266        """String representation of the WarpMap object."""
267        info = (
268            f"WarpMap("
269            f"warp_field_shape={self.warp_field.shape}, "
270            f"block_size={self.block_size.get()}, "
271            f"block_stride={self.block_stride.get()}, "
272            f"transformation: {str(self.mov_shape)} --> {str(self.ref_shape)}"
273        )
274        return info
275
276    def to_h5(self, h5_path, group="warp_map", compression="gzip", overwrite=True):
277        """
278        Save this WarpMap to an HDF5 file.
279
280        Args:
281            h5_path (str or os.PathLike): Path to the HDF5 file.
282            group (str): Group path inside the HDF5 file to store the WarpMap (created if missing).
283            compression (str or None): Dataset compression (e.g., 'gzip', None).
284            overwrite (bool): If True, overwrite existing datasets/attrs inside the group.
285        """
286        with h5py.File(h5_path, "a") as f:
287            if overwrite and (group not in (None, "", "/")) and (group in f):
288                del f[group]
289            if (not overwrite) and (group in f):
290                raise ValueError(f"Group '{group}' already exists in {h5_path}. Set 'overwrite=True' to overwrite it.")
291            grp = f.require_group(group) if group not in (None, "", "/") else f
292            grp.create_dataset("warp_field", data=self.warp_field.get(), compression=compression)
293            grp.create_dataset("block_size", data=self.block_size.get())
294            grp.create_dataset("block_stride", data=self.block_stride.get())
295            grp.create_dataset("ref_shape", data=np.array(self.ref_shape, dtype="int64"))
296            grp.create_dataset("mov_shape", data=np.array(self.mov_shape, dtype="int64"))
297            grp.attrs["class"] = "WarpMap"
298
299    @classmethod
300    def from_h5(cls, h5_path, group="warp_map"):
301        """
302        Load a WarpMap from an HDF5 file.
303
304        Args:
305            h5_path (str or os.PathLike): Path to the HDF5 file.
306            group (str): Group path inside the HDF5 file where the WarpMap is stored.
307
308        Returns:
309            WarpMap: The loaded WarpMap object.
310        """
311        with h5py.File(h5_path, "r") as f:
312            grp = f[group]
313            warp_field = grp["warp_field"][:]
314            block_size = grp["block_size"][:]
315            block_stride = grp["block_stride"][:]
316            ref_shape = tuple(grp["ref_shape"][:].tolist())
317            mov_shape = tuple(grp["mov_shape"][:].tolist())
318        return cls(warp_field, block_size, block_stride, ref_shape, mov_shape)
319
320
321class WarpMapper:
322    """Class that estimates warp field using cross-correlation, based on a piece-wise rigid model.
323
324    Args:
325        ref_vol (numpy.array): The reference volume
326        block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated
327        block_stride (3-element list or numpy.array): stride (usually identical to block_size)
328        proj_method (str or callable): Projection method
329    """
330
331    def __init__(
332        self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5
333    ):
334        if np.any(block_size > np.array(ref_vol.shape)):
335            raise ValueError(f"Block size (currently: {block_size}) must be smaller than the volume shape ({np.array(ref_vol.shape)}).")
336        self.proj_method = proj_method
337        self.plan_rev = [None, None, None]
338        self.subpixel = subpixel
339        self.epsilon = epsilon
340        self.tukey_alpha = tukey_alpha
341        self.update_reference(ref_vol, block_size, block_stride)
342        self.ref_shape = np.array(ref_vol.shape)
343
344    def update_reference(self, ref_vol, block_size, block_stride=None):
345        ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1))
346        block_size = np.array(block_size)
347        block_stride = block_size if block_stride is None else np.array(block_stride)
348        ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride)
349        self.blocks_shape = ref_blocks.shape
350        ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]]
351        if self.tukey_alpha < 1:
352            ref_blocks_proj = [
353                ref_blocks_proj[i]
354                * cp.array(
355                    ndwindow(
356                        [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5)
357                    )
358                ).astype("float32")
359                for i in range(3)
360            ]
361        self.plan_fwd = [
362            cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3)
363        ]
364        self.ref_blocks_proj_ft_conj = [
365            cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3)
366        ]
367        self.block_size = block_size
368        self.block_stride = block_stride
369
370    def get_displacement(self, vol, smooth_func=None):
371        """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.
372
373        Args:
374            vol (numpy.array): Input volume
375            smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
376
377        Returns:
378            WarpMap
379        """
380        vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride)
381        vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]]
382        del vol_blocks
383
384        disp_field = []
385        for i in range(3):
386            R = (
387                cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i])
388                * self.ref_blocks_proj_ft_conj[i]
389            )
390            if self.plan_rev[i] is None:
391                self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R")
392            xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1))
393            if smooth_func is not None:
394                xcorr_proj = smooth_func(xcorr_proj, self.block_size)
395            xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon
396
397            max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:]))
398            max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2
399            del xcorr_proj
400            i0, j0 = max_ix.reshape(2, -1)
401            shifts = upsampled_dft_rfftn(
402                R.reshape(-1, *R.shape[-2:]),
403                upsampled_region_size=int(self.subpixel * 2 + 1),
404                upsample_factor=self.subpixel,
405                axis_offsets=(i0, j0),
406            )
407            del R
408            max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:]))
409            max_sub = (
410                max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2
411            ) / self.subpixel
412            del shifts
413            disp_field.append(max_ix + max_sub)
414
415        disp_field = cp.array(disp_field)
416        disp_field = (
417            cp.array(
418                [
419                    disp_field[1, 0] + disp_field[2, 0],
420                    disp_field[0, 0] + disp_field[2, 1],
421                    disp_field[0, 1] + disp_field[1, 1],
422                ]
423            ).astype("float32")
424            / 2
425        )
426        return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)
427
428
429class RegistrationPyramid:
430    """A class for performing multi-resolution registration.
431
432    Args:
433        ref_vol (numpy.array): Reference volume
434        settings (pandas.DataFrame): Settings for each level of the pyramid.
435            IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level.
436        reg_mask (numpy.array): Mask for registration
437        clip_thresh (float): Threshold for clipping the reference volume
438    """
439
440    def __init__(self, ref_vol, recipe, reg_mask=1):
441        recipe.model_validate(recipe.model_dump())
442        self.recipe = recipe
443        self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C")
444        self.mappers = []
445        ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C")
446        self.ref_shape = ref_vol.shape
447        if self.recipe.pre_filter is not None:
448            ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask)
449        self.mapper_ix = []
450        for i in range(len(recipe.levels)):
451            if recipe.levels[i].repeats < 1:
452                continue
453            block_size = np.array(recipe.levels[i].block_size)
454            tmp = np.r_[ref_vol.shape] // -block_size
455            block_size[block_size < 0] = tmp[block_size < 0]
456            if isinstance(recipe.levels[i].block_stride, (int, float)):
457                block_stride = (block_size * recipe.levels[i].block_stride).astype("int")
458            else:
459                block_stride = np.array(recipe.levels[i].block_stride)
460            self.mappers.append(
461                WarpMapper(
462                    ref_vol,
463                    block_size,
464                    block_stride=block_stride,
465                    proj_method=recipe.levels[i].project,
466                    tukey_alpha=recipe.levels[i].tukey_ref,
467                )
468            )
469            self.mapper_ix.append(i)
470        assert len(self.mappers) > 0, "At least one level of registration is required"
471
472    def register_single(self, vol, callback=None, verbose=False):
473        """Register a single volume to the reference volume.
474
475        Args:
476            vol (array_like): Volume to be registered (numpy or cupy array)
477            callback (function): Callback function to be called after each level of registration
478
479        Returns:
480            - vol (array_like): Registered volume (numpy or cupy array, depending on input)
481            - warp_map (WarpMap): Displacement field
482            - callback_output (list): List of outputs from the callback function
483        """
484        was_numpy = isinstance(vol, np.ndarray)
485        vol = cp.array(vol, "float32", copy=False, order="C")
486        offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2
487        warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape)
488        warp_map = warp_map.resize_to(self.mappers[-1])
489        callback_output = []
490        vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol
491        vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C")
492        warp_map.warp(vol_tmp0, out=vol_tmp)
493        min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0)
494        if callback is not None:
495            callback_output.append(callback(vol_tmp))
496
497        if np.any(self.mappers[-1].block_stride > min_block_stride[0]):
498            warnings.warn(
499                "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)."
500            )
501        for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)):
502            for _ in tqdm(
503                range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose
504            ):
505                wm = mapper.get_displacement(
506                    vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth  # * self.reg_mask,
507                )
508                wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate
509                if self.recipe.levels[self.mapper_ix[k]].median_filter:
510                    wm = wm.median_filter()
511                if self.recipe.levels[self.mapper_ix[k]].affine:
512                    if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1:
513                        raise ValueError(
514                            f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}"
515                        )
516                    wm, _ = wm.fit_affine(
517                        target=dict(
518                            warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]),
519                            block_size=self.mappers[-1].block_size,
520                            block_stride=self.mappers[-1].block_stride,
521                        )
522                    )
523                else:
524                    wm = wm.resize_to(self.mappers[-1])
525
526                warp_map = warp_map.chain(wm)
527                warp_map.warp(vol_tmp0, out=vol_tmp)
528                if callback is not None:
529                    # callback_output.append(callback(warp_map.unwarp(vol)))
530                    callback_output.append(callback(vol_tmp))
531        warp_map.warp(vol, out=vol_tmp)
532        if was_numpy:
533            vol_tmp = vol_tmp.get()
534        return vol_tmp, warp_map, callback_output
535
536
537def register_volumes(ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None):
538    """Register a volume to a reference volume using a registration pyramid.
539
540    Args:
541        ref (numpy.array or cupy.array): Reference volume
542        vol (numpy.array or cupy.array): Volume to be registered
543        recipe (Recipe): Registration recipe
544        reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask)
545        callback (function): Callback function to be called on the volume after each iteration. Default is None.
546            Can be used to monitor and optimize registration. Example: `callback = lambda vol: vol.mean(1).get()`
547            (note that `vol` is a 3D cupy array. Use `.get()` to turn the output into a numpy array and save GPU memory).
548            Callback outputs for each registration step will be returned as a list.
549        verbose (bool): If True, show progress bars. Default is True
550        video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None.
551        vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values.
552
553    Returns:
554        - numpy.array or cupy.array (depending on vol input): Registered volume
555        - WarpMap: Displacement field
556        - list: List of outputs from the callback function
557    """
558    recipe.model_validate(recipe.model_dump())
559    reg = RegistrationPyramid(ref, recipe, reg_mask=reg_mask)
560    registered_vol, warp_map, cbout = reg.register_single(vol, callback=callback, verbose=verbose)
561    del reg
562    gc.collect()
563    cp.fft.config.get_plan_cache().clear()
564
565    if video_path is not None:
566        try:
567            assert cbout[0].ndim == 2, "Callback output must be a 2D array"
568            ref = callback(recipe.pre_filter(ref))
569            vmax = np.percentile(ref, 99.9).item() if vmax is None else vmax
570            create_rgb_video(video_path, ref / vmax, np.array(cbout) / vmax, fps=10)
571        except (ValueError, AssertionError) as e:
572            warnings.warn(f"Video generation failed with error: {e}")
573    return registered_vol, warp_map, cbout
574
575
576class Projector(BaseModel):
577    """A class to apply a 2D projection and filters to a volume block
578
579    Parameters:
580        max: if True, apply a max filter to the volume block. Default is True
581        normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False
582        dog: if True, apply a DoG filter to the volume block. Default is True
583        low: the lower sigma value for the DoG filter. Default is 0.5
584        high: the higher sigma value for the DoG filter. Default is 10.0
585        tukey_env: if True, apply a Tukey window to the output. Default is False
586        gauss_env: if True, apply a Gaussian window to the output. Default is False
587    """
588
589    max: bool = True
590    normalize: Union[bool, float] = False
591    dog: bool = True
592    low: Union[Union[int, float], List[Union[int, float]]] = 0.5
593    high: Union[Union[int, float], List[Union[int, float]]] = 10.0
594    periodic_smooth: bool = False
595
596    def __call__(self, vol_blocks, axis):
597        """Apply a 2D projection and filters to a volume block
598        Args:
599            vol_blocks (cupy.array): Blocked volume to be projected (6D dataset, with the first 3 dimensions being blocks and the last 3 dimensions being voxels)
600            axis (int): Axis along which to project
601        Returns:
602            cupy.array: Projected volume block (5D dataset, with the first 3 dimensions being blocks and the last 2 dimensions being 2D projections)
603        """
604        if self.max:
605            out = vol_blocks.max(axis)
606        else:
607            out = vol_blocks.mean(axis)
608        if self.periodic_smooth:
609            out = periodic_smooth_decomposition_nd_rfft(out)
610        low = np.delete(np.r_[1,1,1] * self.low, axis)
611        high = np.delete(np.r_[1,1,1] * self.high, axis)
612        if self.dog:
613            out = dogfilter(out, [0, 0, 0, *low], [0, 0, 0, *high], mode="reflect")
614        elif not np.all(np.array(self.low) == 0):
615            out = cupyx.scipy.ndimage.gaussian_filter(out, [0, 0, 0, *low], mode="reflect", truncate=5.0)
616        if self.normalize > 0:
617            out /= cp.sqrt(cp.sum(out**2, axis=(-2, -1), keepdims=True)) ** self.normalize + 1e-9
618        return out
619
620
621class Smoother(BaseModel):
622    """Smooth blocks with a Gaussian kernel
623    Args:
624        sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied.
625        truncate (float): truncate parameter for gaussian kernel. Default is 5.
626        shear (float): shear parameter for gaussian kernel. Default is None.
627        long_range_ratio (float): long range ratio for double gaussian kernel. Default is None.
628    """
629
630    sigmas: Union[float, List[float]] = [1.0, 1.0, 1.0]
631    shear: Union[float, None] = None
632    long_range_ratio: Union[float, None] = 0.05
633
634    def __call__(self, xcorr_proj, block_size=None):
635        """Apply a Gaussian filter to the cross-correlation data
636        Args:
637            xcorr_proj (cupy.array): cross-correlation data (5D array, with the first 3 dimensions being the blocks and the last 2 dimensions being the 2D projection)
638            block_size (list): shape of blocks, whose rigid displacement is estimated
639        Returns:
640            cupy.array: smoothed cross-correlation volume
641        """
642        truncate = 4.0
643        if self.sigmas is None:
644            return xcorr_proj
645        if self.shear is not None:
646            shear_blocks = self.shear * (block_size[1] / block_size[0])
647            gw = gausskernel_sheared(self.sigma[:2], shear_blocks, truncate=truncate)
648            gw = cp.array(gw[:, :, None, None, None])
649            xcorr_proj = cupyx.scipy.ndimage.convolve(xcorr_proj, gw, mode="constant")
650            xcorr_proj = cupyx.scipy.ndimage.gaussian_filter1d(
651                xcorr_proj, self.sigmas[2], axis=2, mode="constant", truncate=truncate
652            )
653        else:  # shear is None:
654            xcorr_proj = cupyx.scipy.ndimage.gaussian_filter(
655                xcorr_proj, [*self.sigmas, 0, 0], mode="constant", truncate=truncate
656            )
657        if self.long_range_ratio is not None:
658            xcorr_proj *= 1 - self.long_range_ratio
659            xcorr_proj += (
660                cupyx.scipy.ndimage.gaussian_filter(
661                    xcorr_proj, [*np.array(self.sigmas) * 5, 0, 0], mode="constant", truncate=truncate
662                )
663                * self.long_range_ratio
664            )
665        return xcorr_proj
666
667
668class RegFilter(BaseModel):
669    """A class to apply a filter to the volume before registration
670
671    Parameters:
672        clip_thresh: threshold for clipping the reference volume. Default is 0
673        dog: if True, apply a DoG filter to the volume. Default is True
674        low: the lower sigma value for the DoG filter. Default is 0.5
675        high: the higher sigma value for the DoG filter. Default is 10.0
676    """
677
678    clip_thresh: float = 0
679    dog: bool = True
680    low: float = 0.5
681    high: float = 10.0
682    soft_edge: Union[Union[int, float], List[Union[int, float]]] = 0.0
683
684    def __call__(self, vol, reg_mask=None):
685        """Apply the filter to the volume
686        Args:
687            vol (cupy or numpy array): 3D volume to be filtered
688            reg_mask (array): Mask for registration
689        Returns:
690            cupy.ndarray: Filtered volume
691        """
692        vol = cp.clip(cp.array(vol, "float32", copy=False) - self.clip_thresh, 0, None)
693        if np.any(np.array(self.soft_edge) > 0):
694            vol = soften_edges(vol, soft_edge=self.soft_edge, copy=False)
695        if reg_mask is not None:
696            vol *= cp.array(reg_mask, dtype="float32", copy=False)
697        if self.dog:
698            vol = dogfilter(vol, self.low, self.high, mode="reflect")
699        return vol
700
701
702class LevelConfig(BaseModel):
703    """Configuration for each level of the registration pyramid
704
705    Args:
706        block_size (list): shape of blocks, whose rigid displacement is estimated
707        block_stride (list): stride (usually identical to block_size)
708        repeats (int): number of iterations for this level (deisable level by setting repeats to 0)
709        smooth (Smoother or None): Smoother object
710        project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block.
711        tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5
712        affine (bool): if True, apply affine transformation to the displacement field
713        median_filter (bool): if True, apply median filter to the displacement field
714        update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations.
715    """
716
717    block_size: Union[List[int]]
718    block_stride: Union[List[int], float] = 1.0
719    project: Union[Projector, Callable[[_ArrayType, int], _ArrayType]] = Projector()
720    tukey_ref: Union[float, None] = 0.5
721    smooth: Union[Smoother, None] = Smoother()
722    affine: bool = False
723    median_filter: bool = True
724    update_rate: float = 1.0
725    repeats: int = 5
726
727
728class Recipe(BaseModel):
729    """Configuration for the registration recipe. Recipe is initialized with a single affine level.
730
731    Args:
732        reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume
733        levels (list): List of LevelConfig objects
734    """
735
736    pre_filter: Union[RegFilter, Callable[[_ArrayType], _ArrayType], None] = RegFilter()
737    levels: List[LevelConfig] = [
738        LevelConfig(block_size=[-1, -1, -1], repeats=3),  # translation level
739        LevelConfig(  # affine level
740            block_size=[-2, -2, -2],
741            block_stride=0.5,
742            repeats=10,
743            affine=True,
744            median_filter=False,
745            smooth=Smoother(sigmas=[0.5, 0.5, 0.5]),
746        ),
747    ]
748
749    def add_level(self, block_size, **kwargs):
750        """Add a level to the registration recipe
751
752        Args:
753            block_size (list): shape of blocks, whose rigid displacement is estimated
754            **kwargs: additional arguments for LevelConfig
755        """
756        if isinstance(block_size, (int, float)):
757            block_size = [block_size] * 3
758        if len(block_size) != 3:
759            raise ValueError("block_size must be a list of 3 integers")
760        self.levels.append(LevelConfig(block_size=block_size, **kwargs))
761
762    def insert_level(self, index, block_size, **kwargs):
763        """Insert a level to the registration recipe
764
765        Args:
766            index (int): A number specifying in which position to insert the level
767            block_size (list): shape of blocks, whose rigid displacement is estimated
768            **kwargs: additional arguments for LevelConfig
769        """
770        if isinstance(block_size, (int, float)):
771            block_size = [block_size] * 3
772        if len(block_size) != 3:
773            raise ValueError("block_size must be a list of 3 integers")
774        self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs))
775
776    @classmethod
777    def from_yaml(cls, yaml_path):
778        """Load a recipe from a YAML file
779
780        Args:
781            yaml_path (str): path to the YAML file
782
783        Returns:
784            Recipe: Recipe object
785        """
786        import yaml
787
788        this_file_dir = pathlib.Path(__file__).resolve().parent
789        if os.path.isfile(yaml_path):
790            yaml_path = yaml_path
791        else:
792            yaml_path = os.path.join(this_file_dir, "recipes", yaml_path)
793
794        with open(yaml_path, "r") as f:
795            data = yaml.safe_load(f)
796
797        return cls.model_validate(data)
798
799    def to_yaml(self, yaml_path):
800        """Save the recipe to a YAML file
801
802        Args:
803            yaml_path (str): path to the YAML file
804        """
805        import yaml
806
807        with open(yaml_path, "w") as f:
808            yaml.dump(self.model_dump(), f)
809        print(f"Recipe saved to {yaml_path}")
class WarpMap:
 33class WarpMap:
 34    """Represents a 3D displacement field
 35
 36    Args:
 37        warp_field (numpy.array): the displacement field data (3-x-y-z)
 38        block_size (3-element list or numpy.array):
 39        block_stride (3-element list or numpy.array):
 40        ref_shape (tuple): shape of the reference volume
 41        mov_shape (tuple): shape of the moving volume
 42    """
 43
 44    def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape):
 45        self.warp_field = cp.array(warp_field, dtype="float32")
 46        self.block_size = cp.array(block_size, dtype="float32")
 47        self.block_stride = cp.array(block_stride, dtype="float32")
 48        self.ref_shape = ref_shape
 49        self.mov_shape = mov_shape
 50
 51    def warp(self, vol, out=None):
 52        """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.
 53
 54        Args:
 55            vol (cupy.array): the volume to be warped
 56
 57        Returns:
 58            cupy.array: warped volume
 59        """
 60        if np.any(vol.shape != np.array(self.mov_shape)):
 61            warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.")
 62        if out is None:
 63            out = cp.zeros(self.ref_shape, dtype="float32", order="C")
 64        vol_out = warp_volume(
 65            vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out
 66        )
 67        return vol_out
 68
 69    def apply(self, *args, **kwargs):
 70        """Alias of warp method"""
 71        return self.warp(*args, **kwargs)
 72
 73    def fit_affine(self, target=None):
 74        """Fit affine transformation and return new fitted WarpMap
 75
 76        Args:
 77            target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
 78
 79        Returns:
 80            WarpMap:
 81            numpy.array: affine tranformation coefficients
 82        """
 83        if target is None:
 84            warp_field_shape = self.warp_field.shape
 85            block_size = self.block_size
 86            block_stride = self.block_stride
 87        else:
 88            warp_field_shape = target["warp_field_shape"]
 89            block_size = cp.array(target["block_size"]).astype("float32")
 90            block_stride = cp.array(target["block_stride"]).astype("float32")
 91
 92        ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T
 93        ix = ix * self.block_stride + self.block_size / 2
 94        M = cp.zeros(self.warp_field.shape[1:])
 95        #M[1:-1, 1:-1, 1:-1] = 1
 96        M[:,:,:] = 1
 97        ixg = cp.where(M.flatten() > 0)[0]
 98        a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))])
 99        b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg]
100        coeff = cp.linalg.lstsq(a, b, rcond=None)[0]
101        ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2
102        linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape)
103        return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff
104
105    def median_filter(self):
106        """Apply median filter to the displacement field
107
108        Returns:
109            WarpMap: new WarpMap with median filtered displacement field
110        """
111        warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest")
112        return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape)
113
114    def resize_to(self, target):
115        """Resize to target WarpMap, using linear interpolation
116
117        Args:
118            target (WarpMap or WarpMapper): target to resize to
119                or a dict with keys "shape", "block_size", and "block_stride"
120
121        Returns:
122            WarpMap: resized WarpMap
123        """
124        if isinstance(target, WarpMap):
125            t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride
126        elif isinstance(target, WarpMapper):
127            t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride)
128        elif isinstance(target, dict):
129            t_sh, t_bsz, t_bst = (
130                target["warp_field_shape"][1:],
131                cp.array(target["block_size"]),
132                cp.array(target["block_stride"]),
133            )
134        else:
135            raise ValueError("target must be a WarpMap, WarpMapper, or dict")
136        ix = cp.array(cp.indices(t_sh).reshape(3, -1))
137        # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5
138        ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None]
139        dm_r = cp.array(
140            [
141                cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape(
142                    t_sh
143                )
144                for i in range(3)
145            ]
146        )
147        return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape)
148
149    def chain(self, target):
150        """Chain displacement maps
151
152        Args:
153            target (WarpMap): WarpMap to be added to existing map
154
155        Returns:
156            WarpMap: new WarpMap with chained displacement field
157        """
158        indices = cp.indices(target.warp_field.shape[1:])
159        warp_field = self.warp_field.copy()
160        warp_field += target.warp_field
161        return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape)
162
163    def invert(self, **kwargs):
164        """alias for invert_fast method"""
165        return self.invert_fast(**kwargs)
166
167    def invert_fast(self, sigma=0.5, truncate=20):
168        """Invert the displacement field using accumulation and Gaussian basis interpolation.
169
170        Args:
171            sigma (float): standard deviation for Gaussian basis interpolation
172            truncate (float): truncate parameter for Gaussian basis interpolation
173
174        Returns:
175            WarpMap: inverted WarpMap
176        """
177        warp_field = self.warp_field.get()
178        target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get()
179        wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int")
180        num_coords = accumarray(target_coords, wf_shape)
181        inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype)
182        for i in range(3):
183            inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel())
184            with np.errstate(invalid="ignore"):
185                inv_field[i] /= num_coords
186            inv_field[i][num_coords == 0] = np.nan
187            inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate)
188        return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape)
189
190    def push_coordinates(self, coords, negative_shifts=False):
191        """Push voxel coordinates from fixed to moving space.
192
193        Args:
194            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
195
196        Returns:
197            numpy.array: transformed voxel coordinates
198        """
199        assert coords.shape[0] == 3
200        coords = cp.array(coords, dtype="float32")
201        # coords_blocked = coords / self.block_size[:, None] - 0.5
202        coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None]
203        warp_field = self.warp_field.copy()
204        shifts = cp.zeros_like(coords)
205        for idim in range(3):
206            shifts[idim] = cupyx.scipy.ndimage.map_coordinates(
207                warp_field[idim], coords_blocked, order=1, mode="nearest"
208            )
209        if negative_shifts:
210            shifts = -shifts
211        return coords + shifts
212
213    def pull_coordinates(self, coords):
214        """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.
215
216        Args:
217            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
218
219        Returns:
220            numpy.array: transformed voxel coordinates
221        """
222        return self.invert().push_coordinates(coords, negative_shifts=True)
223
224    def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1):
225        """
226        Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.
227
228        Args:
229            edge_order : passed to np.gradient (1 or 2)
230
231        Returns:
232            detJ: cp.ndarray of shape spatial
233        """
234        scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride
235        coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None]
236        phi = coords + self.warp_field
237        J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32")
238        for i in range(3):
239            grads = cp.gradient(phi[i], *scaling, edge_order=edge_order)
240            for j in range(3):
241                J[..., i, j] = grads[j]
242        return cp.linalg.det(J)
243
244    def as_ants_image(self, voxel_size_um=1):
245        """Convert to ANTsImage
246
247        Args:
248            voxel_size_um (scalar or array): voxel size (default is 1)
249
250        Returns:
251            ants.core.ants_image.ANTsImage:
252        """
253        try:
254            import ants
255        except ImportError:
256            raise ImportError("ANTs is not installed. Please install it using 'pip install ants'")
257
258        ants_image = ants.from_numpy(
259            self.warp_field.get().transpose(1, 2, 3, 0),
260            origin=list((self.block_size.get() - 1) / 2 * voxel_size_um),
261            spacing=list(self.block_stride.get() * voxel_size_um),
262            has_components=True,
263        )
264        return ants_image
265
266    def __repr__(self):
267        """String representation of the WarpMap object."""
268        info = (
269            f"WarpMap("
270            f"warp_field_shape={self.warp_field.shape}, "
271            f"block_size={self.block_size.get()}, "
272            f"block_stride={self.block_stride.get()}, "
273            f"transformation: {str(self.mov_shape)} --> {str(self.ref_shape)}"
274        )
275        return info
276
277    def to_h5(self, h5_path, group="warp_map", compression="gzip", overwrite=True):
278        """
279        Save this WarpMap to an HDF5 file.
280
281        Args:
282            h5_path (str or os.PathLike): Path to the HDF5 file.
283            group (str): Group path inside the HDF5 file to store the WarpMap (created if missing).
284            compression (str or None): Dataset compression (e.g., 'gzip', None).
285            overwrite (bool): If True, overwrite existing datasets/attrs inside the group.
286        """
287        with h5py.File(h5_path, "a") as f:
288            if overwrite and (group not in (None, "", "/")) and (group in f):
289                del f[group]
290            if (not overwrite) and (group in f):
291                raise ValueError(f"Group '{group}' already exists in {h5_path}. Set 'overwrite=True' to overwrite it.")
292            grp = f.require_group(group) if group not in (None, "", "/") else f
293            grp.create_dataset("warp_field", data=self.warp_field.get(), compression=compression)
294            grp.create_dataset("block_size", data=self.block_size.get())
295            grp.create_dataset("block_stride", data=self.block_stride.get())
296            grp.create_dataset("ref_shape", data=np.array(self.ref_shape, dtype="int64"))
297            grp.create_dataset("mov_shape", data=np.array(self.mov_shape, dtype="int64"))
298            grp.attrs["class"] = "WarpMap"
299
300    @classmethod
301    def from_h5(cls, h5_path, group="warp_map"):
302        """
303        Load a WarpMap from an HDF5 file.
304
305        Args:
306            h5_path (str or os.PathLike): Path to the HDF5 file.
307            group (str): Group path inside the HDF5 file where the WarpMap is stored.
308
309        Returns:
310            WarpMap: The loaded WarpMap object.
311        """
312        with h5py.File(h5_path, "r") as f:
313            grp = f[group]
314            warp_field = grp["warp_field"][:]
315            block_size = grp["block_size"][:]
316            block_stride = grp["block_stride"][:]
317            ref_shape = tuple(grp["ref_shape"][:].tolist())
318            mov_shape = tuple(grp["mov_shape"][:].tolist())
319        return cls(warp_field, block_size, block_stride, ref_shape, mov_shape)

Represents a 3D displacement field

Arguments:
  • warp_field (numpy.array): the displacement field data (3-x-y-z)
  • block_size (3-element list or numpy.array):
  • block_stride (3-element list or numpy.array):
  • ref_shape (tuple): shape of the reference volume
  • mov_shape (tuple): shape of the moving volume
WarpMap(warp_field, block_size, block_stride, ref_shape, mov_shape)
44    def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape):
45        self.warp_field = cp.array(warp_field, dtype="float32")
46        self.block_size = cp.array(block_size, dtype="float32")
47        self.block_stride = cp.array(block_stride, dtype="float32")
48        self.ref_shape = ref_shape
49        self.mov_shape = mov_shape
warp_field
block_size
block_stride
ref_shape
mov_shape
def warp(self, vol, out=None):
51    def warp(self, vol, out=None):
52        """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.
53
54        Args:
55            vol (cupy.array): the volume to be warped
56
57        Returns:
58            cupy.array: warped volume
59        """
60        if np.any(vol.shape != np.array(self.mov_shape)):
61            warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.")
62        if out is None:
63            out = cp.zeros(self.ref_shape, dtype="float32", order="C")
64        vol_out = warp_volume(
65            vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out
66        )
67        return vol_out

Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.

Arguments:
  • vol (cupy.array): the volume to be warped
Returns:

cupy.array: warped volume

def apply(self, *args, **kwargs):
69    def apply(self, *args, **kwargs):
70        """Alias of warp method"""
71        return self.warp(*args, **kwargs)

Alias of warp method

def fit_affine(self, target=None):
 73    def fit_affine(self, target=None):
 74        """Fit affine transformation and return new fitted WarpMap
 75
 76        Args:
 77            target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
 78
 79        Returns:
 80            WarpMap:
 81            numpy.array: affine tranformation coefficients
 82        """
 83        if target is None:
 84            warp_field_shape = self.warp_field.shape
 85            block_size = self.block_size
 86            block_stride = self.block_stride
 87        else:
 88            warp_field_shape = target["warp_field_shape"]
 89            block_size = cp.array(target["block_size"]).astype("float32")
 90            block_stride = cp.array(target["block_stride"]).astype("float32")
 91
 92        ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T
 93        ix = ix * self.block_stride + self.block_size / 2
 94        M = cp.zeros(self.warp_field.shape[1:])
 95        #M[1:-1, 1:-1, 1:-1] = 1
 96        M[:,:,:] = 1
 97        ixg = cp.where(M.flatten() > 0)[0]
 98        a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))])
 99        b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg]
100        coeff = cp.linalg.lstsq(a, b, rcond=None)[0]
101        ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2
102        linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape)
103        return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff

Fit affine transformation and return new fitted WarpMap

Arguments:
  • target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
Returns:

WarpMap: numpy.array: affine tranformation coefficients

def median_filter(self):
105    def median_filter(self):
106        """Apply median filter to the displacement field
107
108        Returns:
109            WarpMap: new WarpMap with median filtered displacement field
110        """
111        warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest")
112        return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape)

Apply median filter to the displacement field

Returns:

WarpMap: new WarpMap with median filtered displacement field

def resize_to(self, target):
114    def resize_to(self, target):
115        """Resize to target WarpMap, using linear interpolation
116
117        Args:
118            target (WarpMap or WarpMapper): target to resize to
119                or a dict with keys "shape", "block_size", and "block_stride"
120
121        Returns:
122            WarpMap: resized WarpMap
123        """
124        if isinstance(target, WarpMap):
125            t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride
126        elif isinstance(target, WarpMapper):
127            t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride)
128        elif isinstance(target, dict):
129            t_sh, t_bsz, t_bst = (
130                target["warp_field_shape"][1:],
131                cp.array(target["block_size"]),
132                cp.array(target["block_stride"]),
133            )
134        else:
135            raise ValueError("target must be a WarpMap, WarpMapper, or dict")
136        ix = cp.array(cp.indices(t_sh).reshape(3, -1))
137        # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5
138        ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None]
139        dm_r = cp.array(
140            [
141                cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape(
142                    t_sh
143                )
144                for i in range(3)
145            ]
146        )
147        return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape)

Resize to target WarpMap, using linear interpolation

Arguments:
  • target (WarpMap or WarpMapper): target to resize to or a dict with keys "shape", "block_size", and "block_stride"
Returns:

WarpMap: resized WarpMap

def chain(self, target):
149    def chain(self, target):
150        """Chain displacement maps
151
152        Args:
153            target (WarpMap): WarpMap to be added to existing map
154
155        Returns:
156            WarpMap: new WarpMap with chained displacement field
157        """
158        indices = cp.indices(target.warp_field.shape[1:])
159        warp_field = self.warp_field.copy()
160        warp_field += target.warp_field
161        return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape)

Chain displacement maps

Arguments:
  • target (WarpMap): WarpMap to be added to existing map
Returns:

WarpMap: new WarpMap with chained displacement field

def invert(self, **kwargs):
163    def invert(self, **kwargs):
164        """alias for invert_fast method"""
165        return self.invert_fast(**kwargs)

alias for invert_fast method

def invert_fast(self, sigma=0.5, truncate=20):
167    def invert_fast(self, sigma=0.5, truncate=20):
168        """Invert the displacement field using accumulation and Gaussian basis interpolation.
169
170        Args:
171            sigma (float): standard deviation for Gaussian basis interpolation
172            truncate (float): truncate parameter for Gaussian basis interpolation
173
174        Returns:
175            WarpMap: inverted WarpMap
176        """
177        warp_field = self.warp_field.get()
178        target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get()
179        wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int")
180        num_coords = accumarray(target_coords, wf_shape)
181        inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype)
182        for i in range(3):
183            inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel())
184            with np.errstate(invalid="ignore"):
185                inv_field[i] /= num_coords
186            inv_field[i][num_coords == 0] = np.nan
187            inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate)
188        return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape)

Invert the displacement field using accumulation and Gaussian basis interpolation.

Arguments:
  • sigma (float): standard deviation for Gaussian basis interpolation
  • truncate (float): truncate parameter for Gaussian basis interpolation
Returns:

WarpMap: inverted WarpMap

def push_coordinates(self, coords, negative_shifts=False):
190    def push_coordinates(self, coords, negative_shifts=False):
191        """Push voxel coordinates from fixed to moving space.
192
193        Args:
194            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
195
196        Returns:
197            numpy.array: transformed voxel coordinates
198        """
199        assert coords.shape[0] == 3
200        coords = cp.array(coords, dtype="float32")
201        # coords_blocked = coords / self.block_size[:, None] - 0.5
202        coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None]
203        warp_field = self.warp_field.copy()
204        shifts = cp.zeros_like(coords)
205        for idim in range(3):
206            shifts[idim] = cupyx.scipy.ndimage.map_coordinates(
207                warp_field[idim], coords_blocked, order=1, mode="nearest"
208            )
209        if negative_shifts:
210            shifts = -shifts
211        return coords + shifts

Push voxel coordinates from fixed to moving space.

Arguments:
  • coords (numpy.array): 3D voxel coordinates to be warped (3-by-n array)
Returns:

numpy.array: transformed voxel coordinates

def pull_coordinates(self, coords):
213    def pull_coordinates(self, coords):
214        """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.
215
216        Args:
217            coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array)
218
219        Returns:
220            numpy.array: transformed voxel coordinates
221        """
222        return self.invert().push_coordinates(coords, negative_shifts=True)

Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.

Arguments:
  • coords (numpy.array): 3D voxel coordinates to be warped (3-by-n array)
Returns:

numpy.array: transformed voxel coordinates

def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1):
224    def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1):
225        """
226        Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.
227
228        Args:
229            edge_order : passed to np.gradient (1 or 2)
230
231        Returns:
232            detJ: cp.ndarray of shape spatial
233        """
234        scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride
235        coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None]
236        phi = coords + self.warp_field
237        J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32")
238        for i in range(3):
239            grads = cp.gradient(phi[i], *scaling, edge_order=edge_order)
240            for j in range(3):
241                J[..., i, j] = grads[j]
242        return cp.linalg.det(J)

Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.

Arguments:
  • edge_order : passed to np.gradient (1 or 2)
Returns:

detJ: cp.ndarray of shape spatial

def as_ants_image(self, voxel_size_um=1):
244    def as_ants_image(self, voxel_size_um=1):
245        """Convert to ANTsImage
246
247        Args:
248            voxel_size_um (scalar or array): voxel size (default is 1)
249
250        Returns:
251            ants.core.ants_image.ANTsImage:
252        """
253        try:
254            import ants
255        except ImportError:
256            raise ImportError("ANTs is not installed. Please install it using 'pip install ants'")
257
258        ants_image = ants.from_numpy(
259            self.warp_field.get().transpose(1, 2, 3, 0),
260            origin=list((self.block_size.get() - 1) / 2 * voxel_size_um),
261            spacing=list(self.block_stride.get() * voxel_size_um),
262            has_components=True,
263        )
264        return ants_image

Convert to ANTsImage

Arguments:
  • voxel_size_um (scalar or array): voxel size (default is 1)
Returns:

ants.core.ants_image.ANTsImage:

def to_h5(self, h5_path, group='warp_map', compression='gzip', overwrite=True):
277    def to_h5(self, h5_path, group="warp_map", compression="gzip", overwrite=True):
278        """
279        Save this WarpMap to an HDF5 file.
280
281        Args:
282            h5_path (str or os.PathLike): Path to the HDF5 file.
283            group (str): Group path inside the HDF5 file to store the WarpMap (created if missing).
284            compression (str or None): Dataset compression (e.g., 'gzip', None).
285            overwrite (bool): If True, overwrite existing datasets/attrs inside the group.
286        """
287        with h5py.File(h5_path, "a") as f:
288            if overwrite and (group not in (None, "", "/")) and (group in f):
289                del f[group]
290            if (not overwrite) and (group in f):
291                raise ValueError(f"Group '{group}' already exists in {h5_path}. Set 'overwrite=True' to overwrite it.")
292            grp = f.require_group(group) if group not in (None, "", "/") else f
293            grp.create_dataset("warp_field", data=self.warp_field.get(), compression=compression)
294            grp.create_dataset("block_size", data=self.block_size.get())
295            grp.create_dataset("block_stride", data=self.block_stride.get())
296            grp.create_dataset("ref_shape", data=np.array(self.ref_shape, dtype="int64"))
297            grp.create_dataset("mov_shape", data=np.array(self.mov_shape, dtype="int64"))
298            grp.attrs["class"] = "WarpMap"

Save this WarpMap to an HDF5 file.

Arguments:
  • h5_path (str or os.PathLike): Path to the HDF5 file.
  • group (str): Group path inside the HDF5 file to store the WarpMap (created if missing).
  • compression (str or None): Dataset compression (e.g., 'gzip', None).
  • overwrite (bool): If True, overwrite existing datasets/attrs inside the group.
@classmethod
def from_h5(cls, h5_path, group='warp_map'):
300    @classmethod
301    def from_h5(cls, h5_path, group="warp_map"):
302        """
303        Load a WarpMap from an HDF5 file.
304
305        Args:
306            h5_path (str or os.PathLike): Path to the HDF5 file.
307            group (str): Group path inside the HDF5 file where the WarpMap is stored.
308
309        Returns:
310            WarpMap: The loaded WarpMap object.
311        """
312        with h5py.File(h5_path, "r") as f:
313            grp = f[group]
314            warp_field = grp["warp_field"][:]
315            block_size = grp["block_size"][:]
316            block_stride = grp["block_stride"][:]
317            ref_shape = tuple(grp["ref_shape"][:].tolist())
318            mov_shape = tuple(grp["mov_shape"][:].tolist())
319        return cls(warp_field, block_size, block_stride, ref_shape, mov_shape)

Load a WarpMap from an HDF5 file.

Arguments:
  • h5_path (str or os.PathLike): Path to the HDF5 file.
  • group (str): Group path inside the HDF5 file where the WarpMap is stored.
Returns:

WarpMap: The loaded WarpMap object.

class WarpMapper:
322class WarpMapper:
323    """Class that estimates warp field using cross-correlation, based on a piece-wise rigid model.
324
325    Args:
326        ref_vol (numpy.array): The reference volume
327        block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated
328        block_stride (3-element list or numpy.array): stride (usually identical to block_size)
329        proj_method (str or callable): Projection method
330    """
331
332    def __init__(
333        self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5
334    ):
335        if np.any(block_size > np.array(ref_vol.shape)):
336            raise ValueError(f"Block size (currently: {block_size}) must be smaller than the volume shape ({np.array(ref_vol.shape)}).")
337        self.proj_method = proj_method
338        self.plan_rev = [None, None, None]
339        self.subpixel = subpixel
340        self.epsilon = epsilon
341        self.tukey_alpha = tukey_alpha
342        self.update_reference(ref_vol, block_size, block_stride)
343        self.ref_shape = np.array(ref_vol.shape)
344
345    def update_reference(self, ref_vol, block_size, block_stride=None):
346        ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1))
347        block_size = np.array(block_size)
348        block_stride = block_size if block_stride is None else np.array(block_stride)
349        ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride)
350        self.blocks_shape = ref_blocks.shape
351        ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]]
352        if self.tukey_alpha < 1:
353            ref_blocks_proj = [
354                ref_blocks_proj[i]
355                * cp.array(
356                    ndwindow(
357                        [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5)
358                    )
359                ).astype("float32")
360                for i in range(3)
361            ]
362        self.plan_fwd = [
363            cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3)
364        ]
365        self.ref_blocks_proj_ft_conj = [
366            cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3)
367        ]
368        self.block_size = block_size
369        self.block_stride = block_stride
370
371    def get_displacement(self, vol, smooth_func=None):
372        """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.
373
374        Args:
375            vol (numpy.array): Input volume
376            smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
377
378        Returns:
379            WarpMap
380        """
381        vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride)
382        vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]]
383        del vol_blocks
384
385        disp_field = []
386        for i in range(3):
387            R = (
388                cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i])
389                * self.ref_blocks_proj_ft_conj[i]
390            )
391            if self.plan_rev[i] is None:
392                self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R")
393            xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1))
394            if smooth_func is not None:
395                xcorr_proj = smooth_func(xcorr_proj, self.block_size)
396            xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon
397
398            max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:]))
399            max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2
400            del xcorr_proj
401            i0, j0 = max_ix.reshape(2, -1)
402            shifts = upsampled_dft_rfftn(
403                R.reshape(-1, *R.shape[-2:]),
404                upsampled_region_size=int(self.subpixel * 2 + 1),
405                upsample_factor=self.subpixel,
406                axis_offsets=(i0, j0),
407            )
408            del R
409            max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:]))
410            max_sub = (
411                max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2
412            ) / self.subpixel
413            del shifts
414            disp_field.append(max_ix + max_sub)
415
416        disp_field = cp.array(disp_field)
417        disp_field = (
418            cp.array(
419                [
420                    disp_field[1, 0] + disp_field[2, 0],
421                    disp_field[0, 0] + disp_field[2, 1],
422                    disp_field[0, 1] + disp_field[1, 1],
423                ]
424            ).astype("float32")
425            / 2
426        )
427        return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)

Class that estimates warp field using cross-correlation, based on a piece-wise rigid model.

Arguments:
  • ref_vol (numpy.array): The reference volume
  • block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated
  • block_stride (3-element list or numpy.array): stride (usually identical to block_size)
  • proj_method (str or callable): Projection method
WarpMapper( ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-06, tukey_alpha=0.5)
332    def __init__(
333        self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5
334    ):
335        if np.any(block_size > np.array(ref_vol.shape)):
336            raise ValueError(f"Block size (currently: {block_size}) must be smaller than the volume shape ({np.array(ref_vol.shape)}).")
337        self.proj_method = proj_method
338        self.plan_rev = [None, None, None]
339        self.subpixel = subpixel
340        self.epsilon = epsilon
341        self.tukey_alpha = tukey_alpha
342        self.update_reference(ref_vol, block_size, block_stride)
343        self.ref_shape = np.array(ref_vol.shape)
proj_method
plan_rev
subpixel
epsilon
tukey_alpha
ref_shape
def update_reference(self, ref_vol, block_size, block_stride=None):
345    def update_reference(self, ref_vol, block_size, block_stride=None):
346        ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1))
347        block_size = np.array(block_size)
348        block_stride = block_size if block_stride is None else np.array(block_stride)
349        ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride)
350        self.blocks_shape = ref_blocks.shape
351        ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]]
352        if self.tukey_alpha < 1:
353            ref_blocks_proj = [
354                ref_blocks_proj[i]
355                * cp.array(
356                    ndwindow(
357                        [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5)
358                    )
359                ).astype("float32")
360                for i in range(3)
361            ]
362        self.plan_fwd = [
363            cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3)
364        ]
365        self.ref_blocks_proj_ft_conj = [
366            cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3)
367        ]
368        self.block_size = block_size
369        self.block_stride = block_stride
def get_displacement(self, vol, smooth_func=None):
371    def get_displacement(self, vol, smooth_func=None):
372        """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.
373
374        Args:
375            vol (numpy.array): Input volume
376            smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
377
378        Returns:
379            WarpMap
380        """
381        vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride)
382        vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]]
383        del vol_blocks
384
385        disp_field = []
386        for i in range(3):
387            R = (
388                cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i])
389                * self.ref_blocks_proj_ft_conj[i]
390            )
391            if self.plan_rev[i] is None:
392                self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R")
393            xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1))
394            if smooth_func is not None:
395                xcorr_proj = smooth_func(xcorr_proj, self.block_size)
396            xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon
397
398            max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:]))
399            max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2
400            del xcorr_proj
401            i0, j0 = max_ix.reshape(2, -1)
402            shifts = upsampled_dft_rfftn(
403                R.reshape(-1, *R.shape[-2:]),
404                upsampled_region_size=int(self.subpixel * 2 + 1),
405                upsample_factor=self.subpixel,
406                axis_offsets=(i0, j0),
407            )
408            del R
409            max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:]))
410            max_sub = (
411                max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2
412            ) / self.subpixel
413            del shifts
414            disp_field.append(max_ix + max_sub)
415
416        disp_field = cp.array(disp_field)
417        disp_field = (
418            cp.array(
419                [
420                    disp_field[1, 0] + disp_field[2, 0],
421                    disp_field[0, 0] + disp_field[2, 1],
422                    disp_field[0, 1] + disp_field[1, 1],
423                ]
424            ).astype("float32")
425            / 2
426        )
427        return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)

Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.

Arguments:
  • vol (numpy.array): Input volume
  • smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
Returns:

WarpMap

class RegistrationPyramid:
430class RegistrationPyramid:
431    """A class for performing multi-resolution registration.
432
433    Args:
434        ref_vol (numpy.array): Reference volume
435        settings (pandas.DataFrame): Settings for each level of the pyramid.
436            IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level.
437        reg_mask (numpy.array): Mask for registration
438        clip_thresh (float): Threshold for clipping the reference volume
439    """
440
441    def __init__(self, ref_vol, recipe, reg_mask=1):
442        recipe.model_validate(recipe.model_dump())
443        self.recipe = recipe
444        self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C")
445        self.mappers = []
446        ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C")
447        self.ref_shape = ref_vol.shape
448        if self.recipe.pre_filter is not None:
449            ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask)
450        self.mapper_ix = []
451        for i in range(len(recipe.levels)):
452            if recipe.levels[i].repeats < 1:
453                continue
454            block_size = np.array(recipe.levels[i].block_size)
455            tmp = np.r_[ref_vol.shape] // -block_size
456            block_size[block_size < 0] = tmp[block_size < 0]
457            if isinstance(recipe.levels[i].block_stride, (int, float)):
458                block_stride = (block_size * recipe.levels[i].block_stride).astype("int")
459            else:
460                block_stride = np.array(recipe.levels[i].block_stride)
461            self.mappers.append(
462                WarpMapper(
463                    ref_vol,
464                    block_size,
465                    block_stride=block_stride,
466                    proj_method=recipe.levels[i].project,
467                    tukey_alpha=recipe.levels[i].tukey_ref,
468                )
469            )
470            self.mapper_ix.append(i)
471        assert len(self.mappers) > 0, "At least one level of registration is required"
472
473    def register_single(self, vol, callback=None, verbose=False):
474        """Register a single volume to the reference volume.
475
476        Args:
477            vol (array_like): Volume to be registered (numpy or cupy array)
478            callback (function): Callback function to be called after each level of registration
479
480        Returns:
481            - vol (array_like): Registered volume (numpy or cupy array, depending on input)
482            - warp_map (WarpMap): Displacement field
483            - callback_output (list): List of outputs from the callback function
484        """
485        was_numpy = isinstance(vol, np.ndarray)
486        vol = cp.array(vol, "float32", copy=False, order="C")
487        offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2
488        warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape)
489        warp_map = warp_map.resize_to(self.mappers[-1])
490        callback_output = []
491        vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol
492        vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C")
493        warp_map.warp(vol_tmp0, out=vol_tmp)
494        min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0)
495        if callback is not None:
496            callback_output.append(callback(vol_tmp))
497
498        if np.any(self.mappers[-1].block_stride > min_block_stride[0]):
499            warnings.warn(
500                "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)."
501            )
502        for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)):
503            for _ in tqdm(
504                range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose
505            ):
506                wm = mapper.get_displacement(
507                    vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth  # * self.reg_mask,
508                )
509                wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate
510                if self.recipe.levels[self.mapper_ix[k]].median_filter:
511                    wm = wm.median_filter()
512                if self.recipe.levels[self.mapper_ix[k]].affine:
513                    if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1:
514                        raise ValueError(
515                            f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}"
516                        )
517                    wm, _ = wm.fit_affine(
518                        target=dict(
519                            warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]),
520                            block_size=self.mappers[-1].block_size,
521                            block_stride=self.mappers[-1].block_stride,
522                        )
523                    )
524                else:
525                    wm = wm.resize_to(self.mappers[-1])
526
527                warp_map = warp_map.chain(wm)
528                warp_map.warp(vol_tmp0, out=vol_tmp)
529                if callback is not None:
530                    # callback_output.append(callback(warp_map.unwarp(vol)))
531                    callback_output.append(callback(vol_tmp))
532        warp_map.warp(vol, out=vol_tmp)
533        if was_numpy:
534            vol_tmp = vol_tmp.get()
535        return vol_tmp, warp_map, callback_output

A class for performing multi-resolution registration.

Arguments:
  • ref_vol (numpy.array): Reference volume
  • settings (pandas.DataFrame): Settings for each level of the pyramid. IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level.
  • reg_mask (numpy.array): Mask for registration
  • clip_thresh (float): Threshold for clipping the reference volume
RegistrationPyramid(ref_vol, recipe, reg_mask=1)
441    def __init__(self, ref_vol, recipe, reg_mask=1):
442        recipe.model_validate(recipe.model_dump())
443        self.recipe = recipe
444        self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C")
445        self.mappers = []
446        ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C")
447        self.ref_shape = ref_vol.shape
448        if self.recipe.pre_filter is not None:
449            ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask)
450        self.mapper_ix = []
451        for i in range(len(recipe.levels)):
452            if recipe.levels[i].repeats < 1:
453                continue
454            block_size = np.array(recipe.levels[i].block_size)
455            tmp = np.r_[ref_vol.shape] // -block_size
456            block_size[block_size < 0] = tmp[block_size < 0]
457            if isinstance(recipe.levels[i].block_stride, (int, float)):
458                block_stride = (block_size * recipe.levels[i].block_stride).astype("int")
459            else:
460                block_stride = np.array(recipe.levels[i].block_stride)
461            self.mappers.append(
462                WarpMapper(
463                    ref_vol,
464                    block_size,
465                    block_stride=block_stride,
466                    proj_method=recipe.levels[i].project,
467                    tukey_alpha=recipe.levels[i].tukey_ref,
468                )
469            )
470            self.mapper_ix.append(i)
471        assert len(self.mappers) > 0, "At least one level of registration is required"
recipe
reg_mask
mappers
ref_shape
mapper_ix
def register_single(self, vol, callback=None, verbose=False):
473    def register_single(self, vol, callback=None, verbose=False):
474        """Register a single volume to the reference volume.
475
476        Args:
477            vol (array_like): Volume to be registered (numpy or cupy array)
478            callback (function): Callback function to be called after each level of registration
479
480        Returns:
481            - vol (array_like): Registered volume (numpy or cupy array, depending on input)
482            - warp_map (WarpMap): Displacement field
483            - callback_output (list): List of outputs from the callback function
484        """
485        was_numpy = isinstance(vol, np.ndarray)
486        vol = cp.array(vol, "float32", copy=False, order="C")
487        offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2
488        warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape)
489        warp_map = warp_map.resize_to(self.mappers[-1])
490        callback_output = []
491        vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol
492        vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C")
493        warp_map.warp(vol_tmp0, out=vol_tmp)
494        min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0)
495        if callback is not None:
496            callback_output.append(callback(vol_tmp))
497
498        if np.any(self.mappers[-1].block_stride > min_block_stride[0]):
499            warnings.warn(
500                "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)."
501            )
502        for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)):
503            for _ in tqdm(
504                range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose
505            ):
506                wm = mapper.get_displacement(
507                    vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth  # * self.reg_mask,
508                )
509                wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate
510                if self.recipe.levels[self.mapper_ix[k]].median_filter:
511                    wm = wm.median_filter()
512                if self.recipe.levels[self.mapper_ix[k]].affine:
513                    if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1:
514                        raise ValueError(
515                            f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}"
516                        )
517                    wm, _ = wm.fit_affine(
518                        target=dict(
519                            warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]),
520                            block_size=self.mappers[-1].block_size,
521                            block_stride=self.mappers[-1].block_stride,
522                        )
523                    )
524                else:
525                    wm = wm.resize_to(self.mappers[-1])
526
527                warp_map = warp_map.chain(wm)
528                warp_map.warp(vol_tmp0, out=vol_tmp)
529                if callback is not None:
530                    # callback_output.append(callback(warp_map.unwarp(vol)))
531                    callback_output.append(callback(vol_tmp))
532        warp_map.warp(vol, out=vol_tmp)
533        if was_numpy:
534            vol_tmp = vol_tmp.get()
535        return vol_tmp, warp_map, callback_output

Register a single volume to the reference volume.

Arguments:
  • vol (array_like): Volume to be registered (numpy or cupy array)
  • callback (function): Callback function to be called after each level of registration
Returns:
  • vol (array_like): Registered volume (numpy or cupy array, depending on input)
  • warp_map (WarpMap): Displacement field
  • callback_output (list): List of outputs from the callback function
def register_volumes( ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None):
538def register_volumes(ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None):
539    """Register a volume to a reference volume using a registration pyramid.
540
541    Args:
542        ref (numpy.array or cupy.array): Reference volume
543        vol (numpy.array or cupy.array): Volume to be registered
544        recipe (Recipe): Registration recipe
545        reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask)
546        callback (function): Callback function to be called on the volume after each iteration. Default is None.
547            Can be used to monitor and optimize registration. Example: `callback = lambda vol: vol.mean(1).get()`
548            (note that `vol` is a 3D cupy array. Use `.get()` to turn the output into a numpy array and save GPU memory).
549            Callback outputs for each registration step will be returned as a list.
550        verbose (bool): If True, show progress bars. Default is True
551        video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None.
552        vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values.
553
554    Returns:
555        - numpy.array or cupy.array (depending on vol input): Registered volume
556        - WarpMap: Displacement field
557        - list: List of outputs from the callback function
558    """
559    recipe.model_validate(recipe.model_dump())
560    reg = RegistrationPyramid(ref, recipe, reg_mask=reg_mask)
561    registered_vol, warp_map, cbout = reg.register_single(vol, callback=callback, verbose=verbose)
562    del reg
563    gc.collect()
564    cp.fft.config.get_plan_cache().clear()
565
566    if video_path is not None:
567        try:
568            assert cbout[0].ndim == 2, "Callback output must be a 2D array"
569            ref = callback(recipe.pre_filter(ref))
570            vmax = np.percentile(ref, 99.9).item() if vmax is None else vmax
571            create_rgb_video(video_path, ref / vmax, np.array(cbout) / vmax, fps=10)
572        except (ValueError, AssertionError) as e:
573            warnings.warn(f"Video generation failed with error: {e}")
574    return registered_vol, warp_map, cbout

Register a volume to a reference volume using a registration pyramid.

Arguments:
  • ref (numpy.array or cupy.array): Reference volume
  • vol (numpy.array or cupy.array): Volume to be registered
  • recipe (Recipe): Registration recipe
  • reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask)
  • callback (function): Callback function to be called on the volume after each iteration. Default is None. Can be used to monitor and optimize registration. Example: callback = lambda vol: vol.mean(1).get() (note that vol is a 3D cupy array. Use .get() to turn the output into a numpy array and save GPU memory). Callback outputs for each registration step will be returned as a list.
  • verbose (bool): If True, show progress bars. Default is True
  • video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None.
  • vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values.
Returns:
  • numpy.array or cupy.array (depending on vol input): Registered volume
  • WarpMap: Displacement field
  • list: List of outputs from the callback function
class Projector(pydantic.main.BaseModel):
577class Projector(BaseModel):
578    """A class to apply a 2D projection and filters to a volume block
579
580    Parameters:
581        max: if True, apply a max filter to the volume block. Default is True
582        normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False
583        dog: if True, apply a DoG filter to the volume block. Default is True
584        low: the lower sigma value for the DoG filter. Default is 0.5
585        high: the higher sigma value for the DoG filter. Default is 10.0
586        tukey_env: if True, apply a Tukey window to the output. Default is False
587        gauss_env: if True, apply a Gaussian window to the output. Default is False
588    """
589
590    max: bool = True
591    normalize: Union[bool, float] = False
592    dog: bool = True
593    low: Union[Union[int, float], List[Union[int, float]]] = 0.5
594    high: Union[Union[int, float], List[Union[int, float]]] = 10.0
595    periodic_smooth: bool = False
596
597    def __call__(self, vol_blocks, axis):
598        """Apply a 2D projection and filters to a volume block
599        Args:
600            vol_blocks (cupy.array): Blocked volume to be projected (6D dataset, with the first 3 dimensions being blocks and the last 3 dimensions being voxels)
601            axis (int): Axis along which to project
602        Returns:
603            cupy.array: Projected volume block (5D dataset, with the first 3 dimensions being blocks and the last 2 dimensions being 2D projections)
604        """
605        if self.max:
606            out = vol_blocks.max(axis)
607        else:
608            out = vol_blocks.mean(axis)
609        if self.periodic_smooth:
610            out = periodic_smooth_decomposition_nd_rfft(out)
611        low = np.delete(np.r_[1,1,1] * self.low, axis)
612        high = np.delete(np.r_[1,1,1] * self.high, axis)
613        if self.dog:
614            out = dogfilter(out, [0, 0, 0, *low], [0, 0, 0, *high], mode="reflect")
615        elif not np.all(np.array(self.low) == 0):
616            out = cupyx.scipy.ndimage.gaussian_filter(out, [0, 0, 0, *low], mode="reflect", truncate=5.0)
617        if self.normalize > 0:
618            out /= cp.sqrt(cp.sum(out**2, axis=(-2, -1), keepdims=True)) ** self.normalize + 1e-9
619        return out

A class to apply a 2D projection and filters to a volume block

Arguments:
  • max: if True, apply a max filter to the volume block. Default is True
  • normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False
  • dog: if True, apply a DoG filter to the volume block. Default is True
  • low: the lower sigma value for the DoG filter. Default is 0.5
  • high: the higher sigma value for the DoG filter. Default is 10.0
  • tukey_env: if True, apply a Tukey window to the output. Default is False
  • gauss_env: if True, apply a Gaussian window to the output. Default is False
max: bool = True
normalize: bool | float = False
dog: bool = True
low: int | float | List[int | float] = 0.5
high: int | float | List[int | float] = 10.0
periodic_smooth: bool = False
class Smoother(pydantic.main.BaseModel):
622class Smoother(BaseModel):
623    """Smooth blocks with a Gaussian kernel
624    Args:
625        sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied.
626        truncate (float): truncate parameter for gaussian kernel. Default is 5.
627        shear (float): shear parameter for gaussian kernel. Default is None.
628        long_range_ratio (float): long range ratio for double gaussian kernel. Default is None.
629    """
630
631    sigmas: Union[float, List[float]] = [1.0, 1.0, 1.0]
632    shear: Union[float, None] = None
633    long_range_ratio: Union[float, None] = 0.05
634
635    def __call__(self, xcorr_proj, block_size=None):
636        """Apply a Gaussian filter to the cross-correlation data
637        Args:
638            xcorr_proj (cupy.array): cross-correlation data (5D array, with the first 3 dimensions being the blocks and the last 2 dimensions being the 2D projection)
639            block_size (list): shape of blocks, whose rigid displacement is estimated
640        Returns:
641            cupy.array: smoothed cross-correlation volume
642        """
643        truncate = 4.0
644        if self.sigmas is None:
645            return xcorr_proj
646        if self.shear is not None:
647            shear_blocks = self.shear * (block_size[1] / block_size[0])
648            gw = gausskernel_sheared(self.sigma[:2], shear_blocks, truncate=truncate)
649            gw = cp.array(gw[:, :, None, None, None])
650            xcorr_proj = cupyx.scipy.ndimage.convolve(xcorr_proj, gw, mode="constant")
651            xcorr_proj = cupyx.scipy.ndimage.gaussian_filter1d(
652                xcorr_proj, self.sigmas[2], axis=2, mode="constant", truncate=truncate
653            )
654        else:  # shear is None:
655            xcorr_proj = cupyx.scipy.ndimage.gaussian_filter(
656                xcorr_proj, [*self.sigmas, 0, 0], mode="constant", truncate=truncate
657            )
658        if self.long_range_ratio is not None:
659            xcorr_proj *= 1 - self.long_range_ratio
660            xcorr_proj += (
661                cupyx.scipy.ndimage.gaussian_filter(
662                    xcorr_proj, [*np.array(self.sigmas) * 5, 0, 0], mode="constant", truncate=truncate
663                )
664                * self.long_range_ratio
665            )
666        return xcorr_proj

Smooth blocks with a Gaussian kernel

Arguments:
  • sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied.
  • truncate (float): truncate parameter for gaussian kernel. Default is 5.
  • shear (float): shear parameter for gaussian kernel. Default is None.
  • long_range_ratio (float): long range ratio for double gaussian kernel. Default is None.
sigmas: float | List[float] = [1.0, 1.0, 1.0]
shear: float | None = None
long_range_ratio: float | None = 0.05
class RegFilter(pydantic.main.BaseModel):
669class RegFilter(BaseModel):
670    """A class to apply a filter to the volume before registration
671
672    Parameters:
673        clip_thresh: threshold for clipping the reference volume. Default is 0
674        dog: if True, apply a DoG filter to the volume. Default is True
675        low: the lower sigma value for the DoG filter. Default is 0.5
676        high: the higher sigma value for the DoG filter. Default is 10.0
677    """
678
679    clip_thresh: float = 0
680    dog: bool = True
681    low: float = 0.5
682    high: float = 10.0
683    soft_edge: Union[Union[int, float], List[Union[int, float]]] = 0.0
684
685    def __call__(self, vol, reg_mask=None):
686        """Apply the filter to the volume
687        Args:
688            vol (cupy or numpy array): 3D volume to be filtered
689            reg_mask (array): Mask for registration
690        Returns:
691            cupy.ndarray: Filtered volume
692        """
693        vol = cp.clip(cp.array(vol, "float32", copy=False) - self.clip_thresh, 0, None)
694        if np.any(np.array(self.soft_edge) > 0):
695            vol = soften_edges(vol, soft_edge=self.soft_edge, copy=False)
696        if reg_mask is not None:
697            vol *= cp.array(reg_mask, dtype="float32", copy=False)
698        if self.dog:
699            vol = dogfilter(vol, self.low, self.high, mode="reflect")
700        return vol

A class to apply a filter to the volume before registration

Arguments:
  • clip_thresh: threshold for clipping the reference volume. Default is 0
  • dog: if True, apply a DoG filter to the volume. Default is True
  • low: the lower sigma value for the DoG filter. Default is 0.5
  • high: the higher sigma value for the DoG filter. Default is 10.0
clip_thresh: float = 0
dog: bool = True
low: float = 0.5
high: float = 10.0
soft_edge: int | float | List[int | float] = 0.0
class LevelConfig(pydantic.main.BaseModel):
703class LevelConfig(BaseModel):
704    """Configuration for each level of the registration pyramid
705
706    Args:
707        block_size (list): shape of blocks, whose rigid displacement is estimated
708        block_stride (list): stride (usually identical to block_size)
709        repeats (int): number of iterations for this level (deisable level by setting repeats to 0)
710        smooth (Smoother or None): Smoother object
711        project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block.
712        tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5
713        affine (bool): if True, apply affine transformation to the displacement field
714        median_filter (bool): if True, apply median filter to the displacement field
715        update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations.
716    """
717
718    block_size: Union[List[int]]
719    block_stride: Union[List[int], float] = 1.0
720    project: Union[Projector, Callable[[_ArrayType, int], _ArrayType]] = Projector()
721    tukey_ref: Union[float, None] = 0.5
722    smooth: Union[Smoother, None] = Smoother()
723    affine: bool = False
724    median_filter: bool = True
725    update_rate: float = 1.0
726    repeats: int = 5

Configuration for each level of the registration pyramid

Arguments:
  • block_size (list): shape of blocks, whose rigid displacement is estimated
  • block_stride (list): stride (usually identical to block_size)
  • repeats (int): number of iterations for this level (deisable level by setting repeats to 0)
  • smooth (Smoother or None): Smoother object
  • project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block.
  • tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5
  • affine (bool): if True, apply affine transformation to the displacement field
  • median_filter (bool): if True, apply median filter to the displacement field
  • update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations.
block_size: List[int] = PydanticUndefined
block_stride: List[int] | float = 1.0
project: Projector | Callable[[numpy.ndarray | cupy.ndarray, int], numpy.ndarray | cupy.ndarray] = Projector(max=True, normalize=False, dog=True, low=0.5, high=10.0, periodic_smooth=False)
tukey_ref: float | None = 0.5
smooth: Smoother | None = Smoother(sigmas=[1.0, 1.0, 1.0], shear=None, long_range_ratio=0.05)
affine: bool = False
median_filter: bool = True
update_rate: float = 1.0
repeats: int = 5
class Recipe(pydantic.main.BaseModel):
729class Recipe(BaseModel):
730    """Configuration for the registration recipe. Recipe is initialized with a single affine level.
731
732    Args:
733        reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume
734        levels (list): List of LevelConfig objects
735    """
736
737    pre_filter: Union[RegFilter, Callable[[_ArrayType], _ArrayType], None] = RegFilter()
738    levels: List[LevelConfig] = [
739        LevelConfig(block_size=[-1, -1, -1], repeats=3),  # translation level
740        LevelConfig(  # affine level
741            block_size=[-2, -2, -2],
742            block_stride=0.5,
743            repeats=10,
744            affine=True,
745            median_filter=False,
746            smooth=Smoother(sigmas=[0.5, 0.5, 0.5]),
747        ),
748    ]
749
750    def add_level(self, block_size, **kwargs):
751        """Add a level to the registration recipe
752
753        Args:
754            block_size (list): shape of blocks, whose rigid displacement is estimated
755            **kwargs: additional arguments for LevelConfig
756        """
757        if isinstance(block_size, (int, float)):
758            block_size = [block_size] * 3
759        if len(block_size) != 3:
760            raise ValueError("block_size must be a list of 3 integers")
761        self.levels.append(LevelConfig(block_size=block_size, **kwargs))
762
763    def insert_level(self, index, block_size, **kwargs):
764        """Insert a level to the registration recipe
765
766        Args:
767            index (int): A number specifying in which position to insert the level
768            block_size (list): shape of blocks, whose rigid displacement is estimated
769            **kwargs: additional arguments for LevelConfig
770        """
771        if isinstance(block_size, (int, float)):
772            block_size = [block_size] * 3
773        if len(block_size) != 3:
774            raise ValueError("block_size must be a list of 3 integers")
775        self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs))
776
777    @classmethod
778    def from_yaml(cls, yaml_path):
779        """Load a recipe from a YAML file
780
781        Args:
782            yaml_path (str): path to the YAML file
783
784        Returns:
785            Recipe: Recipe object
786        """
787        import yaml
788
789        this_file_dir = pathlib.Path(__file__).resolve().parent
790        if os.path.isfile(yaml_path):
791            yaml_path = yaml_path
792        else:
793            yaml_path = os.path.join(this_file_dir, "recipes", yaml_path)
794
795        with open(yaml_path, "r") as f:
796            data = yaml.safe_load(f)
797
798        return cls.model_validate(data)
799
800    def to_yaml(self, yaml_path):
801        """Save the recipe to a YAML file
802
803        Args:
804            yaml_path (str): path to the YAML file
805        """
806        import yaml
807
808        with open(yaml_path, "w") as f:
809            yaml.dump(self.model_dump(), f)
810        print(f"Recipe saved to {yaml_path}")

Configuration for the registration recipe. Recipe is initialized with a single affine level.

Arguments:
  • reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume
  • levels (list): List of LevelConfig objects
pre_filter: RegFilter | Callable[[numpy.ndarray | cupy.ndarray], numpy.ndarray | cupy.ndarray] | None = RegFilter(clip_thresh=0, dog=True, low=0.5, high=10.0, soft_edge=0.0)
levels: List[LevelConfig] = [LevelConfig(block_size=[-1, -1, -1], block_stride=1.0, project=Projector(max=True, normalize=False, dog=True, low=0.5, high=10.0, periodic_smooth=False), tukey_ref=0.5, smooth=Smoother(sigmas=[1.0, 1.0, 1.0], shear=None, long_range_ratio=0.05), affine=False, median_filter=True, update_rate=1.0, repeats=3), LevelConfig(block_size=[-2, -2, -2], block_stride=0.5, project=Projector(max=True, normalize=False, dog=True, low=0.5, high=10.0, periodic_smooth=False), tukey_ref=0.5, smooth=Smoother(sigmas=[0.5, 0.5, 0.5], shear=None, long_range_ratio=0.05), affine=True, median_filter=False, update_rate=1.0, repeats=10)]
def add_level(self, block_size, **kwargs):
750    def add_level(self, block_size, **kwargs):
751        """Add a level to the registration recipe
752
753        Args:
754            block_size (list): shape of blocks, whose rigid displacement is estimated
755            **kwargs: additional arguments for LevelConfig
756        """
757        if isinstance(block_size, (int, float)):
758            block_size = [block_size] * 3
759        if len(block_size) != 3:
760            raise ValueError("block_size must be a list of 3 integers")
761        self.levels.append(LevelConfig(block_size=block_size, **kwargs))

Add a level to the registration recipe

Arguments:
  • block_size (list): shape of blocks, whose rigid displacement is estimated
  • **kwargs: additional arguments for LevelConfig
def insert_level(self, index, block_size, **kwargs):
763    def insert_level(self, index, block_size, **kwargs):
764        """Insert a level to the registration recipe
765
766        Args:
767            index (int): A number specifying in which position to insert the level
768            block_size (list): shape of blocks, whose rigid displacement is estimated
769            **kwargs: additional arguments for LevelConfig
770        """
771        if isinstance(block_size, (int, float)):
772            block_size = [block_size] * 3
773        if len(block_size) != 3:
774            raise ValueError("block_size must be a list of 3 integers")
775        self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs))

Insert a level to the registration recipe

Arguments:
  • index (int): A number specifying in which position to insert the level
  • block_size (list): shape of blocks, whose rigid displacement is estimated
  • **kwargs: additional arguments for LevelConfig
@classmethod
def from_yaml(cls, yaml_path):
777    @classmethod
778    def from_yaml(cls, yaml_path):
779        """Load a recipe from a YAML file
780
781        Args:
782            yaml_path (str): path to the YAML file
783
784        Returns:
785            Recipe: Recipe object
786        """
787        import yaml
788
789        this_file_dir = pathlib.Path(__file__).resolve().parent
790        if os.path.isfile(yaml_path):
791            yaml_path = yaml_path
792        else:
793            yaml_path = os.path.join(this_file_dir, "recipes", yaml_path)
794
795        with open(yaml_path, "r") as f:
796            data = yaml.safe_load(f)
797
798        return cls.model_validate(data)

Load a recipe from a YAML file

Arguments:
  • yaml_path (str): path to the YAML file
Returns:

Recipe: Recipe object

def to_yaml(self, yaml_path):
800    def to_yaml(self, yaml_path):
801        """Save the recipe to a YAML file
802
803        Args:
804            yaml_path (str): path to the YAML file
805        """
806        import yaml
807
808        with open(yaml_path, "w") as f:
809            yaml.dump(self.model_dump(), f)
810        print(f"Recipe saved to {yaml_path}")

Save the recipe to a YAML file

Arguments:
  • yaml_path (str): path to the YAML file