warpfield.register
1import warnings 2import gc 3import pathlib 4import os 5from typing import List, Union, Callable 6 7import numpy as np 8import scipy.signal 9import cupy as cp 10import cupyx 11import cupyx.scipy.ndimage 12from pydantic import BaseModel, ValidationError 13from tqdm.auto import tqdm 14import h5py 15 16from .warp import warp_volume 17from .utils import create_rgb_video, mips_callback 18from .ndimage import ( 19 accumarray, 20 dogfilter, 21 gausskernel_sheared, 22 infill_nans, 23 ndwindow, 24 periodic_smooth_decomposition_nd_rfft, 25 sliding_block, 26 upsampled_dft_rfftn, 27 soften_edges, 28) 29 30_ArrayType = Union[np.ndarray, cp.ndarray] 31 32class WarpMap: 33 """Represents a 3D displacement field 34 35 Args: 36 warp_field (numpy.array): the displacement field data (3-x-y-z) 37 block_size (3-element list or numpy.array): 38 block_stride (3-element list or numpy.array): 39 ref_shape (tuple): shape of the reference volume 40 mov_shape (tuple): shape of the moving volume 41 """ 42 43 def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape): 44 self.warp_field = cp.array(warp_field, dtype="float32") 45 self.block_size = cp.array(block_size, dtype="float32") 46 self.block_stride = cp.array(block_stride, dtype="float32") 47 self.ref_shape = ref_shape 48 self.mov_shape = mov_shape 49 50 def warp(self, vol, out=None): 51 """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space. 52 53 Args: 54 vol (cupy.array): the volume to be warped 55 56 Returns: 57 cupy.array: warped volume 58 """ 59 if np.any(vol.shape != np.array(self.mov_shape)): 60 warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.") 61 if out is None: 62 out = cp.zeros(self.ref_shape, dtype="float32", order="C") 63 vol_out = warp_volume( 64 vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out 65 ) 66 return vol_out 67 68 def apply(self, *args, **kwargs): 69 """Alias of warp method""" 70 return self.warp(*args, **kwargs) 71 72 def fit_affine(self, target=None): 73 """Fit affine transformation and return new fitted WarpMap 74 75 Args: 76 target (dict): dict with keys "blocks_shape", "block_size", and "block_stride" 77 78 Returns: 79 WarpMap: 80 numpy.array: affine tranformation coefficients 81 """ 82 if target is None: 83 warp_field_shape = self.warp_field.shape 84 block_size = self.block_size 85 block_stride = self.block_stride 86 else: 87 warp_field_shape = target["warp_field_shape"] 88 block_size = cp.array(target["block_size"]).astype("float32") 89 block_stride = cp.array(target["block_stride"]).astype("float32") 90 91 ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T 92 ix = ix * self.block_stride + self.block_size / 2 93 M = cp.zeros(self.warp_field.shape[1:]) 94 #M[1:-1, 1:-1, 1:-1] = 1 95 M[:,:,:] = 1 96 ixg = cp.where(M.flatten() > 0)[0] 97 a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))]) 98 b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg] 99 coeff = cp.linalg.lstsq(a, b, rcond=None)[0] 100 ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2 101 linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape) 102 return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff 103 104 def median_filter(self): 105 """Apply median filter to the displacement field 106 107 Returns: 108 WarpMap: new WarpMap with median filtered displacement field 109 """ 110 warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest") 111 return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape) 112 113 def resize_to(self, target): 114 """Resize to target WarpMap, using linear interpolation 115 116 Args: 117 target (WarpMap or WarpMapper): target to resize to 118 or a dict with keys "shape", "block_size", and "block_stride" 119 120 Returns: 121 WarpMap: resized WarpMap 122 """ 123 if isinstance(target, WarpMap): 124 t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride 125 elif isinstance(target, WarpMapper): 126 t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride) 127 elif isinstance(target, dict): 128 t_sh, t_bsz, t_bst = ( 129 target["warp_field_shape"][1:], 130 cp.array(target["block_size"]), 131 cp.array(target["block_stride"]), 132 ) 133 else: 134 raise ValueError("target must be a WarpMap, WarpMapper, or dict") 135 ix = cp.array(cp.indices(t_sh).reshape(3, -1)) 136 # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5 137 ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None] 138 dm_r = cp.array( 139 [ 140 cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape( 141 t_sh 142 ) 143 for i in range(3) 144 ] 145 ) 146 return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape) 147 148 def chain(self, target): 149 """Chain displacement maps 150 151 Args: 152 target (WarpMap): WarpMap to be added to existing map 153 154 Returns: 155 WarpMap: new WarpMap with chained displacement field 156 """ 157 indices = cp.indices(target.warp_field.shape[1:]) 158 warp_field = self.warp_field.copy() 159 warp_field += target.warp_field 160 return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape) 161 162 def invert(self, **kwargs): 163 """alias for invert_fast method""" 164 return self.invert_fast(**kwargs) 165 166 def invert_fast(self, sigma=0.5, truncate=20): 167 """Invert the displacement field using accumulation and Gaussian basis interpolation. 168 169 Args: 170 sigma (float): standard deviation for Gaussian basis interpolation 171 truncate (float): truncate parameter for Gaussian basis interpolation 172 173 Returns: 174 WarpMap: inverted WarpMap 175 """ 176 warp_field = self.warp_field.get() 177 target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get() 178 wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int") 179 num_coords = accumarray(target_coords, wf_shape) 180 inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype) 181 for i in range(3): 182 inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel()) 183 with np.errstate(invalid="ignore"): 184 inv_field[i] /= num_coords 185 inv_field[i][num_coords == 0] = np.nan 186 inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate) 187 return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape) 188 189 def push_coordinates(self, coords, negative_shifts=False): 190 """Push voxel coordinates from fixed to moving space. 191 192 Args: 193 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 194 195 Returns: 196 numpy.array: transformed voxel coordinates 197 """ 198 assert coords.shape[0] == 3 199 coords = cp.array(coords, dtype="float32") 200 # coords_blocked = coords / self.block_size[:, None] - 0.5 201 coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None] 202 warp_field = self.warp_field.copy() 203 shifts = cp.zeros_like(coords) 204 for idim in range(3): 205 shifts[idim] = cupyx.scipy.ndimage.map_coordinates( 206 warp_field[idim], coords_blocked, order=1, mode="nearest" 207 ) 208 if negative_shifts: 209 shifts = -shifts 210 return coords + shifts 211 212 def pull_coordinates(self, coords): 213 """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates. 214 215 Args: 216 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 217 218 Returns: 219 numpy.array: transformed voxel coordinates 220 """ 221 return self.invert().push_coordinates(coords, negative_shifts=True) 222 223 def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1): 224 """ 225 Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid. 226 227 Args: 228 edge_order : passed to np.gradient (1 or 2) 229 230 Returns: 231 detJ: cp.ndarray of shape spatial 232 """ 233 scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride 234 coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None] 235 phi = coords + self.warp_field 236 J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32") 237 for i in range(3): 238 grads = cp.gradient(phi[i], *scaling, edge_order=edge_order) 239 for j in range(3): 240 J[..., i, j] = grads[j] 241 return cp.linalg.det(J) 242 243 def as_ants_image(self, voxel_size_um=1): 244 """Convert to ANTsImage 245 246 Args: 247 voxel_size_um (scalar or array): voxel size (default is 1) 248 249 Returns: 250 ants.core.ants_image.ANTsImage: 251 """ 252 try: 253 import ants 254 except ImportError: 255 raise ImportError("ANTs is not installed. Please install it using 'pip install ants'") 256 257 ants_image = ants.from_numpy( 258 self.warp_field.get().transpose(1, 2, 3, 0), 259 origin=list((self.block_size.get() - 1) / 2 * voxel_size_um), 260 spacing=list(self.block_stride.get() * voxel_size_um), 261 has_components=True, 262 ) 263 return ants_image 264 265 def __repr__(self): 266 """String representation of the WarpMap object.""" 267 info = ( 268 f"WarpMap(" 269 f"warp_field_shape={self.warp_field.shape}, " 270 f"block_size={self.block_size.get()}, " 271 f"block_stride={self.block_stride.get()}, " 272 f"transformation: {str(self.mov_shape)} --> {str(self.ref_shape)}" 273 ) 274 return info 275 276 def to_h5(self, h5_path, group="warp_map", compression="gzip", overwrite=True): 277 """ 278 Save this WarpMap to an HDF5 file. 279 280 Args: 281 h5_path (str or os.PathLike): Path to the HDF5 file. 282 group (str): Group path inside the HDF5 file to store the WarpMap (created if missing). 283 compression (str or None): Dataset compression (e.g., 'gzip', None). 284 overwrite (bool): If True, overwrite existing datasets/attrs inside the group. 285 """ 286 with h5py.File(h5_path, "a") as f: 287 if overwrite and (group not in (None, "", "/")) and (group in f): 288 del f[group] 289 if (not overwrite) and (group in f): 290 raise ValueError(f"Group '{group}' already exists in {h5_path}. Set 'overwrite=True' to overwrite it.") 291 grp = f.require_group(group) if group not in (None, "", "/") else f 292 grp.create_dataset("warp_field", data=self.warp_field.get(), compression=compression) 293 grp.create_dataset("block_size", data=self.block_size.get()) 294 grp.create_dataset("block_stride", data=self.block_stride.get()) 295 grp.create_dataset("ref_shape", data=np.array(self.ref_shape, dtype="int64")) 296 grp.create_dataset("mov_shape", data=np.array(self.mov_shape, dtype="int64")) 297 grp.attrs["class"] = "WarpMap" 298 299 @classmethod 300 def from_h5(cls, h5_path, group="warp_map"): 301 """ 302 Load a WarpMap from an HDF5 file. 303 304 Args: 305 h5_path (str or os.PathLike): Path to the HDF5 file. 306 group (str): Group path inside the HDF5 file where the WarpMap is stored. 307 308 Returns: 309 WarpMap: The loaded WarpMap object. 310 """ 311 with h5py.File(h5_path, "r") as f: 312 grp = f[group] 313 warp_field = grp["warp_field"][:] 314 block_size = grp["block_size"][:] 315 block_stride = grp["block_stride"][:] 316 ref_shape = tuple(grp["ref_shape"][:].tolist()) 317 mov_shape = tuple(grp["mov_shape"][:].tolist()) 318 return cls(warp_field, block_size, block_stride, ref_shape, mov_shape) 319 320 321class WarpMapper: 322 """Class that estimates warp field using cross-correlation, based on a piece-wise rigid model. 323 324 Args: 325 ref_vol (numpy.array): The reference volume 326 block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated 327 block_stride (3-element list or numpy.array): stride (usually identical to block_size) 328 proj_method (str or callable): Projection method 329 """ 330 331 def __init__( 332 self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5 333 ): 334 if np.any(block_size > np.array(ref_vol.shape)): 335 raise ValueError(f"Block size (currently: {block_size}) must be smaller than the volume shape ({np.array(ref_vol.shape)}).") 336 self.proj_method = proj_method 337 self.plan_rev = [None, None, None] 338 self.subpixel = subpixel 339 self.epsilon = epsilon 340 self.tukey_alpha = tukey_alpha 341 self.update_reference(ref_vol, block_size, block_stride) 342 self.ref_shape = np.array(ref_vol.shape) 343 344 def update_reference(self, ref_vol, block_size, block_stride=None): 345 ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1)) 346 block_size = np.array(block_size) 347 block_stride = block_size if block_stride is None else np.array(block_stride) 348 ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride) 349 self.blocks_shape = ref_blocks.shape 350 ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]] 351 if self.tukey_alpha < 1: 352 ref_blocks_proj = [ 353 ref_blocks_proj[i] 354 * cp.array( 355 ndwindow( 356 [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5) 357 ) 358 ).astype("float32") 359 for i in range(3) 360 ] 361 self.plan_fwd = [ 362 cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3) 363 ] 364 self.ref_blocks_proj_ft_conj = [ 365 cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3) 366 ] 367 self.block_size = block_size 368 self.block_stride = block_stride 369 370 def get_displacement(self, vol, smooth_func=None): 371 """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks. 372 373 Args: 374 vol (numpy.array): Input volume 375 smooth_func (callable): Smoothing function to be applied to the cross-correlation volume 376 377 Returns: 378 WarpMap 379 """ 380 vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride) 381 vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]] 382 del vol_blocks 383 384 disp_field = [] 385 for i in range(3): 386 R = ( 387 cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]) 388 * self.ref_blocks_proj_ft_conj[i] 389 ) 390 if self.plan_rev[i] is None: 391 self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R") 392 xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1)) 393 if smooth_func is not None: 394 xcorr_proj = smooth_func(xcorr_proj, self.block_size) 395 xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon 396 397 max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:])) 398 max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2 399 del xcorr_proj 400 i0, j0 = max_ix.reshape(2, -1) 401 shifts = upsampled_dft_rfftn( 402 R.reshape(-1, *R.shape[-2:]), 403 upsampled_region_size=int(self.subpixel * 2 + 1), 404 upsample_factor=self.subpixel, 405 axis_offsets=(i0, j0), 406 ) 407 del R 408 max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:])) 409 max_sub = ( 410 max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2 411 ) / self.subpixel 412 del shifts 413 disp_field.append(max_ix + max_sub) 414 415 disp_field = cp.array(disp_field) 416 disp_field = ( 417 cp.array( 418 [ 419 disp_field[1, 0] + disp_field[2, 0], 420 disp_field[0, 0] + disp_field[2, 1], 421 disp_field[0, 1] + disp_field[1, 1], 422 ] 423 ).astype("float32") 424 / 2 425 ) 426 return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape) 427 428 429class RegistrationPyramid: 430 """A class for performing multi-resolution registration. 431 432 Args: 433 ref_vol (numpy.array): Reference volume 434 settings (pandas.DataFrame): Settings for each level of the pyramid. 435 IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level. 436 reg_mask (numpy.array): Mask for registration 437 clip_thresh (float): Threshold for clipping the reference volume 438 """ 439 440 def __init__(self, ref_vol, recipe, reg_mask=1): 441 recipe.model_validate(recipe.model_dump()) 442 self.recipe = recipe 443 self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C") 444 self.mappers = [] 445 ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C") 446 self.ref_shape = ref_vol.shape 447 if self.recipe.pre_filter is not None: 448 ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask) 449 self.mapper_ix = [] 450 for i in range(len(recipe.levels)): 451 if recipe.levels[i].repeats < 1: 452 continue 453 block_size = np.array(recipe.levels[i].block_size) 454 tmp = np.r_[ref_vol.shape] // -block_size 455 block_size[block_size < 0] = tmp[block_size < 0] 456 if isinstance(recipe.levels[i].block_stride, (int, float)): 457 block_stride = (block_size * recipe.levels[i].block_stride).astype("int") 458 else: 459 block_stride = np.array(recipe.levels[i].block_stride) 460 self.mappers.append( 461 WarpMapper( 462 ref_vol, 463 block_size, 464 block_stride=block_stride, 465 proj_method=recipe.levels[i].project, 466 tukey_alpha=recipe.levels[i].tukey_ref, 467 ) 468 ) 469 self.mapper_ix.append(i) 470 assert len(self.mappers) > 0, "At least one level of registration is required" 471 472 def register_single(self, vol, callback=None, verbose=False): 473 """Register a single volume to the reference volume. 474 475 Args: 476 vol (array_like): Volume to be registered (numpy or cupy array) 477 callback (function): Callback function to be called after each level of registration 478 479 Returns: 480 - vol (array_like): Registered volume (numpy or cupy array, depending on input) 481 - warp_map (WarpMap): Displacement field 482 - callback_output (list): List of outputs from the callback function 483 """ 484 was_numpy = isinstance(vol, np.ndarray) 485 vol = cp.array(vol, "float32", copy=False, order="C") 486 offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2 487 warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape) 488 warp_map = warp_map.resize_to(self.mappers[-1]) 489 callback_output = [] 490 vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol 491 vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C") 492 warp_map.warp(vol_tmp0, out=vol_tmp) 493 min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0) 494 if callback is not None: 495 callback_output.append(callback(vol_tmp)) 496 497 if np.any(self.mappers[-1].block_stride > min_block_stride[0]): 498 warnings.warn( 499 "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)." 500 ) 501 for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)): 502 for _ in tqdm( 503 range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose 504 ): 505 wm = mapper.get_displacement( 506 vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth # * self.reg_mask, 507 ) 508 wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate 509 if self.recipe.levels[self.mapper_ix[k]].median_filter: 510 wm = wm.median_filter() 511 if self.recipe.levels[self.mapper_ix[k]].affine: 512 if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1: 513 raise ValueError( 514 f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}" 515 ) 516 wm, _ = wm.fit_affine( 517 target=dict( 518 warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]), 519 block_size=self.mappers[-1].block_size, 520 block_stride=self.mappers[-1].block_stride, 521 ) 522 ) 523 else: 524 wm = wm.resize_to(self.mappers[-1]) 525 526 warp_map = warp_map.chain(wm) 527 warp_map.warp(vol_tmp0, out=vol_tmp) 528 if callback is not None: 529 # callback_output.append(callback(warp_map.unwarp(vol))) 530 callback_output.append(callback(vol_tmp)) 531 warp_map.warp(vol, out=vol_tmp) 532 if was_numpy: 533 vol_tmp = vol_tmp.get() 534 return vol_tmp, warp_map, callback_output 535 536 537def register_volumes(ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None): 538 """Register a volume to a reference volume using a registration pyramid. 539 540 Args: 541 ref (numpy.array or cupy.array): Reference volume 542 vol (numpy.array or cupy.array): Volume to be registered 543 recipe (Recipe): Registration recipe 544 reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask) 545 callback (function): Callback function to be called on the volume after each iteration. Default is None. 546 Can be used to monitor and optimize registration. Example: `callback = lambda vol: vol.mean(1).get()` 547 (note that `vol` is a 3D cupy array. Use `.get()` to turn the output into a numpy array and save GPU memory). 548 Callback outputs for each registration step will be returned as a list. 549 verbose (bool): If True, show progress bars. Default is True 550 video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None. 551 vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values. 552 553 Returns: 554 - numpy.array or cupy.array (depending on vol input): Registered volume 555 - WarpMap: Displacement field 556 - list: List of outputs from the callback function 557 """ 558 recipe.model_validate(recipe.model_dump()) 559 reg = RegistrationPyramid(ref, recipe, reg_mask=reg_mask) 560 registered_vol, warp_map, cbout = reg.register_single(vol, callback=callback, verbose=verbose) 561 del reg 562 gc.collect() 563 cp.fft.config.get_plan_cache().clear() 564 565 if video_path is not None: 566 try: 567 assert cbout[0].ndim == 2, "Callback output must be a 2D array" 568 ref = callback(recipe.pre_filter(ref)) 569 vmax = np.percentile(ref, 99.9).item() if vmax is None else vmax 570 create_rgb_video(video_path, ref / vmax, np.array(cbout) / vmax, fps=10) 571 except (ValueError, AssertionError) as e: 572 warnings.warn(f"Video generation failed with error: {e}") 573 return registered_vol, warp_map, cbout 574 575 576class Projector(BaseModel): 577 """A class to apply a 2D projection and filters to a volume block 578 579 Parameters: 580 max: if True, apply a max filter to the volume block. Default is True 581 normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False 582 dog: if True, apply a DoG filter to the volume block. Default is True 583 low: the lower sigma value for the DoG filter. Default is 0.5 584 high: the higher sigma value for the DoG filter. Default is 10.0 585 tukey_env: if True, apply a Tukey window to the output. Default is False 586 gauss_env: if True, apply a Gaussian window to the output. Default is False 587 """ 588 589 max: bool = True 590 normalize: Union[bool, float] = False 591 dog: bool = True 592 low: Union[Union[int, float], List[Union[int, float]]] = 0.5 593 high: Union[Union[int, float], List[Union[int, float]]] = 10.0 594 periodic_smooth: bool = False 595 596 def __call__(self, vol_blocks, axis): 597 """Apply a 2D projection and filters to a volume block 598 Args: 599 vol_blocks (cupy.array): Blocked volume to be projected (6D dataset, with the first 3 dimensions being blocks and the last 3 dimensions being voxels) 600 axis (int): Axis along which to project 601 Returns: 602 cupy.array: Projected volume block (5D dataset, with the first 3 dimensions being blocks and the last 2 dimensions being 2D projections) 603 """ 604 if self.max: 605 out = vol_blocks.max(axis) 606 else: 607 out = vol_blocks.mean(axis) 608 if self.periodic_smooth: 609 out = periodic_smooth_decomposition_nd_rfft(out) 610 low = np.delete(np.r_[1,1,1] * self.low, axis) 611 high = np.delete(np.r_[1,1,1] * self.high, axis) 612 if self.dog: 613 out = dogfilter(out, [0, 0, 0, *low], [0, 0, 0, *high], mode="reflect") 614 elif not np.all(np.array(self.low) == 0): 615 out = cupyx.scipy.ndimage.gaussian_filter(out, [0, 0, 0, *low], mode="reflect", truncate=5.0) 616 if self.normalize > 0: 617 out /= cp.sqrt(cp.sum(out**2, axis=(-2, -1), keepdims=True)) ** self.normalize + 1e-9 618 return out 619 620 621class Smoother(BaseModel): 622 """Smooth blocks with a Gaussian kernel 623 Args: 624 sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied. 625 truncate (float): truncate parameter for gaussian kernel. Default is 5. 626 shear (float): shear parameter for gaussian kernel. Default is None. 627 long_range_ratio (float): long range ratio for double gaussian kernel. Default is None. 628 """ 629 630 sigmas: Union[float, List[float]] = [1.0, 1.0, 1.0] 631 shear: Union[float, None] = None 632 long_range_ratio: Union[float, None] = 0.05 633 634 def __call__(self, xcorr_proj, block_size=None): 635 """Apply a Gaussian filter to the cross-correlation data 636 Args: 637 xcorr_proj (cupy.array): cross-correlation data (5D array, with the first 3 dimensions being the blocks and the last 2 dimensions being the 2D projection) 638 block_size (list): shape of blocks, whose rigid displacement is estimated 639 Returns: 640 cupy.array: smoothed cross-correlation volume 641 """ 642 truncate = 4.0 643 if self.sigmas is None: 644 return xcorr_proj 645 if self.shear is not None: 646 shear_blocks = self.shear * (block_size[1] / block_size[0]) 647 gw = gausskernel_sheared(self.sigma[:2], shear_blocks, truncate=truncate) 648 gw = cp.array(gw[:, :, None, None, None]) 649 xcorr_proj = cupyx.scipy.ndimage.convolve(xcorr_proj, gw, mode="constant") 650 xcorr_proj = cupyx.scipy.ndimage.gaussian_filter1d( 651 xcorr_proj, self.sigmas[2], axis=2, mode="constant", truncate=truncate 652 ) 653 else: # shear is None: 654 xcorr_proj = cupyx.scipy.ndimage.gaussian_filter( 655 xcorr_proj, [*self.sigmas, 0, 0], mode="constant", truncate=truncate 656 ) 657 if self.long_range_ratio is not None: 658 xcorr_proj *= 1 - self.long_range_ratio 659 xcorr_proj += ( 660 cupyx.scipy.ndimage.gaussian_filter( 661 xcorr_proj, [*np.array(self.sigmas) * 5, 0, 0], mode="constant", truncate=truncate 662 ) 663 * self.long_range_ratio 664 ) 665 return xcorr_proj 666 667 668class RegFilter(BaseModel): 669 """A class to apply a filter to the volume before registration 670 671 Parameters: 672 clip_thresh: threshold for clipping the reference volume. Default is 0 673 dog: if True, apply a DoG filter to the volume. Default is True 674 low: the lower sigma value for the DoG filter. Default is 0.5 675 high: the higher sigma value for the DoG filter. Default is 10.0 676 """ 677 678 clip_thresh: float = 0 679 dog: bool = True 680 low: float = 0.5 681 high: float = 10.0 682 soft_edge: Union[Union[int, float], List[Union[int, float]]] = 0.0 683 684 def __call__(self, vol, reg_mask=None): 685 """Apply the filter to the volume 686 Args: 687 vol (cupy or numpy array): 3D volume to be filtered 688 reg_mask (array): Mask for registration 689 Returns: 690 cupy.ndarray: Filtered volume 691 """ 692 vol = cp.clip(cp.array(vol, "float32", copy=False) - self.clip_thresh, 0, None) 693 if np.any(np.array(self.soft_edge) > 0): 694 vol = soften_edges(vol, soft_edge=self.soft_edge, copy=False) 695 if reg_mask is not None: 696 vol *= cp.array(reg_mask, dtype="float32", copy=False) 697 if self.dog: 698 vol = dogfilter(vol, self.low, self.high, mode="reflect") 699 return vol 700 701 702class LevelConfig(BaseModel): 703 """Configuration for each level of the registration pyramid 704 705 Args: 706 block_size (list): shape of blocks, whose rigid displacement is estimated 707 block_stride (list): stride (usually identical to block_size) 708 repeats (int): number of iterations for this level (deisable level by setting repeats to 0) 709 smooth (Smoother or None): Smoother object 710 project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block. 711 tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5 712 affine (bool): if True, apply affine transformation to the displacement field 713 median_filter (bool): if True, apply median filter to the displacement field 714 update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations. 715 """ 716 717 block_size: Union[List[int]] 718 block_stride: Union[List[int], float] = 1.0 719 project: Union[Projector, Callable[[_ArrayType, int], _ArrayType]] = Projector() 720 tukey_ref: Union[float, None] = 0.5 721 smooth: Union[Smoother, None] = Smoother() 722 affine: bool = False 723 median_filter: bool = True 724 update_rate: float = 1.0 725 repeats: int = 5 726 727 728class Recipe(BaseModel): 729 """Configuration for the registration recipe. Recipe is initialized with a single affine level. 730 731 Args: 732 reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume 733 levels (list): List of LevelConfig objects 734 """ 735 736 pre_filter: Union[RegFilter, Callable[[_ArrayType], _ArrayType], None] = RegFilter() 737 levels: List[LevelConfig] = [ 738 LevelConfig(block_size=[-1, -1, -1], repeats=3), # translation level 739 LevelConfig( # affine level 740 block_size=[-2, -2, -2], 741 block_stride=0.5, 742 repeats=10, 743 affine=True, 744 median_filter=False, 745 smooth=Smoother(sigmas=[0.5, 0.5, 0.5]), 746 ), 747 ] 748 749 def add_level(self, block_size, **kwargs): 750 """Add a level to the registration recipe 751 752 Args: 753 block_size (list): shape of blocks, whose rigid displacement is estimated 754 **kwargs: additional arguments for LevelConfig 755 """ 756 if isinstance(block_size, (int, float)): 757 block_size = [block_size] * 3 758 if len(block_size) != 3: 759 raise ValueError("block_size must be a list of 3 integers") 760 self.levels.append(LevelConfig(block_size=block_size, **kwargs)) 761 762 def insert_level(self, index, block_size, **kwargs): 763 """Insert a level to the registration recipe 764 765 Args: 766 index (int): A number specifying in which position to insert the level 767 block_size (list): shape of blocks, whose rigid displacement is estimated 768 **kwargs: additional arguments for LevelConfig 769 """ 770 if isinstance(block_size, (int, float)): 771 block_size = [block_size] * 3 772 if len(block_size) != 3: 773 raise ValueError("block_size must be a list of 3 integers") 774 self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs)) 775 776 @classmethod 777 def from_yaml(cls, yaml_path): 778 """Load a recipe from a YAML file 779 780 Args: 781 yaml_path (str): path to the YAML file 782 783 Returns: 784 Recipe: Recipe object 785 """ 786 import yaml 787 788 this_file_dir = pathlib.Path(__file__).resolve().parent 789 if os.path.isfile(yaml_path): 790 yaml_path = yaml_path 791 else: 792 yaml_path = os.path.join(this_file_dir, "recipes", yaml_path) 793 794 with open(yaml_path, "r") as f: 795 data = yaml.safe_load(f) 796 797 return cls.model_validate(data) 798 799 def to_yaml(self, yaml_path): 800 """Save the recipe to a YAML file 801 802 Args: 803 yaml_path (str): path to the YAML file 804 """ 805 import yaml 806 807 with open(yaml_path, "w") as f: 808 yaml.dump(self.model_dump(), f) 809 print(f"Recipe saved to {yaml_path}")
33class WarpMap: 34 """Represents a 3D displacement field 35 36 Args: 37 warp_field (numpy.array): the displacement field data (3-x-y-z) 38 block_size (3-element list or numpy.array): 39 block_stride (3-element list or numpy.array): 40 ref_shape (tuple): shape of the reference volume 41 mov_shape (tuple): shape of the moving volume 42 """ 43 44 def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape): 45 self.warp_field = cp.array(warp_field, dtype="float32") 46 self.block_size = cp.array(block_size, dtype="float32") 47 self.block_stride = cp.array(block_stride, dtype="float32") 48 self.ref_shape = ref_shape 49 self.mov_shape = mov_shape 50 51 def warp(self, vol, out=None): 52 """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space. 53 54 Args: 55 vol (cupy.array): the volume to be warped 56 57 Returns: 58 cupy.array: warped volume 59 """ 60 if np.any(vol.shape != np.array(self.mov_shape)): 61 warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.") 62 if out is None: 63 out = cp.zeros(self.ref_shape, dtype="float32", order="C") 64 vol_out = warp_volume( 65 vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out 66 ) 67 return vol_out 68 69 def apply(self, *args, **kwargs): 70 """Alias of warp method""" 71 return self.warp(*args, **kwargs) 72 73 def fit_affine(self, target=None): 74 """Fit affine transformation and return new fitted WarpMap 75 76 Args: 77 target (dict): dict with keys "blocks_shape", "block_size", and "block_stride" 78 79 Returns: 80 WarpMap: 81 numpy.array: affine tranformation coefficients 82 """ 83 if target is None: 84 warp_field_shape = self.warp_field.shape 85 block_size = self.block_size 86 block_stride = self.block_stride 87 else: 88 warp_field_shape = target["warp_field_shape"] 89 block_size = cp.array(target["block_size"]).astype("float32") 90 block_stride = cp.array(target["block_stride"]).astype("float32") 91 92 ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T 93 ix = ix * self.block_stride + self.block_size / 2 94 M = cp.zeros(self.warp_field.shape[1:]) 95 #M[1:-1, 1:-1, 1:-1] = 1 96 M[:,:,:] = 1 97 ixg = cp.where(M.flatten() > 0)[0] 98 a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))]) 99 b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg] 100 coeff = cp.linalg.lstsq(a, b, rcond=None)[0] 101 ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2 102 linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape) 103 return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff 104 105 def median_filter(self): 106 """Apply median filter to the displacement field 107 108 Returns: 109 WarpMap: new WarpMap with median filtered displacement field 110 """ 111 warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest") 112 return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape) 113 114 def resize_to(self, target): 115 """Resize to target WarpMap, using linear interpolation 116 117 Args: 118 target (WarpMap or WarpMapper): target to resize to 119 or a dict with keys "shape", "block_size", and "block_stride" 120 121 Returns: 122 WarpMap: resized WarpMap 123 """ 124 if isinstance(target, WarpMap): 125 t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride 126 elif isinstance(target, WarpMapper): 127 t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride) 128 elif isinstance(target, dict): 129 t_sh, t_bsz, t_bst = ( 130 target["warp_field_shape"][1:], 131 cp.array(target["block_size"]), 132 cp.array(target["block_stride"]), 133 ) 134 else: 135 raise ValueError("target must be a WarpMap, WarpMapper, or dict") 136 ix = cp.array(cp.indices(t_sh).reshape(3, -1)) 137 # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5 138 ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None] 139 dm_r = cp.array( 140 [ 141 cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape( 142 t_sh 143 ) 144 for i in range(3) 145 ] 146 ) 147 return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape) 148 149 def chain(self, target): 150 """Chain displacement maps 151 152 Args: 153 target (WarpMap): WarpMap to be added to existing map 154 155 Returns: 156 WarpMap: new WarpMap with chained displacement field 157 """ 158 indices = cp.indices(target.warp_field.shape[1:]) 159 warp_field = self.warp_field.copy() 160 warp_field += target.warp_field 161 return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape) 162 163 def invert(self, **kwargs): 164 """alias for invert_fast method""" 165 return self.invert_fast(**kwargs) 166 167 def invert_fast(self, sigma=0.5, truncate=20): 168 """Invert the displacement field using accumulation and Gaussian basis interpolation. 169 170 Args: 171 sigma (float): standard deviation for Gaussian basis interpolation 172 truncate (float): truncate parameter for Gaussian basis interpolation 173 174 Returns: 175 WarpMap: inverted WarpMap 176 """ 177 warp_field = self.warp_field.get() 178 target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get() 179 wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int") 180 num_coords = accumarray(target_coords, wf_shape) 181 inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype) 182 for i in range(3): 183 inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel()) 184 with np.errstate(invalid="ignore"): 185 inv_field[i] /= num_coords 186 inv_field[i][num_coords == 0] = np.nan 187 inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate) 188 return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape) 189 190 def push_coordinates(self, coords, negative_shifts=False): 191 """Push voxel coordinates from fixed to moving space. 192 193 Args: 194 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 195 196 Returns: 197 numpy.array: transformed voxel coordinates 198 """ 199 assert coords.shape[0] == 3 200 coords = cp.array(coords, dtype="float32") 201 # coords_blocked = coords / self.block_size[:, None] - 0.5 202 coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None] 203 warp_field = self.warp_field.copy() 204 shifts = cp.zeros_like(coords) 205 for idim in range(3): 206 shifts[idim] = cupyx.scipy.ndimage.map_coordinates( 207 warp_field[idim], coords_blocked, order=1, mode="nearest" 208 ) 209 if negative_shifts: 210 shifts = -shifts 211 return coords + shifts 212 213 def pull_coordinates(self, coords): 214 """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates. 215 216 Args: 217 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 218 219 Returns: 220 numpy.array: transformed voxel coordinates 221 """ 222 return self.invert().push_coordinates(coords, negative_shifts=True) 223 224 def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1): 225 """ 226 Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid. 227 228 Args: 229 edge_order : passed to np.gradient (1 or 2) 230 231 Returns: 232 detJ: cp.ndarray of shape spatial 233 """ 234 scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride 235 coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None] 236 phi = coords + self.warp_field 237 J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32") 238 for i in range(3): 239 grads = cp.gradient(phi[i], *scaling, edge_order=edge_order) 240 for j in range(3): 241 J[..., i, j] = grads[j] 242 return cp.linalg.det(J) 243 244 def as_ants_image(self, voxel_size_um=1): 245 """Convert to ANTsImage 246 247 Args: 248 voxel_size_um (scalar or array): voxel size (default is 1) 249 250 Returns: 251 ants.core.ants_image.ANTsImage: 252 """ 253 try: 254 import ants 255 except ImportError: 256 raise ImportError("ANTs is not installed. Please install it using 'pip install ants'") 257 258 ants_image = ants.from_numpy( 259 self.warp_field.get().transpose(1, 2, 3, 0), 260 origin=list((self.block_size.get() - 1) / 2 * voxel_size_um), 261 spacing=list(self.block_stride.get() * voxel_size_um), 262 has_components=True, 263 ) 264 return ants_image 265 266 def __repr__(self): 267 """String representation of the WarpMap object.""" 268 info = ( 269 f"WarpMap(" 270 f"warp_field_shape={self.warp_field.shape}, " 271 f"block_size={self.block_size.get()}, " 272 f"block_stride={self.block_stride.get()}, " 273 f"transformation: {str(self.mov_shape)} --> {str(self.ref_shape)}" 274 ) 275 return info 276 277 def to_h5(self, h5_path, group="warp_map", compression="gzip", overwrite=True): 278 """ 279 Save this WarpMap to an HDF5 file. 280 281 Args: 282 h5_path (str or os.PathLike): Path to the HDF5 file. 283 group (str): Group path inside the HDF5 file to store the WarpMap (created if missing). 284 compression (str or None): Dataset compression (e.g., 'gzip', None). 285 overwrite (bool): If True, overwrite existing datasets/attrs inside the group. 286 """ 287 with h5py.File(h5_path, "a") as f: 288 if overwrite and (group not in (None, "", "/")) and (group in f): 289 del f[group] 290 if (not overwrite) and (group in f): 291 raise ValueError(f"Group '{group}' already exists in {h5_path}. Set 'overwrite=True' to overwrite it.") 292 grp = f.require_group(group) if group not in (None, "", "/") else f 293 grp.create_dataset("warp_field", data=self.warp_field.get(), compression=compression) 294 grp.create_dataset("block_size", data=self.block_size.get()) 295 grp.create_dataset("block_stride", data=self.block_stride.get()) 296 grp.create_dataset("ref_shape", data=np.array(self.ref_shape, dtype="int64")) 297 grp.create_dataset("mov_shape", data=np.array(self.mov_shape, dtype="int64")) 298 grp.attrs["class"] = "WarpMap" 299 300 @classmethod 301 def from_h5(cls, h5_path, group="warp_map"): 302 """ 303 Load a WarpMap from an HDF5 file. 304 305 Args: 306 h5_path (str or os.PathLike): Path to the HDF5 file. 307 group (str): Group path inside the HDF5 file where the WarpMap is stored. 308 309 Returns: 310 WarpMap: The loaded WarpMap object. 311 """ 312 with h5py.File(h5_path, "r") as f: 313 grp = f[group] 314 warp_field = grp["warp_field"][:] 315 block_size = grp["block_size"][:] 316 block_stride = grp["block_stride"][:] 317 ref_shape = tuple(grp["ref_shape"][:].tolist()) 318 mov_shape = tuple(grp["mov_shape"][:].tolist()) 319 return cls(warp_field, block_size, block_stride, ref_shape, mov_shape)
Represents a 3D displacement field
Arguments:
- warp_field (numpy.array): the displacement field data (3-x-y-z)
- block_size (3-element list or numpy.array):
- block_stride (3-element list or numpy.array):
- ref_shape (tuple): shape of the reference volume
- mov_shape (tuple): shape of the moving volume
44 def __init__(self, warp_field, block_size, block_stride, ref_shape, mov_shape): 45 self.warp_field = cp.array(warp_field, dtype="float32") 46 self.block_size = cp.array(block_size, dtype="float32") 47 self.block_stride = cp.array(block_stride, dtype="float32") 48 self.ref_shape = ref_shape 49 self.mov_shape = mov_shape
51 def warp(self, vol, out=None): 52 """Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space. 53 54 Args: 55 vol (cupy.array): the volume to be warped 56 57 Returns: 58 cupy.array: warped volume 59 """ 60 if np.any(vol.shape != np.array(self.mov_shape)): 61 warnings.warn(f"Volume shape {vol.shape} does not match the expected shape {self.mov_shape}.") 62 if out is None: 63 out = cp.zeros(self.ref_shape, dtype="float32", order="C") 64 vol_out = warp_volume( 65 vol, self.warp_field, self.block_stride, cp.array(-self.block_size / self.block_stride / 2), out=out 66 ) 67 return vol_out
Apply the warp to a volume. Can be thought of as pulling the moving volume to the fixed volume space.
Arguments:
- vol (cupy.array): the volume to be warped
Returns:
cupy.array: warped volume
69 def apply(self, *args, **kwargs): 70 """Alias of warp method""" 71 return self.warp(*args, **kwargs)
Alias of warp method
73 def fit_affine(self, target=None): 74 """Fit affine transformation and return new fitted WarpMap 75 76 Args: 77 target (dict): dict with keys "blocks_shape", "block_size", and "block_stride" 78 79 Returns: 80 WarpMap: 81 numpy.array: affine tranformation coefficients 82 """ 83 if target is None: 84 warp_field_shape = self.warp_field.shape 85 block_size = self.block_size 86 block_stride = self.block_stride 87 else: 88 warp_field_shape = target["warp_field_shape"] 89 block_size = cp.array(target["block_size"]).astype("float32") 90 block_stride = cp.array(target["block_stride"]).astype("float32") 91 92 ix = cp.indices(self.warp_field.shape[1:]).reshape(3, -1).T 93 ix = ix * self.block_stride + self.block_size / 2 94 M = cp.zeros(self.warp_field.shape[1:]) 95 #M[1:-1, 1:-1, 1:-1] = 1 96 M[:,:,:] = 1 97 ixg = cp.where(M.flatten() > 0)[0] 98 a = cp.hstack([ix[ixg], cp.ones((len(ixg), 1))]) 99 b = ix[ixg] + self.warp_field.reshape(3, -1).T[ixg] 100 coeff = cp.linalg.lstsq(a, b, rcond=None)[0] 101 ix_out = cp.indices(warp_field_shape[1:]).reshape(3, -1).T * block_stride + block_size / 2 102 linfit = ((ix_out @ (coeff[:3] - cp.eye(3))) + coeff[3]).T.reshape(warp_field_shape) 103 return WarpMap(linfit, block_size, block_stride, self.ref_shape, self.mov_shape), coeff
Fit affine transformation and return new fitted WarpMap
Arguments:
- target (dict): dict with keys "blocks_shape", "block_size", and "block_stride"
Returns:
WarpMap: numpy.array: affine tranformation coefficients
105 def median_filter(self): 106 """Apply median filter to the displacement field 107 108 Returns: 109 WarpMap: new WarpMap with median filtered displacement field 110 """ 111 warp_field = cupyx.scipy.ndimage.median_filter(self.warp_field, size=[1, 3, 3, 3], mode="nearest") 112 return WarpMap(warp_field, self.block_size, self.block_stride, self.ref_shape, self.mov_shape)
Apply median filter to the displacement field
Returns:
WarpMap: new WarpMap with median filtered displacement field
114 def resize_to(self, target): 115 """Resize to target WarpMap, using linear interpolation 116 117 Args: 118 target (WarpMap or WarpMapper): target to resize to 119 or a dict with keys "shape", "block_size", and "block_stride" 120 121 Returns: 122 WarpMap: resized WarpMap 123 """ 124 if isinstance(target, WarpMap): 125 t_sh, t_bsz, t_bst = target.warp_field.shape[1:], target.block_size, target.block_stride 126 elif isinstance(target, WarpMapper): 127 t_sh, t_bsz, t_bst = target.blocks_shape[:3], cp.array(target.block_size), cp.array(target.block_stride) 128 elif isinstance(target, dict): 129 t_sh, t_bsz, t_bst = ( 130 target["warp_field_shape"][1:], 131 cp.array(target["block_size"]), 132 cp.array(target["block_stride"]), 133 ) 134 else: 135 raise ValueError("target must be a WarpMap, WarpMapper, or dict") 136 ix = cp.array(cp.indices(t_sh).reshape(3, -1)) 137 # ix = (ix + 0.5) / cp.array(self.block_size / t_bsz)[:, None] - 0.5 138 ix = (ix * t_bst[:, None] + (t_bsz - self.block_size)[:, None] / 2) / self.block_stride[:, None] 139 dm_r = cp.array( 140 [ 141 cupyx.scipy.ndimage.map_coordinates(cp.array(self.warp_field[i]), ix, mode="nearest", order=1).reshape( 142 t_sh 143 ) 144 for i in range(3) 145 ] 146 ) 147 return WarpMap(dm_r, t_bsz, t_bst, self.ref_shape, self.mov_shape)
Resize to target WarpMap, using linear interpolation
Arguments:
- target (WarpMap or WarpMapper): target to resize to or a dict with keys "shape", "block_size", and "block_stride"
Returns:
WarpMap: resized WarpMap
149 def chain(self, target): 150 """Chain displacement maps 151 152 Args: 153 target (WarpMap): WarpMap to be added to existing map 154 155 Returns: 156 WarpMap: new WarpMap with chained displacement field 157 """ 158 indices = cp.indices(target.warp_field.shape[1:]) 159 warp_field = self.warp_field.copy() 160 warp_field += target.warp_field 161 return WarpMap(warp_field, target.block_size, target.block_stride, self.ref_shape, self.mov_shape)
Chain displacement maps
Arguments:
- target (WarpMap): WarpMap to be added to existing map
Returns:
WarpMap: new WarpMap with chained displacement field
163 def invert(self, **kwargs): 164 """alias for invert_fast method""" 165 return self.invert_fast(**kwargs)
alias for invert_fast method
167 def invert_fast(self, sigma=0.5, truncate=20): 168 """Invert the displacement field using accumulation and Gaussian basis interpolation. 169 170 Args: 171 sigma (float): standard deviation for Gaussian basis interpolation 172 truncate (float): truncate parameter for Gaussian basis interpolation 173 174 Returns: 175 WarpMap: inverted WarpMap 176 """ 177 warp_field = self.warp_field.get() 178 target_coords = np.indices(warp_field.shape[1:]) + warp_field / self.block_stride[:, None, None, None].get() 179 wf_shape = np.ceil(np.array(self.mov_shape) / self.block_stride.get() + 1).astype("int") 180 num_coords = accumarray(target_coords, wf_shape) 181 inv_field = np.zeros((3, *wf_shape), dtype=warp_field.dtype) 182 for i in range(3): 183 inv_field[i] = -accumarray(target_coords, wf_shape, weights=warp_field[i].ravel()) 184 with np.errstate(invalid="ignore"): 185 inv_field[i] /= num_coords 186 inv_field[i][num_coords == 0] = np.nan 187 inv_field[i] = infill_nans(inv_field[i], sigma=sigma, truncate=truncate) 188 return WarpMap(inv_field, self.block_size, self.block_stride, self.mov_shape, self.ref_shape)
Invert the displacement field using accumulation and Gaussian basis interpolation.
Arguments:
- sigma (float): standard deviation for Gaussian basis interpolation
- truncate (float): truncate parameter for Gaussian basis interpolation
Returns:
WarpMap: inverted WarpMap
190 def push_coordinates(self, coords, negative_shifts=False): 191 """Push voxel coordinates from fixed to moving space. 192 193 Args: 194 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 195 196 Returns: 197 numpy.array: transformed voxel coordinates 198 """ 199 assert coords.shape[0] == 3 200 coords = cp.array(coords, dtype="float32") 201 # coords_blocked = coords / self.block_size[:, None] - 0.5 202 coords_blocked = coords / self.block_stride[:, None] - (self.block_size / (2 * self.block_stride))[:, None] 203 warp_field = self.warp_field.copy() 204 shifts = cp.zeros_like(coords) 205 for idim in range(3): 206 shifts[idim] = cupyx.scipy.ndimage.map_coordinates( 207 warp_field[idim], coords_blocked, order=1, mode="nearest" 208 ) 209 if negative_shifts: 210 shifts = -shifts 211 return coords + shifts
Push voxel coordinates from fixed to moving space.
Arguments:
- coords (numpy.array): 3D voxel coordinates to be warped (3-by-n array)
Returns:
numpy.array: transformed voxel coordinates
213 def pull_coordinates(self, coords): 214 """Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates. 215 216 Args: 217 coords (numpy.array): 3D *voxel* coordinates to be warped (3-by-n array) 218 219 Returns: 220 numpy.array: transformed voxel coordinates 221 """ 222 return self.invert().push_coordinates(coords, negative_shifts=True)
Pull voxel coordinates through the warp field. Involves inversion, followed by pushing coordinates.
Arguments:
- coords (numpy.array): 3D voxel coordinates to be warped (3-by-n array)
Returns:
numpy.array: transformed voxel coordinates
224 def jacobian_det(self, units_per_voxel=[1, 1, 1], edge_order=1): 225 """ 226 Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid. 227 228 Args: 229 edge_order : passed to np.gradient (1 or 2) 230 231 Returns: 232 detJ: cp.ndarray of shape spatial 233 """ 234 scaling = cp.array(units_per_voxel, dtype="float32") * self.block_stride 235 coords = cp.indices(self.warp_field.shape[1:], dtype="float32") * scaling[:, None, None, None] 236 phi = coords + self.warp_field 237 J = cp.empty(self.warp_field.shape[1:] + (3, 3), dtype="float32") 238 for i in range(3): 239 grads = cp.gradient(phi[i], *scaling, edge_order=edge_order) 240 for j in range(3): 241 J[..., i, j] = grads[j] 242 return cp.linalg.det(J)
Compute det J = det(∇φ) for φ(x)=x+u(x), using np.indices for the identity grid.
Arguments:
- edge_order : passed to np.gradient (1 or 2)
Returns:
detJ: cp.ndarray of shape spatial
244 def as_ants_image(self, voxel_size_um=1): 245 """Convert to ANTsImage 246 247 Args: 248 voxel_size_um (scalar or array): voxel size (default is 1) 249 250 Returns: 251 ants.core.ants_image.ANTsImage: 252 """ 253 try: 254 import ants 255 except ImportError: 256 raise ImportError("ANTs is not installed. Please install it using 'pip install ants'") 257 258 ants_image = ants.from_numpy( 259 self.warp_field.get().transpose(1, 2, 3, 0), 260 origin=list((self.block_size.get() - 1) / 2 * voxel_size_um), 261 spacing=list(self.block_stride.get() * voxel_size_um), 262 has_components=True, 263 ) 264 return ants_image
Convert to ANTsImage
Arguments:
- voxel_size_um (scalar or array): voxel size (default is 1)
Returns:
ants.core.ants_image.ANTsImage:
277 def to_h5(self, h5_path, group="warp_map", compression="gzip", overwrite=True): 278 """ 279 Save this WarpMap to an HDF5 file. 280 281 Args: 282 h5_path (str or os.PathLike): Path to the HDF5 file. 283 group (str): Group path inside the HDF5 file to store the WarpMap (created if missing). 284 compression (str or None): Dataset compression (e.g., 'gzip', None). 285 overwrite (bool): If True, overwrite existing datasets/attrs inside the group. 286 """ 287 with h5py.File(h5_path, "a") as f: 288 if overwrite and (group not in (None, "", "/")) and (group in f): 289 del f[group] 290 if (not overwrite) and (group in f): 291 raise ValueError(f"Group '{group}' already exists in {h5_path}. Set 'overwrite=True' to overwrite it.") 292 grp = f.require_group(group) if group not in (None, "", "/") else f 293 grp.create_dataset("warp_field", data=self.warp_field.get(), compression=compression) 294 grp.create_dataset("block_size", data=self.block_size.get()) 295 grp.create_dataset("block_stride", data=self.block_stride.get()) 296 grp.create_dataset("ref_shape", data=np.array(self.ref_shape, dtype="int64")) 297 grp.create_dataset("mov_shape", data=np.array(self.mov_shape, dtype="int64")) 298 grp.attrs["class"] = "WarpMap"
Save this WarpMap to an HDF5 file.
Arguments:
- h5_path (str or os.PathLike): Path to the HDF5 file.
- group (str): Group path inside the HDF5 file to store the WarpMap (created if missing).
- compression (str or None): Dataset compression (e.g., 'gzip', None).
- overwrite (bool): If True, overwrite existing datasets/attrs inside the group.
300 @classmethod 301 def from_h5(cls, h5_path, group="warp_map"): 302 """ 303 Load a WarpMap from an HDF5 file. 304 305 Args: 306 h5_path (str or os.PathLike): Path to the HDF5 file. 307 group (str): Group path inside the HDF5 file where the WarpMap is stored. 308 309 Returns: 310 WarpMap: The loaded WarpMap object. 311 """ 312 with h5py.File(h5_path, "r") as f: 313 grp = f[group] 314 warp_field = grp["warp_field"][:] 315 block_size = grp["block_size"][:] 316 block_stride = grp["block_stride"][:] 317 ref_shape = tuple(grp["ref_shape"][:].tolist()) 318 mov_shape = tuple(grp["mov_shape"][:].tolist()) 319 return cls(warp_field, block_size, block_stride, ref_shape, mov_shape)
Load a WarpMap from an HDF5 file.
Arguments:
- h5_path (str or os.PathLike): Path to the HDF5 file.
- group (str): Group path inside the HDF5 file where the WarpMap is stored.
Returns:
WarpMap: The loaded WarpMap object.
322class WarpMapper: 323 """Class that estimates warp field using cross-correlation, based on a piece-wise rigid model. 324 325 Args: 326 ref_vol (numpy.array): The reference volume 327 block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated 328 block_stride (3-element list or numpy.array): stride (usually identical to block_size) 329 proj_method (str or callable): Projection method 330 """ 331 332 def __init__( 333 self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5 334 ): 335 if np.any(block_size > np.array(ref_vol.shape)): 336 raise ValueError(f"Block size (currently: {block_size}) must be smaller than the volume shape ({np.array(ref_vol.shape)}).") 337 self.proj_method = proj_method 338 self.plan_rev = [None, None, None] 339 self.subpixel = subpixel 340 self.epsilon = epsilon 341 self.tukey_alpha = tukey_alpha 342 self.update_reference(ref_vol, block_size, block_stride) 343 self.ref_shape = np.array(ref_vol.shape) 344 345 def update_reference(self, ref_vol, block_size, block_stride=None): 346 ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1)) 347 block_size = np.array(block_size) 348 block_stride = block_size if block_stride is None else np.array(block_stride) 349 ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride) 350 self.blocks_shape = ref_blocks.shape 351 ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]] 352 if self.tukey_alpha < 1: 353 ref_blocks_proj = [ 354 ref_blocks_proj[i] 355 * cp.array( 356 ndwindow( 357 [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5) 358 ) 359 ).astype("float32") 360 for i in range(3) 361 ] 362 self.plan_fwd = [ 363 cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3) 364 ] 365 self.ref_blocks_proj_ft_conj = [ 366 cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3) 367 ] 368 self.block_size = block_size 369 self.block_stride = block_stride 370 371 def get_displacement(self, vol, smooth_func=None): 372 """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks. 373 374 Args: 375 vol (numpy.array): Input volume 376 smooth_func (callable): Smoothing function to be applied to the cross-correlation volume 377 378 Returns: 379 WarpMap 380 """ 381 vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride) 382 vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]] 383 del vol_blocks 384 385 disp_field = [] 386 for i in range(3): 387 R = ( 388 cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]) 389 * self.ref_blocks_proj_ft_conj[i] 390 ) 391 if self.plan_rev[i] is None: 392 self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R") 393 xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1)) 394 if smooth_func is not None: 395 xcorr_proj = smooth_func(xcorr_proj, self.block_size) 396 xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon 397 398 max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:])) 399 max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2 400 del xcorr_proj 401 i0, j0 = max_ix.reshape(2, -1) 402 shifts = upsampled_dft_rfftn( 403 R.reshape(-1, *R.shape[-2:]), 404 upsampled_region_size=int(self.subpixel * 2 + 1), 405 upsample_factor=self.subpixel, 406 axis_offsets=(i0, j0), 407 ) 408 del R 409 max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:])) 410 max_sub = ( 411 max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2 412 ) / self.subpixel 413 del shifts 414 disp_field.append(max_ix + max_sub) 415 416 disp_field = cp.array(disp_field) 417 disp_field = ( 418 cp.array( 419 [ 420 disp_field[1, 0] + disp_field[2, 0], 421 disp_field[0, 0] + disp_field[2, 1], 422 disp_field[0, 1] + disp_field[1, 1], 423 ] 424 ).astype("float32") 425 / 2 426 ) 427 return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)
Class that estimates warp field using cross-correlation, based on a piece-wise rigid model.
Arguments:
- ref_vol (numpy.array): The reference volume
- block_size (3-element list or numpy.array): shape of blocks, whose rigid displacement is estimated
- block_stride (3-element list or numpy.array): stride (usually identical to block_size)
- proj_method (str or callable): Projection method
332 def __init__( 333 self, ref_vol, block_size, block_stride=None, proj_method=None, subpixel=4, epsilon=1e-6, tukey_alpha=0.5 334 ): 335 if np.any(block_size > np.array(ref_vol.shape)): 336 raise ValueError(f"Block size (currently: {block_size}) must be smaller than the volume shape ({np.array(ref_vol.shape)}).") 337 self.proj_method = proj_method 338 self.plan_rev = [None, None, None] 339 self.subpixel = subpixel 340 self.epsilon = epsilon 341 self.tukey_alpha = tukey_alpha 342 self.update_reference(ref_vol, block_size, block_stride) 343 self.ref_shape = np.array(ref_vol.shape)
345 def update_reference(self, ref_vol, block_size, block_stride=None): 346 ft = lambda arr: cp.fft.rfftn(arr, axes=(-2, -1)) 347 block_size = np.array(block_size) 348 block_stride = block_size if block_stride is None else np.array(block_stride) 349 ref_blocks = sliding_block(cp.array(ref_vol), block_size=block_size, block_stride=block_stride) 350 self.blocks_shape = ref_blocks.shape 351 ref_blocks_proj = [self.proj_method(ref_blocks, axis=iax) for iax in [-3, -2, -1]] 352 if self.tukey_alpha < 1: 353 ref_blocks_proj = [ 354 ref_blocks_proj[i] 355 * cp.array( 356 ndwindow( 357 [1, 1, 1, *ref_blocks_proj[i].shape[-2:]], lambda n: scipy.signal.windows.tukey(n, alpha=0.5) 358 ) 359 ).astype("float32") 360 for i in range(3) 361 ] 362 self.plan_fwd = [ 363 cupyx.scipy.fft.get_fft_plan(ref_blocks_proj[i], axes=(-2, -1), value_type="R2C") for i in range(3) 364 ] 365 self.ref_blocks_proj_ft_conj = [ 366 cupyx.scipy.fft.rfftn(ref_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]).conj() for i in range(3) 367 ] 368 self.block_size = block_size 369 self.block_stride = block_stride
371 def get_displacement(self, vol, smooth_func=None): 372 """Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks. 373 374 Args: 375 vol (numpy.array): Input volume 376 smooth_func (callable): Smoothing function to be applied to the cross-correlation volume 377 378 Returns: 379 WarpMap 380 """ 381 vol_blocks = sliding_block(vol, block_size=self.block_size, block_stride=self.block_stride) 382 vol_blocks_proj = [self.proj_method(vol_blocks, axis=iax) for iax in [-3, -2, -1]] 383 del vol_blocks 384 385 disp_field = [] 386 for i in range(3): 387 R = ( 388 cupyx.scipy.fft.rfftn(vol_blocks_proj[i], axes=(-2, -1), plan=self.plan_fwd[i]) 389 * self.ref_blocks_proj_ft_conj[i] 390 ) 391 if self.plan_rev[i] is None: 392 self.plan_rev[i] = cupyx.scipy.fft.get_fft_plan(R, axes=(-2, -1), value_type="C2R") 393 xcorr_proj = cp.fft.fftshift(cupyx.scipy.fft.irfftn(R, axes=(-2, -1), plan=self.plan_rev[i]), axes=(-2, -1)) 394 if smooth_func is not None: 395 xcorr_proj = smooth_func(xcorr_proj, self.block_size) 396 xcorr_proj[..., xcorr_proj.shape[-2] // 2, xcorr_proj.shape[-1] // 2] += self.epsilon 397 398 max_ix = cp.array(cp.unravel_index(cp.argmax(xcorr_proj, axis=(-2, -1)), xcorr_proj.shape[-2:])) 399 max_ix = max_ix - cp.array(xcorr_proj.shape[-2:])[:, None, None, None] // 2 400 del xcorr_proj 401 i0, j0 = max_ix.reshape(2, -1) 402 shifts = upsampled_dft_rfftn( 403 R.reshape(-1, *R.shape[-2:]), 404 upsampled_region_size=int(self.subpixel * 2 + 1), 405 upsample_factor=self.subpixel, 406 axis_offsets=(i0, j0), 407 ) 408 del R 409 max_sub = cp.array(cp.unravel_index(cp.argmax(shifts, axis=(-2, -1)), shifts.shape[-2:])) 410 max_sub = ( 411 max_sub.reshape(max_ix.shape) - cp.array(shifts.shape[-2:])[:, None, None, None] // 2 412 ) / self.subpixel 413 del shifts 414 disp_field.append(max_ix + max_sub) 415 416 disp_field = cp.array(disp_field) 417 disp_field = ( 418 cp.array( 419 [ 420 disp_field[1, 0] + disp_field[2, 0], 421 disp_field[0, 0] + disp_field[2, 1], 422 disp_field[0, 1] + disp_field[1, 1], 423 ] 424 ).astype("float32") 425 / 2 426 ) 427 return WarpMap(disp_field, self.block_size, self.block_stride, self.ref_shape, vol.shape)
Estimate the displacement of vol with the reference volume, via piece-wise rigid cross-correlation with the pre-saved blocks.
Arguments:
- vol (numpy.array): Input volume
- smooth_func (callable): Smoothing function to be applied to the cross-correlation volume
Returns:
WarpMap
430class RegistrationPyramid: 431 """A class for performing multi-resolution registration. 432 433 Args: 434 ref_vol (numpy.array): Reference volume 435 settings (pandas.DataFrame): Settings for each level of the pyramid. 436 IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level. 437 reg_mask (numpy.array): Mask for registration 438 clip_thresh (float): Threshold for clipping the reference volume 439 """ 440 441 def __init__(self, ref_vol, recipe, reg_mask=1): 442 recipe.model_validate(recipe.model_dump()) 443 self.recipe = recipe 444 self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C") 445 self.mappers = [] 446 ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C") 447 self.ref_shape = ref_vol.shape 448 if self.recipe.pre_filter is not None: 449 ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask) 450 self.mapper_ix = [] 451 for i in range(len(recipe.levels)): 452 if recipe.levels[i].repeats < 1: 453 continue 454 block_size = np.array(recipe.levels[i].block_size) 455 tmp = np.r_[ref_vol.shape] // -block_size 456 block_size[block_size < 0] = tmp[block_size < 0] 457 if isinstance(recipe.levels[i].block_stride, (int, float)): 458 block_stride = (block_size * recipe.levels[i].block_stride).astype("int") 459 else: 460 block_stride = np.array(recipe.levels[i].block_stride) 461 self.mappers.append( 462 WarpMapper( 463 ref_vol, 464 block_size, 465 block_stride=block_stride, 466 proj_method=recipe.levels[i].project, 467 tukey_alpha=recipe.levels[i].tukey_ref, 468 ) 469 ) 470 self.mapper_ix.append(i) 471 assert len(self.mappers) > 0, "At least one level of registration is required" 472 473 def register_single(self, vol, callback=None, verbose=False): 474 """Register a single volume to the reference volume. 475 476 Args: 477 vol (array_like): Volume to be registered (numpy or cupy array) 478 callback (function): Callback function to be called after each level of registration 479 480 Returns: 481 - vol (array_like): Registered volume (numpy or cupy array, depending on input) 482 - warp_map (WarpMap): Displacement field 483 - callback_output (list): List of outputs from the callback function 484 """ 485 was_numpy = isinstance(vol, np.ndarray) 486 vol = cp.array(vol, "float32", copy=False, order="C") 487 offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2 488 warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape) 489 warp_map = warp_map.resize_to(self.mappers[-1]) 490 callback_output = [] 491 vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol 492 vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C") 493 warp_map.warp(vol_tmp0, out=vol_tmp) 494 min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0) 495 if callback is not None: 496 callback_output.append(callback(vol_tmp)) 497 498 if np.any(self.mappers[-1].block_stride > min_block_stride[0]): 499 warnings.warn( 500 "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)." 501 ) 502 for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)): 503 for _ in tqdm( 504 range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose 505 ): 506 wm = mapper.get_displacement( 507 vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth # * self.reg_mask, 508 ) 509 wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate 510 if self.recipe.levels[self.mapper_ix[k]].median_filter: 511 wm = wm.median_filter() 512 if self.recipe.levels[self.mapper_ix[k]].affine: 513 if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1: 514 raise ValueError( 515 f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}" 516 ) 517 wm, _ = wm.fit_affine( 518 target=dict( 519 warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]), 520 block_size=self.mappers[-1].block_size, 521 block_stride=self.mappers[-1].block_stride, 522 ) 523 ) 524 else: 525 wm = wm.resize_to(self.mappers[-1]) 526 527 warp_map = warp_map.chain(wm) 528 warp_map.warp(vol_tmp0, out=vol_tmp) 529 if callback is not None: 530 # callback_output.append(callback(warp_map.unwarp(vol))) 531 callback_output.append(callback(vol_tmp)) 532 warp_map.warp(vol, out=vol_tmp) 533 if was_numpy: 534 vol_tmp = vol_tmp.get() 535 return vol_tmp, warp_map, callback_output
A class for performing multi-resolution registration.
Arguments:
- ref_vol (numpy.array): Reference volume
- settings (pandas.DataFrame): Settings for each level of the pyramid. IMPORTANT: the block sizea in the last level cannot be larger than the block_size in any previous level.
- reg_mask (numpy.array): Mask for registration
- clip_thresh (float): Threshold for clipping the reference volume
441 def __init__(self, ref_vol, recipe, reg_mask=1): 442 recipe.model_validate(recipe.model_dump()) 443 self.recipe = recipe 444 self.reg_mask = cp.array(reg_mask, dtype="float32", copy=False, order="C") 445 self.mappers = [] 446 ref_vol = cp.array(ref_vol, dtype="float32", copy=False, order="C") 447 self.ref_shape = ref_vol.shape 448 if self.recipe.pre_filter is not None: 449 ref_vol = self.recipe.pre_filter(ref_vol, reg_mask=self.reg_mask) 450 self.mapper_ix = [] 451 for i in range(len(recipe.levels)): 452 if recipe.levels[i].repeats < 1: 453 continue 454 block_size = np.array(recipe.levels[i].block_size) 455 tmp = np.r_[ref_vol.shape] // -block_size 456 block_size[block_size < 0] = tmp[block_size < 0] 457 if isinstance(recipe.levels[i].block_stride, (int, float)): 458 block_stride = (block_size * recipe.levels[i].block_stride).astype("int") 459 else: 460 block_stride = np.array(recipe.levels[i].block_stride) 461 self.mappers.append( 462 WarpMapper( 463 ref_vol, 464 block_size, 465 block_stride=block_stride, 466 proj_method=recipe.levels[i].project, 467 tukey_alpha=recipe.levels[i].tukey_ref, 468 ) 469 ) 470 self.mapper_ix.append(i) 471 assert len(self.mappers) > 0, "At least one level of registration is required"
473 def register_single(self, vol, callback=None, verbose=False): 474 """Register a single volume to the reference volume. 475 476 Args: 477 vol (array_like): Volume to be registered (numpy or cupy array) 478 callback (function): Callback function to be called after each level of registration 479 480 Returns: 481 - vol (array_like): Registered volume (numpy or cupy array, depending on input) 482 - warp_map (WarpMap): Displacement field 483 - callback_output (list): List of outputs from the callback function 484 """ 485 was_numpy = isinstance(vol, np.ndarray) 486 vol = cp.array(vol, "float32", copy=False, order="C") 487 offsets = (cp.array(vol.shape) - cp.array(self.ref_shape)) / 2 488 warp_map = WarpMap(offsets[:, None, None, None], cp.ones(3), cp.ones(3), self.ref_shape, vol.shape) 489 warp_map = warp_map.resize_to(self.mappers[-1]) 490 callback_output = [] 491 vol_tmp0 = self.recipe.pre_filter(vol, reg_mask=self.reg_mask) if self.recipe.pre_filter is not None else vol 492 vol_tmp = cp.zeros(self.ref_shape, dtype="float32", order="C") 493 warp_map.warp(vol_tmp0, out=vol_tmp) 494 min_block_stride = np.min([mapper.block_stride for mapper in self.mappers], axis=0) 495 if callback is not None: 496 callback_output.append(callback(vol_tmp)) 497 498 if np.any(self.mappers[-1].block_stride > min_block_stride[0]): 499 warnings.warn( 500 "The block stride (in voxels) in the last level should not be larger than the block stride in any previous level (along any axis)." 501 ) 502 for k, mapper in enumerate(tqdm(self.mappers, desc=f"Levels", disable=not verbose)): 503 for _ in tqdm( 504 range(self.recipe.levels[self.mapper_ix[k]].repeats), leave=False, desc=f"Repeats", disable=not verbose 505 ): 506 wm = mapper.get_displacement( 507 vol_tmp, smooth_func=self.recipe.levels[self.mapper_ix[k]].smooth # * self.reg_mask, 508 ) 509 wm.warp_field *= self.recipe.levels[self.mapper_ix[k]].update_rate 510 if self.recipe.levels[self.mapper_ix[k]].median_filter: 511 wm = wm.median_filter() 512 if self.recipe.levels[self.mapper_ix[k]].affine: 513 if (np.array(mapper.blocks_shape[:3]) < 2).sum() > 1: 514 raise ValueError( 515 f"Affine fit needs at least two axes with at least 2 blocks! Volume shape: {self.ref_shape}; block size: {mapper.block_size}" 516 ) 517 wm, _ = wm.fit_affine( 518 target=dict( 519 warp_field_shape=(3, *self.mappers[-1].blocks_shape[:3]), 520 block_size=self.mappers[-1].block_size, 521 block_stride=self.mappers[-1].block_stride, 522 ) 523 ) 524 else: 525 wm = wm.resize_to(self.mappers[-1]) 526 527 warp_map = warp_map.chain(wm) 528 warp_map.warp(vol_tmp0, out=vol_tmp) 529 if callback is not None: 530 # callback_output.append(callback(warp_map.unwarp(vol))) 531 callback_output.append(callback(vol_tmp)) 532 warp_map.warp(vol, out=vol_tmp) 533 if was_numpy: 534 vol_tmp = vol_tmp.get() 535 return vol_tmp, warp_map, callback_output
Register a single volume to the reference volume.
Arguments:
- vol (array_like): Volume to be registered (numpy or cupy array)
- callback (function): Callback function to be called after each level of registration
Returns:
- vol (array_like): Registered volume (numpy or cupy array, depending on input)
- warp_map (WarpMap): Displacement field
- callback_output (list): List of outputs from the callback function
538def register_volumes(ref, vol, recipe, reg_mask=1, callback=None, verbose=True, video_path=None, vmax=None): 539 """Register a volume to a reference volume using a registration pyramid. 540 541 Args: 542 ref (numpy.array or cupy.array): Reference volume 543 vol (numpy.array or cupy.array): Volume to be registered 544 recipe (Recipe): Registration recipe 545 reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask) 546 callback (function): Callback function to be called on the volume after each iteration. Default is None. 547 Can be used to monitor and optimize registration. Example: `callback = lambda vol: vol.mean(1).get()` 548 (note that `vol` is a 3D cupy array. Use `.get()` to turn the output into a numpy array and save GPU memory). 549 Callback outputs for each registration step will be returned as a list. 550 verbose (bool): If True, show progress bars. Default is True 551 video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None. 552 vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values. 553 554 Returns: 555 - numpy.array or cupy.array (depending on vol input): Registered volume 556 - WarpMap: Displacement field 557 - list: List of outputs from the callback function 558 """ 559 recipe.model_validate(recipe.model_dump()) 560 reg = RegistrationPyramid(ref, recipe, reg_mask=reg_mask) 561 registered_vol, warp_map, cbout = reg.register_single(vol, callback=callback, verbose=verbose) 562 del reg 563 gc.collect() 564 cp.fft.config.get_plan_cache().clear() 565 566 if video_path is not None: 567 try: 568 assert cbout[0].ndim == 2, "Callback output must be a 2D array" 569 ref = callback(recipe.pre_filter(ref)) 570 vmax = np.percentile(ref, 99.9).item() if vmax is None else vmax 571 create_rgb_video(video_path, ref / vmax, np.array(cbout) / vmax, fps=10) 572 except (ValueError, AssertionError) as e: 573 warnings.warn(f"Video generation failed with error: {e}") 574 return registered_vol, warp_map, cbout
Register a volume to a reference volume using a registration pyramid.
Arguments:
- ref (numpy.array or cupy.array): Reference volume
- vol (numpy.array or cupy.array): Volume to be registered
- recipe (Recipe): Registration recipe
- reg_mask (numpy.array): Mask to be multiplied with the reference volume. Default is 1 (no mask)
- callback (function): Callback function to be called on the volume after each iteration. Default is None.
Can be used to monitor and optimize registration. Example:
callback = lambda vol: vol.mean(1).get()(note thatvolis a 3D cupy array. Use.get()to turn the output into a numpy array and save GPU memory). Callback outputs for each registration step will be returned as a list. - verbose (bool): If True, show progress bars. Default is True
- video_path (str): Save a video of the registration process, using callback outputs. The callback has to return 2D frames. Default is None.
- vmax (float): Maximum pixel value (to scale video brightness). If none, set to 99.9 percentile of pixel values.
Returns:
- numpy.array or cupy.array (depending on vol input): Registered volume
- WarpMap: Displacement field
- list: List of outputs from the callback function
577class Projector(BaseModel): 578 """A class to apply a 2D projection and filters to a volume block 579 580 Parameters: 581 max: if True, apply a max filter to the volume block. Default is True 582 normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False 583 dog: if True, apply a DoG filter to the volume block. Default is True 584 low: the lower sigma value for the DoG filter. Default is 0.5 585 high: the higher sigma value for the DoG filter. Default is 10.0 586 tukey_env: if True, apply a Tukey window to the output. Default is False 587 gauss_env: if True, apply a Gaussian window to the output. Default is False 588 """ 589 590 max: bool = True 591 normalize: Union[bool, float] = False 592 dog: bool = True 593 low: Union[Union[int, float], List[Union[int, float]]] = 0.5 594 high: Union[Union[int, float], List[Union[int, float]]] = 10.0 595 periodic_smooth: bool = False 596 597 def __call__(self, vol_blocks, axis): 598 """Apply a 2D projection and filters to a volume block 599 Args: 600 vol_blocks (cupy.array): Blocked volume to be projected (6D dataset, with the first 3 dimensions being blocks and the last 3 dimensions being voxels) 601 axis (int): Axis along which to project 602 Returns: 603 cupy.array: Projected volume block (5D dataset, with the first 3 dimensions being blocks and the last 2 dimensions being 2D projections) 604 """ 605 if self.max: 606 out = vol_blocks.max(axis) 607 else: 608 out = vol_blocks.mean(axis) 609 if self.periodic_smooth: 610 out = periodic_smooth_decomposition_nd_rfft(out) 611 low = np.delete(np.r_[1,1,1] * self.low, axis) 612 high = np.delete(np.r_[1,1,1] * self.high, axis) 613 if self.dog: 614 out = dogfilter(out, [0, 0, 0, *low], [0, 0, 0, *high], mode="reflect") 615 elif not np.all(np.array(self.low) == 0): 616 out = cupyx.scipy.ndimage.gaussian_filter(out, [0, 0, 0, *low], mode="reflect", truncate=5.0) 617 if self.normalize > 0: 618 out /= cp.sqrt(cp.sum(out**2, axis=(-2, -1), keepdims=True)) ** self.normalize + 1e-9 619 return out
A class to apply a 2D projection and filters to a volume block
Arguments:
- max: if True, apply a max filter to the volume block. Default is True
- normalize: if True, normalize projections by the L2 norm (to get correlations, not covariances). Default is False
- dog: if True, apply a DoG filter to the volume block. Default is True
- low: the lower sigma value for the DoG filter. Default is 0.5
- high: the higher sigma value for the DoG filter. Default is 10.0
- tukey_env: if True, apply a Tukey window to the output. Default is False
- gauss_env: if True, apply a Gaussian window to the output. Default is False
622class Smoother(BaseModel): 623 """Smooth blocks with a Gaussian kernel 624 Args: 625 sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied. 626 truncate (float): truncate parameter for gaussian kernel. Default is 5. 627 shear (float): shear parameter for gaussian kernel. Default is None. 628 long_range_ratio (float): long range ratio for double gaussian kernel. Default is None. 629 """ 630 631 sigmas: Union[float, List[float]] = [1.0, 1.0, 1.0] 632 shear: Union[float, None] = None 633 long_range_ratio: Union[float, None] = 0.05 634 635 def __call__(self, xcorr_proj, block_size=None): 636 """Apply a Gaussian filter to the cross-correlation data 637 Args: 638 xcorr_proj (cupy.array): cross-correlation data (5D array, with the first 3 dimensions being the blocks and the last 2 dimensions being the 2D projection) 639 block_size (list): shape of blocks, whose rigid displacement is estimated 640 Returns: 641 cupy.array: smoothed cross-correlation volume 642 """ 643 truncate = 4.0 644 if self.sigmas is None: 645 return xcorr_proj 646 if self.shear is not None: 647 shear_blocks = self.shear * (block_size[1] / block_size[0]) 648 gw = gausskernel_sheared(self.sigma[:2], shear_blocks, truncate=truncate) 649 gw = cp.array(gw[:, :, None, None, None]) 650 xcorr_proj = cupyx.scipy.ndimage.convolve(xcorr_proj, gw, mode="constant") 651 xcorr_proj = cupyx.scipy.ndimage.gaussian_filter1d( 652 xcorr_proj, self.sigmas[2], axis=2, mode="constant", truncate=truncate 653 ) 654 else: # shear is None: 655 xcorr_proj = cupyx.scipy.ndimage.gaussian_filter( 656 xcorr_proj, [*self.sigmas, 0, 0], mode="constant", truncate=truncate 657 ) 658 if self.long_range_ratio is not None: 659 xcorr_proj *= 1 - self.long_range_ratio 660 xcorr_proj += ( 661 cupyx.scipy.ndimage.gaussian_filter( 662 xcorr_proj, [*np.array(self.sigmas) * 5, 0, 0], mode="constant", truncate=truncate 663 ) 664 * self.long_range_ratio 665 ) 666 return xcorr_proj
Smooth blocks with a Gaussian kernel
Arguments:
- sigmas (list): [sigma0, sigma1, sigma2]. If None, no smoothing is applied.
- truncate (float): truncate parameter for gaussian kernel. Default is 5.
- shear (float): shear parameter for gaussian kernel. Default is None.
- long_range_ratio (float): long range ratio for double gaussian kernel. Default is None.
669class RegFilter(BaseModel): 670 """A class to apply a filter to the volume before registration 671 672 Parameters: 673 clip_thresh: threshold for clipping the reference volume. Default is 0 674 dog: if True, apply a DoG filter to the volume. Default is True 675 low: the lower sigma value for the DoG filter. Default is 0.5 676 high: the higher sigma value for the DoG filter. Default is 10.0 677 """ 678 679 clip_thresh: float = 0 680 dog: bool = True 681 low: float = 0.5 682 high: float = 10.0 683 soft_edge: Union[Union[int, float], List[Union[int, float]]] = 0.0 684 685 def __call__(self, vol, reg_mask=None): 686 """Apply the filter to the volume 687 Args: 688 vol (cupy or numpy array): 3D volume to be filtered 689 reg_mask (array): Mask for registration 690 Returns: 691 cupy.ndarray: Filtered volume 692 """ 693 vol = cp.clip(cp.array(vol, "float32", copy=False) - self.clip_thresh, 0, None) 694 if np.any(np.array(self.soft_edge) > 0): 695 vol = soften_edges(vol, soft_edge=self.soft_edge, copy=False) 696 if reg_mask is not None: 697 vol *= cp.array(reg_mask, dtype="float32", copy=False) 698 if self.dog: 699 vol = dogfilter(vol, self.low, self.high, mode="reflect") 700 return vol
A class to apply a filter to the volume before registration
Arguments:
- clip_thresh: threshold for clipping the reference volume. Default is 0
- dog: if True, apply a DoG filter to the volume. Default is True
- low: the lower sigma value for the DoG filter. Default is 0.5
- high: the higher sigma value for the DoG filter. Default is 10.0
703class LevelConfig(BaseModel): 704 """Configuration for each level of the registration pyramid 705 706 Args: 707 block_size (list): shape of blocks, whose rigid displacement is estimated 708 block_stride (list): stride (usually identical to block_size) 709 repeats (int): number of iterations for this level (deisable level by setting repeats to 0) 710 smooth (Smoother or None): Smoother object 711 project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block. 712 tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5 713 affine (bool): if True, apply affine transformation to the displacement field 714 median_filter (bool): if True, apply median filter to the displacement field 715 update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations. 716 """ 717 718 block_size: Union[List[int]] 719 block_stride: Union[List[int], float] = 1.0 720 project: Union[Projector, Callable[[_ArrayType, int], _ArrayType]] = Projector() 721 tukey_ref: Union[float, None] = 0.5 722 smooth: Union[Smoother, None] = Smoother() 723 affine: bool = False 724 median_filter: bool = True 725 update_rate: float = 1.0 726 repeats: int = 5
Configuration for each level of the registration pyramid
Arguments:
- block_size (list): shape of blocks, whose rigid displacement is estimated
- block_stride (list): stride (usually identical to block_size)
- repeats (int): number of iterations for this level (deisable level by setting repeats to 0)
- smooth (Smoother or None): Smoother object
- project (Projector, callable or None): Projector object. The callable should take a volume block and an axis as input and return a projected volume block.
- tukey_ref (float): if not None, apply a Tukey window to the reference volume (alpha = tukey_ref). Default is 0.5
- affine (bool): if True, apply affine transformation to the displacement field
- median_filter (bool): if True, apply median filter to the displacement field
- update_rate (float): update rate for the displacement field. Default is 1.0. Can be lowered to dampen oscillations.
729class Recipe(BaseModel): 730 """Configuration for the registration recipe. Recipe is initialized with a single affine level. 731 732 Args: 733 reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume 734 levels (list): List of LevelConfig objects 735 """ 736 737 pre_filter: Union[RegFilter, Callable[[_ArrayType], _ArrayType], None] = RegFilter() 738 levels: List[LevelConfig] = [ 739 LevelConfig(block_size=[-1, -1, -1], repeats=3), # translation level 740 LevelConfig( # affine level 741 block_size=[-2, -2, -2], 742 block_stride=0.5, 743 repeats=10, 744 affine=True, 745 median_filter=False, 746 smooth=Smoother(sigmas=[0.5, 0.5, 0.5]), 747 ), 748 ] 749 750 def add_level(self, block_size, **kwargs): 751 """Add a level to the registration recipe 752 753 Args: 754 block_size (list): shape of blocks, whose rigid displacement is estimated 755 **kwargs: additional arguments for LevelConfig 756 """ 757 if isinstance(block_size, (int, float)): 758 block_size = [block_size] * 3 759 if len(block_size) != 3: 760 raise ValueError("block_size must be a list of 3 integers") 761 self.levels.append(LevelConfig(block_size=block_size, **kwargs)) 762 763 def insert_level(self, index, block_size, **kwargs): 764 """Insert a level to the registration recipe 765 766 Args: 767 index (int): A number specifying in which position to insert the level 768 block_size (list): shape of blocks, whose rigid displacement is estimated 769 **kwargs: additional arguments for LevelConfig 770 """ 771 if isinstance(block_size, (int, float)): 772 block_size = [block_size] * 3 773 if len(block_size) != 3: 774 raise ValueError("block_size must be a list of 3 integers") 775 self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs)) 776 777 @classmethod 778 def from_yaml(cls, yaml_path): 779 """Load a recipe from a YAML file 780 781 Args: 782 yaml_path (str): path to the YAML file 783 784 Returns: 785 Recipe: Recipe object 786 """ 787 import yaml 788 789 this_file_dir = pathlib.Path(__file__).resolve().parent 790 if os.path.isfile(yaml_path): 791 yaml_path = yaml_path 792 else: 793 yaml_path = os.path.join(this_file_dir, "recipes", yaml_path) 794 795 with open(yaml_path, "r") as f: 796 data = yaml.safe_load(f) 797 798 return cls.model_validate(data) 799 800 def to_yaml(self, yaml_path): 801 """Save the recipe to a YAML file 802 803 Args: 804 yaml_path (str): path to the YAML file 805 """ 806 import yaml 807 808 with open(yaml_path, "w") as f: 809 yaml.dump(self.model_dump(), f) 810 print(f"Recipe saved to {yaml_path}")
Configuration for the registration recipe. Recipe is initialized with a single affine level.
Arguments:
- reg_filter (RegFilter, callable or None): Filter to be applied to the reference volume
- levels (list): List of LevelConfig objects
750 def add_level(self, block_size, **kwargs): 751 """Add a level to the registration recipe 752 753 Args: 754 block_size (list): shape of blocks, whose rigid displacement is estimated 755 **kwargs: additional arguments for LevelConfig 756 """ 757 if isinstance(block_size, (int, float)): 758 block_size = [block_size] * 3 759 if len(block_size) != 3: 760 raise ValueError("block_size must be a list of 3 integers") 761 self.levels.append(LevelConfig(block_size=block_size, **kwargs))
Add a level to the registration recipe
Arguments:
- block_size (list): shape of blocks, whose rigid displacement is estimated
- **kwargs: additional arguments for LevelConfig
763 def insert_level(self, index, block_size, **kwargs): 764 """Insert a level to the registration recipe 765 766 Args: 767 index (int): A number specifying in which position to insert the level 768 block_size (list): shape of blocks, whose rigid displacement is estimated 769 **kwargs: additional arguments for LevelConfig 770 """ 771 if isinstance(block_size, (int, float)): 772 block_size = [block_size] * 3 773 if len(block_size) != 3: 774 raise ValueError("block_size must be a list of 3 integers") 775 self.levels.insert(index, LevelConfig(block_size=block_size, **kwargs))
Insert a level to the registration recipe
Arguments:
- index (int): A number specifying in which position to insert the level
- block_size (list): shape of blocks, whose rigid displacement is estimated
- **kwargs: additional arguments for LevelConfig
777 @classmethod 778 def from_yaml(cls, yaml_path): 779 """Load a recipe from a YAML file 780 781 Args: 782 yaml_path (str): path to the YAML file 783 784 Returns: 785 Recipe: Recipe object 786 """ 787 import yaml 788 789 this_file_dir = pathlib.Path(__file__).resolve().parent 790 if os.path.isfile(yaml_path): 791 yaml_path = yaml_path 792 else: 793 yaml_path = os.path.join(this_file_dir, "recipes", yaml_path) 794 795 with open(yaml_path, "r") as f: 796 data = yaml.safe_load(f) 797 798 return cls.model_validate(data)
Load a recipe from a YAML file
Arguments:
- yaml_path (str): path to the YAML file
Returns:
Recipe: Recipe object
800 def to_yaml(self, yaml_path): 801 """Save the recipe to a YAML file 802 803 Args: 804 yaml_path (str): path to the YAML file 805 """ 806 import yaml 807 808 with open(yaml_path, "w") as f: 809 yaml.dump(self.model_dump(), f) 810 print(f"Recipe saved to {yaml_path}")
Save the recipe to a YAML file
Arguments:
- yaml_path (str): path to the YAML file