• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

paulmthompson / WhiskerToolbox / 18477247352

13 Oct 2025 08:18PM UTC coverage: 72.391% (+0.4%) from 71.943%
18477247352

push

github

web-flow
Merge pull request #140 from paulmthompson/kdtree

Jules PR

164 of 287 new or added lines in 3 files covered. (57.14%)

350 existing lines in 9 files now uncovered.

51889 of 71679 relevant lines covered (72.39%)

63071.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.84
/src/StateEstimation/MinCostFlowTracker.hpp
1
#ifndef MIN_COST_FLOW_TRACKER_HPP
2
#define MIN_COST_FLOW_TRACKER_HPP
3

4
#include "Assignment/IAssigner.hpp"
5
#include "Assignment/hungarian.hpp"
6
#include "Cost/CostFunctions.hpp"
7
#include "DataSource.hpp"
8
#include "Entity/EntityGroupManager.hpp"
9
#include "Features/IFeatureExtractor.hpp"
10
#include "Filter/IFilter.hpp"
11
#include "Filter/Kalman/KalmanMatrixBuilder.hpp"
12
#include "MinCostFlowSolver.hpp"
13
#include "TimeFrame/TimeFrame.hpp"
14

15
#include "spdlog/sinks/basic_file_sink.h"
16
#include "spdlog/spdlog.h"
17
#include <Eigen/Dense>
18

19
#include <algorithm>
20
#include <chrono>
21
#include <cmath>
22
#include <functional>
23
#include <limits>
24
#include <map>
25
#include <memory>
26
#include <optional>
27
#include <set>
28
#include <unordered_map>
29
#include <unordered_set>
30
#include <utility>
31
#include <vector>
32

33
namespace StateEstimation {
34

35
/**
36
 * @brief A tracker that uses a global min-cost flow optimization to solve data association.
37
 *
38
 * This tracker formulates the tracking problem as a graph problem, finding the globally
39
 * optimal set of tracks over an entire interval between anchors. It is more robust to
40
 * ambiguities and identity swaps than iterative, frame-by-frame methods.
41
 *
42
 * @tparam DataType raw observation type (e.g., Line2D)
43
 */
44
template<typename DataType>
45
class MinCostFlowTracker {
46
public:
47
    using GroundTruthMap = std::map<TimeFrameIndex, std::map<GroupId, EntityId>>;
48
    using FrameBucket = std::vector<std::tuple<DataType const *, EntityId, TimeFrameIndex>>;
49

50
    /**
51
     * @brief Construct a new MinCostFlowTracker
52
     *
53
     * @param filter_prototype Prototype filter (cloned for prediction and final smoothing). 
54
     *        If nullptr, prediction is skipped in cost calculation (cost function must handle this)
55
     *        and no smoothing is performed. The filter's uncertainty automatically scales with
56
     *        gap size through process noise accumulation.
57
     * @param feature_extractor Feature extractor for DataType
58
     * @param cost_function Function to compute cost between predicted state and observation
59
     * @param cost_scale_factor Multiplier to convert floating-point costs to integers for the solver.
60
     */
61
    MinCostFlowTracker(std::unique_ptr<IFilter> filter_prototype,
2✔
62
                       std::unique_ptr<IFeatureExtractor<DataType>> feature_extractor,
63
                       CostFunction cost_function,
64
                       double cost_scale_factor = 100.0,
65
                       double cheap_assignment_threshold = 5.0)
66
        : _filter_prototype(std::move(filter_prototype)),
2✔
67
          _feature_extractor(std::move(feature_extractor)),
2✔
68
          _chain_cost_function(cost_function),
2✔
69
          _transition_cost_function(std::move(cost_function)),
2✔
70
          _cost_scale_factor(cost_scale_factor),
2✔
71
          _cheap_assignment_threshold(cheap_assignment_threshold) {}
2✔
72

73
    /**
74
     * @brief Construct with separate cost functions for greedy chaining and meta-node transitions.
75
     *
76
     * @param chain_cost_function Cost for frame-to-frame greedy chaining (typically 1-step)
77
     * @param transition_cost_function Cost for meta-node transitions across k-step gaps
78
     */
79
    MinCostFlowTracker(std::unique_ptr<IFilter> filter_prototype,
9✔
80
                       std::unique_ptr<IFeatureExtractor<DataType>> feature_extractor,
81
                       CostFunction chain_cost_function,
82
                       CostFunction transition_cost_function,
83
                       double cost_scale_factor,
84
                       double cheap_assignment_threshold)
85
        : _filter_prototype(std::move(filter_prototype)),
9✔
86
          _feature_extractor(std::move(feature_extractor)),
9✔
87
          _chain_cost_function(std::move(chain_cost_function)),
9✔
88
          _transition_cost_function(std::move(transition_cost_function)),
9✔
89
          _cost_scale_factor(cost_scale_factor),
9✔
90
          _cheap_assignment_threshold(cheap_assignment_threshold) {}
9✔
91

92
    /**
93
     * @brief Convenience constructor using default Mahalanobis distance cost function.
94
     *
95
     * @param filter_prototype Prototype filter (cloned for prediction and final smoothing).
96
     *        If nullptr, Mahalanobis distance cannot be computed properly (requires filter state covariance).
97
     *        The filter's uncertainty automatically scales with gap size.
98
     * @param feature_extractor Feature extractor for DataType
99
     * @param measurement_matrix H matrix for Mahalanobis distance calculation
100
     * @param measurement_noise_covariance R matrix for Mahalanobis distance calculation
101
     * @param cost_scale_factor Multiplier to convert floating-point costs to integers for the solver.
102
     */
103
    MinCostFlowTracker(std::unique_ptr<IFilter> filter_prototype,
2✔
104
                       std::unique_ptr<IFeatureExtractor<DataType>> feature_extractor,
105
                       Eigen::MatrixXd const & measurement_matrix,
106
                       Eigen::MatrixXd const & measurement_noise_covariance,
107
                       double cost_scale_factor = 100.0,
108
                       double cheap_assignment_threshold = 5.0)
109
        : MinCostFlowTracker(std::move(filter_prototype),
2✔
110
                             std::move(feature_extractor),
2✔
111
                             createMahalanobisCostFunction(measurement_matrix, measurement_noise_covariance),
112
                             cost_scale_factor,
113
                             cheap_assignment_threshold) {}
6✔
114

115
    /**
116
     * @brief Process a range of frames using min-cost flow optimization.
117
     *
118
     * @param data_source Zero-copy data source
119
     * @param group_manager Group manager to record final assignments
120
     * @param ground_truth Ground truth at specific frames (anchors)
121
     * @param start_frame Inclusive start frame
122
     * @param end_frame Inclusive end frame
123
     * @param progress Optional progress callback
124
     * @return Smoothed states per group across processed frames
125
     */
126
    template<typename Source>
127
        requires DataSource<Source, DataType>
128
    [[nodiscard]] SmoothedResults process(Source && data_source,
162✔
129
                                          EntityGroupManager & group_manager,
130
                                          GroundTruthMap const & ground_truth,
131
                                          TimeFrameIndex start_frame,
132
                                          TimeFrameIndex end_frame,
133
                                          ProgressCallback progress = nullptr,
134
                                          std::map<GroupId, GroupId> const * output_group_ids = nullptr,
135
                                          std::unordered_set<EntityId> const * excluded_entities = nullptr,
136
                                          std::unordered_set<EntityId> const * include_entities = nullptr) {
137
        if (_logger) {
162✔
138
            _logger->debug("MCF process: start={} end={}", start_frame.getValue(), end_frame.getValue());
162✔
139
        }
140

141
        auto frame_lookup = buildFrameLookup(data_source, start_frame, end_frame);
162✔
142
        auto start_anchors_it = ground_truth.find(start_frame);
162✔
143
        auto end_anchors_it = ground_truth.find(end_frame);
162✔
144

145
        if (start_anchors_it == ground_truth.end() || end_anchors_it == ground_truth.end()) {
162✔
UNCOV
146
            if (_logger) _logger->error("Min-cost flow requires anchors at both start and end frames.");
×
UNCOV
147
            return {};
×
148
        }
149

150
        // 1. --- Build and Solve the Graph ---
151
        auto solved_paths = solve_flow_problem(frame_lookup,
162✔
152
                                               start_anchors_it->second,
162✔
153
                                               end_anchors_it->second,
162✔
154
                                               start_frame,
155
                                               end_frame,
156
                                               excluded_entities,
157
                                               include_entities);
158

159
        if (solved_paths.empty()) {
162✔
160
            if (_logger) _logger->error("Min-cost flow solver failed or found no paths.");
16✔
161
            return {};
16✔
162
        }
163

164
        // 2. --- Update Group Manager with Solved Tracks ---
165
        for (auto const & [group_id, path]: solved_paths) {
293✔
166
            GroupId write_group = group_id;
147✔
167
            if (output_group_ids) {
147✔
168
                auto it = output_group_ids->find(group_id);
144✔
169
                if (it != output_group_ids->end()) write_group = it->second;
144✔
170
            }
171
            for (auto const & node: path) {
3,059✔
172
                // Never overwrite anchors or any labeled entity: only add unlabeled entities
173
                auto groups = group_manager.getGroupsContainingEntity(node.entity_id);
1,456✔
174
                if (!groups.empty()) continue;
1,456✔
175
                group_manager.addEntityToGroup(write_group, node.entity_id);
1,166✔
176
            }
177
        }
178

179
        // 3. --- Final Forward/Backward Smoothing Pass ---
180
        // Now that we have the globally optimal assignments, run a final KF pass to get the smoothed states.
181
        return generate_smoothed_results(solved_paths, frame_lookup, start_frame, end_frame);
146✔
182
    }
162✔
183

184
    void enableDebugLogging(std::string const & file_path) {
11✔
185
        _logger = std::make_shared<spdlog::logger>("MinCostFlowTracker", std::make_shared<spdlog::sinks::basic_file_sink_mt>(file_path, true));
11✔
186
        _logger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%l] %v");
33✔
187
        _logger->set_level(spdlog::level::debug);
11✔
188
        _logger->flush_on(spdlog::level::debug);
11✔
189
    }
11✔
190

191
private:
192
    struct NodeInfo {
193
        TimeFrameIndex frame = TimeFrameIndex(0);
194
        EntityId entity_id = 0;
195

196
        bool operator<(NodeInfo const & other) const {
197
            if (frame != other.frame) return frame < other.frame;
198
            return entity_id < other.entity_id;
199
        }
200
    };
201

202
    using Path = std::vector<NodeInfo>;
203

204
    /**
205
     * @brief Represents a greedy-linked sequence (meta-node) of cheap assignments across consecutive frames.
206
     * 
207
     * Each meta-node aggregates observations that are very likely to belong to the same chain, so that
208
     * min-cost flow can operate sparsely on these chains instead of per-observation nodes.
209
     */
210
    struct MetaNode {
211
        std::vector<NodeInfo> members;// consecutive observations included in this chain
212
        FilterState start_state;      // filter state after initializing on first observation
213
        FilterState end_state;        // filter state after updating on last observation
214
        TimeFrameIndex start_frame = TimeFrameIndex(0);
215
        TimeFrameIndex end_frame = TimeFrameIndex(0);
216
        EntityId start_entity = 0;
217
        EntityId end_entity = 0;
218
    };
219

220
    // Arc metadata: stores the actual chain of entities represented by this arc
221
    struct ArcChain {
222
        std::vector<NodeInfo> entities;// All entities along this arc (including endpoints)
223
        int64_t cost;
224
    };
225

226
    // Helper: Build a chain of entities from start_node to end_node by greedily following best matches
227
    // Returns the chain and accumulated cost
228
    ArcChain build_entity_chain(
229
            NodeInfo const & start_node,
230
            NodeInfo const & end_node,
231
            std::map<TimeFrameIndex, FrameBucket> const & frame_lookup) {
232

233
        ArcChain chain;
234
        chain.entities.push_back(start_node);
235
        chain.cost = 0;
236

237
        // If start and end are consecutive frames or same frame, just connect directly
238
        if (end_node.frame <= start_node.frame + TimeFrameIndex(1)) {
239
            if (end_node.frame == start_node.frame + TimeFrameIndex(1)) {
240
                chain.entities.push_back(end_node);
241
            }
242
            return chain;
243
        }
244

245
        // Build chain frame-by-frame using greedy best match
246
        TimeFrameIndex current_frame = start_node.frame;
247
        EntityId current_entity = start_node.entity_id;
248

249
        while (current_frame + TimeFrameIndex(1) < end_node.frame) {
250
            TimeFrameIndex next_frame = current_frame + TimeFrameIndex(1);
251

252
            // Skip frames with no observations
253
            if (!frame_lookup.count(next_frame)) {
254
                current_frame = next_frame;
255
                continue;
256
            }
257

258
            // Get current entity data
259
            DataType const * current_data = nullptr;
260
            if (frame_lookup.count(current_frame)) {
261
                for (auto const & item: frame_lookup.at(current_frame)) {
262
                    if (std::get<1>(item) == current_entity) {
263
                        current_data = std::get<0>(item);
264
                        break;
265
                    }
266
                }
267
            }
268

269
            if (!current_data) break;// Can't continue chain
270

271
            // Find best match at next frame
272
            double best_cost = std::numeric_limits<double>::max();
273
            EntityId best_entity = 0;
274

275
            for (auto const & candidate: frame_lookup.at(next_frame)) {
276
                EntityId candidate_id = std::get<1>(candidate);
277
                DataType const * candidate_data = std::get<0>(candidate);
278

279
                double cost = 0.0;
280
                if (_filter_prototype) {
281
                    auto temp_filter = _filter_prototype->clone();
282
                    temp_filter->initialize(_feature_extractor->getInitialState(*current_data));
283
                    FilterState predicted = temp_filter->predict();
284
                    Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*candidate_data);
285
                    cost = _chain_cost_function(predicted, obs, 1);
286
                } else {
287
                    // Simple distance without filter
288
                    Eigen::VectorXd feat_current = _feature_extractor->getFilterFeatures(*current_data);
289
                    Eigen::VectorXd feat_candidate = _feature_extractor->getFilterFeatures(*candidate_data);
290
                    cost = (feat_current - feat_candidate).norm();
291
                }
292

293
                if (cost < best_cost) {
294
                    best_cost = cost;
295
                    best_entity = candidate_id;
296
                }
297
            }
298

299
            if (best_cost < std::numeric_limits<double>::max()) {
300
                chain.entities.push_back({next_frame, best_entity});
301
                chain.cost += static_cast<int64_t>(best_cost * _cost_scale_factor);
302
                current_entity = best_entity;
303
                current_frame = next_frame;
304
            } else {
305
                break;// No valid match found
306
            }
307
        }
308

309
        // Add end node
310
        chain.entities.push_back(end_node);
311

312
        return chain;
313
    }
314

315
    // --- Main Graph Building and Solving Logic ---
316
    std::map<GroupId, Path> solve_flow_problem(
162✔
317
            std::map<TimeFrameIndex, FrameBucket> const & frame_lookup,
318
            std::map<GroupId, EntityId> const & start_anchors,
319
            std::map<GroupId, EntityId> const & end_anchors,
320
            TimeFrameIndex start_frame,
321
            TimeFrameIndex end_frame,
322
            std::unordered_set<EntityId> const * excluded_entities,
323
            std::unordered_set<EntityId> const * include_entities) {
324

325
        // 1) Build greedy meta-nodes (cheap consecutive links) independent of groups
326
        auto meta_nodes = build_meta_nodes(frame_lookup, start_frame, end_frame, excluded_entities, include_entities);
162✔
327

328
        std::map<GroupId, Path> all_solved_paths;
162✔
329

330
        // 2) Solve a separate min-cost flow problem for each group over meta-nodes
331
        for (auto const & [group_id, start_entity_id]: start_anchors) {
488✔
332
            auto end_anchors_it = end_anchors.find(group_id);
163✔
333
            if (end_anchors_it == end_anchors.end()) {
163✔
UNCOV
334
                if (_logger) {
×
UNCOV
335
                    _logger->error("No end anchor found for group {}", static_cast<unsigned long long>(group_id));
×
336
                }
UNCOV
337
                continue;
×
338
            }
339
            EntityId end_entity_id = end_anchors_it->second;
163✔
340

341
            Path solved_path = solve_single_group_flow_over_meta(meta_nodes, frame_lookup, group_id, start_entity_id, end_entity_id, start_frame, end_frame);
163✔
342
            if (!solved_path.empty()) {
163✔
343
                all_solved_paths[group_id] = solved_path;
147✔
344
            }
345
        }
346

347
        return all_solved_paths;
162✔
348
    }
162✔
349

350
    // Solve min-cost flow for a single group over meta-nodes
351
    Path solve_single_group_flow_over_meta(
163✔
352
            std::vector<MetaNode> const & meta_nodes,
353
            std::map<TimeFrameIndex, FrameBucket> const & frame_lookup,
354
            GroupId group_id,
355
            EntityId start_entity_id,
356
            EntityId end_entity_id,
357
            TimeFrameIndex start_frame,
358
            TimeFrameIndex end_frame) {
359

360
        // Map each meta-node to an index
361
        int const num_meta = static_cast<int>(meta_nodes.size());
163✔
362
        auto get_start_meta_index = [&]() -> std::optional<int> {
163✔
363
            for (int i = 0; i < num_meta; ++i) {
164✔
364
                if (meta_nodes[i].start_frame == start_frame && meta_nodes[i].start_entity == start_entity_id) {
164✔
365
                    return i;
163✔
366
                }
367
            }
UNCOV
368
            return std::nullopt;
×
369
        };
370
        auto get_end_meta_index = [&]() -> std::optional<int> {
163✔
371
            for (int i = 0; i < num_meta; ++i) {
202✔
372
                // End meta-node must contain the end anchor entity at end_frame
373
                if (meta_nodes[i].end_frame == end_frame && meta_nodes[i].end_entity == end_entity_id) {
202✔
374
                    return i;
163✔
375
                }
376
            }
UNCOV
377
            return std::nullopt;
×
378
        };
379

380
        auto start_meta_opt = get_start_meta_index();
163✔
381
        auto end_meta_opt = get_end_meta_index();
163✔
382
        if (!start_meta_opt.has_value() || !end_meta_opt.has_value()) {
163✔
UNCOV
383
            if (_logger) {
×
UNCOV
384
                _logger->error("Group {} missing start or end meta-node anchor (start={}, end={})",
×
UNCOV
385
                               static_cast<unsigned long long>(group_id),
×
UNCOV
386
                               start_meta_opt.has_value(), end_meta_opt.has_value());
×
387
            }
UNCOV
388
            return {};
×
389
        }
390

391
        // Node indexing: 0..num_meta-1 are meta-nodes, plus source and sink
392
        int const source_node = num_meta;
163✔
393
        int const sink_node = num_meta + 1;
163✔
394

395
        // Build arcs for the abstract solver
396
        std::vector<ArcSpec> arcs;
163✔
397
        arcs.reserve(static_cast<size_t>(num_meta * num_meta / 4 + 4));
163✔
398
        // Source -> start meta
399
        arcs.push_back({source_node, *start_meta_opt, 1, 0});
163✔
400
        // End meta -> sink
401
        arcs.push_back({*end_meta_opt, sink_node, 1, 0});
163✔
402

403
        // Transitions between meta-nodes (only forward in time)
404
        int num_transition_arcs = 0;
163✔
405
        constexpr int64_t max_prediction_horizon = 50;// allow longer jumps across blackouts
163✔
406
        for (int i = 0; i < num_meta; ++i) {
492✔
407
            MetaNode const & from = meta_nodes[i];
329✔
408
            for (int j = 0; j < num_meta; ++j) {
1,048✔
409
                MetaNode const & to = meta_nodes[j];
673✔
410
                if (to.start_frame <= from.end_frame) continue;// must go forward
673✔
411
                int num_steps = (to.start_frame - from.end_frame).getValue();
74✔
412
                if (num_steps <= 0 || num_steps > max_prediction_horizon) continue;
74✔
413

414
                // Predict from the end of 'from' to the start of 'to'
415
                FilterState predicted_state;
46✔
416
                if (_filter_prototype) {
46✔
417
                    auto temp_filter = _filter_prototype->clone();
46✔
418
                    // Coerce the saved end_state to the filter's expected state dimension if needed
419
                    FilterState const proto_state = temp_filter->getState();
46✔
420
                    int const target_dim = static_cast<int>(proto_state.state_mean.size());
46✔
421
                    FilterState init_state = from.end_state;
46✔
422
                    if (static_cast<int>(init_state.state_mean.size()) != target_dim ||
46✔
423
                        init_state.state_covariance.rows() != target_dim ||
92✔
424
                        init_state.state_covariance.cols() != target_dim) {
46✔
425
                        // Build a compatible state by copying what fits and padding the rest
UNCOV
426
                        FilterState coerced;
×
UNCOV
427
                        coerced.state_mean = Eigen::VectorXd::Zero(target_dim);
×
UNCOV
428
                        int const copy_dim = std::min<int>(target_dim, static_cast<int>(init_state.state_mean.size()));
×
UNCOV
429
                        if (copy_dim > 0) coerced.state_mean.head(copy_dim) = init_state.state_mean.head(copy_dim);
×
430

UNCOV
431
                        coerced.state_covariance = Eigen::MatrixXd::Zero(target_dim, target_dim);
×
UNCOV
432
                        int const cr = std::min<int>(target_dim, init_state.state_covariance.rows());
×
UNCOV
433
                        int const cc = std::min<int>(target_dim, init_state.state_covariance.cols());
×
UNCOV
434
                        if (cr > 0 && cc > 0) {
×
UNCOV
435
                            int const b = std::min(cr, cc);
×
UNCOV
436
                            coerced.state_covariance.topLeftCorner(b, b) = init_state.state_covariance.topLeftCorner(b, b);
×
437
                        }
438
                        // Pad remaining diagonal to a large uncertainty to remain conservative
UNCOV
439
                        constexpr double kPadVar = 1e6;
×
UNCOV
440
                        for (int d = 0; d < target_dim; ++d) {
×
UNCOV
441
                            if (coerced.state_covariance(d, d) <= 0.0) coerced.state_covariance(d, d) = kPadVar;
×
442
                        }
UNCOV
443
                        init_state = std::move(coerced);
×
UNCOV
444
                        if (_logger) {
×
UNCOV
445
                            _logger->warn("State dimension coerced for transition prediction: was {} -> now {}",
×
UNCOV
446
                                          static_cast<int>(from.end_state.state_mean.size()), target_dim);
×
447
                        }
UNCOV
448
                    }
×
449

450
                    temp_filter->initialize(init_state);
46✔
451
                    for (int s = 0; s < num_steps; ++s) {
1,192✔
452
                        predicted_state = temp_filter->predict();
1,146✔
453
                    }
454
                }
46✔
455
                // Compute cost to the first observation of 'to'
456
                DataType const * to_start_data = findEntity(frame_lookup.at(to.start_frame), to.start_entity);
46✔
457
                if (!to_start_data) continue;
46✔
458
                Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*to_start_data);
46✔
459
                double dist = _transition_cost_function(predicted_state, obs, num_steps);
46✔
460
                int64_t arc_cost = static_cast<int64_t>(dist * _cost_scale_factor);
46✔
461
                arcs.push_back({i, j, 1, arc_cost});
46✔
462
                num_transition_arcs++;
46✔
463
            }
464
        }
465

466
        if (_logger) {
163✔
467
            _logger->debug("Group {} meta-graph: {} meta-nodes, transitions={}",
163✔
468
                           static_cast<unsigned long long>(group_id), num_meta, num_transition_arcs);
163✔
469
        }
470

471
        // Solve using private solver and reconstruct meta-node path
472
        auto const seq_opt = solveMinCostSingleUnitPath(num_meta + 2, source_node, sink_node, arcs);
163✔
473
        if (!seq_opt.has_value()) {
163✔
474
            if (_logger) {
16✔
475
                _logger->error("Min-cost flow (meta) failed: no optimal path");
16✔
476
            }
477
            return {};
16✔
478
        }
479

480
        Path expanded_path;
147✔
481
        auto const & sequence = *seq_opt;
147✔
482
        for (size_t idx = 1; idx < sequence.size(); ++idx) {// skip the source at index 0
461✔
483
            int node_index = sequence[idx];
314✔
484
            if (node_index >= 0 && node_index < num_meta) {
314✔
485
                for (auto const & n: meta_nodes[static_cast<size_t>(node_index)].members) {
1,623✔
486
                    expanded_path.push_back(n);
1,456✔
487
                }
488
            }
489
        }
490

491
        return expanded_path;
147✔
492
    }
163✔
493

494
    /**
495
     * @brief Build meta-nodes using Hungarian algorithm for optimal chain extension.
496
     * 
497
     * Unlike greedy assignment, this uses Hungarian algorithm at each frame to ensure
498
     * global optimal assignment of chains to candidates, preventing "stealing" where
499
     * one chain takes another's best match.
500
     * 
501
     * @pre frame_lookup contains observations in [start_frame, end_frame]
502
     * @post Each observation belongs to at most one meta-node
503
     */
504
    std::vector<MetaNode> build_meta_nodes(
162✔
505
            std::map<TimeFrameIndex, FrameBucket> const & frame_lookup,
506
            TimeFrameIndex start_frame,
507
            TimeFrameIndex end_frame,
508
            std::unordered_set<EntityId> const * excluded_entities,
509
            std::unordered_set<EntityId> const * include_entities) {
510

511
        // Structure to track active chains being built
512
        struct ActiveChain {
513
            size_t meta_node_idx;// Index in meta_nodes vector
514
            TimeFrameIndex curr_frame;
515
            EntityId curr_entity;
516
            DataType const * curr_data;
517
            std::unique_ptr<IFilter> filter;// Cloned filter for this chain
518
            FilterState predicted;          // Cached prediction for next frame
519

520
            // Constructor to properly initialize TimeFrameIndex
521
            ActiveChain()
325✔
522
                : meta_node_idx(0),
325✔
523
                  curr_frame(TimeFrameIndex(0)),
325✔
524
                  curr_entity(0),
325✔
525
                  curr_data(nullptr) {}
325✔
526
        };
527

528
        std::vector<MetaNode> meta_nodes;
162✔
529
        std::set<std::pair<long long, EntityId>> used;// key: (frame, entity)
162✔
530
        std::vector<ActiveChain> active_chains;
162✔
531

532
        // Process frame by frame, using Hungarian algorithm to extend chains optimally
533
        for (TimeFrameIndex f = start_frame; f <= end_frame; ++f) {
3,144✔
534
            if (!frame_lookup.count(f)) continue;
1,664✔
535

536
            // Step 1: Start new chains for unused observations in current frame
537
            for (auto const & item: frame_lookup.at(f)) {
5,282✔
538
                EntityId entity_id = std::get<1>(item);
3,298✔
539
                auto used_key = std::make_pair(static_cast<long long>(f.getValue()), entity_id);
3,298✔
540
                if (used.count(used_key)) continue;
3,298✔
541

542
                if (excluded_entities && excluded_entities->count(entity_id) > 0) {
1,001✔
543
                    if (!(include_entities && include_entities->count(entity_id) > 0)) {
870✔
544
                        continue;
676✔
545
                    }
546
                }
547

548
                DataType const * start_data = std::get<0>(item);
325✔
549

550
                // Initialize filter for this new chain
551
                FilterState start_state;
325✔
552
                std::unique_ptr<IFilter> chain_filter;
325✔
553
                if (_filter_prototype) {
325✔
554
                    chain_filter = _filter_prototype->clone();
325✔
555
                    chain_filter->initialize(_feature_extractor->getInitialState(*start_data));
325✔
556
                    start_state = chain_filter->getState();
325✔
557
                }
558

559
                // Create new meta-node
560
                MetaNode node;
325✔
561
                node.start_frame = f;
325✔
562
                node.start_entity = entity_id;
325✔
563
                node.members.push_back(NodeInfo{f, entity_id});
325✔
564
                node.start_state = start_state;
325✔
565
                // Seed end_state so it's never empty even for single-frame nodes
566
                node.end_state = start_state;
325✔
567
                used.insert(used_key);
325✔
568

569
                size_t node_idx = meta_nodes.size();
325✔
570
                meta_nodes.push_back(std::move(node));
325✔
571

572
                // Add to active chains for extension
573
                ActiveChain chain;
325✔
574
                chain.meta_node_idx = node_idx;
325✔
575
                chain.curr_frame = f;
325✔
576
                chain.curr_entity = entity_id;
325✔
577
                chain.curr_data = start_data;
325✔
578
                chain.filter = std::move(chain_filter);
325✔
579
                active_chains.push_back(std::move(chain));
325✔
580
            }
581

582
            // Step 2: Try to extend all active chains to next frame using Hungarian algorithm
583
            TimeFrameIndex next_frame = f + TimeFrameIndex(1);
1,659✔
584
            if (next_frame > end_frame || !frame_lookup.count(next_frame)) {
1,659✔
585
                continue;// No next frame, chains will terminate
163✔
586
            }
587

588
            if (active_chains.empty()) continue;
1,496✔
589

590
            // Predict all active chains
591
            for (auto & chain: active_chains) {
3,777✔
592
                if (chain.filter) {
2,459✔
593
                    chain.predicted = chain.filter->predict();
2,459✔
594
                }
595
            }
596

597
            // Collect available candidates in next frame
598
            std::vector<std::tuple<EntityId, DataType const *, size_t>> candidates;// entity_id, data, index
1,318✔
599
            for (size_t cand_idx = 0; cand_idx < frame_lookup.at(next_frame).size(); ++cand_idx) {
3,935✔
600
                auto const & cand = frame_lookup.at(next_frame)[cand_idx];
2,617✔
601
                EntityId cand_id = std::get<1>(cand);
2,617✔
602
                auto key = std::make_pair(static_cast<long long>(next_frame.getValue()), cand_id);
2,617✔
603
                if (used.count(key)) continue;
2,617✔
604
                if (excluded_entities && excluded_entities->count(cand_id) > 0) {
2,617✔
605
                    if (!(include_entities && include_entities->count(cand_id) > 0)) {
320✔
606
                        continue;
194✔
607
                    }
608
                }
609
                DataType const * cand_data = std::get<0>(cand);
2,423✔
610
                candidates.emplace_back(cand_id, cand_data, cand_idx);
2,423✔
611
            }
612

613
            if (candidates.empty()) {
1,318✔
614
                // No candidates available - all chains terminate
615
                if (_logger) {
34✔
616
                    _logger->debug("{} active chains terminating at frame {} (no available candidates in frame {})",
34✔
617
                                   active_chains.size(), f.getValue(), next_frame.getValue());
34✔
618
                }
619
                active_chains.clear();
34✔
620
                continue;
34✔
621
            }
622

623
            // Build cost matrix for Hungarian algorithm
624
            // Rows = active chains, Cols = candidates
625
            int const cost_scaling_factor = 1000;
1,284✔
626
            int const max_cost = static_cast<int>(_cheap_assignment_threshold * cost_scaling_factor);
1,284✔
627
            std::vector<std::vector<int>> cost_matrix(active_chains.size(),
5,136✔
628
                                                      std::vector<int>(candidates.size()));
5,136✔
629

630
            for (size_t chain_idx = 0; chain_idx < active_chains.size(); ++chain_idx) {
3,709✔
631
                auto const & chain = active_chains[chain_idx];
2,425✔
632
                for (size_t cand_idx = 0; cand_idx < candidates.size(); ++cand_idx) {
7,004✔
633
                    DataType const * cand_data = std::get<1>(candidates[cand_idx]);
4,579✔
634
                    Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*cand_data);
4,579✔
635

636
                    double cost_double;
637
                    if (chain.filter) {
4,579✔
638
                        cost_double = _chain_cost_function(chain.predicted, obs, 1);
4,579✔
639
                    } else {
UNCOV
640
                        Eigen::VectorXd curr_obs = _feature_extractor->getFilterFeatures(*chain.curr_data);
×
UNCOV
641
                        cost_double = (curr_obs - obs).norm();
×
642
                    }
×
643

644
                    int cost = static_cast<int>(cost_double * cost_scaling_factor);
4,579✔
645
                    if (cost >= std::numeric_limits<int>::max()) {
4,579✔
UNCOV
646
                        cost = std::numeric_limits<int>::max() - 1;
×
647
                    }
648
                    cost_matrix[chain_idx][cand_idx] = cost;
4,579✔
649
                }
650
            }
651

652
            // Solve Hungarian assignment
653
            std::vector<std::vector<int>> assignment_matrix;
1,284✔
654
            Munkres::hungarian_with_assignment(cost_matrix, assignment_matrix);
1,284✔
655

656
            // Process assignments
657
            std::vector<bool> chain_extended(active_chains.size(), false);
3,852✔
658
            std::vector<ActiveChain> remaining_chains;
1,284✔
659

660
            for (size_t chain_idx = 0; chain_idx < assignment_matrix.size(); ++chain_idx) {
3,709✔
661
                bool found_assignment = false;
2,425✔
662
                int assigned_cand_idx = -1;
2,425✔
663

664
                for (size_t cand_idx = 0; cand_idx < assignment_matrix[chain_idx].size(); ++cand_idx) {
3,631✔
665
                    if (assignment_matrix[chain_idx][cand_idx] == 1) {
3,503✔
666
                        // Check if cost is within threshold
667
                        if (cost_matrix[chain_idx][cand_idx] <= max_cost) {
2,297✔
668
                            found_assignment = true;
2,297✔
669
                            assigned_cand_idx = static_cast<int>(cand_idx);
2,297✔
670
                        }
671
                        break;
2,297✔
672
                    }
673
                }
674

675
                auto & chain = active_chains[chain_idx];
2,425✔
676
                auto & node = meta_nodes[chain.meta_node_idx];
2,425✔
677

678
                if (found_assignment) {
2,425✔
679
                    // Extend chain
680
                    EntityId best_entity = std::get<0>(candidates[assigned_cand_idx]);
2,297✔
681
                    DataType const * best_data = std::get<1>(candidates[assigned_cand_idx]);
2,297✔
682

683
                    Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*best_data);
2,297✔
684
                    if (chain.filter) {
2,297✔
685
                        chain.filter->update(chain.predicted, {obs});
4,594✔
686

687
                        // Check covariance health
688
                        auto updated_state = chain.filter->getState();
2,297✔
689
                        double determinant = updated_state.state_covariance.determinant();
2,297✔
690

691
                        if (std::abs(determinant) < 1e-10 && _logger) {
2,297✔
UNCOV
692
                            Eigen::JacobiSVD<Eigen::MatrixXd> svd(updated_state.state_covariance);
×
UNCOV
693
                            double condition_number = svd.singularValues()(0) /
×
UNCOV
694
                                                      (svd.singularValues()(svd.singularValues().size() - 1) + 1e-20);
×
695

UNCOV
696
                            _logger->warn("State covariance singular at frame {} entity {}: det={:.2e}, cond={:.2e}",
×
UNCOV
697
                                          next_frame.getValue(), best_entity, determinant, condition_number);
×
698

UNCOV
699
                            if (condition_number > 1e12) {
×
UNCOV
700
                                _logger->warn("  Terminating chain due to ill-conditioned covariance");
×
UNCOV
701
                                found_assignment = false;// Terminate this chain
×
702
                            }
UNCOV
703
                        }
×
704
                    }
2,297✔
705

706
                    if (found_assignment) {
2,297✔
707
                        NodeInfo next_info{next_frame, best_entity};
2,297✔
708
                        node.members.push_back(next_info);
2,297✔
709
                        used.insert(std::make_pair(static_cast<long long>(next_frame.getValue()), best_entity));
2,297✔
710

711
                        chain.curr_frame = next_frame;
2,297✔
712
                        chain.curr_entity = best_entity;
2,297✔
713
                        chain.curr_data = best_data;
2,297✔
714
                        chain_extended[chain_idx] = true;
2,297✔
715
                        remaining_chains.push_back(std::move(chain));
2,297✔
716
                    }
717
                }
2,297✔
718

719
                if (!found_assignment) {
2,425✔
720
                    // Chain terminates - finalize meta-node
721
                    node.end_frame = node.members.back().frame;
128✔
722
                    node.end_entity = node.members.back().entity_id;
128✔
723
                    if (chain.filter) {
128✔
724
                        node.end_state = chain.filter->getState();
128✔
725
                    }
726

727
                    if (_logger) {
128✔
728
                        // Log why chain ended
729
                        double best_cost_double = std::numeric_limits<double>::infinity();
128✔
730
                        if (assigned_cand_idx >= 0) {
128✔
UNCOV
731
                            best_cost_double = cost_matrix[chain_idx][assigned_cand_idx] / static_cast<double>(cost_scaling_factor);
×
732
                        }
733

734
                        _logger->debug("Meta-node #{}: frames {} to {} ({} frames), entities {} to {}, {} members - terminated (best_cost={:.2f}, threshold={:.2f})",
128✔
735
                                       chain.meta_node_idx,
128✔
736
                                       node.start_frame.getValue(),
128✔
737
                                       node.end_frame.getValue(),
256✔
738
                                       node.end_frame.getValue() - node.start_frame.getValue() + 1,
256✔
739
                                       node.start_entity,
128✔
740
                                       node.end_entity,
128✔
741
                                       node.members.size(),
256✔
742
                                       best_cost_double,
743
                                       _cheap_assignment_threshold);
128✔
744
                    }
745
                }
746
            }
747

748
            active_chains = std::move(remaining_chains);
1,284✔
749
        }
750

751
        // Finalize any remaining active chains at end of range
752
        for (auto & chain: active_chains) {
325✔
753
            auto & node = meta_nodes[chain.meta_node_idx];
163✔
754
            node.end_frame = node.members.back().frame;
163✔
755
            node.end_entity = node.members.back().entity_id;
163✔
756
            if (chain.filter) {
163✔
757
                node.end_state = chain.filter->getState();
163✔
758
            }
759

760
            if (_logger) {
163✔
761
                _logger->debug("Meta-node #{}: frames {} to {} ({} frames), entities {} to {}, {} members - reached end",
163✔
762
                               chain.meta_node_idx,
163✔
763
                               node.start_frame.getValue(),
163✔
764
                               node.end_frame.getValue(),
326✔
765
                               node.end_frame.getValue() - node.start_frame.getValue() + 1,
326✔
766
                               node.start_entity,
163✔
767
                               node.end_entity,
163✔
768
                               node.members.size());
326✔
769
            }
770
        }
771

772
        if (_logger) {
162✔
773
            _logger->debug("Built {} meta-nodes using Hungarian assignment", meta_nodes.size());
162✔
774

775
            // Compute statistics on meta-node lengths
776
            if (!meta_nodes.empty()) {
162✔
777
                std::vector<int> lengths;
162✔
778
                for (auto const & mn: meta_nodes) {
487✔
779
                    lengths.push_back(static_cast<int>(mn.members.size()));
325✔
780
                }
781
                std::sort(lengths.begin(), lengths.end());
162✔
782

783
                int total_length = 0;
162✔
784
                for (int len: lengths) total_length += len;
487✔
785
                double mean_length = static_cast<double>(total_length) / lengths.size();
162✔
786

787
                int median_length = lengths[lengths.size() / 2];
162✔
788
                int min_length = lengths.front();
162✔
789
                int max_length = lengths.back();
162✔
790

791
                _logger->debug("Meta-node length statistics: min={}, median={}, mean={:.1f}, max={}",
162✔
792
                               min_length, median_length, mean_length, max_length);
793

794
                // Count single-frame meta-nodes
795
                int single_frame_count = 0;
162✔
796
                for (int len: lengths) {
487✔
797
                    if (len == 1) single_frame_count++;
325✔
798
                }
799
                if (single_frame_count > 0) {
162✔
800
                    _logger->debug("  {} single-frame meta-nodes ({:.1f}%)",
34✔
801
                                   single_frame_count,
802
                                   100.0 * single_frame_count / meta_nodes.size());
34✔
803
                }
804
            }
162✔
805
        }
806

807
        return meta_nodes;
324✔
808
    }
2,459✔
809

810
    // Remove all the old greedy code that follows (everything from "double best_cost" to the old "return meta_nodes")
811
    /*
812
    OLD GREEDY CODE REMOVED - replaced with Hungarian-based approach above
813
    The new algorithm:
814
    1. Starts new chains for all unused observations at each frame
815
    2. Predicts all active chains forward one frame
816
    3. Builds cost matrix (chains x candidates)
817
    4. Uses Hungarian algorithm for optimal assignment
818
    5. Only accepts assignments below threshold
819
    6. Chains that don't get assigned (or exceed threshold) terminate
820
    
821
    This prevents "stealing" where long chains take candidates that would be better matches for other chains.
822
    */
823

824
    // --- Final Smoothing Step ---
825
    SmoothedResults generate_smoothed_results(
146✔
826
            std::map<GroupId, Path> const & solved_paths,
827
            std::map<TimeFrameIndex, FrameBucket> const & frame_lookup,
828
            TimeFrameIndex start_frame,
829
            TimeFrameIndex end_frame) {
830

831
        SmoothedResults final_results;
146✔
832

833
        // Skip smoothing if no filter is provided
834
        if (!_filter_prototype) {
146✔
UNCOV
835
            return final_results;
×
836
        }
837

838
        for (auto const & [group_id, path]: solved_paths) {
440✔
839
            if (path.empty()) continue;
147✔
840

841
            auto filter = _filter_prototype->clone();
147✔
842
            std::vector<FilterState> forward_states;
147✔
843

844
            // Forward pass using the solved path
845
            for (size_t i = 0; i < path.size(); ++i) {
1,603✔
846
                auto const & node = path[i];
1,456✔
847
                auto const * data = findEntity(frame_lookup.at(node.frame), node.entity_id);
1,456✔
848
                if (!data) continue;
1,456✔
849

850
                if (i == 0) {
1,456✔
851
                    filter->initialize(_feature_extractor->getInitialState(*data));
147✔
852
                } else {
853
                    TimeFrameIndex prev_frame = path[i - 1].frame;
1,309✔
854
                    int num_steps = (node.frame - prev_frame).getValue();
1,309✔
855

856
                    if (num_steps <= 0) {
1,309✔
UNCOV
857
                        if (_logger) _logger->error("Invalid num_steps in smoothing: {}", num_steps);
×
UNCOV
858
                        continue;// Skip invalid steps
×
859
                    }
860

861
                    // Multi-step prediction: call predict() for each frame step
862
                    // The last predict() call will set the filter's internal state to the predicted state
863
                    FilterState pred = filter->getState();// Initialize with current state
1,309✔
864
                    for (int step = 0; step < num_steps; ++step) {
2,726✔
865
                        pred = filter->predict();
1,417✔
866
                    }
867
                    // Now filter's internal state is at 'pred', and we update it with the measurement
868
                    filter->update(pred, {_feature_extractor->getFilterFeatures(*data)});
2,618✔
869
                }
1,309✔
870
                forward_states.push_back(filter->getState());
1,456✔
871
            }
872

873
            // Backward smoothing pass
874
            if (forward_states.size() > 1) {
147✔
875
                final_results[group_id] = filter->smooth(forward_states);
147✔
876
            } else {
UNCOV
877
                final_results[group_id] = forward_states;
×
878
            }
879
        }
880
        return final_results;
881
    }
1,309✔
882

883
    // --- Utility Functions ---
884
    [[nodiscard]] std::map<TimeFrameIndex, FrameBucket>
885
    buildFrameLookup(auto && data_source, TimeFrameIndex start_frame, TimeFrameIndex end_frame) const {
162✔
886
        std::map<TimeFrameIndex, FrameBucket> lookup;
162✔
887
        for (auto const & item: data_source) {
64,356✔
888
            TimeFrameIndex t = getTimeFrameIndex(item);
32,034✔
889
            if (t >= start_frame && t <= end_frame) {
32,034✔
890
                lookup[t].emplace_back(&getData(item), getEntityId(item), t);
3,298✔
891
            }
892
        }
893
        return lookup;
162✔
UNCOV
894
    }
×
895

896
    static DataType const * findEntity(FrameBucket const & bucket, EntityId id) {
1,502✔
897
        for (auto const & item: bucket) {
2,243✔
898
            if (std::get<1>(item) == id) return std::get<0>(item);
2,243✔
899
        }
UNCOV
900
        return nullptr;
×
901
    }
902

903
private:
904
    std::unique_ptr<IFilter> _filter_prototype;
905
    std::unique_ptr<IFeatureExtractor<DataType>> _feature_extractor;
906
    CostFunction _chain_cost_function;
907
    CostFunction _transition_cost_function;
908
    double _cost_scale_factor;
909
    double _cheap_assignment_threshold;
910
    std::shared_ptr<spdlog::logger> _logger;
911
};
912

913
}// namespace StateEstimation
914

915
#endif// MIN_COST_FLOW_TRACKER_HPP
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc