• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

neurospin-deepinsight / surfify / 14535991368

18 Apr 2025 01:35PM UTC coverage: 84.325% (+0.2%) from 84.083%
14535991368

push

github

AGrigis
surfify/datasets/_generic: fix dataset.

1 of 9 new or added lines in 1 file covered. (11.11%)

47 existing lines in 3 files now uncovered.

2410 of 2858 relevant lines covered (84.32%)

0.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.43
/surfify/utils/sampling.py
1
# -*- coding: utf-8 -*-
2
##########################################################################
3
# NSAp - Copyright (C) CEA, 2021
4
# Distributed under the terms of the CeCILL-B license, as published by
5
# the CEA-CNRS-INRIA. Refer to the LICENSE file or to
6
# http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
7
# for details.
8
##########################################################################
9

10
"""
11
Spherical sampling & associated utilities.
12
"""
13

14
# Imports
15
import os
1✔
16
import tempfile
1✔
17
import collections
1✔
18
import numpy as np
1✔
19
import networkx as nx
1✔
20
from scipy.spatial import transform
1✔
21
from joblib import Parallel, delayed
1✔
22
from sklearn.neighbors import BallTree, NearestNeighbors
1✔
23
from .io import HidePrints
1✔
24

25

26
def normalize(vertex):
1✔
27
    """ Return vertex coordinates fixed to the unit sphere.
28
    """
29
    x, y, z = vertex
1✔
30
    length = np.sqrt(x**2 + y**2 + z**2)
1✔
31
    return [idx / length for idx in (x, y, z)]
1✔
32

33

34
R = (1 + np.sqrt(5)) / 2
1✔
35
STANDARD_ICO = {
1✔
36
    "vertices": [
37
        normalize([-1, R, 0]),
38
        normalize([1, R, 0]),
39
        normalize([-1, -R, 0]),
40
        normalize([1, -R, 0]),
41
        normalize([0, -1, R]),
42
        normalize([0, 1, R]),
43
        normalize([0, -1, -R]),
44
        normalize([0, 1, -R]),
45
        normalize([R, 0, -1]),
46
        normalize([R, 0, 1]),
47
        normalize([-R, 0, -1]),
48
        normalize([-R, 0, 1])],
49
    "triangles": [
50
        [0, 11, 5],
51
        [0, 5, 1],
52
        [0, 1, 7],
53
        [0, 7, 10],
54
        [0, 10, 11],
55
        [1, 5, 9],
56
        [5, 11, 4],
57
        [11, 10, 2],
58
        [10, 7, 6],
59
        [7, 1, 8],
60
        [3, 9, 4],
61
        [3, 4, 2],
62
        [3, 2, 6],
63
        [3, 6, 8],
64
        [3, 8, 9],
65
        [4, 9, 5],
66
        [2, 4, 11],
67
        [6, 2, 10],
68
        [8, 6, 7],
69
        [9, 8, 1]]
70
}
71

72

73
def neighbors(vertices, triangles, depth=1, direct_neighbor=False):
1✔
74
    """ Build mesh vertices neighbors.
75

76
    This is the base function to build Direct Neighbors (DiNe) kernels.
77

78
    See Also
79
    --------
80
    neighbors_rec
81

82
    Examples
83
    --------
84
    >>> from surfify.utils import icosahedron, neighbors
85
    >>> import matplotlib.pyplot as plt
86
    >>> from surfify.plotting import plot_trisurf
87
    >>> ico2_verts, ico2_tris = icosahedron(order=2)
88
    >>> neighs = neighbors(ico2_verts, ico2_tris, direct_neighbor=True)
89
    >>> fig, ax = plt.subplots(1, 1, subplot_kw={
90
            "projection": "3d", "aspect": "auto"}, figsize=(10, 10))
91
    >>> plot_trisurf(ico2_verts, triangles=ico2_tris, colorbar=False, fig=fig,
92
                     ax=ax)
93
    >>> center = ico2_verts[0]
94
    >>> for cnt, idx in enumerate(neighs[0]):
95
    >>>     point = ico2_verts[idx]
96
    >>>     ax.scatter(point[0], point[1], point[2], marker="o", c="red",
97
                       s=100)
98
    >>> ax.scatter(center[0], center[1], center[2], marker="o", c="blue",
99
                   s=100)
100
    >>> plt.show()
101

102
    Parameters
103
    ----------
104
    vertices: array (N, 3)
105
        the icosahedron vertices.
106
    triangles: array (M, 3)
107
        the icosahedron triangles.
108
    depth: int, default 1
109
        depth to stop the neighbors search, only paths of length <= depth are
110
        returned.
111
    direct_neighbor: bool, default False
112
        each spherical surface is composed of two types of vertices: 1) 12
113
        vertices with each having only 5 direct neighbors; and 2) the
114
        remaining vertices with each having 6 direct neighbors. For those
115
        vertices with 6 neighbors, DiNe assigns the index 1 to the center
116
        vertex and the indices 2-7 to its neighbors sequentially according
117
        to the angle between the vector of center vertex to neighboring vertex
118
        and the x-axis in the tangent plane. For the 12 vertices with only
119
        5 neighbors, DiNe assigns the indices both 1 and 2 to the center
120
        vertex, and indices 3-7 to the neighbors in the same way as those
121
        vertices with 6 neighbors.
122

123
    Returns
124
    --------
125
    neighs: dict
126
        a dictionary with vertices row index as keys and a dictionary of
127
        neighbors vertices row indexes organized by rings as values.
128
    """
129
    graph = vertex_adjacency_graph(vertices, triangles)
1✔
130
    degrees = dict((node, val) for node, val in graph.degree())
1✔
131
    neighs = collections.OrderedDict()
1✔
132
    for node in sorted(graph.nodes):
1✔
133
        node_neighs = {}
1✔
134
        # node_neighs = [idx for idx in graph.neighbors(node)]
135
        for neigh, ring in nx.single_source_shortest_path_length(
1✔
136
                graph, node, cutoff=depth).items():
137
            if ring == 0:
1✔
138
                continue
1✔
139
            node_neighs.setdefault(ring, []).append(neigh)
1✔
140
        if direct_neighbor:
1✔
141
            _node_neighs, _missing_neighs = [], {}
1✔
142
            n_neighs, center_missing_neighs = 0, False
1✔
143
            for ring, ring_neighs in node_neighs.items():
1✔
144
                angles = np.asarray([
1✔
145
                    get_angle_with_xaxis(vertices[node], vertices[node], vec)
146
                    for vec in vertices[ring_neighs]])
147
                angles = np.degrees(angles)
1✔
148
                ring_neighs = [x for _, x in sorted(
1✔
149
                    zip(angles, ring_neighs), key=lambda pair: pair[0])]
150
                node_neighs[ring] = ring_neighs
1✔
151
                n_neighs += 6 * ring
1✔
152
                _center_neighs = node_neighs[ring - 1] if ring > 1 else [node]
1✔
153
                _node_missing_neighs = [
1✔
154
                    _node for _node in _center_neighs if degrees[_node] == 5]
155
                for _node, _counts in _missing_neighs.items():
1✔
156
                    ring_neighs = [_node] * _counts[0] + ring_neighs
1✔
157
                    _missing_neighs[_node] = _counts[1:]
1✔
158
                for _node in _node_missing_neighs:
1✔
159
                    _missing_neighs[_node] = list(range(2, depth + 2 - ring))
1✔
160
                    if _node == node:
1✔
161
                        center_missing_neighs = True
1✔
162
                        continue
1✔
163
                    _node_neighs.insert(_node_neighs.index(_node), _node)
1✔
164
                _node_neighs.extend(ring_neighs)
1✔
165
            _node_neighs.insert(0, node)
1✔
166
            if center_missing_neighs:
1✔
167
                _node_neighs.insert(0, node)
1✔
168
            if len(_node_neighs) != n_neighs + 1:
1✔
UNCOV
169
                raise ValueError("Mesh is not an icosahedron.")
×
170
            node_neighs = _node_neighs
1✔
171
        neighs[node] = node_neighs
1✔
172
    return neighs
1✔
173

174

175
def vertex_adjacency_graph(vertices, triangles):
1✔
176
    """ Build a networkx graph representation of the vertices and
177
    their connections in the mesh.
178

179
    Examples
180
    --------
181
    This is useful for getting nearby vertices for a given vertex,
182
    potentially for some simple smoothing techniques.
183
    >>> graph = mesh.vertex_adjacency_graph
184
    >>> graph.neighbors(0)
185
    > [1, 3, 4]
186

187
    Parameters
188
    ----------
189
    vertices: array (N, 3)
190
        the icosahedron vertices.
191
    triangles: array (M, 3)
192
        the icosahedron triangles.
193

194
    Returns
195
    -------
196
    graph: networkx.Graph
197
        Graph representing vertices and edges between
198
        them where vertices are nodes and edges are edges
199
    """
200
    graph = nx.Graph()
1✔
201
    graph.add_nodes_from(range(len(vertices)))
1✔
202
    edges, edges_triangle = triangles_to_edges(triangles)
1✔
203
    edges_cache = []
1✔
204
    for idx1, idx2 in edges:
1✔
205
        smaller_index = min(idx1, idx2)
1✔
206
        greater_index = max(idx1, idx2)
1✔
207
        key = "{0}-{1}".format(smaller_index, greater_index)
1✔
208
        if key in edges_cache:
1✔
209
            continue
1✔
210
        edges_cache.append(key)
1✔
211
        graph.add_edge(smaller_index, greater_index)
1✔
212
    return graph
1✔
213

214

215
def get_angle_with_xaxis(center, normal, point):
1✔
216
    """ Project a point to the sphere tangent plane and compute the angle
217
    with the x-axis.
218

219
    Parameters
220
    ----------
221
    center: array (3, )
222
        a point in the plane.
223
    normal: array (3, )
224
        the normal to the plane.
225
    points: array (3, )
226
        the points to be projected.
227
    """
228
    # Assert is array
229
    center = np.asarray(center)
1✔
230
    normal = np.asarray(normal)
1✔
231
    point = np.asarray(point)
1✔
232

233
    # Project points to plane
234
    vector = point - center
1✔
235
    dist = np.dot(vector, normal)
1✔
236
    projection = point - normal * dist
1✔
237

238
    # Compute normal of the new projected x-axis and y-axis
239
    if center[0] != 0 or center[1] != 0:
1✔
240
        nx = np.cross(np.array([0, 0, 1]), center)
1✔
241
        ny = np.cross(center, nx)
1✔
242
    else:
243
        nx = np.array([1, 0, 0])
1✔
244
        ny = np.array([0, 1, 0])
1✔
245

246
    # Compute the angle between projected points and the x-axis
247
    vector = projection - center
1✔
248
    unit_vector = vector
1✔
249
    if np.linalg.norm(vector) != 0:
1✔
250
        unit_vector = unit_vector / np.linalg.norm(vector)
1✔
251
    unit_nx = nx / np.linalg.norm(nx)
1✔
252
    cos_theta = np.dot(unit_vector, unit_nx)
1✔
253
    if cos_theta > 1.:
1✔
254
        cos_theta = 1.
1✔
255
    elif cos_theta < -1.:
1✔
256
        cos_theta = -1.
1✔
257
    angle = np.arccos(cos_theta)
1✔
258
    if np.dot(unit_vector, ny) < 0:
1✔
259
        angle = 2 * np.pi - angle
1✔
260

261
    return angle
1✔
262

263

264
def triangles_to_edges(triangles, return_index=False):
1✔
265
    """ Given a list of triangles, return a list of edges.
266

267
    Parameters
268
    ----------
269
    triangles: array int (N, 3)
270
        Vertex indices representing triangles.
271

272
    Returns
273
    -------
274
    edges: array int (N * 3, 2)
275
        Vertex indices representing edges.
276
    triangles_index: array (N * 3, )
277
        Triangle indexes.
278
    """
279
    # Each triangles has three edges
280
    edges = triangles[:, [0, 1, 1, 2, 2, 0]].reshape((-1, 2))
1✔
281

282
    # Edges are in order of triangles due to reshape
283
    triangles_index = np.tile(
1✔
284
        np.arange(len(triangles)), (3, 1)).T.reshape(-1)
285

286
    return edges, triangles_index
1✔
287

288

289
def neighbors_rec(vertices, triangles, size=5, zoom=5):
1✔
290
    """ Build rectangular grid neighbors and weights.
291

292
    This is the base function to build Rectangular Patch (RePa) kernels.
293

294
    See Also
295
    --------
296
    neighbors
297

298
    Examples
299
    --------
300
    >>> from surfify.utils import icosahedron, neighbors_rec
301
    >>> import matplotlib.pyplot as plt
302
    >>> from surfify.plotting import plot_trisurf
303
    >>> ico2_verts, ico2_tris = icosahedron(order=2)
304
    >>> neighs = neighbors_rec(ico2_verts, ico2_tris, size=3, zoom=3)
305
    >>> fig, ax = plt.subplots(1, 1, subplot_kw={
306
            "projection": "3d", "aspect": "auto"}, figsize=(10, 10))
307
    >>> plot_trisurf(ico2_verts, triangles=ico2_tris, colorbar=False, fig=fig,
308
                     ax=ax)
309
    >>> center = ico2_verts[0]
310
    >>> for cnt, point in enumerate(neighs[2][0]):
311
    >>>     ax.scatter(point[0], point[1], point[2], marker="o", c="red",
312
                       s=100)
313
    >>> ax.scatter(center[0], center[1], center[2], marker="o", c="blue",
314
                   s=100)
315
    >>> plt.show()
316

317
    Parameters
318
    ----------
319
    vertices: array (N, 3)
320
        the icosahedron vertices.
321
    triangles: array (N, 3)
322
        the icosahedron triangles.
323
    size: int, default 5
324
        the rectangular grid size.
325
    zoom: int, default 5
326
        scale factor applied on the unit sphere to control the neighborhood
327
        density.
328

329
    Returns
330
    --------
331
    neighs: array (N, size**2, 3)
332
        grid samples neighbors for each vertex.
333
    weights: array (N, size**2, 3)
334
        grid samples weights with neighbors for each vertex.
335
    grid_in_sphere: array (N, size**2, 3)
336
        zoomed rectangular grid on the sphere vertices.
337
    """
338
    grid_in_sphere = np.zeros((len(vertices), size**2, 3), dtype=float)
1✔
339
    neighs = np.zeros((len(vertices), size**2, 3), dtype=int)
1✔
340
    weights = np.zeros((len(vertices), size**2, 3), dtype=float)
1✔
341
    for idx1, node in enumerate(vertices):
1✔
342
        grid_in_sphere[idx1], _ = get_rectangular_projection(
1✔
343
            node, size=size, zoom=zoom)
344
        for idx2, point in enumerate(grid_in_sphere[idx1]):
1✔
345
            dist = np.linalg.norm(vertices - point, axis=1)
1✔
346
            ordered_neighs = np.argsort(dist)
1✔
347
            neighs[idx1, idx2] = ordered_neighs[:3]
1✔
348
            weights[idx1, idx2] = dist[neighs[idx1, idx2]]
1✔
349
            weights[idx1, idx2] /= np.sum(dist[neighs[idx1, idx2]])
1✔
350
    return neighs, weights, grid_in_sphere
1✔
351

352

353
def get_rectangular_projection(node, size=5, zoom=5):
1✔
354
    """ Project 2D rectangular grid defined in node tangent space into 3D
355
    spherical space.
356

357
    Parameters
358
    ----------
359
    node: array (3, )
360
        a point in the sphere.
361
    size: int, default 5
362
        the rectangular grid size.
363
    zoom: int, default 5
364
        scale factor applied on the unit sphere to control the neighborhood
365
        density.
366

367
    Returns
368
    -------
369
    grid_in_sphere: array (size**2, 3)
370
        zoomed rectangular grid on the sphere.
371
    grid_in_tplane: array (size**2, 3)
372
        zoomed rectangular grid in the tangent space.
373
    """
374
    # Check kernel size
375
    if (size % 2) == 0:
1✔
UNCOV
376
        raise ValueError("An odd kernel size is expected.")
×
377
    midsize = size // 2
1✔
378

379
    # Compute normal of the new projected x-axis and y-axis
380
    node = node.copy()
1✔
381
    if node[0] != 0 or node[1] != 0:
1✔
382
        nx = np.cross(np.array([0, 0, 1]), node)
1✔
383
        ny = np.cross(node, nx)
1✔
384
    else:
385
        nx = np.array([1, 0, 0])
1✔
386
        ny = np.array([0, 1, 0])
1✔
387
    nx = nx / np.linalg.norm(nx)
1✔
388
    ny = ny / np.linalg.norm(ny)
1✔
389

390
    # Caculate the grid coordinate in tangent plane and project back on sphere
391
    grid_in_tplane = np.zeros((size ** 2, 3))
1✔
392
    grid_in_sphere = np.zeros((size ** 2, 3))
1✔
393
    spacing = 1 / zoom
1✔
394
    midsize *= spacing
1✔
395
    corner = node - midsize * nx + midsize * ny
1✔
396
    for row in range(size):
1✔
397
        for column in range(size):
1✔
398
            point = corner - row * spacing * ny + column * spacing * nx
1✔
399
            grid_in_tplane[row * size + column, :] = point
1✔
400
            grid_in_sphere[row * size + column, :] = (
1✔
401
                point / np.linalg.norm(point))
402

403
    return grid_in_sphere, grid_in_tplane
1✔
404

405

406
def find_neighbors(start_node, order, neighbors):
1✔
407
    """ Recursively find neighbors from a starting node up to a certain order.
408

409
    See Also
410
    --------
411
    neighbors, neighbors_rec
412

413
    Examples
414
    --------
415
    >>> from surfify.utils import icosahedron, neighbors_rec, find_neighbors
416
    >>> import matplotlib.pyplot as plt
417
    >>> from surfify.plotting import plot_trisurf
418
    >>> ico2_verts, ico2_tris = icosahedron(order=2)
419
    >>> neighs = neighbors_rec(ico2_verts, ico2_tris, size=3, zoom=3)[0]
420
    >>> neighs = neighs.reshape(len(neighs), -1)
421
    >>> neighs = neighbors(ico2_verts, ico2_tris, depth=1,
422
                           direct_neighbor=True)
423
    >>> node = 0
424
    >>> node_neighs = find_neighbors(node, order=3, neighbors=neighs)
425
    >>> fig, ax = plt.subplots(1, 1, subplot_kw={
426
            "projection": "3d", "aspect": "auto"}, figsize=(10, 10))
427
    >>> plot_trisurf(ico2_verts, triangles=ico2_tris, colorbar=False, fig=fig,
428
                     ax=ax)
429
    >>> center = ico2_verts[node]
430
    >>> for cnt, idx in enumerate(node_neighs):
431
    >>>     point = ico2_verts[idx]
432
    >>>     ax.scatter(point[0], point[1], point[2], marker="o", c="red",
433
                       s=100)
434
    >>> ax.scatter(center[0], center[1], center[2], marker="o", c="blue",
435
                   s=100)
436
    >>> plt.show()
437

438
    Parameters
439
    ----------
440
    start_node: int
441
        node index to start search from.
442
    order: int
443
        order up to which to look for neighbors.
444
    neighbors: dict
445
        neighbors for each node as generated by the 'neighbors' or
446
        'neighbors_rec' functions.
447

448
    Returns
449
    -------
450
    indices: list of int
451
        the n-ring neighbors indices.
452
    """
453
    indices = []
1✔
454
    if order <= 0:
1✔
UNCOV
455
        return [start_node]
×
456
    for neigh in neighbors[start_node]:
1✔
457
        if order == 1:
1✔
458
            indices.append(neigh)
1✔
459
        else:
UNCOV
460
            indices += find_neighbors(neigh, order - 1, neighbors)
×
461
    return list(set(indices))
1✔
462

463

464
def build_freesurfer_ico(ico_file=None):
1✔
465
    """ Build FreeSurfer reference icosahedron by fetching existing data
466
    and building lower orders using downsampling.
467

468
    Freesurfer coordinates are between -100 and 100, and are rescaled between
469
    -1 and 1.
470

471
    Parameters
472
    ----------
473
    ico_file: str, default None
474
        path to the generated FreeSurfer reference icosahedron topologies.
475
    """
476
    from nilearn.surface import load_surf_mesh
×
UNCOV
477
    from nilearn.datasets import fetch_surf_fsaverage
×
478

479
    if ico_file is None:
×
UNCOV
480
        resource_dir = os.path.join(
×
481
            os.path.dirname(os.path.dirname(__file__)), "resources")
482
        ico_file = os.path.join(resource_dir, "freesurfer_icos.npz")
×
483
    data = {}
×
484
    for order in range(7, 2, -1):
×
485
        surf_name = "fsaverage{0}".format(order)
×
486
        with HidePrints(hide_err=True):
×
487
            with tempfile.TemporaryDirectory() as tmpdir:
×
UNCOV
488
                fsaverage = fetch_surf_fsaverage(
×
489
                    mesh=surf_name, data_dir=tmpdir)
490
                vertices, triangles = load_surf_mesh(fsaverage["sphere_left"])
×
491
            vertices /= 100.
×
492
            data[surf_name + ".vertices"] = vertices.astype(np.float32)
×
493
            data[surf_name + ".triangles"] = triangles
×
494
    for order in range(2, -1, -1):
×
495
        surf_name = "fsaverage{0}".format(order)
×
496
        up_vertices = data["fsaverage{0}.vertices".format(order + 1)]
×
497
        up_triangles = data["fsaverage{0}.triangles".format(order + 1)]
×
498
        vertices, triangles = downsample_ico(up_vertices, up_triangles, by=1)
×
499
        data[surf_name + ".vertices"] = vertices
×
500
        data[surf_name + ".triangles"] = triangles
×
UNCOV
501
    np.savez(ico_file, **data)
×
502

503

504
def build_fslr_ref(ref_file=None):
1✔
505
    """ Build FSLR reference by fetching existing data.
506

507
    Parameters
508
    ----------
509
    ref_file: str, default None
510
        path to the generated FSLR reference topologies.
511
    """
512
    from nilearn.surface import load_surf_mesh
×
UNCOV
513
    from neuromaps.datasets import fetch_fslr
×
514

515
    if ref_file is None:
×
UNCOV
516
        resource_dir = os.path.join(
×
517
            os.path.dirname(os.path.dirname(__file__)), "resources")
518
        ref_file = os.path.join(resource_dir, "fslr_refs.npz")
×
519
    data = {}
×
520
    for den in ("4k", "8k", "32k", "164k"):
×
521
        surf_name = "fslr{0}".format(den)
×
522
        with HidePrints(hide_err=True):
×
523
            with tempfile.TemporaryDirectory() as tmpdir:
×
524
                fslr = fetch_fslr(density=den, data_dir=tmpdir)
×
525
                vertices, triangles = load_surf_mesh(fslr["sphere"].L)
×
526
            data[surf_name + ".vertices"] = vertices.astype(np.float32)
×
527
            data[surf_name + ".triangles"] = triangles
×
UNCOV
528
    np.savez(ref_file, **data)
×
529

530

531
def icosahedron(order=3, standard_ico=False):
1✔
532
    """ Define an icosahedron mesh of any order.
533

534
    Examples
535
    --------
536
    >>> from surfify.utils import icosahedron
537
    >>> import matplotlib.pyplot as plt
538
    >>> from surfify.plotting import plot_trisurf
539
    >>> ico3_verts, ico3_tris = icosahedron(order=3)
540
    >>> print(ico3_verts.shape, ico3_tris.shape)
541
    >>> fig, ax = plt.subplots(1, 1, subplot_kw={
542
            "projection": "3d", "aspect": "auto"}, figsize=(10, 10))
543
    >>> plot_trisurf(ico3_verts, triangles=ico3_tris, colorbar=False, fig=fig,
544
                     ax=ax)
545
    >>> plt.show()
546

547
    Parameters
548
    ----------
549
    order: int, default 3
550
        the icosahedron order.
551
    standard_ico: bool, default False
552
        optionally uses a standard icosahedron tessalation.
553

554
    Returns
555
    -------
556
    vertices: array (N, 3)
557
        the icosahedron vertices.
558
    triangles: array (M, 3)
559
        the icosahedron triangles.
560
    """
561
    if standard_ico:
1✔
562
        vertices = STANDARD_ICO["vertices"].copy()
1✔
563
        triangles = STANDARD_ICO["triangles"].copy()
1✔
564
        middle_point_cache = {}
1✔
565
        for _ in range(order):
1✔
566
            subdiv = []
1✔
567
            for tri in triangles:
1✔
568
                v1 = middle_point(tri[0], tri[1], vertices, middle_point_cache)
1✔
569
                v2 = middle_point(tri[1], tri[2], vertices, middle_point_cache)
1✔
570
                v3 = middle_point(tri[2], tri[0], vertices, middle_point_cache)
1✔
571
                subdiv.append([tri[0], v1, v3])
1✔
572
                subdiv.append([tri[1], v2, v1])
1✔
573
                subdiv.append([tri[2], v3, v2])
1✔
574
                subdiv.append([v1, v2, v3])
1✔
575
            triangles = subdiv
1✔
576
        vertices = np.asarray(vertices)
1✔
577
        triangles = np.asarray(triangles)
1✔
578
    else:
579
        resource_dir = os.path.join(
1✔
580
            os.path.dirname(os.path.dirname(__file__)), "resources")
581
        resource_file = os.path.join(resource_dir, "freesurfer_icos.npz")
1✔
582
        icos = np.load(resource_file)
1✔
583
        surf_name = "fsaverage{0}".format(order)
1✔
584
        try:
1✔
585
            vertices = icos[surf_name + ".vertices"]
1✔
586
            triangles = icos[surf_name + ".triangles"]
1✔
587
        except Exception as err:
×
588
            print("-- available topologies:", icos.files)
×
UNCOV
589
            raise err
×
590

591
    return vertices, triangles
1✔
592

593

594
def middle_point(point_1, point_2, vertices, middle_point_cache=None):
1✔
595
    """ Find a middle point and project it to the unit sphere.
596

597
    This function is only used to build an icosahedron geometry.
598
    """
599
    # We check if we have already cut this edge first to avoid duplicated verts
600
    smaller_index = min(point_1, point_2)
1✔
601
    greater_index = max(point_1, point_2)
1✔
602
    key = "{0}-{1}".format(smaller_index, greater_index)
1✔
603
    if middle_point_cache is not None and key in middle_point_cache:
1✔
604
        return middle_point_cache[key]
1✔
605

606
    # If it's not in cache, then we can cut it
607
    vert_1 = vertices[point_1]
1✔
608
    vert_2 = vertices[point_2]
1✔
609
    middle = [sum(elems) / 2. for elems in zip(vert_1, vert_2)]
1✔
610
    vertices.append(normalize(middle))
1✔
611
    index = len(vertices) - 1
1✔
612
    if middle_point_cache is not None:
1✔
613
        middle_point_cache[key] = index
1✔
614

615
    return index
1✔
616

617

618
def patch_tri(order=6, standard_ico=False, size=3, direct_neighbor=False,
1✔
619
              n_jobs=1):
620
    """ Build triangular patches that map the icosahedron.
621

622
    This is the base function for Vision Transformers.
623

624
    See Also
625
    --------
626

627
    Examples
628
    --------
629
    >>> from surfify.utils import icosahedron, patch_tri
630
    >>> import matplotlib.pyplot as plt
631
    >>> from surfify.plotting import plot_trisurf
632
    >>> ico3_verts, ico3_tris = icosahedron(order=3)
633
    >>> patches = patch_tri(order=3, size=1, size=1)
634
    >>> fig, ax = plt.subplots(1, 1, subplot_kw={
635
            "projection": "3d", "aspect": "auto"}, figsize=(10, 10))
636
    >>> plot_trisurf(ico2_verts, triangles=ico2_tris, colorbar=False, fig=fig,
637
                     ax=ax)
638
    >>> for cnt, idx in enumerate(patches[10]):
639
    >>>     point = ico3_verts[idx]
640
    >>>     ax.scatter(point[0], point[1], point[2], marker="o", s=100)
641
    >>> plt.show()
642

643
    Parameters
644
    ----------
645
    order: int, default 6
646
        the icosahedron order.
647
    standard_ico: bool, default False
648
        optionally uses a standard icosahedron tessalation. FreeSurfer
649
        tesselation is used by default.
650
    size: int, default 3
651
        the patch size.
652
    direct_neighbor: bool, default False
653
        order patch vertices.
654
    n_jobs: int, default 1
655
        the maximum number of concurrently running jobs.
656

657
    Returns
658
    --------
659
    patches: array
660
        triangular patches containing icosahedron indices.
661
    """
662
    assert (order - size) >= 0, "Wrong patch definition!"
1✔
663
    vertices, triangles = icosahedron(order, standard_ico)
1✔
664
    lower_vertices, lower_triangles = icosahedron(order - size, standard_ico)
1✔
665
    neigh = NearestNeighbors(n_neighbors=1)
1✔
666
    neigh.fit(vertices)
1✔
667
    patches = []
1✔
668
    patches = Parallel(n_jobs=n_jobs)(delayed(_patch_tri_iter)(
1✔
669
        vertices, lower_vertices, tri, size, neigh, direct_neighbor)
670
            for tri in lower_triangles)
671
    patches = np.array(patches)
1✔
672
    return patches
1✔
673

674

675
def _patch_tri_iter(vertices, lower_vertices, tri, size, neigh,
1✔
676
                    direct_neighbor):
677
    """ Build a triangular patch from input triangle.
678

679
    See Also
680
    --------
681
    patch_tri
682
    """
683
    _vertices = [lower_vertices[idx] for idx in tri]
×
684
    _triangles = [[0, 1, 2]]
×
685
    for _ in range(size):
×
686
        subdiv = []
×
UNCOV
687
        for _tri in _triangles:
×
UNCOV
688
            v1 = middle_point(_tri[0], _tri[1], _vertices)
×
689
            v2 = middle_point(_tri[1], _tri[2], _vertices)
×
690
            v3 = middle_point(_tri[2], _tri[0], _vertices)
×
UNCOV
691
            subdiv.append([_tri[0], v1, v3])
×
692
            subdiv.append([_tri[1], v2, v1])
×
693
            subdiv.append([_tri[2], v3, v2])
×
694
            subdiv.append([v1, v2, v3])
×
UNCOV
695
        _triangles = subdiv
×
UNCOV
696
    locs = neigh.kneighbors(_vertices, return_distance=False)
×
UNCOV
697
    locs = np.unique(locs.squeeze())
×
UNCOV
698
    if direct_neighbor:
×
UNCOV
699
        center = np.mean(lower_vertices[tri], axis=1)
×
UNCOV
700
        center /= np.linalg.norm(center)
×
UNCOV
701
        angles = np.asarray([
×
702
            get_angle_with_xaxis(center, center, vec)
703
            for vec in vertices[locs]])
UNCOV
704
        angles = np.degrees(angles)
×
UNCOV
705
        locs = [x for _, x in sorted(
×
706
            zip(angles, locs), key=lambda pair: pair[0])]
UNCOV
707
    return locs
×
708

709

710
def number_of_ico_vertices(order=3):
1✔
711
    """ Get the number of vertices of an icosahedron of specific order.
712

713
    See Also
714
    --------
715
    order_of_ico_from_vertices
716

717
    Examples
718
    --------
719
    >>> from surfify.utils import number_of_ico_vertices, icosahedron
720
    >>> ico3_verts, ico3_tris = icosahedron(order=3)
721
    >>> n_verts = number_of_ico_vertices(order=3)
722
    >>> print(n_verts, ico3_verts.shape)
723

724
    Parameters
725
    ----------
726
    order: int, default 3
727
        the icosahedron order.
728

729
    Returns
730
    -------
731
    n_vertices: int
732
        number of vertices of the corresponding icosahedron
733
    """
734
    return 10 * 4 ** order + 2
1✔
735

736

737
def order_of_ico_from_vertices(n_vertices):
1✔
738
    """ Get the order of an icosahedron from his number of vertices.
739

740
    See Also
741
    --------
742
    number_of_ico_vertices
743

744
    Examples
745
    --------
746
    >>> from surfify.utils import order_of_ico_from_vertices, icosahedron
747
    >>> ico3_verts, ico3_tris = icosahedron(order=3)
748
    >>> order = order_of_ico_from_vertices(len(ico3_verts))
749
    >>> print(order)
750

751
    Parameters
752
    ----------
753
    n_vertices: int
754
        the number of vertices of an icosahedron.
755

756
    Returns
757
    -------
758
    order: int
759
        the order of the icosahedron
760
    """
761
    order = np.log((n_vertices - 2) / 10) / np.log(4)
1✔
762
    if int(order) != order:
1✔
UNCOV
763
        raise ValueError(
×
764
            "This number of vertices does not correspond to those of a "
765
            "regular icosahedron.")
766
    return int(order)
1✔
767

768

769
def number_of_neighbors(depth):
1✔
770
    """ Get the number of neighbors up to a certain depth.
771

772
    See Also
773
    --------
774
    min_order_to_get_n_neighbors
775

776
    Examples
777
    --------
778
    >>> from surfify.utils import number_of_neighbors
779
    >>> for depth in range(4):
780
    >>>     n_neighs = number_of_neighbors(depth)
781
    >>>     print(n_neighs)
782

783
    Parameters
784
    ----------
785
    n_vertices: int
786
        the number of vertices of an icosahedron.
787

788
    Returns
789
    -------
790
    order: int
791
        the order of the icosahedron.
792
    """
UNCOV
793
    n_neighs = 1
×
UNCOV
794
    for order in range(1, depth + 1):
×
UNCOV
795
        n_neighs += 6 * order
×
UNCOV
796
    return n_neighs
×
797

798

799
def min_depth_to_get_n_neighbors(n_neighs):
1✔
800
    """ Get the minimal depth of neighborhood to get a desired number of
801
    neighbors.
802

803
    See Also
804
    --------
805
    number_of_neighbors
806

807
    Examples
808
    --------
809
    >>> from surfify.utils import min_depth_to_get_n_neighbors, icosahedron
810
    >>> ico3_verts, ico3_tris = icosahedron(order=3)
811
    >>> depth = min_depth_to_get_n_neighbors(len(ico3_verts) / 2)
812
    >>> print(depth)
813

814
    Parameters
815
    ----------
816
    n_vertices: int
817
        the number of vertices of an icosahedron.
818

819
    Returns
820
    -------
821
    order: int
822
        the order of the icosahedron.
823
    """
UNCOV
824
    cum_n_neighs = 1
×
UNCOV
825
    depth = 1
×
UNCOV
826
    while (cum_n_neighs < n_neighs):
×
UNCOV
827
        cum_n_neighs += 6 * depth
×
UNCOV
828
        depth += 1
×
UNCOV
829
    return depth
×
830

831

832
def interpolate(vertices, target_vertices, target_triangles):
1✔
833
    """ Interpolate icosahedron missing data by finding nearest neighbors.
834

835
    Interpolation weights are set to 1 for a regular icosahedron geometry.
836

837
    See Also
838
    --------
839
    interpolate_data, downsample, downsample_data, downsample_ico
840

841
    Examples
842
    --------
843
    >>> from surfify.utils import icosahedron, interpolate
844
    >>> from surfify.datasets import make_classification
845
    >>> import matplotlib.pyplot as plt
846
    >>> from surfify.plotting import plot_trisurf
847
    >>> ico2_verts, ico2_tris = icosahedron(order=2)
848
    >>> ico3_verts, ico3_tris = icosahedron(order=3)
849
    >>> X, y = make_classification(ico2_verts, n_samples=1, n_classes=3,
850
                                   scale=1, seed=42)
851
    >>> up_indices = interpolate(ico2_verts, ico3_verts, ico3_tris)
852
    >>> up_indices = np.asarray(list(up_indices.values()))
853
    >>> y_up = y[up_indices.reshape(-1)].reshape(up_indices.shape)
854
    >>> y_up = np.mean(y_up, axis=-1)
855
    >>> plot_trisurf(ico3_verts, triangles=ico3_tris, texture=y_up,
856
                     is_label=False)
857
    >>> plt.show()
858

859
    Parameters
860
    ----------
861
    vertices: array (n_samples, n_dim)
862
        points of data set.
863
    target_vertices: array (n_query, n_dim)
864
        points to find interpolated texture for.
865
    target_triangles: array (n_query, 3)
866
        the mesh geometry definition.
867

868
    Returns
869
    -------
870
    interp_indices: array (n_query, n_feats)
871
        the interpolation indices.
872
    """
873
    interp_indices = collections.OrderedDict()
1✔
874
    graph = vertex_adjacency_graph(target_vertices, target_triangles)
1✔
875
    common_vertices = downsample(target_vertices, vertices)
1✔
876
    # missing_vertices = (set(range(len(target_vertices))) -
877
    #                     set(common_vertices))
878
    for node in sorted(graph.nodes):
1✔
879
        if node in common_vertices:
1✔
880
            interp_indices[node] = [node] * 2
1✔
881
        else:
882
            node_neighs = [idx for idx in graph.neighbors(node)
1✔
883
                           if idx in common_vertices]
884
            interp_indices[node] = node_neighs
1✔
885
    return interp_indices
1✔
886

887

888
def interpolate_data(data, by=1, up_indices=None):
1✔
889
    """ Interpolate data/texture on the icosahedron to an upper order.
890

891
    See Also
892
    --------
893
    interpolate, downsample, downsample_data, downsample_ico
894

895
    Examples
896
    --------
897
    >>> from surfify.utils import icosahedron, interpolate_data
898
    >>> from surfify.datasets import make_classification
899
    >>> import matplotlib.pyplot as plt
900
    >>> from surfify.plotting import plot_trisurf
901
    >>> ico2_verts, ico2_tris = icosahedron(order=2)
902
    >>> ico4_verts, ico4_tris = icosahedron(order=4)
903
    >>> X, y = make_classification(ico2_verts, n_samples=1, n_classes=3,
904
                                   scale=1, seed=42)
905
    >>> y = y.reshape(1, -1, 1)
906
    >>> y_up = interpolate_data(y, by=2).squeeze()
907
    >>> plot_trisurf(ico4_verts, triangles=ico4_tris, texture=y_up,
908
                     is_label=False)
909
    >>> plt.show()
910

911
    Parameters
912
    ----------
913
    data: array (n_samples, n_vertices, n_features)
914
        data to be upsampled.
915
    by: int, default 1
916
        number of orders to increase the icosahedron by.
917
    up_indices: list of array, default None
918
        optionally specify the list of consecutive upsampling vertices
919
        indices.
920

921
    Returns
922
    -------
923
    upsampled_data: array (n_samples, new_n_vertices, n_features)
924
        upsampled data.
925
    """
926
    if len(data.shape) != 3:
1✔
UNCOV
927
        raise ValueError(
×
928
            "Unexpected input data. Must be (n_samples, n_vertices, "
929
            "n_features) but got '{0}'.".format(data.shape))
930
    if up_indices is None:
1✔
931
        order = order_of_ico_from_vertices(data.shape[1])
1✔
932
        ico_verts, _ = icosahedron(order)
1✔
933
        up_indices = []
1✔
934
        for up_order in range(order + 1, order + 1 + by, 1):
1✔
935
            up_ico_verts, up_ico_tris = icosahedron(up_order)
1✔
936
            _up_indices = interpolate(ico_verts, up_ico_verts, up_ico_tris)
1✔
937
            up_indices.append(np.asarray(list(_up_indices.values())))
1✔
938
            ico_verts = up_ico_verts
1✔
939
    n_samples = len(data)
1✔
940
    n_features = data.shape[-1]
1✔
941
    for indices in up_indices:
1✔
942
        n_new_vertices, n_neighs = indices.shape
1✔
943
        data = data[:, indices.reshape(-1)].reshape(
1✔
944
            n_samples, n_new_vertices, n_neighs, n_features)
945
        data = np.mean(data, axis=2)
1✔
946
    return data
1✔
947

948

949
def downsample(vertices, target_vertices):
1✔
950
    """ Downsample icosahedron vertices by finding nearest neighbors.
951

952
    See Also
953
    --------
954
    downsample_data, downsample_ico, interpolate, interpolate_data
955

956
    Examples
957
    --------
958
    >>> from surfify.utils import icosahedron, downsample
959
    >>> import matplotlib.pyplot as plt
960
    >>> from surfify.plotting import plot_trisurf
961
    >>> ico2_verts, ico2_tris = icosahedron(order=2)
962
    >>> ico3_verts, ico3_tris = icosahedron(order=3)
963
    >>> down3to2 = downsample(ico3_verts, ico2_verts)
964
    >>> ico3_down_vertices = ico3_verts[down3to2]
965
    >>> fig, ax = plt.subplots(1, 1, subplot_kw={
966
            "projection": "3d", "aspect": "auto"}, figsize=(10, 10))
967
    >>> plot_trisurf(ico3_verts, triangles=ico3_tris, colorbar=False, fig=fig,
968
                     ax=ax)
969
    >>> for cnt, point in enumerate(ico3_down_vertices):
970
    >>>     ax.scatter(point[0], point[1], point[2], marker="o", c="red",
971
                       s=100)
972
    >>> plt.show()
973

974
    Parameters
975
    ----------
976
    vertices: array (n_samples, n_dim)
977
        points of data set.
978
    target_vertices: array (n_query, n_dim)
979
        points to find nearest neighbors for.
980

981
    Returns
982
    -------
983
    nearest_idx: array (n_query, )
984
        index of nearest neighbor in target_vertices for every point in
985
        vertices.
986
    """
987
    if vertices.size == 0 or target_vertices.size == 0:
1✔
UNCOV
988
        return np.array([], int), np.array([])
×
989
    tree = BallTree(vertices, leaf_size=2)
1✔
990
    distances, nearest_idx = tree.query(
1✔
991
        target_vertices, return_distance=True, k=1)
992
    n_duplicates = len(nearest_idx) - len(np.unique(nearest_idx))
1✔
993
    if n_duplicates:
1✔
UNCOV
994
        raise RuntimeError("Could not downsample proprely, '{0}' duplicates "
×
995
                           "were found. Are you using an icosahedron "
996
                           "mesh?".format(n_duplicates))
997
    return nearest_idx.squeeze()
1✔
998

999

1000
def downsample_data(data, by=1, down_indices=None):
1✔
1001
    """ Downsample data/texture on the icosahedron to a lower order.
1002

1003
    See Also
1004
    --------
1005
    downsample, downsample_ico, interpolate, interpolate_data
1006

1007
    Examples
1008
    --------
1009
    >>> from surfify.utils import icosahedron, downsample_data
1010
    >>> from surfify.datasets import make_classification
1011
    >>> import matplotlib.pyplot as plt
1012
    >>> from surfify.plotting import plot_trisurf
1013
    >>> ico2_verts, ico2_tris = icosahedron(order=2)
1014
    >>> ico4_verts, ico4_tris = icosahedron(order=4)
1015
    >>> X, y = make_classification(ico4_verts, n_samples=1, n_classes=3,
1016
                                   scale=1, seed=42)
1017
    >>> y = y.reshape(1, -1, 1)
1018
    >>> y_down = downsample_data(y, by=2).squeeze()
1019
    >>> plot_trisurf(ico2_verts, triangles=ico2_tris, texture=y_down,
1020
                     is_label=True)
1021
    >>> plt.show()
1022

1023
    Parameters
1024
    ----------
1025
    data: array (n_samples, n_vertices, n_features)
1026
        data to be downsampled.
1027
    by: int, default 1
1028
        number of orders to reduce the icosahedron by.
1029
    down_indices: list of array, default None
1030
        optionally specify the list of consecutive downsampling vertices
1031
        indices.
1032

1033
    Returns
1034
    -------
1035
    downsampled_data: array (n_samples, new_n_vertices, n_features)
1036
        downsampled data.
1037
    """
1038
    if len(data.shape) != 3:
1✔
UNCOV
1039
        raise ValueError(
×
1040
            "Unexpected input data. Must be (n_samples, n_vertices, "
1041
            "n_features) but got '{0}'.".format(data.shape))
1042
    if down_indices is None:
1✔
1043
        order = order_of_ico_from_vertices(data.shape[1])
1✔
1044
        ico_verts, _ = icosahedron(order)
1✔
1045
        down_indices = []
1✔
1046
        for low_order in range(order - 1, order - 1 - by, -1):
1✔
1047
            low_ico_verts, _ = icosahedron(low_order)
1✔
1048
            down_indices.append(downsample(ico_verts, low_ico_verts))
1✔
1049
            ico_verts = low_ico_verts
1✔
1050
    for indices in down_indices:
1✔
1051
        data = data[:, indices]
1✔
1052
    return data
1✔
1053

1054

1055
def downsample_ico(vertices, triangles, by=1, down_indices=None):
1✔
1056
    """ Downsample an icosahedron full geometry: vertices and triangles.
1057

1058
    See Also
1059
    --------
1060
    downsample, downsample_data, interpolate, interpolate_data
1061

1062
    Examples
1063
    --------
1064
    >>> from surfify.utils import icosahedron, downsample_ico
1065
    >>> import matplotlib.pyplot as plt
1066
    >>> from surfify.plotting import plot_trisurf
1067
    >>> ico4_verts, ico4_tris = icosahedron(order=4)
1068
    >>> ico2_down_verts, ico2_down_tris = downsample_ico(
1069
            ico4_verts, ico4_tris, by=2)
1070
    >>> plot_trisurf(ico2_down_verts, triangles=ico2_down_tris, colorbar=False)
1071
    >>> plt.show()
1072

1073
    Parameters
1074
    ----------
1075
    vertices: array (N, 3)
1076
        vertices of the icosahedron to reduce.
1077
    triangles: array (M, 3)
1078
        triangles of the icosahedron to reduce.
1079
    by: int, default 1
1080
        number of orders to reduce the icosahedron by.
1081
    down_indices: list of array, default None
1082
        optionally specify the list of consecutive downsampling vertices
1083
        indices.
1084

1085
    Returns
1086
    -------
1087
    new_vertices: array (N', 3)
1088
        vertices of the newly downsampled icosahedorn.
1089
    new_triangles: array (M', 3)
1090
        triangles of the newly downsampled icosahedron.
1091
    """
1092
    for idx_order in range(by):
1✔
1093
        former_order = order_of_ico_from_vertices(len(vertices))
1✔
1094
        n_new_vertices = number_of_ico_vertices(former_order - 1)
1✔
1095
        if down_indices is None:
1✔
1096
            indices = np.arange(n_new_vertices)
1✔
1097
        else:
UNCOV
1098
            indices = down_indices[idx_order]
×
1099
        new_vertices = vertices[indices]
1✔
1100
        new_triangles = []
1✔
1101
        former_neighbors = neighbors(vertices, triangles, direct_neighbor=True)
1✔
1102
        former_neighbors = np.array(list(former_neighbors.values()))
1✔
1103
        for idx_down, down_node in enumerate(indices):
1✔
1104
            for idx_neigh, neigh_node in enumerate(
1✔
1105
                    former_neighbors[down_node]):
1106
                # for each central node k (that belong to the smaller
1107
                # icosahedron), we look in its neighborhood. For each oriented
1108
                # pair of neighbors we search in their respective neighborhood
1109
                # for a vertice that is in the downsample indices and is not
1110
                # the base node k. This trio gives us a triangle of the smaller
1111
                # icosahedron. We consider the triangles as a list of sets
1112
                # because the order of vertices do not matter for each triangle
1113
                if neigh_node != down_node:
1✔
1114
                    next_neigh_node = former_neighbors[down_node][
1✔
1115
                        (idx_neigh + 1) % len(former_neighbors[down_node])]
1116
                    neigh_node_neighs = former_neighbors[neigh_node]
1✔
1117
                    next_neigh_node_neighs = former_neighbors[next_neigh_node]
1✔
1118
                    candidates = [idx_down]
1✔
1119
                    for neighs in (neigh_node_neighs, next_neigh_node_neighs):
1✔
1120
                        for neigh_idx in neighs:
1✔
1121
                            if neigh_idx in indices and neigh_idx != down_node:
1✔
1122
                                candidates.append(
1✔
1123
                                    indices.tolist().index(neigh_idx))
1124
                                break
1✔
1125

1126
                    if (set(candidates) not in new_triangles and
1✔
1127
                            len(candidates) == 3):
1128
                        new_triangles.append(set(candidates))
1✔
1129
        new_triangles = np.array([list(tri) for tri in new_triangles])
1✔
1130
        vertices = new_vertices
1✔
1131
        triangles = new_triangles
1✔
1132
    return new_vertices, new_triangles
1✔
1133

1134

1135
def find_rotation_interpol_coefs(vertices, triangles, angles,
1✔
1136
                                 interpolation="barycentric"):
1137
    """ Function to compute interpolation coefficient asssociated to
1138
    a rotation of the provided icosahedron. Used by the 'rotate_data'
1139
    function.
1140

1141
    Parameters
1142
    ----------
1143
    vertices: array (N, 3)
1144
        vertices of the icosahedron to reduce.
1145
    triangles: array (N, 3)
1146
        triangles of the icosahedron to reduce.
1147
    angles: 3-uplet
1148
        the rotation angles in degrees for each axis (Euler representation).
1149
    interpolation: str, default 'barycentric'
1150
        type of interpolation to use: 'euclidian' or 'barycentric'.
1151

1152
    Returns
1153
    -------
1154
    dict:
1155
        neighs: array (N, 3)
1156
            indices of the three closest neighbors on the rotated icosahedron
1157
            for each vertice
1158
        weights: array (N, 3)
1159
            weights associated to each of these neighbors
1160
    """
1161
    if interpolation not in ["euclidian", "barycentric"]:
1✔
UNCOV
1162
        raise ValueError("The interpolation should be one of 'euclidian' "
×
1163
                         "or 'barycentric'.")
1164

1165
    n_vertices = len(vertices)
1✔
1166
    neighs = np.zeros((n_vertices, 3), dtype=int)
1✔
1167
    weights = np.zeros((n_vertices, 3), dtype=float)
1✔
1168

1169
    rotation = transform.Rotation.from_euler("xyz", angles, degrees=True)
1✔
1170
    rotated_vertices = rotation.apply(vertices)
1✔
1171

1172
    if interpolation == "euclidian":
1✔
1173
        for idx, point in enumerate(vertices):
1✔
1174
            dist = np.linalg.norm(rotated_vertices - point, axis=1)
1✔
1175
            ordered_neighs = np.argsort(dist)
1✔
1176
            neighs[idx] = ordered_neighs[:3]
1✔
1177
            weights[idx] = dist[neighs[idx]] / np.sum(dist[neighs[idx]])
1✔
1178
    else:
1179
        eps = np.finfo(np.float32).eps
1✔
1180
        triangles = order_triangles(rotated_vertices, triangles)
1✔
1181

1182
        candidate_triangles = [[] for _ in range(n_vertices)]
1✔
1183
        for tri in triangles:
1✔
1184
            for node in tri:
1✔
1185
                candidate_triangles[node].append(tri)
1✔
1186
        for idx, point in enumerate(vertices):
1✔
1187
            found = False
1✔
1188
            # in order not to look in all the triangles for the barycentric
1189
            # coordinates, we only consider the triangles associated with
1190
            # the closest rotated vertice
1191
            closest_point = np.argmin(
1✔
1192
                np.linalg.norm(point - rotated_vertices, axis=1))
1193
            for triangle in candidate_triangles[closest_point]:
1✔
1194
                T = rotated_vertices[triangle]
1✔
1195
                B = np.linalg.solve(T.T, point)
1✔
1196
                if sum((B >= 0) | (np.abs(B) <= eps)) == 3:
1✔
1197
                    found = True
1✔
1198
                    neighs[idx] = triangle
1✔
1199
                    weights[idx] = B
1✔
1200
                    break
1✔
1201
            if not found:
1✔
UNCOV
1202
                raise RuntimeError(
×
1203
                    "Barycentric coordinate for vertex {} was not found. "
1204
                    "It may be due to a numerical error. You might want "
1205
                    "to consider an other type of interpolation.".format(
1206
                        idx
1207
                    ))
1208
    return {"neighs": neighs, "weights": weights}
1✔
1209

1210

1211
def rotate_data(data, vertices, triangles, angles,
1✔
1212
                interpolation="barycentric", neighs=None,
1213
                weights=None):
1214
    """ Rotate data/texture on an icosahedron. the decorator allows
1215
    the user not to care about the interpolation weights and neighbors,
1216
    which are automatically computed and stored to be reused the first
1217
    time the function is called with given arguments.
1218

1219
    Examples
1220
    --------
1221
    >>> from surfify.utils import icosahedron, rotate_data
1222
    >>> from surfify.datasets import make_classification
1223
    >>> import matplotlib.pyplot as plt
1224
    >>> from surfify.plotting import plot_trisurf
1225
    >>> ico3_verts, ico3_tris = icosahedron(order=3)
1226
    >>> X, y = make_classification(ico3_verts, n_samples=1, n_classes=3,
1227
                                   scale=1, seed=42)
1228
    >>> y_rot = rotate_data(y.reshape(1, -1, 1), ico3_verts, ico3_tris,
1229
                            (45, 0, 0)).squeeze()
1230
    >>> plot_trisurf(ico3_verts, triangles=ico3_tris, texture=y,
1231
                     is_label=False)
1232
    >>> plot_trisurf(ico3_verts, triangles=ico3_tris, texture=y_rot,
1233
                     is_label=False)
1234
    >>> plt.show()
1235

1236
    Parameters
1237
    ----------
1238
    data: array (n_samples, N, n_features)
1239
        data to be rotated.
1240
    vertices: array (N, 3)
1241
        vertices of the icosahedron.
1242
    triangles: array (N, 3)
1243
        triangles of the icosahedron.
1244
    angles: 3-uplet
1245
        the rotation angles in degrees for each axis (Euler representation).
1246
    interpolation: str, default 'barycentric'.
1247
        the type of interpolation to use: 'euclidean' or 'barycentric'.
1248
    neighs: array (N, 3) or None, default None
1249
        neighbors to interpolate from for each vertex. If None, the function
1250
        computes the neighbors via the provided interpolation method.
1251
    weights: array (N, 3) or None, default None
1252
        weights associated to each neighbors for each vertex.  If None, the
1253
        function computes the weights via the provided interpolation method.
1254

1255
    Returns
1256
    -------
1257
    rotated_data: array (n_samples, n_vertices, n_features)
1258
        rotated data.
1259
    """
1260
    if len(data.shape) != 3:
1✔
UNCOV
1261
        raise ValueError(
×
1262
            "Unexpected input data. Must be (n_samples, n_vertices, "
1263
            "n_features) but got '{0}'.".format(data.shape))
1264

1265
    if neighs is None or weights is None:
1✔
1266
        interp_coefs = find_rotation_interpol_coefs(
1✔
1267
            vertices, triangles, angles, interpolation)
1268
        neighs = interp_coefs["neighs"]
1✔
1269
        weights = interp_coefs["weights"]
1✔
1270
    n_samples = len(data)
1✔
1271
    n_features = data.shape[-1]
1✔
1272
    n_vertices, n_neighs = neighs.shape
1✔
1273
    flat_neighs = neighs.reshape(-1)
1✔
1274
    flat_weights = np.repeat(weights.reshape(1, -1, 1), n_samples, axis=0)
1✔
1275
    rotated_data = data[:, flat_neighs] * flat_weights
1✔
1276
    rotated_data = rotated_data.reshape(n_samples, n_vertices, n_neighs,
1✔
1277
                                        n_features)
1278
    rotated_data = np.sum(rotated_data, axis=2)
1✔
1279

1280
    return rotated_data
1✔
1281

1282

1283
def order_triangles(vertices, triangles, clockwise_from_center=True):
1✔
1284
    """ Order the icosahedron triangles to be in a clockwise order when viewed
1285
    from the center of the sphere. Used by the 'find_rotation_interpol_coefs'
1286
    for barycentric interpolation.
1287

1288
    Examples
1289
    --------
1290
    >>> from surfify.utils import icosahedron, order_triangles
1291
    >>> ico0_verts, ico0_tris = icosahedron(order=0)
1292
    >>> clockwise_ico0_tris = order_triangles(
1293
            ico0_verts, ico0_tris, clockwise_from_center=True)
1294
    >>> counter_clockwise_ico0_tris = order_triangles(
1295
            ico0_verts, ico0_tris, clockwise_from_center=False)
1296
    >>> print(clockwise_ico0_tris)
1297
    >>> print(counter_clockwise_ico0_tris)
1298

1299
    Parameters
1300
    ----------
1301
    vertices: array (N, 3)
1302
        the icosahedron's vertices.
1303
    triangles: array (M, 3)
1304
        the icosahedron's triangles.
1305
    clockwise_from_center: bool, default True
1306
        optionally use counter clockwise order.
1307

1308
    Returns
1309
    -------
1310
    reordered_triangles: array (M, 3)
1311
        reordered triangles.
1312
    """
1313
    reordered_triangles = triangles.copy()
1✔
1314
    for idx, triangle in enumerate(triangles):
1✔
1315
        loc_x, loc_y, loc_z = vertices[triangle]
1✔
1316
        norm = np.cross((loc_y - loc_x), (loc_z - loc_x))
1✔
1317
        w = np.dot(norm, loc_x)
1✔
1318
        if ((clockwise_from_center and w >= 0) or
1✔
1319
                (not clockwise_from_center and w <= 0)):
1320
            reordered_triangles[idx] = triangle[[0, 2, 1]]
1✔
1321
    return reordered_triangles
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc