warpfield.register
1import warnings 2import gc 3import pathlib 4import os 5from typing import List, Union, Callable 6 7import numpy as np 8import scipy.signal 9import cupy as cp 10import cupyx 11import cupyx.scipy.ndimage 12from pydantic import BaseModel, ValidationError 13from tqdm.auto import tqdm 14 15from .warp import warp_volume 16from .utils import create_rgb_video, mips_callback 17from .ndimage import ( 18 accumarray, 19 dogfilter, 20 gausskernel_sheared, 21 infill_nans, 22 ndwindow, 23 periodic_smooth_decomposition_nd_rfft, 24 sliding_block, 25 upsampled_dft_rfftn, 26 soften_edges, 27) 28 29_ArrayType = Union[np.ndarray, cp.ndarray] 30 31class WarpMap: 32 """Represents a 3D displacement field 33 34 Args: 35 warp_field (numpy.array): the displacement field data (3-x-y-z) 36 block_size (3-element list or numpy.array): 37 block_stride (3-element list or numpy.array): 38 ref_shape (tuple): shape of the reference volume 39 mov_shape (tuple): shape of the moving volume 40 """ 41 42 def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape): 43 self.warp_field = cp.array(warp_field, dtype="float32") 44 self.block_size = cp.array(block_size, dtype="float32") 45 self.block_stride = cp.array(block_stride, dtype="float32") 46 self.ref_shape = ref_shape 47 self.mov_shape = mov_shape 48 49 def warp(self, vol, out=None): 50 """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space. 51 52 Args: 53 vol (cupy.array): the volume to be warped 54 55 Returns: 56 cupy.array: warped volume 57 """ 58 if np.any(vol.shape != np.array(self.mov_shape)): 59 warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.") 60 if out is None: 61 out = cp.zeros(self.ref_shape, dtype="float32", order="C") 62 vol_out = warp_volume( 63 vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out 64 ) 65 return vol_out 66 67 def apply(self, *args, **kwargs): 68 """Alias of warp method""" 69 return self.warp(*args, **kwargs) 70 71 def fit_affine(self, target=None): 72 """Fit affine transformation and return new fitted WarpMap 73 74 Args: 75 target (dict): dict with keys "blocks_shape", "block_size", and "block_stride" 76 77 Returns: 78 WarpMap: 79 numpy.array: affine tranformation coefficients 80 """ 81 if target is None: 82 warp_field_shape = self.warp_field.shape 83 block_size = self.block_size 84 block_stride = self.block_stride 85 else: 86 warp_field_shape = target["warp_field_shape"] 87 block_size = cp.array(target["block_size"]).astype("float32") 88 block_stride = cp.array(target["block_stride"]).astype("float32") 89 90 ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T 91 ix = ix * self.block_stride + self.block_size / 2 92 M = cp.zeros(self.warp_field.shape[1:]) 93 #M[1:-1, 1:-1, 1:-1] = 1 94 M[:,:,:] = 1 95 ixg = cp.where(M.flatten() > 0)[0] 96 a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))]) 97 b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg] 98 coeff = cp.linalg.lstsq(a, b, rcond=None)[0] 99 ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2 100 linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape) 101 return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff 102 103 def median_filter(self): 104 """Apply median filter to the displacement field 105 106 Returns: 107 WarpMap: new WarpMap with median filtered displacement field 108 """ 109 warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest") 110 return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape) 111 112 def resize_to(self, target): 113 """Resize to target WarpMap, using linear interpolation 114 115 Args: 116 target (WarpMap or WarpMapper): target to resize to 117 or a dict with keys "shape", "block_size", and "block_stride" 118 119 Returns: 120 WarpMap: resized WarpMap 121 """ 122 if isinstance(target, WarpMap): 123 t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride 124 elif isinstance(target, WarpMapper): 125 t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride) 126 elif isinstance(target, dict): 127 t_sh, t_bsz, t_bst = ( 128 target["warp_field_shape"][1:], 129 cp.array(target["block_size"]), 130 cp.array(target["block_stride"]), 131 ) 132 else: 133 raise ValueError("target must be a WarpMap, WarpMapper, or dict") 134 ix = cp.array(cp.indices(t_sh).reshape(3, -1)) 135 # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5 136 ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None] 137 dm_r = cp.array( 138 [ 139 cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape( 140 t_sh 141 ) 142 for i in range(3) 143 ] 144 ) 145 return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape) 146 147 def chain(self, target): 148 """Chain displacement maps 149 150 Args: 151 target (WarpMap): WarpMap to be added to existing map 152 153 Returns: 154 WarpMap: new WarpMap with chained displacement field 155 """ 156 indices = cp.indices(target.warp_field.shape[1:]) 157 warp_field = self.warp_field.copy() 158 warp_field += target.warp_field 159 return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape) 160 161 def invert(self, **kwargs): 162 """alias for invert_fast method""" 163 return self.invert_fast(**kwargs) 164 165 def invert_fast(self, sigma=0.5, truncate=20): 166 """Invert the displacement field using accumulation and Gaussian basis interpolation. 167 168 Args: 169 sigma (float): standard deviation for Gaussian basis interpolation 170 truncate (float): truncate parameter for Gaussian basis interpolation 171 172 Returns: 173 WarpMap: inverted WarpMap 174 """ 175 warp_field = self.warp_field.get() 176 target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get() 177 wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int") 178 num_coords = accumarray(target_coords, wf_shape) 179 inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype) 180 for i in range(3): 181 inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel()) 182 with np.errstate(invalid="ignore"): 183 inv_field[i] /= num_coords 184 inv_field[i][num_coords == 0] = np.nan 185 inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate) 186 return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape) 187 188 def push_coordinates(self, coords, negative_shifts=False): 189 """Push voxel coordinates from fixed to moving space. 190 191 Args: 192 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 193 194 Returns: 195 numpy.array: transformed voxel coordinates 196 """ 197 assert coords.shape[0] == 3 198 coords = cp.array(coords, dtype="float32") 199 # coords_blocked = coords / self.block_size[:, None] - 0.5 200 coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None] 201 warp_field = self.warp_field.copy() 202 shifts = cp.zeros_like(coords) 203 for idim in range(3): 204 shifts[idim] = cupyx.scipy.ndimage.map_coordinates( 205 warp_field[idim], coords_blocked, order=1, mode="nearest" 206 ) 207 if negative_shifts: 208 shifts = -shifts 209 return coords + shifts 210 211 def pull_coordinates(self, coords): 212 """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates. 213 214 Args: 215 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 216 217 Returns: 218 numpy.array: transformed voxel coordinates 219 """ 220 return self.invert().push_coordinates(coords, negative_shifts=True) 221 222 def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1): 223 """ 224 Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid. 225 226 Args: 227 edge_order : passed to np.gradient (1 or 2) 228 229 Returns: 230 detJ: cp.ndarray of shape spatial 231 """ 232 scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride 233 coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None] 234 phi = coords + self.warp_field 235 J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32") 236 for i in range(3): 237 grads = cp.gradient(phi[i], *scaling, edge_order=edge_order) 238 for j in range(3): 239 J[..., i, j] = grads[j] 240 return cp.linalg.det(J) 241 242 def as_ants_image(self, voxel_size_um=1): 243 """Convert to ANTsImage 244 245 Args: 246 voxel_size_um (scalar or array): voxel size (default is 1) 247 248 Returns: 249 ants.core.ants_image.ANTsImage: 250 """ 251 try: 252 import ants 253 except ImportError: 254 raise ImportError("ANTs is not installed. Please install it using 'pip install ants'") 255 256 ants_image = ants.from_numpy( 257 self.warp_field.get().transpose(1, 2, 3, 0), 258 origin=list((self.block_size.get() - 1) / 2 * voxel_size_um), 259 spacing=list(self.block_stride.get() * voxel_size_um), 260 has_components=True, 261 ) 262 return ants_image 263 264 def __repr__(self): 265 """String representation of the WarpMap object.""" 266 info = ( 267 f"WarpMap(" 268 f"warp_field_shape={self.warp_field.shape}, " 269 f"block_size={self.block_size.get()}, " 270 f"block_stride={self.block_stride.get()}, " 271 f"transformation: {str(self.mov_shape)} --> {str(self.ref_shape)}" 272 ) 273 return info 274 275 276class WarpMapper: 277 """Class that estimates warp field using cross-correlation, based on a piece-wise rigid model. 278 279 Args: 280 ref_vol (numpy.array): The reference volume 281 block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated 282 block_stride (3-element list or numpy.array): stride (usually identical to block_size) 283 proj_method (str or callable): Projection method 284 """ 285 286 def __init__( 287 self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5 288 ): 289 self.proj_method = proj_method 290 self.plan_rev = [None, None, None] 291 self.subpixel = subpixel 292 self.epsilon = epsilon 293 self.tukey_alpha = tukey_alpha 294 self.update_reference(ref_vol, block_size, block_stride) 295 self.ref_shape = np.array(ref_vol.shape) 296 if np.any(block_size > np.array(ref_vol.shape)): 297 raise ValueError(f"Block size ({block_size}) must be smaller than the volume shape ({ref_vol.shape})") 298 299 def update_reference(self, ref_vol, block_size, block_stride=None): 300 ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1)) 301 block_size = np.array(block_size) 302 block_stride = block_size if block_stride is None else np.array(block_stride) 303 ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride) 304 self.blocks_shape = ref_blocks.shape 305 ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]] 306 if self.tukey_alpha < 1: 307 ref_blocks_proj = [ 308 ref_blocks_proj[i] 309 * cp.array( 310 ndwindow( 311 [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5) 312 ) 313 ).astype("float32") 314 for i in range(3) 315 ] 316 self.plan_fwd = [ 317 cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3) 318 ] 319 self.ref_blocks_proj_ft_conj = [ 320 cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3) 321 ] 322 self.block_size = block_size 323 self.block_stride = block_stride 324 325 def get_displacement(self, vol, smooth_func=None): 326 """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks. 327 328 Args: 329 vol (numpy.array): Input volume 330 smooth_func (callable): Smoothing function to be applied to the cross-correlation volume 331 332 Returns: 333 WarpMap 334 """ 335 vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride) 336 vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]] 337 del vol_blocks 338 339 disp_field = [] 340 for i in range(3): 341 R = ( 342 cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]) 343 * self.ref_blocks_proj_ft_conj[i] 344 ) 345 if self.plan_rev[i] is None: 346 self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R") 347 xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1)) 348 if smooth_func is not None: 349 xcorr_proj = smooth_func(xcorr_proj, self.block_size) 350 xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon 351 352 max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:])) 353 max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2 354 del xcorr_proj 355 i0, j0 = max_ix.reshape(2, -1) 356 shifts = upsampled_dft_rfftn( 357 R.reshape(-1, *R.shape[-2:]), 358 upsampled_region_size=int(self.subpixel * 2 + 1), 359 upsample_factor=self.subpixel, 360 axis_offsets=(i0, j0), 361 ) 362 del R 363 max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:])) 364 max_sub = ( 365 max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2 366 ) / self.subpixel 367 del shifts 368 disp_field.append(max_ix + max_sub) 369 370 disp_field = cp.array(disp_field) 371 disp_field = ( 372 cp.array( 373 [ 374 disp_field[1, 0] + disp_field[2, 0], 375 disp_field[0, 0] + disp_field[2, 1], 376 disp_field[0, 1] + disp_field[1, 1], 377 ] 378 ).astype("float32") 379 / 2 380 ) 381 return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape) 382 383 384class RegistrationPyramid: 385 """A class for performing multi-resolution registration. 386 387 Args: 388 ref_vol (numpy.array): Reference volume 389 settings (pandas.DataFrame): Settings for each level of the pyramid. 390 IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level. 391 reg_mask (numpy.array): Mask for registration 392 clip_thresh (float): Threshold for clipping the reference volume 393 """ 394 395 def __init__(self, ref_vol, recipe, reg_mask=1): 396 recipe.model_validate(recipe.model_dump()) 397 self.recipe = recipe 398 self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C") 399 self.mappers = [] 400 ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C") 401 self.ref_shape = ref_vol.shape 402 if self.recipe.pre_filter is not None: 403 ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask) 404 self.mapper_ix = [] 405 for i in range(len(recipe.levels)): 406 if recipe.levels[i].repeats < 1: 407 continue 408 block_size = np.array(recipe.levels[i].block_size) 409 tmp = np.r_[ref_vol.shape] // -block_size 410 block_size[block_size < 0] = tmp[block_size < 0] 411 if isinstance(recipe.levels[i].block_stride, (int, float)): 412 block_stride = np.round(block_size * recipe.levels[i].block_stride).astype("int") 413 else: 414 block_stride = np.array(recipe.levels[i].block_stride) 415 self.mappers.append( 416 WarpMapper( 417 ref_vol, 418 block_size, 419 block_stride=block_stride, 420 proj_method=recipe.levels[i].project, 421 tukey_alpha=recipe.levels[i].tukey_ref, 422 ) 423 ) 424 self.mapper_ix.append(i) 425 assert len(self.mappers) > 0, "At least one level of registration is required" 426 427 def register_single(self, vol, callback=None, verbose=False): 428 """Register a single volume to the reference volume. 429 430 Args: 431 vol (array_like): Volume to be registered (numpy or cupy array) 432 callback (function): Callback function to be called after each level of registration 433 434 Returns: 435 - vol (array_like): Registered volume (numpy or cupy array, depending on input) 436 - warp_map (WarpMap): Displacement field 437 - callback_output (list): List of outputs from the callback function 438 """ 439 was_numpy = isinstance(vol, np.ndarray) 440 vol = cp.array(vol, "float32", copy=False, order="C") 441 offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2 442 warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape) 443 warp_map = warp_map.resize_to(self.mappers[-1]) 444 callback_output = [] 445 vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol 446 vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C") 447 warp_map.warp(vol_tmp0, out=vol_tmp) 448 min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0) 449 if callback is not None: 450 callback_output.append(callback(vol_tmp)) 451 452 if np.any(self.mappers[-1].block_stride > min_block_stride[0]): 453 warnings.warn( 454 "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)." 455 ) 456 for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)): 457 for _ in tqdm( 458 range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose 459 ): 460 wm = mapper.get_displacement( 461 vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth # * self.reg_mask, 462 ) 463 wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate 464 if self.recipe.levels[self.mapper_ix[k]].median_filter: 465 wm = wm.median_filter() 466 if self.recipe.levels[self.mapper_ix[k]].affine: 467 if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1: 468 raise ValueError( 469 f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}" 470 ) 471 wm, _ = wm.fit_affine( 472 target=dict( 473 warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]), 474 block_size=self.mappers[-1].block_size, 475 block_stride=self.mappers[-1].block_stride, 476 ) 477 ) 478 else: 479 wm = wm.resize_to(self.mappers[-1]) 480 481 warp_map = warp_map.chain(wm) 482 warp_map.warp(vol_tmp0, out=vol_tmp) 483 if callback is not None: 484 # callback_output.append(callback(warp_map.unwarp(vol))) 485 callback_output.append(callback(vol_tmp)) 486 warp_map.warp(vol, out=vol_tmp) 487 if was_numpy: 488 vol_tmp = vol_tmp.get() 489 return vol_tmp, warp_map, callback_output 490 491 492def register_volumes(ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None): 493 """Register a volume to a reference volume using a registration pyramid. 494 495 Args: 496 ref (numpy.array or cupy.array): Reference volume 497 vol (numpy.array or cupy.array): Volume to be registered 498 recipe (Recipe): Registration recipe 499 reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask) 500 callback (function): Callback function to be called on the volume after each iteration. Default is None. 501 Can be used to monitor and optimize registration. Example: `callback = lambda vol: vol.mean(1).get()` 502 (note that `vol` is a 3D cupy array. Use `.get()` to turn the output into a numpy array and save GPU memory). 503 Callback outputs for each registration step will be returned as a list. 504 verbose (bool): If True, show progress bars. Default is True 505 video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None. 506 vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values. 507 508 Returns: 509 - numpy.array or cupy.array (depending on vol input): Registered volume 510 - WarpMap: Displacement field 511 - list: List of outputs from the callback function 512 """ 513 recipe.model_validate(recipe.model_dump()) 514 reg = RegistrationPyramid(ref, recipe, reg_mask=reg_mask) 515 registered_vol, warp_map, cbout = reg.register_single(vol, callback=callback, verbose=verbose) 516 del reg 517 gc.collect() 518 cp.fft.config.get_plan_cache().clear() 519 520 if video_path is not None: 521 try: 522 assert cbout[0].ndim == 2, "Callback output must be a 2D array" 523 ref = callback(recipe.pre_filter(ref)) 524 vmax = np.percentile(ref, 99.9).item() if vmax is None else vmax 525 create_rgb_video(video_path, ref / vmax, np.array(cbout) / vmax, fps=10) 526 except (ValueError, AssertionError) as e: 527 warnings.warn(f"Video generation failed with error: {e}") 528 return registered_vol, warp_map, cbout 529 530 531class Projector(BaseModel): 532 """A class to apply a 2D projection and filters to a volume block 533 534 Parameters: 535 max: if True, apply a max filter to the volume block. Default is True 536 normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False 537 dog: if True, apply a DoG filter to the volume block. Default is True 538 low: the lower sigma value for the DoG filter. Default is 0.5 539 high: the higher sigma value for the DoG filter. Default is 10.0 540 tukey_env: if True, apply a Tukey window to the output. Default is False 541 gauss_env: if True, apply a Gaussian window to the output. Default is False 542 """ 543 544 max: bool = True 545 normalize: Union[bool, float] = False 546 dog: bool = True 547 low: Union[Union[int, float], List[Union[int, float]]] = 0.5 548 high: Union[Union[int, float], List[Union[int, float]]] = 10.0 549 periodic_smooth: bool = False 550 551 def __call__(self, vol_blocks, axis): 552 """Apply a 2D projection and filters to a volume block 553 Args: 554 vol_blocks (cupy.array): Blocked volume to be projected (6D dataset, with the first 3 dimensions being blocks and the last 3 dimensions being voxels) 555 axis (int): Axis along which to project 556 Returns: 557 cupy.array: Projected volume block (5D dataset, with the first 3 dimensions being blocks and the last 2 dimensions being 2D projections) 558 """ 559 if self.max: 560 out = vol_blocks.max(axis) 561 else: 562 out = vol_blocks.mean(axis) 563 if self.periodic_smooth: 564 out = periodic_smooth_decomposition_nd_rfft(out) 565 low = np.delete(np.r_[1,1,1] * self.low, axis) 566 high = np.delete(np.r_[1,1,1] * self.high, axis) 567 if self.dog: 568 out = dogfilter(out, [0, 0, 0, *low], [0, 0, 0, *high], mode="reflect") 569 elif not np.all(np.array(self.low) == 0): 570 out = cupyx.scipy.ndimage.gaussian_filter(out, [0, 0, 0, *low], mode="reflect", truncate=5.0) 571 if self.normalize > 0: 572 out /= cp.sqrt(cp.sum(out**2, axis=(-2, -1), keepdims=True)) ** self.normalize + 1e-9 573 return out 574 575 576class Smoother(BaseModel): 577 """Smooth blocks with a Gaussian kernel 578 Args: 579 sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied. 580 truncate (float): truncate parameter for gaussian kernel. Default is 5. 581 shear (float): shear parameter for gaussian kernel. Default is None. 582 long_range_ratio (float): long range ratio for double gaussian kernel. Default is None. 583 """ 584 585 sigmas: Union[float, List[float]] = [1.0, 1.0, 1.0] 586 shear: Union[float, None] = None 587 long_range_ratio: Union[float, None] = 0.05 588 589 def __call__(self, xcorr_proj, block_size=None): 590 """Apply a Gaussian filter to the cross-correlation data 591 Args: 592 xcorr_proj (cupy.array): cross-correlation data (5D array, with the first 3 dimensions being the blocks and the last 2 dimensions being the 2D projection) 593 block_size (list): shape of blocks, whose rigid displacement is estimated 594 Returns: 595 cupy.array: smoothed cross-correlation volume 596 """ 597 truncate = 4.0 598 if self.sigmas is None: 599 return xcorr_proj 600 if self.shear is not None: 601 shear_blocks = self.shear * (block_size[1] / block_size[0]) 602 gw = gausskernel_sheared(self.sigma[:2], shear_blocks, truncate=truncate) 603 gw = cp.array(gw[:, :, None, None, None]) 604 xcorr_proj = cupyx.scipy.ndimage.convolve(xcorr_proj, gw, mode="constant") 605 xcorr_proj = cupyx.scipy.ndimage.gaussian_filter1d( 606 xcorr_proj, self.sigmas[2], axis=2, mode="constant", truncate=truncate 607 ) 608 else: # shear is None: 609 xcorr_proj = cupyx.scipy.ndimage.gaussian_filter( 610 xcorr_proj, [*self.sigmas, 0, 0], mode="constant", truncate=truncate 611 ) 612 if self.long_range_ratio is not None: 613 xcorr_proj *= 1 - self.long_range_ratio 614 xcorr_proj += ( 615 cupyx.scipy.ndimage.gaussian_filter( 616 xcorr_proj, [*np.array(self.sigmas) * 5, 0, 0], mode="constant", truncate=truncate 617 ) 618 * self.long_range_ratio 619 ) 620 return xcorr_proj 621 622 623class RegFilter(BaseModel): 624 """A class to apply a filter to the volume before registration 625 626 Parameters: 627 clip_thresh: threshold for clipping the reference volume. Default is 0 628 dog: if True, apply a DoG filter to the volume. Default is True 629 low: the lower sigma value for the DoG filter. Default is 0.5 630 high: the higher sigma value for the DoG filter. Default is 10.0 631 """ 632 633 clip_thresh: float = 0 634 dog: bool = True 635 low: float = 0.5 636 high: float = 10.0 637 soft_edge: Union[Union[int, float], List[Union[int, float]]] = 0.0 638 639 def __call__(self, vol, reg_mask=None): 640 """Apply the filter to the volume 641 Args: 642 vol (cupy or numpy array): 3D volume to be filtered 643 reg_mask (array): Mask for registration 644 Returns: 645 cupy.ndarray: Filtered volume 646 """ 647 vol = cp.clip(cp.array(vol, "float32", copy=False) - self.clip_thresh, 0, None) 648 if np.any(np.array(self.soft_edge) > 0): 649 vol = soften_edges(vol, soft_edge=self.soft_edge, copy=False) 650 if reg_mask is not None: 651 vol *= cp.array(reg_mask, dtype="float32", copy=False) 652 if self.dog: 653 vol = dogfilter(vol, self.low, self.high, mode="reflect") 654 return vol 655 656 657class LevelConfig(BaseModel): 658 """Configuration for each level of the registration pyramid 659 660 Args: 661 block_size (list): shape of blocks, whose rigid displacement is estimated 662 block_stride (list): stride (usually identical to block_size) 663 repeats (int): number of iterations for this level (deisable level by setting repeats to 0) 664 smooth (Smoother or None): Smoother object 665 project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block. 666 tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5 667 affine (bool): if True, apply affine transformation to the displacement field 668 median_filter (bool): if True, apply median filter to the displacement field 669 update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations. 670 """ 671 672 block_size: Union[List[int]] 673 block_stride: Union[List[int], float] = 1.0 674 project: Union[Projector, Callable[[_ArrayType, int], _ArrayType]] = Projector() 675 tukey_ref: Union[float, None] = 0.5 676 smooth: Union[Smoother, None] = Smoother() 677 affine: bool = False 678 median_filter: bool = True 679 update_rate: float = 1.0 680 repeats: int = 5 681 682 683class Recipe(BaseModel): 684 """Configuration for the registration recipe. Recipe is initialized with a single affine level. 685 686 Args: 687 reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume 688 levels (list): List of LevelConfig objects 689 """ 690 691 pre_filter: Union[RegFilter, Callable[[_ArrayType], _ArrayType], None] = RegFilter() 692 levels: List[LevelConfig] = [ 693 LevelConfig(block_size=[-1, -1, -1], repeats=1), # translation level 694 LevelConfig( # affine level 695 block_size=[-4, -4, -4], 696 repeats=10, 697 affine=True, 698 median_filter=False, 699 smooth=Smoother(sigmas=[0.5, 0.5, 0.5]), 700 ), 701 ] 702 703 def add_level(self, block_size, **kwargs): 704 """Add a level to the registration recipe 705 706 Args: 707 block_size (list): shape of blocks, whose rigid displacement is estimated 708 **kwargs: additional arguments for LevelConfig 709 """ 710 if isinstance(block_size, (int, float)): 711 block_size = [block_size] * 3 712 if len(block_size) != 3: 713 raise ValueError("block_size must be a list of 3 integers") 714 self.levels.append(LevelConfig(block_size=block_size, **kwargs)) 715 716 def insert_level(self, index, block_size, **kwargs): 717 """Insert a level to the registration recipe 718 719 Args: 720 index (int): A number specifying in which position to insert the level 721 block_size (list): shape of blocks, whose rigid displacement is estimated 722 **kwargs: additional arguments for LevelConfig 723 """ 724 if isinstance(block_size, (int, float)): 725 block_size = [block_size] * 3 726 if len(block_size) != 3: 727 raise ValueError("block_size must be a list of 3 integers") 728 self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs)) 729 730 @classmethod 731 def from_yaml(cls, yaml_path): 732 """Load a recipe from a YAML file 733 734 Args: 735 yaml_path (str): path to the YAML file 736 737 Returns: 738 Recipe: Recipe object 739 """ 740 import yaml 741 742 this_file_dir = pathlib.Path(__file__).resolve().parent 743 if os.path.isfile(yaml_path): 744 yaml_path = yaml_path 745 else: 746 yaml_path = os.path.join(this_file_dir, "recipes", yaml_path) 747 748 with open(yaml_path, "r") as f: 749 data = yaml.safe_load(f) 750 751 return cls.model_validate(data) 752 753 def to_yaml(self, yaml_path): 754 """Save the recipe to a YAML file 755 756 Args: 757 yaml_path (str): path to the YAML file 758 """ 759 import yaml 760 761 with open(yaml_path, "w") as f: 762 yaml.dump(self.model_dump(), f) 763 print(f"Recipe saved to {yaml_path}")
32class WarpMap: 33 """Represents a 3D displacement field 34 35 Args: 36 warp_field (numpy.array): the displacement field data (3-x-y-z) 37 block_size (3-element list or numpy.array): 38 block_stride (3-element list or numpy.array): 39 ref_shape (tuple): shape of the reference volume 40 mov_shape (tuple): shape of the moving volume 41 """ 42 43 def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape): 44 self.warp_field = cp.array(warp_field, dtype="float32") 45 self.block_size = cp.array(block_size, dtype="float32") 46 self.block_stride = cp.array(block_stride, dtype="float32") 47 self.ref_shape = ref_shape 48 self.mov_shape = mov_shape 49 50 def warp(self, vol, out=None): 51 """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space. 52 53 Args: 54 vol (cupy.array): the volume to be warped 55 56 Returns: 57 cupy.array: warped volume 58 """ 59 if np.any(vol.shape != np.array(self.mov_shape)): 60 warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.") 61 if out is None: 62 out = cp.zeros(self.ref_shape, dtype="float32", order="C") 63 vol_out = warp_volume( 64 vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out 65 ) 66 return vol_out 67 68 def apply(self, *args, **kwargs): 69 """Alias of warp method""" 70 return self.warp(*args, **kwargs) 71 72 def fit_affine(self, target=None): 73 """Fit affine transformation and return new fitted WarpMap 74 75 Args: 76 target (dict): dict with keys "blocks_shape", "block_size", and "block_stride" 77 78 Returns: 79 WarpMap: 80 numpy.array: affine tranformation coefficients 81 """ 82 if target is None: 83 warp_field_shape = self.warp_field.shape 84 block_size = self.block_size 85 block_stride = self.block_stride 86 else: 87 warp_field_shape = target["warp_field_shape"] 88 block_size = cp.array(target["block_size"]).astype("float32") 89 block_stride = cp.array(target["block_stride"]).astype("float32") 90 91 ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T 92 ix = ix * self.block_stride + self.block_size / 2 93 M = cp.zeros(self.warp_field.shape[1:]) 94 #M[1:-1, 1:-1, 1:-1] = 1 95 M[:,:,:] = 1 96 ixg = cp.where(M.flatten() > 0)[0] 97 a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))]) 98 b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg] 99 coeff = cp.linalg.lstsq(a, b, rcond=None)[0] 100 ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2 101 linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape) 102 return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff 103 104 def median_filter(self): 105 """Apply median filter to the displacement field 106 107 Returns: 108 WarpMap: new WarpMap with median filtered displacement field 109 """ 110 warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest") 111 return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape) 112 113 def resize_to(self, target): 114 """Resize to target WarpMap, using linear interpolation 115 116 Args: 117 target (WarpMap or WarpMapper): target to resize to 118 or a dict with keys "shape", "block_size", and "block_stride" 119 120 Returns: 121 WarpMap: resized WarpMap 122 """ 123 if isinstance(target, WarpMap): 124 t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride 125 elif isinstance(target, WarpMapper): 126 t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride) 127 elif isinstance(target, dict): 128 t_sh, t_bsz, t_bst = ( 129 target["warp_field_shape"][1:], 130 cp.array(target["block_size"]), 131 cp.array(target["block_stride"]), 132 ) 133 else: 134 raise ValueError("target must be a WarpMap, WarpMapper, or dict") 135 ix = cp.array(cp.indices(t_sh).reshape(3, -1)) 136 # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5 137 ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None] 138 dm_r = cp.array( 139 [ 140 cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape( 141 t_sh 142 ) 143 for i in range(3) 144 ] 145 ) 146 return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape) 147 148 def chain(self, target): 149 """Chain displacement maps 150 151 Args: 152 target (WarpMap): WarpMap to be added to existing map 153 154 Returns: 155 WarpMap: new WarpMap with chained displacement field 156 """ 157 indices = cp.indices(target.warp_field.shape[1:]) 158 warp_field = self.warp_field.copy() 159 warp_field += target.warp_field 160 return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape) 161 162 def invert(self, **kwargs): 163 """alias for invert_fast method""" 164 return self.invert_fast(**kwargs) 165 166 def invert_fast(self, sigma=0.5, truncate=20): 167 """Invert the displacement field using accumulation and Gaussian basis interpolation. 168 169 Args: 170 sigma (float): standard deviation for Gaussian basis interpolation 171 truncate (float): truncate parameter for Gaussian basis interpolation 172 173 Returns: 174 WarpMap: inverted WarpMap 175 """ 176 warp_field = self.warp_field.get() 177 target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get() 178 wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int") 179 num_coords = accumarray(target_coords, wf_shape) 180 inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype) 181 for i in range(3): 182 inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel()) 183 with np.errstate(invalid="ignore"): 184 inv_field[i] /= num_coords 185 inv_field[i][num_coords == 0] = np.nan 186 inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate) 187 return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape) 188 189 def push_coordinates(self, coords, negative_shifts=False): 190 """Push voxel coordinates from fixed to moving space. 191 192 Args: 193 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 194 195 Returns: 196 numpy.array: transformed voxel coordinates 197 """ 198 assert coords.shape[0] == 3 199 coords = cp.array(coords, dtype="float32") 200 # coords_blocked = coords / self.block_size[:, None] - 0.5 201 coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None] 202 warp_field = self.warp_field.copy() 203 shifts = cp.zeros_like(coords) 204 for idim in range(3): 205 shifts[idim] = cupyx.scipy.ndimage.map_coordinates( 206 warp_field[idim], coords_blocked, order=1, mode="nearest" 207 ) 208 if negative_shifts: 209 shifts = -shifts 210 return coords + shifts 211 212 def pull_coordinates(self, coords): 213 """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates. 214 215 Args: 216 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 217 218 Returns: 219 numpy.array: transformed voxel coordinates 220 """ 221 return self.invert().push_coordinates(coords, negative_shifts=True) 222 223 def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1): 224 """ 225 Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid. 226 227 Args: 228 edge_order : passed to np.gradient (1 or 2) 229 230 Returns: 231 detJ: cp.ndarray of shape spatial 232 """ 233 scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride 234 coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None] 235 phi = coords + self.warp_field 236 J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32") 237 for i in range(3): 238 grads = cp.gradient(phi[i], *scaling, edge_order=edge_order) 239 for j in range(3): 240 J[..., i, j] = grads[j] 241 return cp.linalg.det(J) 242 243 def as_ants_image(self, voxel_size_um=1): 244 """Convert to ANTsImage 245 246 Args: 247 voxel_size_um (scalar or array): voxel size (default is 1) 248 249 Returns: 250 ants.core.ants_image.ANTsImage: 251 """ 252 try: 253 import ants 254 except ImportError: 255 raise ImportError("ANTs is not installed. Please install it using 'pip install ants'") 256 257 ants_image = ants.from_numpy( 258 self.warp_field.get().transpose(1, 2, 3, 0), 259 origin=list((self.block_size.get() - 1) / 2 * voxel_size_um), 260 spacing=list(self.block_stride.get() * voxel_size_um), 261 has_components=True, 262 ) 263 return ants_image 264 265 def __repr__(self): 266 """String representation of the WarpMap object.""" 267 info = ( 268 f"WarpMap(" 269 f"warp_field_shape={self.warp_field.shape}, " 270 f"block_size={self.block_size.get()}, " 271 f"block_stride={self.block_stride.get()}, " 272 f"transformation: {str(self.mov_shape)} --> {str(self.ref_shape)}" 273 ) 274 return info
Represents a 3D displacement field
Arguments:
- warp_field (numpy.array): the displacement field data (3-x-y-z)
- block_size (3-element list or numpy.array):
- block_stride (3-element list or numpy.array):
- ref_shape (tuple): shape of the reference volume
- mov_shape (tuple): shape of the moving volume
43 def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape): 44 self.warp_field = cp.array(warp_field, dtype="float32") 45 self.block_size = cp.array(block_size, dtype="float32") 46 self.block_stride = cp.array(block_stride, dtype="float32") 47 self.ref_shape = ref_shape 48 self.mov_shape = mov_shape
50 def warp(self, vol, out=None): 51 """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space. 52 53 Args: 54 vol (cupy.array): the volume to be warped 55 56 Returns: 57 cupy.array: warped volume 58 """ 59 if np.any(vol.shape != np.array(self.mov_shape)): 60 warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.") 61 if out is None: 62 out = cp.zeros(self.ref_shape, dtype="float32", order="C") 63 vol_out = warp_volume( 64 vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out 65 ) 66 return vol_out
Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.
Arguments:
- vol (cupy.array): the volume to be warped
Returns:
cupy.array: warped volume
68 def apply(self, *args, **kwargs): 69 """Alias of warp method""" 70 return self.warp(*args, **kwargs)
Alias of warp method
72 def fit_affine(self, target=None): 73 """Fit affine transformation and return new fitted WarpMap 74 75 Args: 76 target (dict): dict with keys "blocks_shape", "block_size", and "block_stride" 77 78 Returns: 79 WarpMap: 80 numpy.array: affine tranformation coefficients 81 """ 82 if target is None: 83 warp_field_shape = self.warp_field.shape 84 block_size = self.block_size 85 block_stride = self.block_stride 86 else: 87 warp_field_shape = target["warp_field_shape"] 88 block_size = cp.array(target["block_size"]).astype("float32") 89 block_stride = cp.array(target["block_stride"]).astype("float32") 90 91 ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T 92 ix = ix * self.block_stride + self.block_size / 2 93 M = cp.zeros(self.warp_field.shape[1:]) 94 #M[1:-1, 1:-1, 1:-1] = 1 95 M[:,:,:] = 1 96 ixg = cp.where(M.flatten() > 0)[0] 97 a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))]) 98 b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg] 99 coeff = cp.linalg.lstsq(a, b, rcond=None)[0] 100 ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2 101 linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape) 102 return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff
Fit affine transformation and return new fitted WarpMap
Arguments:
- target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
Returns:
WarpMap: numpy.array: affine tranformation coefficients
104 def median_filter(self): 105 """Apply median filter to the displacement field 106 107 Returns: 108 WarpMap: new WarpMap with median filtered displacement field 109 """ 110 warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest") 111 return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape)
Apply median filter to the displacement field
Returns:
WarpMap: new WarpMap with median filtered displacement field
113 def resize_to(self, target): 114 """Resize to target WarpMap, using linear interpolation 115 116 Args: 117 target (WarpMap or WarpMapper): target to resize to 118 or a dict with keys "shape", "block_size", and "block_stride" 119 120 Returns: 121 WarpMap: resized WarpMap 122 """ 123 if isinstance(target, WarpMap): 124 t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride 125 elif isinstance(target, WarpMapper): 126 t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride) 127 elif isinstance(target, dict): 128 t_sh, t_bsz, t_bst = ( 129 target["warp_field_shape"][1:], 130 cp.array(target["block_size"]), 131 cp.array(target["block_stride"]), 132 ) 133 else: 134 raise ValueError("target must be a WarpMap, WarpMapper, or dict") 135 ix = cp.array(cp.indices(t_sh).reshape(3, -1)) 136 # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5 137 ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None] 138 dm_r = cp.array( 139 [ 140 cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape( 141 t_sh 142 ) 143 for i in range(3) 144 ] 145 ) 146 return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape)
Resize to target WarpMap, using linear interpolation
Arguments:
- target (WarpMap or WarpMapper): target to resize to or a dict with keys "shape", "block_size", and "block_stride"
Returns:
WarpMap: resized WarpMap
148 def chain(self, target): 149 """Chain displacement maps 150 151 Args: 152 target (WarpMap): WarpMap to be added to existing map 153 154 Returns: 155 WarpMap: new WarpMap with chained displacement field 156 """ 157 indices = cp.indices(target.warp_field.shape[1:]) 158 warp_field = self.warp_field.copy() 159 warp_field += target.warp_field 160 return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape)
Chain displacement maps
Arguments:
- target (WarpMap): WarpMap to be added to existing map
Returns:
WarpMap: new WarpMap with chained displacement field
162 def invert(self, **kwargs): 163 """alias for invert_fast method""" 164 return self.invert_fast(**kwargs)
alias for invert_fast method
166 def invert_fast(self, sigma=0.5, truncate=20): 167 """Invert the displacement field using accumulation and Gaussian basis interpolation. 168 169 Args: 170 sigma (float): standard deviation for Gaussian basis interpolation 171 truncate (float): truncate parameter for Gaussian basis interpolation 172 173 Returns: 174 WarpMap: inverted WarpMap 175 """ 176 warp_field = self.warp_field.get() 177 target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get() 178 wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int") 179 num_coords = accumarray(target_coords, wf_shape) 180 inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype) 181 for i in range(3): 182 inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel()) 183 with np.errstate(invalid="ignore"): 184 inv_field[i] /= num_coords 185 inv_field[i][num_coords == 0] = np.nan 186 inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate) 187 return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape)
Invert the displacement field using accumulation and Gaussian basis interpolation.
Arguments:
- sigma (float): standard deviation for Gaussian basis interpolation
- truncate (float): truncate parameter for Gaussian basis interpolation
Returns:
WarpMap: inverted WarpMap
189 def push_coordinates(self, coords, negative_shifts=False): 190 """Push voxel coordinates from fixed to moving space. 191 192 Args: 193 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 194 195 Returns: 196 numpy.array: transformed voxel coordinates 197 """ 198 assert coords.shape[0] == 3 199 coords = cp.array(coords, dtype="float32") 200 # coords_blocked = coords / self.block_size[:, None] - 0.5 201 coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None] 202 warp_field = self.warp_field.copy() 203 shifts = cp.zeros_like(coords) 204 for idim in range(3): 205 shifts[idim] = cupyx.scipy.ndimage.map_coordinates( 206 warp_field[idim], coords_blocked, order=1, mode="nearest" 207 ) 208 if negative_shifts: 209 shifts = -shifts 210 return coords + shifts
Push voxel coordinates from fixed to moving space.
Arguments:
- coords (numpy.array): 3D voxel coordinates to be warped (3-by-n array)
Returns:
numpy.array: transformed voxel coordinates
212 def pull_coordinates(self, coords): 213 """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates. 214 215 Args: 216 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 217 218 Returns: 219 numpy.array: transformed voxel coordinates 220 """ 221 return self.invert().push_coordinates(coords, negative_shifts=True)
Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.
Arguments:
- coords (numpy.array): 3D voxel coordinates to be warped (3-by-n array)
Returns:
numpy.array: transformed voxel coordinates
223 def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1): 224 """ 225 Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid. 226 227 Args: 228 edge_order : passed to np.gradient (1 or 2) 229 230 Returns: 231 detJ: cp.ndarray of shape spatial 232 """ 233 scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride 234 coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None] 235 phi = coords + self.warp_field 236 J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32") 237 for i in range(3): 238 grads = cp.gradient(phi[i], *scaling, edge_order=edge_order) 239 for j in range(3): 240 J[..., i, j] = grads[j] 241 return cp.linalg.det(J)
Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.
Arguments:
- edge_order : passed to np.gradient (1 or 2)
Returns:
detJ: cp.ndarray of shape spatial
243 def as_ants_image(self, voxel_size_um=1): 244 """Convert to ANTsImage 245 246 Args: 247 voxel_size_um (scalar or array): voxel size (default is 1) 248 249 Returns: 250 ants.core.ants_image.ANTsImage: 251 """ 252 try: 253 import ants 254 except ImportError: 255 raise ImportError("ANTs is not installed. Please install it using 'pip install ants'") 256 257 ants_image = ants.from_numpy( 258 self.warp_field.get().transpose(1, 2, 3, 0), 259 origin=list((self.block_size.get() - 1) / 2 * voxel_size_um), 260 spacing=list(self.block_stride.get() * voxel_size_um), 261 has_components=True, 262 ) 263 return ants_image
Convert to ANTsImage
Arguments:
- voxel_size_um (scalar or array): voxel size (default is 1)
Returns:
ants.core.ants_image.ANTsImage:
277class WarpMapper: 278 """Class that estimates warp field using cross-correlation, based on a piece-wise rigid model. 279 280 Args: 281 ref_vol (numpy.array): The reference volume 282 block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated 283 block_stride (3-element list or numpy.array): stride (usually identical to block_size) 284 proj_method (str or callable): Projection method 285 """ 286 287 def __init__( 288 self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5 289 ): 290 self.proj_method = proj_method 291 self.plan_rev = [None, None, None] 292 self.subpixel = subpixel 293 self.epsilon = epsilon 294 self.tukey_alpha = tukey_alpha 295 self.update_reference(ref_vol, block_size, block_stride) 296 self.ref_shape = np.array(ref_vol.shape) 297 if np.any(block_size > np.array(ref_vol.shape)): 298 raise ValueError(f"Block size ({block_size}) must be smaller than the volume shape ({ref_vol.shape})") 299 300 def update_reference(self, ref_vol, block_size, block_stride=None): 301 ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1)) 302 block_size = np.array(block_size) 303 block_stride = block_size if block_stride is None else np.array(block_stride) 304 ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride) 305 self.blocks_shape = ref_blocks.shape 306 ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]] 307 if self.tukey_alpha < 1: 308 ref_blocks_proj = [ 309 ref_blocks_proj[i] 310 * cp.array( 311 ndwindow( 312 [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5) 313 ) 314 ).astype("float32") 315 for i in range(3) 316 ] 317 self.plan_fwd = [ 318 cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3) 319 ] 320 self.ref_blocks_proj_ft_conj = [ 321 cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3) 322 ] 323 self.block_size = block_size 324 self.block_stride = block_stride 325 326 def get_displacement(self, vol, smooth_func=None): 327 """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks. 328 329 Args: 330 vol (numpy.array): Input volume 331 smooth_func (callable): Smoothing function to be applied to the cross-correlation volume 332 333 Returns: 334 WarpMap 335 """ 336 vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride) 337 vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]] 338 del vol_blocks 339 340 disp_field = [] 341 for i in range(3): 342 R = ( 343 cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]) 344 * self.ref_blocks_proj_ft_conj[i] 345 ) 346 if self.plan_rev[i] is None: 347 self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R") 348 xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1)) 349 if smooth_func is not None: 350 xcorr_proj = smooth_func(xcorr_proj, self.block_size) 351 xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon 352 353 max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:])) 354 max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2 355 del xcorr_proj 356 i0, j0 = max_ix.reshape(2, -1) 357 shifts = upsampled_dft_rfftn( 358 R.reshape(-1, *R.shape[-2:]), 359 upsampled_region_size=int(self.subpixel * 2 + 1), 360 upsample_factor=self.subpixel, 361 axis_offsets=(i0, j0), 362 ) 363 del R 364 max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:])) 365 max_sub = ( 366 max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2 367 ) / self.subpixel 368 del shifts 369 disp_field.append(max_ix + max_sub) 370 371 disp_field = cp.array(disp_field) 372 disp_field = ( 373 cp.array( 374 [ 375 disp_field[1, 0] + disp_field[2, 0], 376 disp_field[0, 0] + disp_field[2, 1], 377 disp_field[0, 1] + disp_field[1, 1], 378 ] 379 ).astype("float32") 380 / 2 381 ) 382 return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)
Class that estimates warp field using cross-correlation, based on a piece-wise rigid model.
Arguments:
- ref_vol (numpy.array): The reference volume
- block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated
- block_stride (3-element list or numpy.array): stride (usually identical to block_size)
- proj_method (str or callable): Projection method
287 def __init__( 288 self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5 289 ): 290 self.proj_method = proj_method 291 self.plan_rev = [None, None, None] 292 self.subpixel = subpixel 293 self.epsilon = epsilon 294 self.tukey_alpha = tukey_alpha 295 self.update_reference(ref_vol, block_size, block_stride) 296 self.ref_shape = np.array(ref_vol.shape) 297 if np.any(block_size > np.array(ref_vol.shape)): 298 raise ValueError(f"Block size ({block_size}) must be smaller than the volume shape ({ref_vol.shape})")
300 def update_reference(self, ref_vol, block_size, block_stride=None): 301 ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1)) 302 block_size = np.array(block_size) 303 block_stride = block_size if block_stride is None else np.array(block_stride) 304 ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride) 305 self.blocks_shape = ref_blocks.shape 306 ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]] 307 if self.tukey_alpha < 1: 308 ref_blocks_proj = [ 309 ref_blocks_proj[i] 310 * cp.array( 311 ndwindow( 312 [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5) 313 ) 314 ).astype("float32") 315 for i in range(3) 316 ] 317 self.plan_fwd = [ 318 cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3) 319 ] 320 self.ref_blocks_proj_ft_conj = [ 321 cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3) 322 ] 323 self.block_size = block_size 324 self.block_stride = block_stride
326 def get_displacement(self, vol, smooth_func=None): 327 """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks. 328 329 Args: 330 vol (numpy.array): Input volume 331 smooth_func (callable): Smoothing function to be applied to the cross-correlation volume 332 333 Returns: 334 WarpMap 335 """ 336 vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride) 337 vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]] 338 del vol_blocks 339 340 disp_field = [] 341 for i in range(3): 342 R = ( 343 cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]) 344 * self.ref_blocks_proj_ft_conj[i] 345 ) 346 if self.plan_rev[i] is None: 347 self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R") 348 xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1)) 349 if smooth_func is not None: 350 xcorr_proj = smooth_func(xcorr_proj, self.block_size) 351 xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon 352 353 max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:])) 354 max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2 355 del xcorr_proj 356 i0, j0 = max_ix.reshape(2, -1) 357 shifts = upsampled_dft_rfftn( 358 R.reshape(-1, *R.shape[-2:]), 359 upsampled_region_size=int(self.subpixel * 2 + 1), 360 upsample_factor=self.subpixel, 361 axis_offsets=(i0, j0), 362 ) 363 del R 364 max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:])) 365 max_sub = ( 366 max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2 367 ) / self.subpixel 368 del shifts 369 disp_field.append(max_ix + max_sub) 370 371 disp_field = cp.array(disp_field) 372 disp_field = ( 373 cp.array( 374 [ 375 disp_field[1, 0] + disp_field[2, 0], 376 disp_field[0, 0] + disp_field[2, 1], 377 disp_field[0, 1] + disp_field[1, 1], 378 ] 379 ).astype("float32") 380 / 2 381 ) 382 return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)
Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.
Arguments:
- vol (numpy.array): Input volume
- smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
Returns:
WarpMap
385class RegistrationPyramid: 386 """A class for performing multi-resolution registration. 387 388 Args: 389 ref_vol (numpy.array): Reference volume 390 settings (pandas.DataFrame): Settings for each level of the pyramid. 391 IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level. 392 reg_mask (numpy.array): Mask for registration 393 clip_thresh (float): Threshold for clipping the reference volume 394 """ 395 396 def __init__(self, ref_vol, recipe, reg_mask=1): 397 recipe.model_validate(recipe.model_dump()) 398 self.recipe = recipe 399 self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C") 400 self.mappers = [] 401 ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C") 402 self.ref_shape = ref_vol.shape 403 if self.recipe.pre_filter is not None: 404 ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask) 405 self.mapper_ix = [] 406 for i in range(len(recipe.levels)): 407 if recipe.levels[i].repeats < 1: 408 continue 409 block_size = np.array(recipe.levels[i].block_size) 410 tmp = np.r_[ref_vol.shape] // -block_size 411 block_size[block_size < 0] = tmp[block_size < 0] 412 if isinstance(recipe.levels[i].block_stride, (int, float)): 413 block_stride = np.round(block_size * recipe.levels[i].block_stride).astype("int") 414 else: 415 block_stride = np.array(recipe.levels[i].block_stride) 416 self.mappers.append( 417 WarpMapper( 418 ref_vol, 419 block_size, 420 block_stride=block_stride, 421 proj_method=recipe.levels[i].project, 422 tukey_alpha=recipe.levels[i].tukey_ref, 423 ) 424 ) 425 self.mapper_ix.append(i) 426 assert len(self.mappers) > 0, "At least one level of registration is required" 427 428 def register_single(self, vol, callback=None, verbose=False): 429 """Register a single volume to the reference volume. 430 431 Args: 432 vol (array_like): Volume to be registered (numpy or cupy array) 433 callback (function): Callback function to be called after each level of registration 434 435 Returns: 436 - vol (array_like): Registered volume (numpy or cupy array, depending on input) 437 - warp_map (WarpMap): Displacement field 438 - callback_output (list): List of outputs from the callback function 439 """ 440 was_numpy = isinstance(vol, np.ndarray) 441 vol = cp.array(vol, "float32", copy=False, order="C") 442 offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2 443 warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape) 444 warp_map = warp_map.resize_to(self.mappers[-1]) 445 callback_output = [] 446 vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol 447 vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C") 448 warp_map.warp(vol_tmp0, out=vol_tmp) 449 min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0) 450 if callback is not None: 451 callback_output.append(callback(vol_tmp)) 452 453 if np.any(self.mappers[-1].block_stride > min_block_stride[0]): 454 warnings.warn( 455 "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)." 456 ) 457 for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)): 458 for _ in tqdm( 459 range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose 460 ): 461 wm = mapper.get_displacement( 462 vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth # * self.reg_mask, 463 ) 464 wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate 465 if self.recipe.levels[self.mapper_ix[k]].median_filter: 466 wm = wm.median_filter() 467 if self.recipe.levels[self.mapper_ix[k]].affine: 468 if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1: 469 raise ValueError( 470 f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}" 471 ) 472 wm, _ = wm.fit_affine( 473 target=dict( 474 warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]), 475 block_size=self.mappers[-1].block_size, 476 block_stride=self.mappers[-1].block_stride, 477 ) 478 ) 479 else: 480 wm = wm.resize_to(self.mappers[-1]) 481 482 warp_map = warp_map.chain(wm) 483 warp_map.warp(vol_tmp0, out=vol_tmp) 484 if callback is not None: 485 # callback_output.append(callback(warp_map.unwarp(vol))) 486 callback_output.append(callback(vol_tmp)) 487 warp_map.warp(vol, out=vol_tmp) 488 if was_numpy: 489 vol_tmp = vol_tmp.get() 490 return vol_tmp, warp_map, callback_output
A class for performing multi-resolution registration.
Arguments:
- ref_vol (numpy.array): Reference volume
- settings (pandas.DataFrame): Settings for each level of the pyramid. IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level.
- reg_mask (numpy.array): Mask for registration
- clip_thresh (float): Threshold for clipping the reference volume
396 def __init__(self, ref_vol, recipe, reg_mask=1): 397 recipe.model_validate(recipe.model_dump()) 398 self.recipe = recipe 399 self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C") 400 self.mappers = [] 401 ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C") 402 self.ref_shape = ref_vol.shape 403 if self.recipe.pre_filter is not None: 404 ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask) 405 self.mapper_ix = [] 406 for i in range(len(recipe.levels)): 407 if recipe.levels[i].repeats < 1: 408 continue 409 block_size = np.array(recipe.levels[i].block_size) 410 tmp = np.r_[ref_vol.shape] // -block_size 411 block_size[block_size < 0] = tmp[block_size < 0] 412 if isinstance(recipe.levels[i].block_stride, (int, float)): 413 block_stride = np.round(block_size * recipe.levels[i].block_stride).astype("int") 414 else: 415 block_stride = np.array(recipe.levels[i].block_stride) 416 self.mappers.append( 417 WarpMapper( 418 ref_vol, 419 block_size, 420 block_stride=block_stride, 421 proj_method=recipe.levels[i].project, 422 tukey_alpha=recipe.levels[i].tukey_ref, 423 ) 424 ) 425 self.mapper_ix.append(i) 426 assert len(self.mappers) > 0, "At least one level of registration is required"
428 def register_single(self, vol, callback=None, verbose=False): 429 """Register a single volume to the reference volume. 430 431 Args: 432 vol (array_like): Volume to be registered (numpy or cupy array) 433 callback (function): Callback function to be called after each level of registration 434 435 Returns: 436 - vol (array_like): Registered volume (numpy or cupy array, depending on input) 437 - warp_map (WarpMap): Displacement field 438 - callback_output (list): List of outputs from the callback function 439 """ 440 was_numpy = isinstance(vol, np.ndarray) 441 vol = cp.array(vol, "float32", copy=False, order="C") 442 offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2 443 warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape) 444 warp_map = warp_map.resize_to(self.mappers[-1]) 445 callback_output = [] 446 vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol 447 vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C") 448 warp_map.warp(vol_tmp0, out=vol_tmp) 449 min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0) 450 if callback is not None: 451 callback_output.append(callback(vol_tmp)) 452 453 if np.any(self.mappers[-1].block_stride > min_block_stride[0]): 454 warnings.warn( 455 "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)." 456 ) 457 for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)): 458 for _ in tqdm( 459 range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose 460 ): 461 wm = mapper.get_displacement( 462 vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth # * self.reg_mask, 463 ) 464 wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate 465 if self.recipe.levels[self.mapper_ix[k]].median_filter: 466 wm = wm.median_filter() 467 if self.recipe.levels[self.mapper_ix[k]].affine: 468 if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1: 469 raise ValueError( 470 f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}" 471 ) 472 wm, _ = wm.fit_affine( 473 target=dict( 474 warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]), 475 block_size=self.mappers[-1].block_size, 476 block_stride=self.mappers[-1].block_stride, 477 ) 478 ) 479 else: 480 wm = wm.resize_to(self.mappers[-1]) 481 482 warp_map = warp_map.chain(wm) 483 warp_map.warp(vol_tmp0, out=vol_tmp) 484 if callback is not None: 485 # callback_output.append(callback(warp_map.unwarp(vol))) 486 callback_output.append(callback(vol_tmp)) 487 warp_map.warp(vol, out=vol_tmp) 488 if was_numpy: 489 vol_tmp = vol_tmp.get() 490 return vol_tmp, warp_map, callback_output
Register a single volume to the reference volume.
Arguments:
- vol (array_like): Volume to be registered (numpy or cupy array)
- callback (function): Callback function to be called after each level of registration
Returns:
- vol (array_like): Registered volume (numpy or cupy array, depending on input)
- warp_map (WarpMap): Displacement field
- callback_output (list): List of outputs from the callback function
493def register_volumes(ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None): 494 """Register a volume to a reference volume using a registration pyramid. 495 496 Args: 497 ref (numpy.array or cupy.array): Reference volume 498 vol (numpy.array or cupy.array): Volume to be registered 499 recipe (Recipe): Registration recipe 500 reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask) 501 callback (function): Callback function to be called on the volume after each iteration. Default is None. 502 Can be used to monitor and optimize registration. Example: `callback = lambda vol: vol.mean(1).get()` 503 (note that `vol` is a 3D cupy array. Use `.get()` to turn the output into a numpy array and save GPU memory). 504 Callback outputs for each registration step will be returned as a list. 505 verbose (bool): If True, show progress bars. Default is True 506 video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None. 507 vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values. 508 509 Returns: 510 - numpy.array or cupy.array (depending on vol input): Registered volume 511 - WarpMap: Displacement field 512 - list: List of outputs from the callback function 513 """ 514 recipe.model_validate(recipe.model_dump()) 515 reg = RegistrationPyramid(ref, recipe, reg_mask=reg_mask) 516 registered_vol, warp_map, cbout = reg.register_single(vol, callback=callback, verbose=verbose) 517 del reg 518 gc.collect() 519 cp.fft.config.get_plan_cache().clear() 520 521 if video_path is not None: 522 try: 523 assert cbout[0].ndim == 2, "Callback output must be a 2D array" 524 ref = callback(recipe.pre_filter(ref)) 525 vmax = np.percentile(ref, 99.9).item() if vmax is None else vmax 526 create_rgb_video(video_path, ref / vmax, np.array(cbout) / vmax, fps=10) 527 except (ValueError, AssertionError) as e: 528 warnings.warn(f"Video generation failed with error: {e}") 529 return registered_vol, warp_map, cbout
Register a volume to a reference volume using a registration pyramid.
Arguments:
- ref (numpy.array or cupy.array): Reference volume
- vol (numpy.array or cupy.array): Volume to be registered
- recipe (Recipe): Registration recipe
- reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask)
- callback (function): Callback function to be called on the volume after each iteration. Default is None.
Can be used to monitor and optimize registration. Example:
callback = lambda vol: vol.mean(1).get()
(note thatvol
is a 3D cupy array. Use.get()
to turn the output into a numpy array and save GPU memory). Callback outputs for each registration step will be returned as a list. - verbose (bool): If True, show progress bars. Default is True
- video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None.
- vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values.
Returns:
- numpy.array or cupy.array (depending on vol input): Registered volume
- WarpMap: Displacement field
- list: List of outputs from the callback function
532class Projector(BaseModel): 533 """A class to apply a 2D projection and filters to a volume block 534 535 Parameters: 536 max: if True, apply a max filter to the volume block. Default is True 537 normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False 538 dog: if True, apply a DoG filter to the volume block. Default is True 539 low: the lower sigma value for the DoG filter. Default is 0.5 540 high: the higher sigma value for the DoG filter. Default is 10.0 541 tukey_env: if True, apply a Tukey window to the output. Default is False 542 gauss_env: if True, apply a Gaussian window to the output. Default is False 543 """ 544 545 max: bool = True 546 normalize: Union[bool, float] = False 547 dog: bool = True 548 low: Union[Union[int, float], List[Union[int, float]]] = 0.5 549 high: Union[Union[int, float], List[Union[int, float]]] = 10.0 550 periodic_smooth: bool = False 551 552 def __call__(self, vol_blocks, axis): 553 """Apply a 2D projection and filters to a volume block 554 Args: 555 vol_blocks (cupy.array): Blocked volume to be projected (6D dataset, with the first 3 dimensions being blocks and the last 3 dimensions being voxels) 556 axis (int): Axis along which to project 557 Returns: 558 cupy.array: Projected volume block (5D dataset, with the first 3 dimensions being blocks and the last 2 dimensions being 2D projections) 559 """ 560 if self.max: 561 out = vol_blocks.max(axis) 562 else: 563 out = vol_blocks.mean(axis) 564 if self.periodic_smooth: 565 out = periodic_smooth_decomposition_nd_rfft(out) 566 low = np.delete(np.r_[1,1,1] * self.low, axis) 567 high = np.delete(np.r_[1,1,1] * self.high, axis) 568 if self.dog: 569 out = dogfilter(out, [0, 0, 0, *low], [0, 0, 0, *high], mode="reflect") 570 elif not np.all(np.array(self.low) == 0): 571 out = cupyx.scipy.ndimage.gaussian_filter(out, [0, 0, 0, *low], mode="reflect", truncate=5.0) 572 if self.normalize > 0: 573 out /= cp.sqrt(cp.sum(out**2, axis=(-2, -1), keepdims=True)) ** self.normalize + 1e-9 574 return out
A class to apply a 2D projection and filters to a volume block
Arguments:
- max: if True, apply a max filter to the volume block. Default is True
- normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False
- dog: if True, apply a DoG filter to the volume block. Default is True
- low: the lower sigma value for the DoG filter. Default is 0.5
- high: the higher sigma value for the DoG filter. Default is 10.0
- tukey_env: if True, apply a Tukey window to the output. Default is False
- gauss_env: if True, apply a Gaussian window to the output. Default is False
577class Smoother(BaseModel): 578 """Smooth blocks with a Gaussian kernel 579 Args: 580 sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied. 581 truncate (float): truncate parameter for gaussian kernel. Default is 5. 582 shear (float): shear parameter for gaussian kernel. Default is None. 583 long_range_ratio (float): long range ratio for double gaussian kernel. Default is None. 584 """ 585 586 sigmas: Union[float, List[float]] = [1.0, 1.0, 1.0] 587 shear: Union[float, None] = None 588 long_range_ratio: Union[float, None] = 0.05 589 590 def __call__(self, xcorr_proj, block_size=None): 591 """Apply a Gaussian filter to the cross-correlation data 592 Args: 593 xcorr_proj (cupy.array): cross-correlation data (5D array, with the first 3 dimensions being the blocks and the last 2 dimensions being the 2D projection) 594 block_size (list): shape of blocks, whose rigid displacement is estimated 595 Returns: 596 cupy.array: smoothed cross-correlation volume 597 """ 598 truncate = 4.0 599 if self.sigmas is None: 600 return xcorr_proj 601 if self.shear is not None: 602 shear_blocks = self.shear * (block_size[1] / block_size[0]) 603 gw = gausskernel_sheared(self.sigma[:2], shear_blocks, truncate=truncate) 604 gw = cp.array(gw[:, :, None, None, None]) 605 xcorr_proj = cupyx.scipy.ndimage.convolve(xcorr_proj, gw, mode="constant") 606 xcorr_proj = cupyx.scipy.ndimage.gaussian_filter1d( 607 xcorr_proj, self.sigmas[2], axis=2, mode="constant", truncate=truncate 608 ) 609 else: # shear is None: 610 xcorr_proj = cupyx.scipy.ndimage.gaussian_filter( 611 xcorr_proj, [*self.sigmas, 0, 0], mode="constant", truncate=truncate 612 ) 613 if self.long_range_ratio is not None: 614 xcorr_proj *= 1 - self.long_range_ratio 615 xcorr_proj += ( 616 cupyx.scipy.ndimage.gaussian_filter( 617 xcorr_proj, [*np.array(self.sigmas) * 5, 0, 0], mode="constant", truncate=truncate 618 ) 619 * self.long_range_ratio 620 ) 621 return xcorr_proj
Smooth blocks with a Gaussian kernel
Arguments:
- sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied.
- truncate (float): truncate parameter for gaussian kernel. Default is 5.
- shear (float): shear parameter for gaussian kernel. Default is None.
- long_range_ratio (float): long range ratio for double gaussian kernel. Default is None.
624class RegFilter(BaseModel): 625 """A class to apply a filter to the volume before registration 626 627 Parameters: 628 clip_thresh: threshold for clipping the reference volume. Default is 0 629 dog: if True, apply a DoG filter to the volume. Default is True 630 low: the lower sigma value for the DoG filter. Default is 0.5 631 high: the higher sigma value for the DoG filter. Default is 10.0 632 """ 633 634 clip_thresh: float = 0 635 dog: bool = True 636 low: float = 0.5 637 high: float = 10.0 638 soft_edge: Union[Union[int, float], List[Union[int, float]]] = 0.0 639 640 def __call__(self, vol, reg_mask=None): 641 """Apply the filter to the volume 642 Args: 643 vol (cupy or numpy array): 3D volume to be filtered 644 reg_mask (array): Mask for registration 645 Returns: 646 cupy.ndarray: Filtered volume 647 """ 648 vol = cp.clip(cp.array(vol, "float32", copy=False) - self.clip_thresh, 0, None) 649 if np.any(np.array(self.soft_edge) > 0): 650 vol = soften_edges(vol, soft_edge=self.soft_edge, copy=False) 651 if reg_mask is not None: 652 vol *= cp.array(reg_mask, dtype="float32", copy=False) 653 if self.dog: 654 vol = dogfilter(vol, self.low, self.high, mode="reflect") 655 return vol
A class to apply a filter to the volume before registration
Arguments:
- clip_thresh: threshold for clipping the reference volume. Default is 0
- dog: if True, apply a DoG filter to the volume. Default is True
- low: the lower sigma value for the DoG filter. Default is 0.5
- high: the higher sigma value for the DoG filter. Default is 10.0
658class LevelConfig(BaseModel): 659 """Configuration for each level of the registration pyramid 660 661 Args: 662 block_size (list): shape of blocks, whose rigid displacement is estimated 663 block_stride (list): stride (usually identical to block_size) 664 repeats (int): number of iterations for this level (deisable level by setting repeats to 0) 665 smooth (Smoother or None): Smoother object 666 project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block. 667 tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5 668 affine (bool): if True, apply affine transformation to the displacement field 669 median_filter (bool): if True, apply median filter to the displacement field 670 update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations. 671 """ 672 673 block_size: Union[List[int]] 674 block_stride: Union[List[int], float] = 1.0 675 project: Union[Projector, Callable[[_ArrayType, int], _ArrayType]] = Projector() 676 tukey_ref: Union[float, None] = 0.5 677 smooth: Union[Smoother, None] = Smoother() 678 affine: bool = False 679 median_filter: bool = True 680 update_rate: float = 1.0 681 repeats: int = 5
Configuration for each level of the registration pyramid
Arguments:
- block_size (list): shape of blocks, whose rigid displacement is estimated
- block_stride (list): stride (usually identical to block_size)
- repeats (int): number of iterations for this level (deisable level by setting repeats to 0)
- smooth (Smoother or None): Smoother object
- project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block.
- tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5
- affine (bool): if True, apply affine transformation to the displacement field
- median_filter (bool): if True, apply median filter to the displacement field
- update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations.
684class Recipe(BaseModel): 685 """Configuration for the registration recipe. Recipe is initialized with a single affine level. 686 687 Args: 688 reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume 689 levels (list): List of LevelConfig objects 690 """ 691 692 pre_filter: Union[RegFilter, Callable[[_ArrayType], _ArrayType], None] = RegFilter() 693 levels: List[LevelConfig] = [ 694 LevelConfig(block_size=[-1, -1, -1], repeats=1), # translation level 695 LevelConfig( # affine level 696 block_size=[-4, -4, -4], 697 repeats=10, 698 affine=True, 699 median_filter=False, 700 smooth=Smoother(sigmas=[0.5, 0.5, 0.5]), 701 ), 702 ] 703 704 def add_level(self, block_size, **kwargs): 705 """Add a level to the registration recipe 706 707 Args: 708 block_size (list): shape of blocks, whose rigid displacement is estimated 709 **kwargs: additional arguments for LevelConfig 710 """ 711 if isinstance(block_size, (int, float)): 712 block_size = [block_size] * 3 713 if len(block_size) != 3: 714 raise ValueError("block_size must be a list of 3 integers") 715 self.levels.append(LevelConfig(block_size=block_size, **kwargs)) 716 717 def insert_level(self, index, block_size, **kwargs): 718 """Insert a level to the registration recipe 719 720 Args: 721 index (int): A number specifying in which position to insert the level 722 block_size (list): shape of blocks, whose rigid displacement is estimated 723 **kwargs: additional arguments for LevelConfig 724 """ 725 if isinstance(block_size, (int, float)): 726 block_size = [block_size] * 3 727 if len(block_size) != 3: 728 raise ValueError("block_size must be a list of 3 integers") 729 self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs)) 730 731 @classmethod 732 def from_yaml(cls, yaml_path): 733 """Load a recipe from a YAML file 734 735 Args: 736 yaml_path (str): path to the YAML file 737 738 Returns: 739 Recipe: Recipe object 740 """ 741 import yaml 742 743 this_file_dir = pathlib.Path(__file__).resolve().parent 744 if os.path.isfile(yaml_path): 745 yaml_path = yaml_path 746 else: 747 yaml_path = os.path.join(this_file_dir, "recipes", yaml_path) 748 749 with open(yaml_path, "r") as f: 750 data = yaml.safe_load(f) 751 752 return cls.model_validate(data) 753 754 def to_yaml(self, yaml_path): 755 """Save the recipe to a YAML file 756 757 Args: 758 yaml_path (str): path to the YAML file 759 """ 760 import yaml 761 762 with open(yaml_path, "w") as f: 763 yaml.dump(self.model_dump(), f) 764 print(f"Recipe saved to {yaml_path}")
Configuration for the registration recipe. Recipe is initialized with a single affine level.
Arguments:
- reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume
- levels (list): List of LevelConfig objects
704 def add_level(self, block_size, **kwargs): 705 """Add a level to the registration recipe 706 707 Args: 708 block_size (list): shape of blocks, whose rigid displacement is estimated 709 **kwargs: additional arguments for LevelConfig 710 """ 711 if isinstance(block_size, (int, float)): 712 block_size = [block_size] * 3 713 if len(block_size) != 3: 714 raise ValueError("block_size must be a list of 3 integers") 715 self.levels.append(LevelConfig(block_size=block_size, **kwargs))
Add a level to the registration recipe
Arguments:
- block_size (list): shape of blocks, whose rigid displacement is estimated
- **kwargs: additional arguments for LevelConfig
717 def insert_level(self, index, block_size, **kwargs): 718 """Insert a level to the registration recipe 719 720 Args: 721 index (int): A number specifying in which position to insert the level 722 block_size (list): shape of blocks, whose rigid displacement is estimated 723 **kwargs: additional arguments for LevelConfig 724 """ 725 if isinstance(block_size, (int, float)): 726 block_size = [block_size] * 3 727 if len(block_size) != 3: 728 raise ValueError("block_size must be a list of 3 integers") 729 self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs))
Insert a level to the registration recipe
Arguments:
- index (int): A number specifying in which position to insert the level
- block_size (list): shape of blocks, whose rigid displacement is estimated
- **kwargs: additional arguments for LevelConfig
731 @classmethod 732 def from_yaml(cls, yaml_path): 733 """Load a recipe from a YAML file 734 735 Args: 736 yaml_path (str): path to the YAML file 737 738 Returns: 739 Recipe: Recipe object 740 """ 741 import yaml 742 743 this_file_dir = pathlib.Path(__file__).resolve().parent 744 if os.path.isfile(yaml_path): 745 yaml_path = yaml_path 746 else: 747 yaml_path = os.path.join(this_file_dir, "recipes", yaml_path) 748 749 with open(yaml_path, "r") as f: 750 data = yaml.safe_load(f) 751 752 return cls.model_validate(data)
Load a recipe from a YAML file
Arguments:
- yaml_path (str): path to the YAML file
Returns:
Recipe: Recipe object
754 def to_yaml(self, yaml_path): 755 """Save the recipe to a YAML file 756 757 Args: 758 yaml_path (str): path to the YAML file 759 """ 760 import yaml 761 762 with open(yaml_path, "w") as f: 763 yaml.dump(self.model_dump(), f) 764 print(f"Recipe saved to {yaml_path}")
Save the recipe to a YAML file
Arguments:
- yaml_path (str): path to the YAML file