• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / ibllib / 9216753496079630

pending completion
9216753496079630

Pull #565

continuous-integration/UCL

olivier
add test
Pull Request #565: Regions volume

5 of 5 new or added lines in 1 file covered. (100.0%)

3055 of 17940 relevant lines covered (17.03%)

0.17 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

28.96
/ibllib/atlas/atlas.py
1
from dataclasses import dataclass
1✔
2
import logging
1✔
3
import matplotlib.pyplot as plt
1✔
4
from pathlib import Path, PurePosixPath
1✔
5
import numpy as np
1✔
6
import nrrd
1✔
7

8
from one.webclient import http_download_file
1✔
9
import one.params
1✔
10
import one.remote.aws as aws
1✔
11

12
from iblutil.numerical import ismember
1✔
13
from ibllib.atlas.regions import BrainRegions, FranklinPaxinosRegions
1✔
14

15

16
_logger = logging.getLogger(__name__)
1✔
17
ALLEN_CCF_LANDMARKS_MLAPDV_UM = {'bregma': np.array([5739, 5400, 332])}
1✔
18
PAXINOS_CCF_LANDMARKS_MLAPDV_UM = {'bregma': np.array([5700, 4300 + 160, 330])}
1✔
19

20
S3_BUCKET_IBL = 'ibl-brain-wide-map-public'
1✔
21

22

23
def cart2sph(x, y, z):
1✔
24
    """
25
    Converts cartesian to spherical Coordinates
26
    theta: polar angle, phi: azimuth
27
    """
28
    r = np.sqrt(x ** 2 + y ** 2 + z ** 2)
×
29
    phi = np.arctan2(y, x) * 180 / np.pi
×
30
    theta = np.zeros_like(r)
×
31
    iok = r != 0
×
32
    theta[iok] = np.arccos(z[iok] / r[iok]) * 180 / np.pi
×
33
    if theta.size == 1:
×
34
        theta = float(theta)
×
35
    return r, theta, phi
×
36

37

38
def sph2cart(r, theta, phi):
1✔
39
    """
40
    Converts Spherical to Cartesian coordinates
41
    theta: polar angle, phi: azimuth
42
    """
43
    x = r * np.cos(phi / 180 * np.pi) * np.sin(theta / 180 * np.pi)
×
44
    y = r * np.sin(phi / 180 * np.pi) * np.sin(theta / 180 * np.pi)
×
45
    z = r * np.cos(theta / 180 * np.pi)
×
46
    return x, y, z
×
47

48

49
class BrainCoordinates:
1✔
50
    """
1✔
51
    Class for mapping and indexing a 3D array to real-world coordinates
52
    x = ml, right positive
53
    y = ap, anterior positive
54
    z = dv, dorsal positive
55

56
    The layout of the Atlas dimension is done according to the most used sections so they lay
57
    contiguous on disk assuming C-ordering: V[iap, iml, idv]
58

59
    nxyz: number of elements along each cartesian axis (nx, ny, nz) = (nml, nap, ndv)
60
    xyz0: coordinates of the element volume[0, 0, 0]] in the coordinate space
61
    dxyz: spatial interval of the volume along the 3 dimensions
62
    """
63

64
    def __init__(self, nxyz, xyz0=[0, 0, 0], dxyz=[1, 1, 1]):
1✔
65
        if np.isscalar(dxyz):
1✔
66
            dxyz = [dxyz for i in range(3)]
×
67
        self.x0, self.y0, self.z0 = list(xyz0)
1✔
68
        self.dx, self.dy, self.dz = list(dxyz)
1✔
69
        self.nx, self.ny, self.nz = list(nxyz)
1✔
70

71
    @property
1✔
72
    def dxyz(self):
1✔
73
        return np.array([self.dx, self.dy, self.dz])
×
74

75
    @property
1✔
76
    def nxyz(self):
1✔
77
        return np.array([self.nx, self.ny, self.nz])
×
78

79
    """Methods ratios to indice"""
80
    def r2ix(self, r):
1✔
81
        return int((self.nx - 1) * r)
×
82

83
    def r2iy(self, r):
1✔
84
        return int((self.nz - 1) * r)
×
85

86
    def r2iz(self, r):
1✔
87
        return int((self.nz - 1) * r)
×
88

89
    """Methods distance to indice"""
90
    @staticmethod
1✔
91
    def _round(i, round=True):
1✔
92
        nanval = 0
×
93
        if round:
×
94
            ii = np.array(np.round(i)).astype(int)
×
95
            ii[np.isnan(i)] = nanval
×
96
            return ii
×
97
        else:
98
            return i
×
99

100
    def x2i(self, x, round=True, mode='raise'):
1✔
101
        i = np.asarray(self._round((x - self.x0) / self.dx, round=round))
×
102
        if np.any(i < 0) or np.any(i >= self.nx):
×
103
            if mode == 'clip':
×
104
                i[i < 0] = 0
×
105
                i[i >= self.nx] = self.nx - 1
×
106
            elif mode == 'raise':
×
107
                raise ValueError("At least one x value lies outside of the atlas volume.")
×
108
            elif mode == 'wrap':
×
109
                pass
110
        return i
×
111

112
    def y2i(self, y, round=True, mode='raise'):
1✔
113
        i = np.asarray(self._round((y - self.y0) / self.dy, round=round))
×
114
        if np.any(i < 0) or np.any(i >= self.ny):
×
115
            if mode == 'clip':
×
116
                i[i < 0] = 0
×
117
                i[i >= self.ny] = self.ny - 1
×
118
            elif mode == 'raise':
×
119
                raise ValueError("At least one y value lies outside of the atlas volume.")
×
120
            elif mode == 'wrap':
×
121
                pass
122
        return i
×
123

124
    def z2i(self, z, round=True, mode='raise'):
1✔
125
        i = np.asarray(self._round((z - self.z0) / self.dz, round=round))
×
126
        if np.any(i < 0) or np.any(i >= self.nz):
×
127
            if mode == 'clip':
×
128
                i[i < 0] = 0
×
129
                i[i >= self.nz] = self.nz - 1
×
130
            elif mode == 'raise':
×
131
                raise ValueError("At least one z value lies outside of the atlas volume.")
×
132
            elif mode == 'wrap':
×
133
                pass
134
        return i
×
135

136
    def xyz2i(self, xyz, round=True, mode='raise'):
1✔
137
        """
138
        :param mode: {‘raise’, 'clip', 'wrap'} determines what to do when determined index lies outside the atlas volume
139
                     'raise' will raise a ValueError
140
                     'clip' will replace the index with the closest index inside the volume
141
                     'wrap' will wrap around to the other side of the volume. This is only here for legacy reasons
142
        """
143
        xyz = np.array(xyz)
×
144
        dt = int if round else float
×
145
        out = np.zeros_like(xyz, dtype=dt)
×
146
        out[..., 0] = self.x2i(xyz[..., 0], round=round, mode=mode)
×
147
        out[..., 1] = self.y2i(xyz[..., 1], round=round, mode=mode)
×
148
        out[..., 2] = self.z2i(xyz[..., 2], round=round, mode=mode)
×
149
        return out
×
150

151
    """Methods indices to distance"""
152
    def i2x(self, ind):
1✔
153
        return ind * self.dx + self.x0
1✔
154

155
    def i2y(self, ind):
1✔
156
        return ind * self.dy + self.y0
1✔
157

158
    def i2z(self, ind):
1✔
159
        return ind * self.dz + self.z0
1✔
160

161
    def i2xyz(self, iii):
1✔
162
        iii = np.array(iii, dtype=float)
1✔
163
        out = np.zeros_like(iii)
1✔
164
        out[..., 0] = self.i2x(iii[..., 0])
1✔
165
        out[..., 1] = self.i2y(iii[..., 1])
1✔
166
        out[..., 2] = self.i2z(iii[..., 2])
1✔
167
        return out
1✔
168

169
    """Methods bounds"""
170
    @property
1✔
171
    def xlim(self):
1✔
172
        return self.i2x(np.array([0, self.nx - 1]))
×
173

174
    @property
1✔
175
    def ylim(self):
1✔
176
        return self.i2y(np.array([0, self.ny - 1]))
×
177

178
    @property
1✔
179
    def zlim(self):
1✔
180
        return self.i2z(np.array([0, self.nz - 1]))
×
181

182
    def lim(self, axis):
1✔
183
        if axis == 0:
×
184
            return self.xlim
×
185
        elif axis == 1:
×
186
            return self.ylim
×
187
        elif axis == 2:
×
188
            return self.zlim
×
189

190
    """returns scales"""
191
    @property
1✔
192
    def xscale(self):
1✔
193
        return self.i2x(np.arange(self.nx))
×
194

195
    @property
1✔
196
    def yscale(self):
1✔
197
        return self.i2y(np.arange(self.ny))
×
198

199
    @property
1✔
200
    def zscale(self):
1✔
201
        return self.i2z(np.arange(self.nz))
×
202

203
    """returns the 3d mgrid used for 3d visualization"""
204
    @property
1✔
205
    def mgrid(self):
1✔
206
        return np.meshgrid(self.xscale, self.yscale, self.zscale)
×
207

208

209
class BrainAtlas:
1✔
210
    """
1✔
211
    Objects that holds image, labels and coordinate transforms for a brain Atlas.
212
    Currently this is designed for the AllenCCF at several resolutions,
213
    yet this class can be used for other atlases arises.
214
    """
215
    def __init__(self, image, label, dxyz, regions, iorigin=[0, 0, 0],
1✔
216
                 dims2xyz=[0, 1, 2], xyz2dims=[0, 1, 2]):
217
        """
218
        self.image: image volume (ap, ml, dv)
219
        self.label: label volume (ap, ml, dv)
220
        self.bc: atlas.BrainCoordinate object
221
        self.regions: atlas.BrainRegions object
222
        self.top: 2d np array (ap, ml) containing the z-coordinate (m) of the surface of the brain
223
        self.dims2xyz and self.zyz2dims: map image axis order to xyz coordinates order
224
        """
225

226
        self.image = image
1✔
227
        self.label = label
1✔
228
        self.regions = regions
1✔
229
        self.dims2xyz = dims2xyz
1✔
230
        self.xyz2dims = xyz2dims
1✔
231
        assert np.all(self.dims2xyz[self.xyz2dims] == np.array([0, 1, 2]))
1✔
232
        assert np.all(self.xyz2dims[self.dims2xyz] == np.array([0, 1, 2]))
1✔
233
        # create the coordinate transform object that maps volume indices to real world coordinates
234
        nxyz = np.array(self.image.shape)[self.dims2xyz]
1✔
235
        bc = BrainCoordinates(nxyz=nxyz, xyz0=(0, 0, 0), dxyz=dxyz)
1✔
236
        self.bc = BrainCoordinates(nxyz=nxyz, xyz0=-bc.i2xyz(iorigin), dxyz=dxyz)
1✔
237

238
        self.surface = None
1✔
239
        self.boundary = None
1✔
240

241
    @staticmethod
1✔
242
    def _get_cache_dir():
1✔
243
        par = one.params.get(silent=True)
×
244
        path_atlas = Path(par.CACHE_DIR).joinpath(PurePosixPath('histology', 'ATLAS', 'Needles', 'Allen', 'flatmaps'))
×
245
        return path_atlas
×
246

247
    def compute_surface(self):
1✔
248
        """
249
        Get the volume top, bottom, left and right surfaces, and from these the outer surface of
250
        the image volume. This is needed to compute probe insertions intersections.
251

252
        NOTE: In places where the top or bottom surface touch the top or bottom of the atlas volume, the surface
253
        will be set to np.nan. If you encounter issues working with these surfaces check if this might be the cause.
254
        """
255
        if self.surface is None:  # only compute if it hasn't already been computed
×
256
            axz = self.xyz2dims[2]  # this is the dv axis
×
257
            _surface = (self.label == 0).astype(np.int8) * 2
×
258
            l0 = np.diff(_surface, axis=axz, append=2)
×
259
            _top = np.argmax(l0 == -2, axis=axz).astype(float)
×
260
            _top[_top == 0] = np.nan
×
261
            _bottom = self.bc.nz - np.argmax(np.flip(l0, axis=axz) == 2, axis=axz).astype(float)
×
262
            _bottom[_bottom == self.bc.nz] = np.nan
×
263
            self.top = self.bc.i2z(_top + 1)
×
264
            self.bottom = self.bc.i2z(_bottom - 1)
×
265
            self.surface = np.diff(_surface, axis=self.xyz2dims[0], append=2) + l0
×
266
            idx_srf = np.where(self.surface != 0)
×
267
            self.surface[idx_srf] = 1
×
268
            self.srf_xyz = self.bc.i2xyz(np.c_[idx_srf[self.xyz2dims[0]], idx_srf[self.xyz2dims[1]],
×
269
                                               idx_srf[self.xyz2dims[2]]].astype(float))
270

271
    def _lookup_inds(self, ixyz, mode='raise'):
1✔
272
        """
273
        Performs a 3D lookup from volume indices ixyz to the image volume
274
        :param ixyz: [n, 3] array of indices in the mlapdv order
275
        :return: n array of flat indices
276
        """
277
        idims = np.split(ixyz[..., self.xyz2dims], [1, 2], axis=-1)
×
278
        inds = np.ravel_multi_index(idims, self.bc.nxyz[self.xyz2dims], mode=mode)
×
279
        return inds.squeeze()
×
280

281
    def _lookup(self, xyz, mode='raise'):
1✔
282
        """
283
        Performs a 3D lookup from real world coordinates to the flat indices in the volume
284
        defined in the BrainCoordinates object
285
        :param xyz: [n, 3] array of coordinates
286
        :return: n array of flat indices
287
        """
288
        return self._lookup_inds(self.bc.xyz2i(xyz, mode=mode), mode=mode)
×
289

290
    def get_labels(self, xyz, mapping=None, radius_um=None, mode='raise'):
1✔
291
        """
292
        Performs a 3D lookup from real world coordinates to the volume labels
293
        and return the regions ids according to the mapping
294
        :param xyz: [n, 3] array of coordinates
295
        :param mapping: brain region mapping (defaults to original Allen mapping)
296
        :param radius_um: if not null, returns a regions ids array and an array of proportion
297
         of regions in a sphere of size radius around the coordinates.
298
        :return: n array of region ids
299
        """
300
        mapping = mapping or self.regions.default_mapping
×
301

302
        if radius_um:
×
303
            nrx = int(np.ceil(radius_um / abs(self.bc.dx) / 1e6))
×
304
            nry = int(np.ceil(radius_um / abs(self.bc.dy) / 1e6))
×
305
            nrz = int(np.ceil(radius_um / abs(self.bc.dz) / 1e6))
×
306
            nr = [nrx, nry, nrz]
×
307
            iii = self.bc.xyz2i(xyz)
×
308
            # computing the cube radius and indices is more complicated as volume indices are not
309
            # necessariy in ml, ap, dv order so the indices order is dynamic
310
            rcube = np.meshgrid(*tuple((np.arange(
×
311
                -nr[i], nr[i] + 1) * self.bc.dxyz[i]) ** 2 for i in self.xyz2dims))
312
            rcube = np.sqrt(rcube[0] + rcube[1], rcube[2]) * 1e6
×
313
            icube = tuple(slice(-nr[i] + iii[i], nr[i] + iii[i] + 1) for i in self.xyz2dims)
×
314
            cube = self.regions.mappings[mapping][self.label[icube]]
×
315
            ilabs, counts = np.unique(cube[rcube <= radius_um], return_counts=True)
×
316
            return self.regions.id[ilabs], counts / np.sum(counts)
×
317
        else:
318
            regions_indices = self._get_mapping(mapping=mapping)[self.label.flat[self._lookup(xyz, mode=mode)]]
×
319
            return self.regions.id[regions_indices]
×
320

321
    def _get_mapping(self, mapping=None):
1✔
322
        """
323
        Safe way to get mappings if nothing defined in regions.
324
        A mapping transforms from the full allen brain Atlas ids to the remapped ids
325
        new_ids = ids[mapping]
326
        """
327
        mapping = mapping or self.regions.default_mapping
×
328
        if hasattr(self.regions, 'mappings'):
×
329
            return self.regions.mappings[mapping]
×
330
        else:
331
            return np.arange(self.regions.id.size)
×
332

333
    def _label2rgb(self, imlabel):
1✔
334
        """
335
        Converts a slice from the label volume to its RGB equivalent for display
336
        :param imlabel: 2D np-array containing label ids (slice of the label volume)
337
        :return: 3D np-array of the slice uint8 rgb values
338
        """
339
        if getattr(self.regions, 'rgb', None) is None:
×
340
            return self.regions.id[imlabel]
×
341
        else:  # if the regions exist and have the rgb attribute, do the rgb lookup
342
            return self.regions.rgb[imlabel]
×
343

344
    def tilted_slice(self, xyz, axis, volume='image'):
1✔
345
        """
346
        From line coordinates, extracts the tilted plane containing the line from the 3D volume
347
        :param xyz: np.array: points defining a probe trajectory in 3D space (xyz triplets)
348
        if more than 2 points are provided will take the best fit
349
        :param axis:
350
            0: along ml = sagittal-slice
351
            1: along ap = coronal-slice
352
            2: along dv = horizontal-slice
353
        :param volume: 'image' or 'annotation'
354
        :return: np.array, abscissa extent (width), ordinate extent (height),
355
        squeezed axis extent (depth)
356
        """
357
        if axis == 0:   # sagittal slice (squeeze/take along ml-axis)
×
358
            wdim, hdim, ddim = (1, 2, 0)
×
359
        elif axis == 1:  # coronal slice (squeeze/take along ap-axis)
×
360
            wdim, hdim, ddim = (0, 2, 1)
×
361
        elif axis == 2:  # horizontal slice (squeeze/take along dv-axis)
×
362
            wdim, hdim, ddim = (0, 1, 2)
×
363
        # get the best fit and find exit points of the volume along squeezed axis
364
        trj = Trajectory.fit(xyz)
×
365
        sub_volume = trj._eval(self.bc.lim(axis=hdim), axis=hdim)
×
366
        sub_volume[:, wdim] = self.bc.lim(axis=wdim)
×
367
        sub_volume_i = self.bc.xyz2i(sub_volume)
×
368
        tile_shape = np.array([np.diff(sub_volume_i[:, hdim])[0] + 1, self.bc.nxyz[wdim]])
×
369
        # get indices along each dimension
370
        indx = np.arange(tile_shape[1])
×
371
        indy = np.arange(tile_shape[0])
×
372
        inds = np.linspace(*sub_volume_i[:, ddim], tile_shape[0])
×
373
        # compute the slice indices and output the slice
374
        _, INDS = np.meshgrid(indx, np.int64(np.around(inds)))
×
375
        INDX, INDY = np.meshgrid(indx, indy)
×
376
        indsl = [[INDX, INDY, INDS][i] for i in np.argsort([wdim, hdim, ddim])[self.xyz2dims]]
×
377
        if isinstance(volume, np.ndarray):
×
378
            tslice = volume[indsl[0], indsl[1], indsl[2]]
×
379
        elif volume.lower() == 'annotation':
×
380
            tslice = self._label2rgb(self.label[indsl[0], indsl[1], indsl[2]])
×
381
        elif volume.lower() == 'image':
×
382
            tslice = self.image[indsl[0], indsl[1], indsl[2]]
×
383
        elif volume.lower() == 'surface':
×
384
            tslice = self.surface[indsl[0], indsl[1], indsl[2]]
×
385

386
        #  get extents with correct convention NB: matplotlib flips the y-axis on imshow !
387
        width = np.sort(sub_volume[:, wdim])[np.argsort(self.bc.lim(axis=wdim))]
×
388
        height = np.flipud(np.sort(sub_volume[:, hdim])[np.argsort(self.bc.lim(axis=hdim))])
×
389
        depth = np.flipud(np.sort(sub_volume[:, ddim])[np.argsort(self.bc.lim(axis=ddim))])
×
390
        return tslice, width, height, depth
×
391

392
    def plot_tilted_slice(self, xyz, axis, volume='image', cmap=None, ax=None, sec_ax=False, **kwargs):
1✔
393
        """
394
        From line coordinates, extracts the tilted plane containing the line from the 3D volume
395
        :param xyz: np.array: points defining a probe trajectory in 3D space (xyz triplets)
396
        if more than 2 points are provided will take the best fit
397
        :param axis:
398
            0: along ml = sagittal-slice
399
            1: along ap = coronal-slice
400
            2: along dv = horizontal-slice
401
        :param volume: 'image' or 'annotation'
402
        :return: matplotlib axis
403
        """
404
        if axis == 0:
×
405
            axis_labels = np.array(['ap (um)', 'dv (um)', 'ml (um)'])
×
406
        elif axis == 1:
×
407
            axis_labels = np.array(['ml (um)', 'dv (um)', 'ap (um)'])
×
408
        elif axis == 2:
×
409
            axis_labels = np.array(['ml (um)', 'ap (um)', 'dv (um)'])
×
410

411
        tslice, width, height, depth = self.tilted_slice(xyz, axis, volume=volume)
×
412
        width = width * 1e6
×
413
        height = height * 1e6
×
414
        depth = depth * 1e6
×
415
        if not ax:
×
416
            plt.figure()
×
417
            ax = plt.gca()
×
418
            ax.axis('equal')
×
419
        if not cmap:
×
420
            cmap = plt.get_cmap('bone')
×
421
        # get the transfer function from y-axis to squeezed axis for second axe
422
        ab = np.linalg.solve(np.c_[height, height * 0 + 1], depth)
×
423
        height * ab[0] + ab[1]
×
424
        ax.imshow(tslice, extent=np.r_[width, height], cmap=cmap, **kwargs)
×
425
        sec_ax = ax.secondary_yaxis('right', functions=(
×
426
                                    lambda x: x * ab[0] + ab[1],
427
                                    lambda y: (y - ab[1]) / ab[0]))
428
        ax.set_xlabel(axis_labels[0])
×
429
        ax.set_ylabel(axis_labels[1])
×
430
        sec_ax.set_ylabel(axis_labels[2])
×
431
        if sec_ax:
×
432
            return ax, sec_ax
×
433
        else:
434
            return ax
×
435

436
    @staticmethod
1✔
437
    def _plot_slice(im, extent, ax=None, cmap=None, volume=None, **kwargs):
1✔
438
        if not ax:
×
439
            ax = plt.gca()
×
440
            ax.axis('equal')
×
441
        if not cmap:
×
442
            cmap = plt.get_cmap('bone')
×
443

444
        if volume == 'boundary':
×
445
            imb = np.zeros((*im.shape[:2], 4), dtype=np.uint8)
×
446
            imb[im == 1] = np.array([0, 0, 0, 255])
×
447
            im = imb
×
448

449
        ax.imshow(im, extent=extent, cmap=cmap, **kwargs)
×
450
        return ax
×
451

452
    def extent(self, axis):
1✔
453
        """
454
        :param axis: direction along which the volume is stacked:
455
         (2 = z for horizontal slice)
456
         (1 = y for coronal slice)
457
         (0 = x for sagittal slice)
458
        :return:
459
        """
460

461
        if axis == 0:
×
462
            extent = np.r_[self.bc.ylim, np.flip(self.bc.zlim)] * 1e6
×
463
        elif axis == 1:
×
464
            extent = np.r_[self.bc.xlim, np.flip(self.bc.zlim)] * 1e6
×
465
        elif axis == 2:
×
466
            extent = np.r_[self.bc.xlim, np.flip(self.bc.ylim)] * 1e6
×
467
        return extent
×
468

469
    def slice(self, coordinate, axis, volume='image', mode='raise', region_values=None,
1✔
470
              mapping=None, bc=None):
471
        """
472
        Get slice through atlas
473

474
        :param coordinate: coordinate to slice in metres, float
475
        :param axis: xyz convention:  0 for ml, 1 for ap, 2 for dv
476
            - 0: sagittal slice (along ml axis)
477
            - 1: coronal slice (along ap axis)
478
            - 2: horizontal slice (along dv axis)
479
        :param volume:
480
            - 'image' - allen image volume
481
            - 'annotation' - allen annotation volume
482
            - 'surface' - outer surface of mesh
483
            - 'boundary' - outline of boundaries between all regions
484
            - 'volume' - custom volume, must pass in volume of shape ba.image.shape as regions_value argument
485
            - 'value' - custom value per allen region, must pass in array of shape ba.regions.id as regions_value argument
486
        :param mode: error mode for out of bounds coordinates
487
            -   'raise' raise an error
488
            -   'clip' gets the first or last index
489
        :param region_values: custom values to plot
490
            - if volume='volume', region_values must have shape ba.image.shape
491
            - if volume='value', region_values must have shape ba.regions.id
492
        :param mapping: mapping to use. Options can be found using ba.regions.mappings.keys()
493
        :return: 2d array or 3d RGB numpy int8 array
494
        """
495
        if axis == 0:
×
496
            index = self.bc.x2i(np.array(coordinate), mode=mode)
×
497
        elif axis == 1:
×
498
            index = self.bc.y2i(np.array(coordinate), mode=mode)
×
499
        elif axis == 2:
×
500
            index = self.bc.z2i(np.array(coordinate), mode=mode)
×
501

502
        # np.take is 50 thousand times slower than straight slicing !
503
        def _take(vol, ind, axis):
×
504
            if mode == 'clip':
×
505
                ind = np.minimum(np.maximum(ind, 0), vol.shape[axis] - 1)
×
506
            if axis == 0:
×
507
                return vol[ind, :, :]
×
508
            elif axis == 1:
×
509
                return vol[:, ind, :]
×
510
            elif axis == 2:
×
511
                return vol[:, :, ind]
×
512

513
        def _take_remap(vol, ind, axis, mapping):
×
514
            # For the labels, remap the regions indices according to the mapping
515
            return self._get_mapping(mapping=mapping)[_take(vol, ind, axis)]
×
516

517
        if isinstance(volume, np.ndarray):
×
518
            return _take(volume, index, axis=self.xyz2dims[axis])
×
519
        elif volume in 'annotation':
×
520
            iregion = _take_remap(self.label, index, self.xyz2dims[axis], mapping)
×
521
            return self._label2rgb(iregion)
×
522
        elif volume == 'image':
×
523
            return _take(self.image, index, axis=self.xyz2dims[axis])
×
524
        elif volume == 'value':
×
525
            return region_values[_take_remap(self.label, index, self.xyz2dims[axis], mapping)]
×
526
        elif volume == 'image':
×
527
            return _take(self.image, index, axis=self.xyz2dims[axis])
×
528
        elif volume in ['surface', 'edges']:
×
529
            self.compute_surface()
×
530
            return _take(self.surface, index, axis=self.xyz2dims[axis])
×
531
        elif volume == 'boundary':
×
532
            iregion = _take_remap(self.label, index, self.xyz2dims[axis], mapping)
×
533
            return self.compute_boundaries(iregion)
×
534

535
        elif volume == 'volume':
×
536
            if bc is not None:
×
537
                index = bc.xyz2i(np.array([coordinate] * 3))[axis]
×
538
            return _take(region_values, index, axis=self.xyz2dims[axis])
×
539

540
    def compute_boundaries(self, values):
1✔
541
        """
542
        Compute the boundaries between regions on slice
543
        :param values:
544
        :return:
545
        """
546
        boundary = np.abs(np.diff(values, axis=0, prepend=0))
×
547
        boundary = boundary + np.abs(np.diff(values, axis=1, prepend=0))
×
548
        boundary = boundary + np.abs(np.diff(values, axis=1, append=0))
×
549
        boundary = boundary + np.abs(np.diff(values, axis=0, append=0))
×
550

551
        boundary[boundary != 0] = 1
×
552

553
        return boundary
×
554

555
    def plot_slices(self, xyz, *args, **kwargs):
1✔
556
        """
557
        From a single coordinate, plots the 3 slices that intersect at this point in a single
558
        matplotlib figure
559
        :param xyz: mlapdv coordinate in m
560
        :param args: arguments to be forwarded to plot slices
561
        :param kwargs: keyword arguments to be forwarded to plot slices
562
        :return: 2 by 2 array of axes
563
        """
564
        fig, axs = plt.subplots(2, 2)
×
565
        self.plot_cslice(xyz[1], *args, ax=axs[0, 0], **kwargs)
×
566
        self.plot_sslice(xyz[0], *args, ax=axs[0, 1], **kwargs)
×
567
        self.plot_hslice(xyz[2], *args, ax=axs[1, 0], **kwargs)
×
568
        xyz_um = xyz * 1e6
×
569
        axs[0, 0].plot(xyz_um[0], xyz_um[2], 'g*')
×
570
        axs[0, 1].plot(xyz_um[1], xyz_um[2], 'g*')
×
571
        axs[1, 0].plot(xyz_um[0], xyz_um[1], 'g*')
×
572
        return axs
×
573

574
    def plot_cslice(self, ap_coordinate, volume='image', mapping=None, region_values=None, **kwargs):
1✔
575
        """
576
        Plot coronal slice through atlas at given ap_coordinate
577

578
        :param: ap_coordinate (m)
579
        :param volume:
580
            - 'image' - allen image volume
581
            - 'annotation' - allen annotation volume
582
            - 'surface' - outer surface of mesh
583
            - 'boundary' - outline of boundaries between all regions
584
            - 'volume' - custom volume, must pass in volume of shape ba.image.shape as regions_value argument
585
            - 'value' - custom value per allen region, must pass in array of shape ba.regions.id as regions_value argument
586
        :param mapping: mapping to use. Options can be found using ba.regions.mappings.keys()
587
        :param region_values: custom values to plot
588
            - if volume='volume', region_values must have shape ba.image.shape
589
            - if volume='value', region_values must have shape ba.regions.id
590
        :param mapping: mapping to use. Options can be found using ba.regions.mappings.keys()
591
        :param **kwargs: matplotlib.pyplot.imshow kwarg arguments
592
        :return: matplotlib ax object
593
        """
594

595
        cslice = self.slice(ap_coordinate, axis=1, volume=volume, mapping=mapping, region_values=region_values)
×
596
        return self._plot_slice(np.moveaxis(cslice, 0, 1), extent=self.extent(axis=1), volume=volume, **kwargs)
×
597

598
    def plot_hslice(self, dv_coordinate, volume='image', mapping=None, region_values=None, **kwargs):
1✔
599
        """
600
        Plot horizontal slice through atlas at given dv_coordinate
601

602
        :param: dv_coordinate (m)
603
        :param volume:
604
            - 'image' - allen image volume
605
            - 'annotation' - allen annotation volume
606
            - 'surface' - outer surface of mesh
607
            - 'boundary' - outline of boundaries between all regions
608
            - 'volume' - custom volume, must pass in volume of shape ba.image.shape as regions_value argument
609
            - 'value' - custom value per allen region, must pass in array of shape ba.regions.id as regions_value argument
610
        :param mapping: mapping to use. Options can be found using ba.regions.mappings.keys()
611
        :param region_values: custom values to plot
612
            - if volume='volume', region_values must have shape ba.image.shape
613
            - if volume='value', region_values must have shape ba.regions.id
614
        :param mapping: mapping to use. Options can be found using ba.regions.mappings.keys()
615
        :param **kwargs: matplotlib.pyplot.imshow kwarg arguments
616
        :return: matplotlib ax object
617
        """
618

619
        hslice = self.slice(dv_coordinate, axis=2, volume=volume, mapping=mapping, region_values=region_values)
×
620
        return self._plot_slice(hslice, extent=self.extent(axis=2), volume=volume, **kwargs)
×
621

622
    def plot_sslice(self, ml_coordinate, volume='image', mapping=None, region_values=None, **kwargs):
1✔
623
        """
624
        Plot sagittal slice through atlas at given ml_coordinate
625

626
        :param: ml_coordinate (m)
627
        :param volume:
628
            - 'image' - allen image volume
629
            - 'annotation' - allen annotation volume
630
            - 'surface' - outer surface of mesh
631
            - 'boundary' - outline of boundaries between all regions
632
            - 'volume' - custom volume, must pass in volume of shape ba.image.shape as regions_value argument
633
            - 'value' - custom value per allen region, must pass in array of shape ba.regions.id as regions_value argument
634
        :param mapping: mapping to use. Options can be found using ba.regions.mappings.keys()
635
        :param region_values: custom values to plot
636
            - if volume='volume', region_values must have shape ba.image.shape
637
            - if volume='value', region_values must have shape ba.regions.id
638
        :param mapping: mapping to use. Options can be found using ba.regions.mappings.keys()
639
        :param **kwargs: matplotlib.pyplot.imshow kwarg arguments
640
        :return: matplotlib ax object
641
        """
642

643
        sslice = self.slice(ml_coordinate, axis=0, volume=volume, mapping=mapping, region_values=region_values)
×
644
        return self._plot_slice(np.swapaxes(sslice, 0, 1), extent=self.extent(axis=0), volume=volume, **kwargs)
×
645

646
    def plot_top(self, volume='annotation', mapping=None, region_values=None, ax=None, **kwargs):
1✔
647
        """
648
        Plot top view of atlas
649
        :param volume:
650
            - 'image' - allen image volume
651
            - 'annotation' - allen annotation volume
652
            - 'boundary' - outline of boundaries between all regions
653
            - 'volume' - custom volume, must pass in volume of shape ba.image.shape as regions_value argument
654
            - 'value' - custom value per allen region, must pass in array of shape ba.regions.id as regions_value argument
655

656
        :param mapping: mapping to use. Options can be found using ba.regions.mappings.keys()
657
        :param region_values:
658
        :param ax:
659
        :param kwargs:
660
        :return:
661
        """
662

663
        self.compute_surface()
×
664
        ix, iy = np.meshgrid(np.arange(self.bc.nx), np.arange(self.bc.ny))
×
665
        iz = self.bc.z2i(self.top)
×
666
        inds = self._lookup_inds(np.stack((ix, iy, iz), axis=-1))
×
667

668
        regions = self._get_mapping(mapping=mapping)[self.label.flat[inds]]
×
669

670
        if volume == 'annotation':
×
671
            im = self._label2rgb(regions)
×
672
        elif volume == 'image':
×
673
            im = self.top
×
674
        elif volume == 'value':
×
675
            im = region_values[regions]
×
676
        elif volume == 'volume':
×
677
            im = np.zeros((iz.shape))
×
678
            for x in range(im.shape[0]):
×
679
                for y in range(im.shape[1]):
×
680
                    im[x, y] = region_values[x, y, iz[x, y]]
×
681
        elif volume == 'boundary':
×
682
            im = self.compute_boundaries(regions)
×
683

684
        return self._plot_slice(im, self.extent(axis=2), ax=ax, volume=volume, **kwargs)
×
685

686

687
@dataclass
1✔
688
class Trajectory:
1✔
689
    """
1✔
690
    3D Trajectory (usually for a linear probe). Minimally defined by a vector and a point.
691
    instantiate from a best fit from a n by 3 array containing xyz coordinates:
692
        trj = Trajectory.fit(xyz)
693
    """
694
    vector: np.ndarray
1✔
695
    point: np.ndarray
1✔
696

697
    @staticmethod
1✔
698
    def fit(xyz):
1✔
699
        """
700
        fits a line to a 3D cloud of points, returns a Trajectory object
701
        :param xyz: n by 3 numpy array containing cloud of points
702
        :returns: a Trajectory object
703
        """
704
        xyz_mean = np.mean(xyz, axis=0)
×
705
        return Trajectory(vector=np.linalg.svd(xyz - xyz_mean)[2][0], point=xyz_mean)
×
706

707
    def eval_x(self, x):
1✔
708
        """
709
        given an array of x coordinates, returns the xyz array of coordinates along the insertion
710
        :param x: n by 1 or numpy array containing x-coordinates
711
        :return: n by 3 numpy array containing xyz-coordinates
712
        """
713
        return self._eval(x, axis=0)
×
714

715
    def eval_y(self, y):
1✔
716
        """
717
        given an array of y coordinates, returns the xyz array of coordinates along the insertion
718
        :param y: n by 1 or numpy array containing y-coordinates
719
        :return: n by 3 numpy array containing xyz-coordinates
720
        """
721
        return self._eval(y, axis=1)
×
722

723
    def eval_z(self, z):
1✔
724
        """
725
        given an array of z coordinates, returns the xyz array of coordinates along the insertion
726
        :param z: n by 1 or numpy array containing z-coordinates
727
        :return: n by 3 numpy array containing xyz-coordinates
728
        """
729
        return self._eval(z, axis=2)
×
730

731
    def project(self, point):
1✔
732
        """
733
        projects a point onto the trajectory line
734
        :param point: np.array(x, y, z) coordinates
735
        :return:
736
        """
737
        # https://mathworld.wolfram.com/Point-LineDistance3-Dimensional.html
738
        if point.ndim == 1:
×
739
            return self.project(point[np.newaxis])[0]
×
740
        return (self.point + np.dot(point[:, np.newaxis] - self.point, self.vector) /
×
741
                np.dot(self.vector, self.vector) * self.vector)
742

743
    def mindist(self, xyz, bounds=None):
1✔
744
        """
745
        Computes the minimum distance to the trajectory line for one or a set of points.
746
        If bounds are provided, computes the minimum distance to the segment instead of an
747
        infinite line.
748
        :param xyz: [..., 3]
749
        :param bounds: defaults to None.  np.array [2, 3]: segment boundaries, inf line if None
750
        :return: minimum distance [...]
751
        """
752
        proj = self.project(xyz)
×
753
        d = np.sqrt(np.sum((proj - xyz) ** 2, axis=-1))
×
754
        if bounds is not None:
×
755
            # project the boundaries and the points along the traj
756
            b = np.dot(bounds, self.vector)
×
757
            ob = np.argsort(b)
×
758
            p = np.dot(xyz[:, np.newaxis], self.vector).squeeze()
×
759
            # for points below and above boundaries, compute cartesian distance to the boundary
760
            imin = p < np.min(b)
×
761
            d[imin] = np.sqrt(np.sum((xyz[imin, :] - bounds[ob[0], :]) ** 2, axis=-1))
×
762
            imax = p > np.max(b)
×
763
            d[imax] = np.sqrt(np.sum((xyz[imax, :] - bounds[ob[1], :]) ** 2, axis=-1))
×
764
        return d
×
765

766
    def _eval(self, c, axis):
1✔
767
        # uses symmetric form of 3d line equation to get xyz coordinates given one coordinate
768
        if not isinstance(c, np.ndarray):
×
769
            c = np.array(c)
×
770
        while c.ndim < 2:
×
771
            c = c[..., np.newaxis]
×
772
        # there are cases where it's impossible to project if a line is // to the axis
773
        if self.vector[axis] == 0:
×
774
            return np.nan * np.zeros((c.shape[0], 3))
×
775
        else:
776
            return (c - self.point[axis]) * self.vector / self.vector[axis] + self.point
×
777

778
    def exit_points(self, bc):
1✔
779
        """
780
        Given a Trajectory and a BrainCoordinates object, computes the intersection of the
781
        trajectory with the brain coordinates bounding box
782
        :param bc: BrainCoordinate objects
783
        :return: np.ndarray 2 y 3 corresponding to exit points xyz coordinates
784
        """
785
        bounds = np.c_[bc.xlim, bc.ylim, bc.zlim]
×
786
        epoints = np.r_[self.eval_x(bc.xlim), self.eval_y(bc.ylim), self.eval_z(bc.zlim)]
×
787
        epoints = epoints[~np.all(np.isnan(epoints), axis=1)]
×
788
        ind = np.all(np.bitwise_and(bounds[0, :] <= epoints, epoints <= bounds[1, :]), axis=1)
×
789
        return epoints[ind, :]
×
790

791

792
@dataclass
1✔
793
class Insertion:
1✔
794
    """
1✔
795
    Defines an ephys probe insertion in 3D coordinate. IBL conventions.
796
    To instantiate, use the static methods:
797
    Insertion.from_track
798
    Insertion.from_dict
799
    """
800
    x: float
1✔
801
    y: float
1✔
802
    z: float
1✔
803
    phi: float
1✔
804
    theta: float
1✔
805
    depth: float
1✔
806
    label: str = ''
1✔
807
    beta: float = 0
1✔
808

809
    @staticmethod
1✔
810
    def from_track(xyzs, brain_atlas=None):
1✔
811
        """
812
        :param brain_atlas: None. If provided, disregards the z coordinate and locks the insertion
813
        point to the z of the brain surface
814
        :return: Trajectory object
815
        """
816
        assert brain_atlas, 'Input argument brain_atlas must be defined'
×
817
        traj = Trajectory.fit(xyzs)
×
818
        # project the deepest point into the vector to get the tip coordinate
819
        tip = traj.project(xyzs[np.argmin(xyzs[:, 2]), :])
×
820
        # get intersection with the brain surface as an entry point
821
        entry = Insertion.get_brain_entry(traj, brain_atlas)
×
822
        # convert to spherical system to store the insertion
823
        depth, theta, phi = cart2sph(*(entry - tip))
×
824
        insertion_dict = {'x': entry[0], 'y': entry[1], 'z': entry[2],
×
825
                          'phi': phi, 'theta': theta, 'depth': depth}
826
        return Insertion(**insertion_dict)
×
827

828
    @staticmethod
1✔
829
    def from_dict(d, brain_atlas=None):
1✔
830
        """
831
        Constructs an Insertion object from the json information stored in probes.description file
832
        :param trj: dictionary containing at least the following keys, in um
833
           {
834
            'x': 544.0,
835
            'y': 1285.0,
836
            'z': 0.0,
837
            'phi': 0.0,
838
            'theta': 5.0,
839
            'depth': 4501.0
840
            }
841
        :param brain_atlas: None. If provided, disregards the z coordinate and locks the insertion
842
        point to the z of the brain surface
843
        :return: Trajectory object
844
        """
845
        z = d['z'] / 1e6
×
846
        if brain_atlas:
×
847
            iy = brain_atlas.bc.y2i(d['y'] / 1e6)
×
848
            ix = brain_atlas.bc.x2i(d['x'] / 1e6)
×
849
            # Only use the brain surface value as z if it isn't NaN (this happens when the surface touches the edges
850
            # of the atlas volume
851
            if not np.isnan(brain_atlas.top[iy, ix]):
×
852
                z = brain_atlas.top[iy, ix]
×
853
        return Insertion(x=d['x'] / 1e6, y=d['y'] / 1e6, z=z,
×
854
                         phi=d['phi'], theta=d['theta'], depth=d['depth'] / 1e6,
855
                         beta=d.get('beta', 0), label=d.get('label', ''))
856

857
    @property
1✔
858
    def trajectory(self):
1✔
859
        """
860
        Gets the trajectory object matching insertion coordinates
861
        :return: atlas.Trajectory
862
        """
863
        return Trajectory.fit(self.xyz)
×
864

865
    @property
1✔
866
    def xyz(self):
1✔
867
        return np.c_[self.entry, self.tip].transpose()
×
868

869
    @property
1✔
870
    def entry(self):
1✔
871
        return np.array((self.x, self.y, self.z))
×
872

873
    @property
1✔
874
    def tip(self):
1✔
875
        return sph2cart(- self.depth, self.theta, self.phi) + np.array((self.x, self.y, self.z))
×
876

877
    @staticmethod
1✔
878
    def _get_surface_intersection(traj, brain_atlas, surface='top'):
1✔
879

880
        brain_atlas.compute_surface()
×
881

882
        distance = traj.mindist(brain_atlas.srf_xyz)
×
883
        dist_sort = np.argsort(distance)
×
884
        # In some cases the nearest two intersection points are not the top and bottom of brain
885
        # So we find all intersection points that fall within one voxel and take the one with
886
        # highest dV to be entry and lowest dV to be exit
887
        idx_lim = np.sum(distance[dist_sort] * 1e6 < np.max(brain_atlas.res_um))
×
888
        dist_lim = dist_sort[0:idx_lim]
×
889
        z_val = brain_atlas.srf_xyz[dist_lim, 2]
×
890
        if surface == 'top':
×
891
            ma = np.argmax(z_val)
×
892
            _xyz = brain_atlas.srf_xyz[dist_lim[ma], :]
×
893
            _ixyz = brain_atlas.bc.xyz2i(_xyz)
×
894
            _ixyz[brain_atlas.xyz2dims[2]] += 1
×
895
        elif surface == 'bottom':
×
896
            ma = np.argmin(z_val)
×
897
            _xyz = brain_atlas.srf_xyz[dist_lim[ma], :]
×
898
            _ixyz = brain_atlas.bc.xyz2i(_xyz)
×
899

900
        xyz = brain_atlas.bc.i2xyz(_ixyz.astype(float))
×
901

902
        return xyz
×
903

904
    @staticmethod
1✔
905
    def get_brain_exit(traj, brain_atlas):
1✔
906
        """
907
        Given a Trajectory and a BrainAtlas object, computes the brain exit coordinate as the
908
        intersection of the trajectory and the brain surface (brain_atlas.surface)
909
        :param brain_atlas:
910
        :return: 3 element array x,y,z
911
        """
912
        # Find point where trajectory intersects with bottom of brain
913
        return Insertion._get_surface_intersection(traj, brain_atlas, surface='bottom')
×
914

915
    @staticmethod
1✔
916
    def get_brain_entry(traj, brain_atlas):
1✔
917
        """
918
        Given a Trajectory and a BrainAtlas object, computes the brain entry coordinate as the
919
        intersection of the trajectory and the brain surface (brain_atlas.surface)
920
        :param brain_atlas:
921
        :return: 3 element array x,y,z
922
        """
923
        # Find point where trajectory intersects with top of brain
924
        return Insertion._get_surface_intersection(traj, brain_atlas, surface='top')
×
925

926

927
class AllenAtlas(BrainAtlas):
1✔
928
    """
1✔
929
    Instantiates an atlas.BrainAtlas corresponding to the Allen CCF at the given resolution
930
    using the IBL Bregma and coordinate system
931
    """
932

933
    def __init__(self, res_um=25, scaling=np.array([1, 1, 1]), mock=False, hist_path=None):
1✔
934
        """
935
        :param res_um: 10, 25 or 50 um
936
        :param scaling: scale factor along ml, ap, dv for squeeze and stretch ([1, 1, 1])
937
        :param mock: for testing purpose
938
        :param hist_path
939
        :return: atlas.BrainAtlas
940
        """
941

942
        par = one.params.get(silent=True)
1✔
943
        FLAT_IRON_ATLAS_REL_PATH = PurePosixPath('histology', 'ATLAS', 'Needles', 'Allen')
1✔
944
        LUT_VERSION = "v01"  # version 01 is the lateralized version
1✔
945
        regions = BrainRegions()
1✔
946
        xyz2dims = np.array([1, 0, 2])  # this is the c-contiguous ordering
1✔
947
        dims2xyz = np.array([1, 0, 2])
1✔
948
        # we use Bregma as the origin
949
        self.res_um = res_um
1✔
950
        ibregma = (ALLEN_CCF_LANDMARKS_MLAPDV_UM['bregma'] / self.res_um)
1✔
951
        dxyz = self.res_um * 1e-6 * np.array([1, -1, -1]) * scaling
1✔
952
        if mock:
1✔
953
            image, label = [np.zeros((528, 456, 320), dtype=np.int16) for _ in range(2)]
×
954
            label[:, :, 100:105] = 1327  # lookup index for retina, id 304325711 (no id 1327)
×
955
        else:
956
            path_atlas = Path(par.CACHE_DIR).joinpath(FLAT_IRON_ATLAS_REL_PATH)
1✔
957
            file_image = hist_path or path_atlas.joinpath(f'average_template_{res_um}.nrrd')
1✔
958
            # get the image volume
959
            if not file_image.exists():
1✔
960
                _download_atlas_allen(file_image, FLAT_IRON_ATLAS_REL_PATH, par)
×
961
            # get the remapped label volume
962
            file_label = path_atlas.joinpath(f'annotation_{res_um}.nrrd')
1✔
963
            if not file_label.exists():
1✔
964
                _download_atlas_allen(file_label, FLAT_IRON_ATLAS_REL_PATH, par)
×
965
            file_label_remap = path_atlas.joinpath(f'annotation_{res_um}_lut_{LUT_VERSION}.npz')
1✔
966
            if not file_label_remap.exists():
1✔
967
                label = self._read_volume(file_label).astype(dtype=np.int32)
×
968
                _logger.info("computing brain atlas annotations lookup table")
×
969
                # lateralize atlas: for this the regions of the left hemisphere have primary
970
                # keys opposite to to the normal ones
971
                lateral = np.zeros(label.shape[xyz2dims[0]])
×
972
                lateral[int(np.floor(ibregma[0]))] = 1
×
973
                lateral = np.sign(np.cumsum(lateral)[np.newaxis, :, np.newaxis] - 0.5)
×
974
                label = label * lateral.astype(np.int32)
×
975
                # the 10 um atlas is too big to fit in memory so work by chunks instead
976
                if res_um == 10:
×
977
                    first, ncols = (0, 10)
×
978
                    while True:
979
                        last = np.minimum(first + ncols, label.shape[-1])
×
980
                        _logger.info(f"Computing... {last} on {label.shape[-1]}")
×
981
                        _, im = ismember(label[:, :, first:last], regions.id)
×
982
                        label[:, :, first:last] = np.reshape(im, label[:, :, first:last].shape)
×
983
                        if last == label.shape[-1]:
×
984
                            break
×
985
                        first += ncols
×
986
                    label = label.astype(dtype=np.uint16)
×
987
                    _logger.info("Saving npz, this can take a long time")
×
988
                else:
989
                    _, im = ismember(label, regions.id)
×
990
                    label = np.reshape(im.astype(np.uint16), label.shape)
×
991
                np.savez_compressed(file_label_remap, label)
×
992
                _logger.info(f"Cached remapping file {file_label_remap} ...")
×
993
            # loads the files
994
            label = self._read_volume(file_label_remap)
1✔
995
            image = self._read_volume(file_image)
1✔
996

997
        super().__init__(image, label, dxyz, regions, ibregma,
1✔
998
                         dims2xyz=dims2xyz, xyz2dims=xyz2dims)
999

1000
    @staticmethod
1✔
1001
    def _read_volume(file_volume):
1✔
1002
        if file_volume.suffix == '.nrrd':
1✔
1003
            volume, _ = nrrd.read(file_volume, index_order='C')  # ml, dv, ap
1✔
1004
            # we want the coronal slice to be the most contiguous
1005
            volume = np.transpose(volume, (2, 0, 1))  # image[iap, iml, idv]
1✔
1006
        elif file_volume.suffix == '.npz':
1✔
1007
            volume = np.load(file_volume)['arr_0']
1✔
1008
        return volume
1✔
1009

1010
    def xyz2ccf(self, xyz, ccf_order='mlapdv', mode='raise'):
1✔
1011
        """
1012
        Converts coordinates to the CCF coordinates, which is assumed to be the cube indices
1013
        times the spacing.
1014
        :param xyz: mlapdv coordinates in meters, origin Bregma
1015
        :param ccf_order: order that you want values returned 'mlapdv' (ibl) or 'apdvml'
1016
        (Allen mcc vertices)
1017
        :param mode: {‘raise’, 'clip', 'wrap'} determines what to do when determined index lies outside the atlas volume
1018
                     'raise' will raise a ValueError
1019
                     'clip' will replace the index with the closest index inside the volume
1020
                     'wrap' will wrap around to the other side of the volume. This is only here for legacy reasons
1021
        :return: coordinates in CCF space um, origin is the front left top corner of the data
1022
        volume, order determined by ccf_order
1023
        """
1024
        ordre = self._ccf_order(ccf_order)
×
1025
        ccf = self.bc.xyz2i(xyz, round=False, mode=mode) * float(self.res_um)
×
1026
        return ccf[..., ordre]
×
1027

1028
    def ccf2xyz(self, ccf, ccf_order='mlapdv'):
1✔
1029
        """
1030
        Converts coordinates from the CCF coordinates, which is assumed to be the cube indices
1031
        times the spacing.
1032
        :param ccf coordinates in CCF space in um, origin is the front left top corner of the data
1033
        volume
1034
        :param ccf_order: order of ccf coordinates given 'mlapdv' (ibl) or 'apdvml'
1035
        (Allen mcc vertices)
1036
        :return: xyz: mlapdv coordinates in m, origin Bregma
1037
        """
1038
        ordre = self._ccf_order(ccf_order, reverse=True)
×
1039
        return self.bc.i2xyz((ccf[..., ordre] / float(self.res_um)))
×
1040

1041
    @staticmethod
1✔
1042
    def _ccf_order(ccf_order, reverse=False):
1✔
1043
        """
1044
        Returns the mapping to go from CCF coordinates order to the brain atlas xyz
1045
        :param ccf_order: 'mlapdv' or 'apdvml'
1046
        :param reverse: defaults to False.
1047
            If False, returns from CCF to brain atlas
1048
            If True, returns from brain atlas to CCF
1049
        :return:
1050
        """
1051
        if ccf_order == 'mlapdv':
×
1052
            return [0, 1, 2]
×
1053
        elif ccf_order == 'apdvml':
×
1054
            if reverse:
×
1055
                return [2, 0, 1]
×
1056
            else:
1057
                return [1, 2, 0]
×
1058
        else:
1059
            ValueError("ccf_order needs to be either 'mlapdv' or 'apdvml'")
×
1060

1061
    def compute_regions_volume(self):
1✔
1062
        """
1063
        Sums the number of voxels in the labels volume for each region.
1064
        Then compute volumes for all of the levels of hierarchy in cubic mm.
1065
        :return:
1066
        """
1067
        nr = self.regions.id.shape[0]
×
1068
        count = np.bincount(self.label.flatten(), minlength=nr)
×
1069
        self.regions.compute_hierarchy()
×
1070
        self.regions.volume = np.sum(count[self.regions.hierarchy], axis=0) * (self.res_um / 1e3) ** 3
×
1071

1072

1073
def NeedlesAtlas(*args, **kwargs):
1✔
1074
    """
1075
    Instantiates an atlas.BrainAtlas corresponding to the Allen CCF at the given resolution
1076
    using the IBL Bregma and coordinate system. The Needles atlas defines a stretch along AP
1077
    axis and a sqeeze along the DV axis.
1078
    :param res_um: 10, 25 or 50 um
1079
    :return: atlas.BrainAtlas
1080
    """
1081
    DV_SCALE = 0.952  # multiplicative factor on DV dimension, determined from MRI->CCF transform
×
1082
    AP_SCALE = 1.087  # multiplicative factor on AP dimension
×
1083
    kwargs['scaling'] = np.array([1, AP_SCALE, DV_SCALE])
×
1084
    return AllenAtlas(*args, **kwargs)
×
1085

1086

1087
def MRITorontoAtlas(*args, **kwargs):
1✔
1088
    """
1089
    Instantiates an atlas.BrainAtlas corresponding to the Allen CCF at the given resolution
1090
    using the IBL Bregma and coordinate system. The MRI Toronto atlas defines a stretch along AP
1091
    a squeeze along DV *and* a squeeze along ML. These are based on 12 p65 mice MRIs averaged.
1092
    See: https://www.nature.com/articles/s41467-018-04921-2 DB has access to the dataset.
1093
    :param res_um: 10, 25 or 50 um
1094
    :return: atlas.BrainAtlas
1095
    """
1096
    ML_SCALE = 0.952
×
1097
    DV_SCALE = 0.885  # multiplicative factor on DV dimension, determined from MRI->CCF transform
×
1098
    AP_SCALE = 1.031  # multiplicative factor on AP dimension
×
1099
    kwargs['scaling'] = np.array([ML_SCALE, AP_SCALE, DV_SCALE])
×
1100
    return AllenAtlas(*args, **kwargs)
×
1101

1102

1103
def _download_atlas_allen(file_image, FLAT_IRON_ATLAS_REL_PATH, par):
1✔
1104
    """
1105
    © 2015 Allen Institute for Brain Science. Allen Mouse Brain Atlas (2015)
1106
    with region annotations (2017).
1107
    Available from: http://download.alleninstitute.org/informatics-archive/current-release/
1108
    mouse_ccf/annotation/
1109
    See Allen Mouse Common Coordinate Framework Technical White Paper for details
1110
    http://help.brain-map.org/download/attachments/8323525/
1111
    Mouse_Common_Coordinate_Framework.pdf?version=3&modificationDate=1508178848279&api=v2
1112
    """
1113

1114
    file_image.parent.mkdir(exist_ok=True, parents=True)
×
1115

1116
    template_url = ('http://download.alleninstitute.org/informatics-archive/'
×
1117
                    'current-release/mouse_ccf/average_template')
1118
    annotation_url = ('http://download.alleninstitute.org/informatics-archive/'
×
1119
                      'current-release/mouse_ccf/annotation/ccf_2017')
1120

1121
    if file_image.name.split('_')[0] == 'average':
×
1122
        url = template_url + '/' + file_image.name
×
1123
    elif file_image.name.split('_')[0] == 'annotation':
×
1124
        url = annotation_url + '/' + file_image.name
×
1125
    else:
1126
        raise ValueError('Unrecognized file image')
×
1127

1128
    cache_dir = Path(par.CACHE_DIR).joinpath(FLAT_IRON_ATLAS_REL_PATH)
×
1129
    return http_download_file(url, target_dir=cache_dir)
×
1130

1131

1132
class FlatMap(AllenAtlas):
1✔
1133

1134
    def __init__(self, flatmap='dorsal_cortex', res_um=25):
1✔
1135
        """
1136
        Avaiable flatmaps are currently 'dorsal_cortex', 'circles' and 'pyramid'
1137
        :param flatmap:
1138
        :param res_um:
1139
        """
1140
        super().__init__(res_um=res_um)
×
1141
        self.name = flatmap
×
1142
        if flatmap == 'dorsal_cortex':
×
1143
            self._get_flatmap_from_file()
×
1144
        elif flatmap == 'circles':
×
1145
            from ibllib.atlas.flatmaps import circles
×
1146
            if res_um != 25:
×
1147
                raise NotImplementedError('Pyramid circles not implemented for resolution other than 25um')
×
1148
            self.flatmap, self.ml_scale, self.ap_scale = circles(N=5, atlas=self, display='flat')
×
1149
        elif flatmap == 'pyramid':
×
1150
            from ibllib.atlas.flatmaps import circles
×
1151
            if res_um != 25:
×
1152
                raise NotImplementedError('Pyramid circles not implemented for resolution other than 25um')
×
1153
            self.flatmap, self.ml_scale, self.ap_scale = circles(N=5, atlas=self, display='pyramid')
×
1154

1155
    def _get_flatmap_from_file(self):
1✔
1156
        # gets the file in the ONE cache for the flatmap name in the property, downloads it if needed
1157
        file_flatmap = self._get_cache_dir().joinpath(f'{self.name}_{self.res_um}.nrrd')
×
1158
        if not file_flatmap.exists():
×
1159
            file_flatmap.parent.mkdir(exist_ok=True, parents=True)
×
1160
            aws.s3_download_file(f'atlas/{file_flatmap.name}', file_flatmap)
×
1161
        self.flatmap, _ = nrrd.read(file_flatmap)
×
1162

1163
    def plot_flatmap(self, depth=0, volume='annotation', mapping='Allen', region_values=None, ax=None, **kwargs):
1✔
1164
        """
1165
        Displays the 2D image corresponding to the flatmap. If there are several depths, by default it
1166
        will display the first one
1167
        :param depth: index of the depth to display in the flatmap volume (the last dimension)
1168
        :param volume:
1169
        :param mapping:
1170
        :param region_values:
1171
        :param ax:
1172
        :param kwargs:
1173
        :return:
1174
        """
1175
        if self.flatmap.ndim == 3:
×
1176
            inds = np.int32(self.flatmap[:, :, depth])
×
1177
        else:
1178
            inds = np.int32(self.flatmap[:, :])
×
1179
        regions = self._get_mapping(mapping=mapping)[self.label.flat[inds]]
×
1180
        if volume == 'annotation':
×
1181
            im = self._label2rgb(regions)
×
1182
        elif volume == 'value':
×
1183
            im = region_values[regions]
×
1184
        elif volume == 'boundary':
×
1185
            im = self.compute_boundaries(regions)
×
1186
        elif volume == 'image':
×
1187
            im = self.image.flat[inds]
×
1188
        if not ax:
×
1189
            ax = plt.gca()
×
1190

1191
        return self._plot_slice(im, self.extent_flmap(), ax=ax, volume=volume, **kwargs)
×
1192

1193
    def extent_flmap(self):
1✔
1194
        extent = np.r_[0, self.flatmap.shape[1], 0, self.flatmap.shape[0]]
×
1195
        return extent
×
1196

1197

1198
class FranklinPaxinosAtlas(BrainAtlas):
1✔
1199
    """
1✔
1200
    Instantiates an atlas.BrainAtlas corresponding to the Allen CCF at the given resolution
1201
    using the IBL Bregma and coordinate system
1202
    """
1203

1204
    def __init__(self, res_um=np.array([10, 100, 10]), scaling=np.array([1, 1, 1]), mock=False, hist_path=None):
1✔
1205
        """
1206
        :param res_um: 10, 25 or 50 um
1207
        :param scaling: scale factor along ml, ap, dv for squeeze and stretch ([1, 1, 1])
1208
        :param mock: for testing purpose
1209
        :param hist_path
1210
        :return: atlas.BrainAtlas
1211
        """
1212
        # TODO interpolate?
1213
        par = one.params.get(silent=True)
×
1214
        FLAT_IRON_ATLAS_REL_PATH = PurePosixPath('histology', 'ATLAS', 'Needles', 'FranklinPaxinos')
×
1215
        LUT_VERSION = "v01"  # version 01 is the lateralized version
×
1216
        regions = FranklinPaxinosRegions()
×
1217
        xyz2dims = np.array([1, 0, 2])  # this is the c-contiguous ordering
×
1218
        dims2xyz = np.array([1, 0, 2])
×
1219
        # we use Bregma as the origin
1220
        self.res_um = res_um
×
1221
        ibregma = (PAXINOS_CCF_LANDMARKS_MLAPDV_UM['bregma'] / self.res_um)
×
1222
        dxyz = self.res_um * 1e-6 * np.array([1, -1, -1]) * scaling
×
1223
        if mock:
×
1224
            image, label = [np.zeros((528, 456, 320), dtype=np.int16) for _ in range(2)]
×
1225
            label[:, :, 100:105] = 1327  # lookup index for retina, id 304325711 (no id 1327)
×
1226
        else:
1227
            path_atlas = Path(par.CACHE_DIR).joinpath(FLAT_IRON_ATLAS_REL_PATH)
×
1228
            file_image = hist_path or path_atlas.joinpath(f'average_template_{res_um[0]}_{res_um[1]}_{res_um[2]}.npz')
×
1229
            # # get the image volume
1230
            if not file_image.exists():
×
1231
                path_atlas.mkdir(exist_ok=True, parents=True)
×
1232
                aws.s3_download_file(f'atlas/FranklinPaxinos/{file_image.name}', str(file_image))
×
1233
            # # get the remapped label volume
1234
            file_label = path_atlas.joinpath(f'annotation_{res_um[0]}_{res_um[1]}_{res_um[2]}.npz')
×
1235
            if not file_label.exists():
×
1236
                path_atlas.mkdir(exist_ok=True, parents=True)
×
1237
                aws.s3_download_file(f'atlas/FranklinPaxinos/{file_label.name}', str(file_label))
×
1238

1239
            file_label_remap = path_atlas.joinpath(f'annotation_{res_um[0]}_{res_um[1]}_{res_um[2]}_lut_{LUT_VERSION}.npz')
×
1240

1241
            if not file_label_remap.exists():
×
1242
                label = self._read_volume(file_label).astype(dtype=np.int32)
×
1243
                _logger.info("computing brain atlas annotations lookup table")
×
1244
                # lateralize atlas: for this the regions of the left hemisphere have primary
1245
                # keys opposite to to the normal ones
1246
                lateral = np.zeros(label.shape[xyz2dims[0]])
×
1247
                lateral[int(np.floor(ibregma[0]))] = 1
×
1248
                lateral = np.sign(np.cumsum(lateral)[np.newaxis, :, np.newaxis] - 0.5)
×
1249
                label = label * lateral.astype(np.int32)
×
1250
                _, im = ismember(label, regions.id)
×
1251
                label = np.reshape(im.astype(np.uint16), label.shape)
×
1252
                np.savez_compressed(file_label_remap, label)
×
1253
                _logger.info(f"Cached remapping file {file_label_remap} ...")
×
1254
            # loads the files
1255
            label = self._read_volume(file_label_remap)
×
1256
            image = self._read_volume(file_image)
×
1257

1258
        super().__init__(image, label, dxyz, regions, ibregma,
×
1259
                         dims2xyz=dims2xyz, xyz2dims=xyz2dims)
1260

1261
    @staticmethod
1✔
1262
    def _read_volume(file_volume):
1✔
1263
        if file_volume.suffix == '.nrrd':
×
1264
            volume, _ = nrrd.read(file_volume, index_order='C')  # ml, dv, ap
×
1265
            # we want the coronal slice to be the most contiguous
1266
            volume = np.transpose(volume, (2, 0, 1))  # image[iap, iml, idv]
×
1267
        elif file_volume.suffix == '.npz':
×
1268
            volume = np.load(file_volume)['arr_0']
×
1269
        return volume
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc