• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

paulmthompson / WhiskerToolbox / 18477247352

13 Oct 2025 08:18PM UTC coverage: 72.391% (+0.4%) from 71.943%
18477247352

push

github

web-flow
Merge pull request #140 from paulmthompson/kdtree

Jules PR

164 of 287 new or added lines in 3 files covered. (57.14%)

350 existing lines in 9 files now uncovered.

51889 of 71679 relevant lines covered (72.39%)

63071.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.36
/src/SpatialIndex/KdTree.cpp
1
//
2
// Kd-Tree implementation.
3
//
4
// Copyright: Christoph Dalitz, 2018-2023
5
//            Jens Wilberg, 2018
6
// Version:   1.3
7
// License:   BSD style license
8
//            (see the file LICENSE for details)
9
//
10

11
#include "KdTree.hpp"
12
#include <math.h>
13
#include <algorithm>
14
#include <limits>
15
#include <stdexcept>
16

17
#include "CoreGeometry/points.hpp"
18

19
namespace Kdtree {
20

21
//--------------------------------------------------------------
22
// function object for comparing only dimension d of two vecotrs
23
//--------------------------------------------------------------
24
template<typename T>
25
class compare_dimension {
26
 public:
27
  compare_dimension(size_t dim) { d = dim; }
2✔
28
  bool operator()(const KdNode<T>& p, const KdNode<T>& q) {
8✔
29
    if (d == 0)
8✔
30
      return (p.point.x < q.point.x);
8✔
31
    else
NEW
32
      return (p.point.y < q.point.y);
×
33
  }
34
  size_t d;
35
};
36

37
//--------------------------------------------------------------
38
// internal node structure used by kdtree
39
//--------------------------------------------------------------
40
template<typename T>
41
class kdtree_node {
42
 public:
43
  kdtree_node() {
6✔
44
    dataindex = cutdim = 0;
6✔
45
    loson = hison = (kdtree_node<T>*)NULL;
6✔
46
  }
6✔
47
  ~kdtree_node() {
6✔
48
    if (loson) delete loson;
6✔
49
    if (hison) delete hison;
6✔
50
  }
6✔
51
  // index of node data in kdtree array "allnodes"
52
  size_t dataindex;
53
  // cutting dimension
54
  size_t cutdim;
55
  // value of point
56
  // double cutval; // == point[cutdim]
57
  CoordPoint<T> point;
58
  //  roots of the two subtrees
59
  kdtree_node<T> *loson, *hison;
60
  // bounding rectangle of this node's subtree
61
  CoordPoint<T> lobound, upbound;
62
};
63

64
//--------------------------------------------------------------
65
// different distance metrics
66
//--------------------------------------------------------------
67
template<typename T>
68
class DistanceMeasure {
69
 public:
70
  DistanceMeasure() {}
2✔
71
  virtual ~DistanceMeasure() {}
2✔
72
  virtual double distance(const CoordPoint<T>& p, const CoordPoint<T>& q) = 0;
73
  virtual double coordinate_distance(T x, T y, size_t dim) = 0;
74
};
75
// Maximum distance (Linfinite norm)
76
template<typename T>
77
class DistanceL0 : virtual public DistanceMeasure<T> {
78
  WeightVector* w;
79

80
 public:
NEW
81
  DistanceL0(const WeightVector* weights = NULL) {
×
NEW
82
    if (weights)
×
NEW
83
      w = new WeightVector(*weights);
×
84
    else
NEW
85
      w = (WeightVector*)NULL;
×
NEW
86
  }
×
NEW
87
  ~DistanceL0() {
×
NEW
88
    if (w) delete w;
×
NEW
89
  }
×
NEW
90
  double distance(const CoordPoint<T>& p, const CoordPoint<T>& q) {
×
91
    double dist, test;
NEW
92
    if (w) {
×
NEW
93
      dist = (*w)[0] * fabs(p.x - q.x);
×
NEW
94
      test = (*w)[1] * fabs(p.y - q.y);
×
NEW
95
      if (test > dist) dist = test;
×
96
    } else {
NEW
97
      dist = fabs(p.x - q.x);
×
NEW
98
      test = fabs(p.y - q.y);
×
NEW
99
      if (test > dist) dist = test;
×
100
    }
NEW
101
    return dist;
×
102
  }
NEW
103
  double coordinate_distance(T x, T y, size_t dim) {
×
NEW
104
    if (w)
×
NEW
105
      return (*w)[dim] * fabs(x - y);
×
106
    else
NEW
107
      return fabs(x - y);
×
108
  }
109
};
110
// Manhatten distance (L1 norm)
111
template<typename T>
112
class DistanceL1 : virtual public DistanceMeasure<T> {
113
  WeightVector* w;
114

115
 public:
NEW
116
  DistanceL1(const WeightVector* weights = NULL) {
×
NEW
117
    if (weights)
×
NEW
118
      w = new WeightVector(*weights);
×
119
    else
NEW
120
      w = (WeightVector*)NULL;
×
NEW
121
  }
×
NEW
122
  ~DistanceL1() {
×
NEW
123
    if (w) delete w;
×
NEW
124
  }
×
NEW
125
  double distance(const CoordPoint<T>& p, const CoordPoint<T>& q) {
×
NEW
126
    double dist = 0.0;
×
NEW
127
    if (w) {
×
NEW
128
      dist += (*w)[0] * fabs(p.x - q.x);
×
NEW
129
      dist += (*w)[1] * fabs(p.y - q.y);
×
130
    } else {
NEW
131
      dist += fabs(p.x - q.x);
×
NEW
132
      dist += fabs(p.y - q.y);
×
133
    }
NEW
134
    return dist;
×
135
  }
NEW
136
  double coordinate_distance(T x, T y, size_t dim) {
×
NEW
137
    if (w)
×
NEW
138
      return (*w)[dim] * fabs(x - y);
×
139
    else
NEW
140
      return fabs(x - y);
×
141
  }
142
};
143
// Euklidean distance (L2 norm) (squared)
144
template<typename T>
145
class DistanceL2 : virtual public DistanceMeasure<T> {
146
  WeightVector* w;
147

148
 public:
149
  DistanceL2(const WeightVector* weights = NULL) {
2✔
150
    if (weights)
2✔
NEW
151
      w = new WeightVector(*weights);
×
152
    else
153
      w = (WeightVector*)NULL;
2✔
154
  }
2✔
155
  ~DistanceL2() {
4✔
NEW
156
    if (w) delete w;
×
157
  }
4✔
158
  double distance(const CoordPoint<T>& p, const CoordPoint<T>& q) {
2✔
159
    double dist = 0.0;
2✔
160
    if (w) {
2✔
NEW
161
      dist += (*w)[0] * (p.x - q.x) * (p.x - q.x);
×
NEW
162
      dist += (*w)[1] * (p.y - q.y) * (p.y - q.y);
×
163
    } else {
164
      dist += (p.x - q.x) * (p.x - q.x);
2✔
165
      dist += (p.y - q.y) * (p.y - q.y);
2✔
166
    }
167
    return dist;
2✔
168
  }
169
  double coordinate_distance(T x, T y, size_t dim) {
3✔
170
    if (w)
3✔
NEW
171
      return (*w)[dim] * (x - y) * (x - y);
×
172
    else
173
      return (x - y) * (x - y);
3✔
174
  }
175
};
176

177
//--------------------------------------------------------------
178
// destructor and constructor of kdtree
179
//--------------------------------------------------------------
180
template<typename T>
181
KdTree<T>::~KdTree() {
2✔
182
  if (root) delete root;
2✔
183
  delete distance;
2✔
184
}
2✔
185
// distance_type can be 0 (Maximum), 1 (Manhatten), or 2 (Euklidean [squared])
186
template<typename T>
187
KdTree<T>::KdTree(const KdNodeVector<T>* nodes, int distance_type /*=2*/) {
2✔
188
  size_t i;
189
  // copy over input data
190
  if (!nodes || nodes->empty())
2✔
NEW
191
    throw std::invalid_argument(
×
192
        "kdtree::KdTree(): argument nodes must not be empty");
193
  dimension = 2;
2✔
194
  allnodes = *nodes;
2✔
195
  // initialize distance values
196
  distance = NULL;
2✔
197
  this->distance_type = -1;
2✔
198
  set_distance(distance_type);
2✔
199
  // compute global bounding box
200
  lobound = nodes->begin()->point;
2✔
201
  upbound = nodes->begin()->point;
2✔
202
  for (i = 1; i < nodes->size(); i++) {
6✔
203
    if (allnodes[i].point.x < lobound.x) lobound.x = allnodes[i].point.x;
4✔
204
    if (allnodes[i].point.y < lobound.y) lobound.y = allnodes[i].point.y;
4✔
205
    if (allnodes[i].point.x > upbound.x) upbound.x = allnodes[i].point.x;
4✔
206
    if (allnodes[i].point.y > upbound.y) upbound.y = allnodes[i].point.y;
4✔
207
  }
208
  // build tree recursively
209
  root = build_tree(0, 0, allnodes.size());
2✔
210
}
2✔
211

212
// distance_type can be 0 (Maximum), 1 (Manhatten), or 2 (Euklidean [squared])
213
template<typename T>
214
void KdTree<T>::set_distance(int distance_type,
2✔
215
                          const WeightVector* weights /*=NULL*/) {
216
  if (distance) delete distance;
2✔
217
  this->distance_type = distance_type;
2✔
218
  if (distance_type == 0) {
2✔
NEW
219
    distance = (DistanceMeasure<T>*)new DistanceL0<T>(weights);
×
220
  } else if (distance_type == 1) {
2✔
NEW
221
    distance = (DistanceMeasure<T>*)new DistanceL1<T>(weights);
×
222
  } else {
223
    distance = (DistanceMeasure<T>*)new DistanceL2<T>(weights);
2✔
224
  }
225
}
2✔
226

227
//--------------------------------------------------------------
228
// recursive build of tree
229
// "a" and "b"-1 are the lower and upper indices
230
// from "allnodes" from which the subtree is to be built
231
//--------------------------------------------------------------
232
template<typename T>
233
kdtree_node<T>* KdTree<T>::build_tree(size_t depth, size_t a, size_t b) {
6✔
234
  size_t m;
235
  T temp, cutval;
236
  kdtree_node<T>* node = new kdtree_node<T>();
6✔
237
  node->lobound = lobound;
6✔
238
  node->upbound = upbound;
6✔
239
  node->cutdim = depth % dimension;
6✔
240
  if (b - a <= 1) {
6✔
241
    node->dataindex = a;
4✔
242
    node->point = allnodes[a].point;
4✔
243
  } else {
244
    m = (a + b) / 2;
2✔
245
    std::nth_element(allnodes.begin() + a, allnodes.begin() + m,
6✔
246
                     allnodes.begin() + b, compare_dimension<T>(node->cutdim));
4✔
247
    node->point = allnodes[m].point;
2✔
248
    if (node->cutdim == 0)
2✔
249
        cutval = allnodes[m].point.x;
2✔
250
    else
NEW
251
        cutval = allnodes[m].point.y;
×
252
    node->dataindex = m;
2✔
253
    if (m - a > 0) {
2✔
254
      if (node->cutdim == 0) {
2✔
255
        temp = upbound.x;
2✔
256
        upbound.x = cutval;
2✔
257
        node->loson = build_tree(depth + 1, a, m);
2✔
258
        upbound.x = temp;
2✔
259
      } else {
NEW
260
        temp = upbound.y;
×
NEW
261
        upbound.y = cutval;
×
NEW
262
        node->loson = build_tree(depth + 1, a, m);
×
NEW
263
        upbound.y = temp;
×
264
      }
265
    }
266
    if (b - m > 1) {
2✔
267
      if (node->cutdim == 0) {
2✔
268
        temp = lobound.x;
2✔
269
        lobound.x = cutval;
2✔
270
        node->hison = build_tree(depth + 1, m + 1, b);
2✔
271
        lobound.x = temp;
2✔
272
      } else {
NEW
273
        temp = lobound.y;
×
NEW
274
        lobound.y = cutval;
×
NEW
275
        node->hison = build_tree(depth + 1, m + 1, b);
×
NEW
276
        lobound.y = temp;
×
277
      }
278
    }
279
  }
280
  return node;
6✔
281
}
282

283
//--------------------------------------------------------------
284
// k nearest neighbor search
285
// returns the *k* nearest neighbors of *point* in O(log(n))
286
// time. The result is returned in *result* and is sorted by
287
// distance from *point*.
288
// The optional search predicate is a callable class (aka "functor")
289
// derived from KdNodePredicate. When Null (default, no search
290
// predicate is applied).
291
//--------------------------------------------------------------
292
template<typename T>
293
void KdTree<T>::k_nearest_neighbors(const CoordPoint<T>& point, size_t k,
1✔
294
                                 KdNodeVector<T>* result,
295
                                 KdNodePredicate<T>* pred /*=NULL*/) {
296
  size_t i;
297
  KdNode<T> temp;
1✔
298
  searchpredicate = pred;
1✔
299

300
  result->clear();
1✔
301
  if (k < 1) return;
1✔
302

303
  // collect result of k values in neighborheap
304
  SearchQueue* neighborheap = new SearchQueue();
1✔
305
  if (k > allnodes.size()) {
1✔
306
    // when more neighbors asked than nodes in tree, return everything
NEW
307
    k = allnodes.size();
×
NEW
308
    for (i = 0; i < k; i++) {
×
NEW
309
      if (!(searchpredicate && !(*searchpredicate)(allnodes[i])))
×
NEW
310
        neighborheap->push(
×
NEW
311
            nn4heap(i, distance->distance(allnodes[i].point, point)));
×
312
    }
313
  } else {
314
    neighbor_search(point, root, k, neighborheap);
1✔
315
  }
316

317
  // copy over result sorted by distance
318
  // (we must revert the vector for ascending order)
319
  while (!neighborheap->empty()) {
2✔
320
    i = neighborheap->top().dataindex;
1✔
321
    neighborheap->pop();
1✔
322
    result->push_back(allnodes[i]);
1✔
323
  }
324
  // beware that less than k results might have been returned
325
  k = result->size();
1✔
326
  for (i = 0; i < k / 2; i++) {
1✔
NEW
327
    temp = (*result)[i];
×
NEW
328
    (*result)[i] = (*result)[k - 1 - i];
×
NEW
329
    (*result)[k - 1 - i] = temp;
×
330
  }
331
  delete neighborheap;
1✔
332
}
333

334
//--------------------------------------------------------------
335
// range nearest neighbor search
336
// returns the nearest neighbors of *point* in the given range
337
// *r*. The result is returned in *result* and is sorted by
338
// distance from *point*.
339
//--------------------------------------------------------------
340
template<typename T>
NEW
341
void KdTree<T>::range_nearest_neighbors(const CoordPoint<T>& point, double r,
×
342
                                     KdNodeVector<T>* result) {
NEW
343
  result->clear();
×
NEW
344
  if (this->distance_type == 2) {
×
345
    // if euclidien distance is used the range must be squared because we
346
    // get squared distances from this implementation
NEW
347
    r *= r;
×
348
  }
349

350
  // collect result in range_result
NEW
351
  std::vector<size_t> range_result;
×
NEW
352
  range_search(point, root, r, &range_result);
×
353

354
  // copy over result
NEW
355
  for (std::vector<size_t>::iterator i = range_result.begin();
×
NEW
356
       i != range_result.end(); ++i) {
×
NEW
357
    result->push_back(allnodes[*i]);
×
358
  }
359

360
  // clear vector
NEW
361
  range_result.clear();
×
NEW
362
}
×
363

364
//--------------------------------------------------------------
365
// recursive function for nearest neighbor search in subtree
366
// under *node*. Stores result in *neighborheap*.
367
// returns "true" when no nearer neighbor elsewhere possible
368
//--------------------------------------------------------------
369
template<typename T>
370
bool KdTree<T>::neighbor_search(const CoordPoint<T>& point, kdtree_node<T>* node,
2✔
371
                             size_t k, SearchQueue* neighborheap) {
372
  double curdist, dist;
373

374
  curdist = distance->distance(point, node->point);
2✔
375
  if (!(searchpredicate && !(*searchpredicate)(allnodes[node->dataindex]))) {
2✔
376
    if (neighborheap->size() < k) {
2✔
377
      neighborheap->push(nn4heap(node->dataindex, curdist));
1✔
378
    } else if (curdist < neighborheap->top().distance) {
1✔
379
      neighborheap->pop();
1✔
380
      neighborheap->push(nn4heap(node->dataindex, curdist));
1✔
381
    }
382
  }
383

384
  T p_dim;
385
  T n_dim;
386

387
  if(node->cutdim == 0) {
2✔
388
    p_dim = point.x;
1✔
389
    n_dim = node->point.x;
1✔
390
  } else {
391
    p_dim = point.y;
1✔
392
    n_dim = node->point.y;
1✔
393
  }
394
  // first search on side closer to point
395
  if (p_dim < n_dim) {
2✔
396
    if (node->loson)
1✔
397
      if (neighbor_search(point, node->loson, k, neighborheap)) return true;
1✔
398
  } else {
399
    if (node->hison)
1✔
NEW
400
      if (neighbor_search(point, node->hison, k, neighborheap)) return true;
×
401
  }
402
  // second search on farther side, if necessary
403
  if (neighborheap->size() < k) {
2✔
NEW
404
    dist = std::numeric_limits<double>::max();
×
405
  } else {
406
    dist = neighborheap->top().distance;
2✔
407
  }
408
  if (p_dim < n_dim) {
2✔
409
    if (node->hison && bounds_overlap_ball(point, dist, node->hison))
1✔
NEW
410
      if (neighbor_search(point, node->hison, k, neighborheap)) return true;
×
411
  } else {
412
    if (node->loson && bounds_overlap_ball(point, dist, node->loson))
1✔
NEW
413
      if (neighbor_search(point, node->loson, k, neighborheap)) return true;
×
414
  }
415

416
  if (neighborheap->size() == k) dist = neighborheap->top().distance;
2✔
417
  return ball_within_bounds(point, dist, node);
2✔
418
}
419

420
//--------------------------------------------------------------
421
// recursive function for range search in subtree under *node*.
422
// Stores result in *range_result*.
423
//--------------------------------------------------------------
424
template<typename T>
NEW
425
void KdTree<T>::range_search(const CoordPoint<T>& point, kdtree_node<T>* node,
×
426
                          double r, std::vector<size_t>* range_result) {
NEW
427
  double curdist = distance->distance(point, node->point);
×
NEW
428
  if (curdist <= r) {
×
NEW
429
    range_result->push_back(node->dataindex);
×
430
  }
NEW
431
  if (node->loson != NULL && this->bounds_overlap_ball(point, r, node->loson)) {
×
NEW
432
    range_search(point, node->loson, r, range_result);
×
433
  }
NEW
434
  if (node->hison != NULL && this->bounds_overlap_ball(point, r, node->hison)) {
×
NEW
435
    range_search(point, node->hison, r, range_result);
×
436
  }
NEW
437
}
×
438

439
// returns true when the bounds of *node* overlap with the
440
// ball with radius *dist* around *point*
441
template<typename T>
442
bool KdTree<T>::bounds_overlap_ball(const CoordPoint<T>& point, double dist,
1✔
443
                                 kdtree_node<T>* node) {
444
  if (distance_type != 0) {
1✔
445
    double distsum = 0.0;
1✔
446
    if (point.x < node->lobound.x) {  // lower than low boundary
1✔
447
      distsum += distance->coordinate_distance(point.x, node->lobound.x, 0);
1✔
448
      if (distsum > dist) return false;
1✔
NEW
449
    } else if (point.x > node->upbound.x) {  // higher than high boundary
×
NEW
450
      distsum += distance->coordinate_distance(point.x, node->upbound.x, 0);
×
NEW
451
      if (distsum > dist) return false;
×
452
    }
NEW
453
    if (point.y < node->lobound.y) {  // lower than low boundary
×
NEW
454
      distsum += distance->coordinate_distance(point.y, node->lobound.y, 1);
×
NEW
455
      if (distsum > dist) return false;
×
NEW
456
    } else if (point.y > node->upbound.y) {  // higher than high boundary
×
NEW
457
      distsum += distance->coordinate_distance(point.y, node->upbound.y, 1);
×
NEW
458
      if (distsum > dist) return false;
×
459
    }
NEW
460
    return true;
×
461
  } else { // maximum distance needs different treatment
NEW
462
    double max_dist = 0.0;
×
NEW
463
    double curr_dist = 0.0;
×
NEW
464
    if (point.x < node->lobound.x) {  // lower than low boundary
×
NEW
465
      curr_dist = distance->coordinate_distance(point.x, node->lobound.x, 0);
×
NEW
466
    } else if (point.x > node->upbound.x) {  // higher than high boundary
×
NEW
467
      curr_dist = distance->coordinate_distance(point.x, node->upbound.x, 0);
×
468
    }
NEW
469
    if(curr_dist > max_dist) {
×
NEW
470
      max_dist = curr_dist;
×
471
    }
NEW
472
    if (max_dist > dist) return false;
×
NEW
473
    if (point.y < node->lobound.y) {  // lower than low boundary
×
NEW
474
      curr_dist = distance->coordinate_distance(point.y, node->lobound.y, 1);
×
NEW
475
    } else if (point.y > node->upbound.y) {  // higher than high boundary
×
NEW
476
      curr_dist = distance->coordinate_distance(point.y, node->upbound.y, 1);
×
477
    }
NEW
478
    if(curr_dist > max_dist) {
×
NEW
479
        max_dist = curr_dist;
×
480
    }
NEW
481
    if (max_dist > dist) return false;
×
482

NEW
483
    return true;
×
484
  }
485
}
486

487
// returns true when the bounds of *node* completely contain the
488
// ball with radius *dist* around *point*
489
template<typename T>
490
bool KdTree<T>::ball_within_bounds(const CoordPoint<T>& point, double dist,
2✔
491
                                kdtree_node<T>* node) {
492

493
  if (distance->coordinate_distance(point.x, node->lobound.x, 0) <= dist ||
2✔
NEW
494
      distance->coordinate_distance(point.x, node->upbound.x, 0) <= dist ||
×
495
      distance->coordinate_distance(point.y, node->lobound.y, 1) <= dist ||
2✔
NEW
496
      distance->coordinate_distance(point.y, node->upbound.y, 1) <= dist)
×
497
    return false;
2✔
NEW
498
  return true;
×
499
}
500

501
}  // namespace Kdtree
502

503
template class Kdtree::KdTree<float>;
504
template class Kdtree::KdTree<uint32_t>;
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc