• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

paulmthompson / WhiskerToolbox / 18685379784

21 Oct 2025 01:25PM UTC coverage: 72.522% (+0.1%) from 72.391%
18685379784

push

github

paulmthompson
fix failing tests

18 of 40 new or added lines in 1 file covered. (45.0%)

1765 existing lines in 32 files now uncovered.

53998 of 74457 relevant lines covered (72.52%)

46177.73 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.87
/src/StateEstimation/MinCostFlowTracker.hpp
1
#ifndef MIN_COST_FLOW_TRACKER_HPP
2
#define MIN_COST_FLOW_TRACKER_HPP
3

4
#include "Assignment/IAssigner.hpp"
5
#include "Assignment/NScanLookahead.hpp"
6
#include "Assignment/hungarian.hpp"
7
#include "Cost/CostFunctions.hpp"
8
#include "DataSource.hpp"
9
#include "Entity/EntityGroupManager.hpp"
10
#include "Features/IFeatureExtractor.hpp"
11
#include "Filter/IFilter.hpp"
12
#include "Filter/Kalman/KalmanMatrixBuilder.hpp"
13
#include "MinCostFlowSolver.hpp"
14
#include "TimeFrame/TimeFrame.hpp"
15
#include "Tracking/AnchorUtils.hpp"
16
#include "Tracking/Tracklet.hpp"
17

18
#include "spdlog/sinks/basic_file_sink.h"
19
#include "spdlog/spdlog.h"
20
#include <Eigen/Dense>
21

22
#include <algorithm>
23
#include <chrono>
24
#include <cmath>
25
#include <functional>
26
#include <limits>
27
#include <map>
28
#include <memory>
29
#include <optional>
30
#include <set>
31
#include <unordered_map>
32
#include <unordered_set>
33
#include <utility>
34
#include <vector>
35

36
/*
37
 data objects have features extracted in a time series. 
38
 These time series of features have a filter applied to find "tracklets" or "meta nodes" that represent small 
39
 time series of features that represent the same object across multiple frames. Once the tracklets are determined, 
40
 sparse labels are used to try to assign IDs to each entity in the tracklets. To do this, we perform the 
41
 following: we order or labels into pairs of time representing the nearest neighbor 
42
 times (e.g. 1-1000, 1000-4000, 4000-10000 etc). we identify the frames with labels for a group. We construct a 
43
 subset of meta nodes that repesent 1) the first frame in the tracklet assigned to that label along with the rest 
44
 of the tracklet (source tracklet). if the label does not exist on the left side of this tracklet, the tracklet is 
45
 modified into a sliced meta nodet; 2) the tracklet assigned to the last frame from its start to the end frame sliced 
46
 to remove frames after the anchor; 3) all meta nodes into between these. WE then apply a minimum cost flow solver to 
47
 find which tracklets can "link" the left and right sliced meta nodes. If a minimum cost flow solution is found, we add all 
48
 entity IDs on that path to the group that label corresponds to. If the min cost flow solver fails (i.e. because there is a 
49
 large gap somewhere between tracklets), we will just assign all the entities in the sliced tracklets attached to the 
50
 anchors to the label group. We then repeat this procedure for all label pairs.
51

52
*/
53

54
namespace StateEstimation {
55

56
/**
57
 * @brief Contract policy for how MinCostFlowTracker handles invariant violations.
58
 */
59
enum class TrackerContractPolicy {
60
    Throw,         // Throw std::logic_error
61
    LogAndContinue,// Log error and continue with best-effort result
62
    Abort          // Log critical and abort process
63
};
64

65
struct TrackerDiagnostics {
66
    std::size_t noOptimalPathCount = 0;// Number of times solver found no optimal path
67
};
68

69
/**
70
 * @brief A tracker that uses a global min-cost flow optimization to solve data association.
71
 *
72
 * This tracker formulates the tracking problem as a graph problem, finding the globally
73
 * optimal set of tracks over an entire interval between anchors. It is more robust to
74
 * ambiguities and identity swaps than iterative, frame-by-frame methods.
75
 *
76
 * @tparam DataType raw observation type (e.g., Line2D)
77
 */
78
template<typename DataType>
79
class MinCostFlowTracker {
80
public:
81

82
    /**
83
     * @brief Construct a new MinCostFlowTracker
84
     *
85
     * @param filter_prototype Prototype filter (cloned for prediction and final smoothing). 
86
     *        If nullptr, prediction is skipped in cost calculation (cost function must handle this)
87
     *        and no smoothing is performed. The filter's uncertainty automatically scales with
88
     *        gap size through process noise accumulation.
89
     * @param feature_extractor Feature extractor for DataType
90
     * @param cost_function Function to compute cost between predicted state and observation
91
     * @param cost_scale_factor Multiplier to convert floating-point costs to integers for the solver.
92
     * @param cheap_assignment_threshold Threshold for greedy chaining
93
     * @param policy Contract violation policy
94
     * @param n_scan_depth Number of frames to look ahead when assignments are ambiguous (default 3)
95
     * @param enable_n_scan Enable N-scan lookahead for ambiguous assignments (default true)
96
     * @param max_gap_frames Maximum frames a chain can skip before terminating (default 3, set to -1 for unlimited)
97
     */
98
    MinCostFlowTracker(std::unique_ptr<IFilter> filter_prototype,
3✔
99
                       std::unique_ptr<IFeatureExtractor<DataType>> feature_extractor,
100
                       CostFunction cost_function,
101
                       double cost_scale_factor = 100.0,
102
                       double cheap_assignment_threshold = 5.0,
103
                       TrackerContractPolicy policy = TrackerContractPolicy::Throw,
104
                       int n_scan_depth = 3,
105
                       bool enable_n_scan = true,
106
                       int max_gap_frames = 3)
107
        : _filter_prototype(std::move(filter_prototype)),
3✔
108
          _feature_extractor(std::move(feature_extractor)),
3✔
109
          _chain_cost_function(cost_function),
3✔
110
          _transition_cost_function(cost_function),
3✔
111
          _lookahead_cost_function(cost_function),
3✔
112
          _cost_scale_factor(cost_scale_factor),
3✔
113
          _cheap_assignment_threshold(cheap_assignment_threshold),
3✔
114
          _policy(policy),
3✔
115
          _n_scan_depth(n_scan_depth),
3✔
116
          _enable_n_scan(enable_n_scan),
3✔
117
          _max_gap_frames(max_gap_frames) {}
3✔
118

119
    /**
120
     * @brief Construct with separate cost functions for greedy chaining and meta-node transitions.
121
     *
122
     * @param chain_cost_function Cost for frame-to-frame greedy chaining (typically 1-step)
123
     * @param transition_cost_function Cost for meta-node transitions across k-step gaps
124
     * @param n_scan_depth Number of frames to look ahead when assignments are ambiguous (default 3)
125
     * @param enable_n_scan Enable N-scan lookahead for ambiguous assignments (default true)
126
     */
127
    MinCostFlowTracker(std::unique_ptr<IFilter> filter_prototype,
9✔
128
                       std::unique_ptr<IFeatureExtractor<DataType>> feature_extractor,
129
                       CostFunction chain_cost_function,
130
                       CostFunction transition_cost_function,
131
                       double cost_scale_factor,
132
                       double cheap_assignment_threshold,
133
                       TrackerContractPolicy policy = TrackerContractPolicy::Throw,
134
                       int n_scan_depth = 3,
135
                       bool enable_n_scan = true,
136
                       int max_gap_frames = 3)
137
        : _filter_prototype(std::move(filter_prototype)),
9✔
138
          _feature_extractor(std::move(feature_extractor)),
9✔
139
          _chain_cost_function(chain_cost_function),
9✔
140
          _transition_cost_function(std::move(transition_cost_function)),
9✔
141
          _lookahead_cost_function(chain_cost_function),
9✔
142
          _cost_scale_factor(cost_scale_factor),
9✔
143
          _cheap_assignment_threshold(cheap_assignment_threshold),
9✔
144
          _policy(policy),
9✔
145
          _n_scan_depth(n_scan_depth),
9✔
146
          _enable_n_scan(enable_n_scan),
9✔
147
          _max_gap_frames(max_gap_frames) {}
9✔
148

149
    /**
150
     * @brief Convenience constructor using default Mahalanobis distance cost function.
151
     *
152
     * @param filter_prototype Prototype filter (cloned for prediction and final smoothing).
153
     *        If nullptr, Mahalanobis distance cannot be computed properly (requires filter state covariance).
154
     *        The filter's uncertainty automatically scales with gap size.
155
     * @param feature_extractor Feature extractor for DataType
156
     * @param measurement_matrix H matrix for Mahalanobis distance calculation
157
     * @param measurement_noise_covariance R matrix for Mahalanobis distance calculation
158
     * @param cost_scale_factor Multiplier to convert floating-point costs to integers for the solver.
159
     * @param cheap_assignment_threshold Threshold for greedy chaining
160
     * @param policy Contract violation policy
161
     * @param n_scan_depth Number of frames to look ahead when assignments are ambiguous (default 3)
162
     * @param enable_n_scan Enable N-scan lookahead for ambiguous assignments (default true)
163
     */
164
    MinCostFlowTracker(std::unique_ptr<IFilter> filter_prototype,
3✔
165
                       std::unique_ptr<IFeatureExtractor<DataType>> feature_extractor,
166
                       Eigen::MatrixXd const & measurement_matrix,
167
                       Eigen::MatrixXd const & measurement_noise_covariance,
168
                       double cost_scale_factor = 100.0,
169
                       double cheap_assignment_threshold = 5.0,
170
                       TrackerContractPolicy policy = TrackerContractPolicy::Throw,
171
                       int n_scan_depth = 3,
172
                       bool enable_n_scan = true,
173
                       int max_gap_frames = 3)
174
        : MinCostFlowTracker(std::move(filter_prototype),
3✔
175
                             std::move(feature_extractor),
3✔
176
                             createMahalanobisCostFunction(measurement_matrix, measurement_noise_covariance),
177
                             cost_scale_factor,
178
                             cheap_assignment_threshold,
179
                             policy,
180
                             n_scan_depth,
181
                             enable_n_scan,
182
                             max_gap_frames) {}
9✔
183

184
    /**
185
     * @brief Set a dedicated cost function for N-scan lookahead scoring.
186
     *
187
     * This function is used exclusively inside the lookahead expansion and can
188
     * differ from the greedy chaining or meta-node transition costs. It is
189
     * useful to introduce dynamics-aware penalties (velocity/acceleration) only
190
     * for ambiguity resolution while keeping cheaper costs elsewhere.
191
     *
192
     * @pre cost_fn should be a valid callable; behavior is undefined if empty
193
     * @post Subsequent N-scan calls will use the provided function
194
     */
195
    void setLookaheadCostFunction(CostFunction cost_fn) {
196
        _lookahead_cost_function = std::move(cost_fn);
197
    }
198

199
    /**
200
     * @brief Override the transition cost used between meta-nodes in the MCF graph.
201
     */
202
    void setTransitionCostFunction(CostFunction cost_fn) {
1✔
203
        _transition_cost_function = std::move(cost_fn);
1✔
204
    }
1✔
205

206
    /**
207
     * @brief Set the acceptance threshold for N-scan lookahead costs.
208
     *
209
     * Use a larger threshold for dynamics-aware costs whose scale exceeds the
210
     * cheap chaining threshold. Set to infinity to disable pruning by threshold.
211
     */
212
    void setLookaheadThreshold(double threshold) { _lookahead_threshold = threshold; }
1✔
213

214
    /**
215
     * @brief Set ambiguity threshold and optional margin to decide when to run N-scan.
216
     * If the best cost < ambiguity_threshold and (second_best - best) >= ambiguity_margin,
217
     * the chain is considered certain and N-scan is skipped.
218
     */
219
    void setAmbiguityThreshold(double threshold) { _ambiguity_threshold = threshold; }
220
    void setAmbiguityMargin(double margin) { _ambiguity_margin = margin; }
221

222
    /**
223
     * @brief Process a range of frames using min-cost flow optimization.
224
     *
225
     * @param data_source Zero-copy data source
226
     * @param group_manager Group manager to record final assignments
227
     * @param ground_truth Ground truth at specific frames (anchors)
228
     * @param start_frame Inclusive start frame
229
     * @param end_frame Inclusive end frame
230
     * @param progress Optional progress callback
231
     * @return Smoothed states per group across processed frames
232
     */
233
    template<typename Source>
234
        requires DataSource<Source, DataType>
235
    [[nodiscard]] SmoothedResults process(Source && data_source,
21✔
236
                                          EntityGroupManager & group_manager,
237
                                          GroundTruthMap const & ground_truth,
238
                                          TimeFrameIndex start_frame,
239
                                          TimeFrameIndex end_frame,
240
                                          ProgressCallback progress,
241
                                          std::map<GroupId, GroupId> const * output_group_ids = nullptr,
242
                                          std::unordered_set<EntityId> const * excluded_entities = nullptr,
243
                                          std::unordered_set<EntityId> const * include_entities = nullptr) {
244
        if (_logger) {
21✔
245
            _logger->debug("MCF process: start={} end={}", start_frame.getValue(), end_frame.getValue());
6✔
246
        }
247

248
        auto frame_lookup = buildFrameLookup<Source, DataType>(data_source, start_frame, end_frame);
21✔
249

250
        // Print ground truth map contents
251
        if (_logger) {
21✔
252
            _logger->debug("Ground truth map contents:");
6✔
253
            for (auto const & [frame, group_entities]: ground_truth) {
50✔
254
                _logger->debug("  Frame {}:", frame.getValue());
44✔
255
                for (auto const & [group_id, entity_id]: group_entities) {
90✔
256
                    _logger->debug("    Group {}: Entity {}", static_cast<unsigned long long>(group_id), static_cast<unsigned long long>(entity_id));
46✔
257
                }
258
            }
259
        }
260

261
        // 1. --- Build and Solve the Graph ---
262
        auto solved_paths = solve_flow_problem_new(frame_lookup,
21✔
263
                                                   ground_truth,
264
                                                   start_frame,
265
                                                   end_frame,
266
                                                   progress);
267

268
        if (solved_paths.empty()) {
21✔
UNCOV
269
            if (_logger) _logger->error("Min-cost flow solver failed or found no paths.");
×
UNCOV
270
            return {};
×
271
        }
272

273
        // 2. --- Update Group Manager with Solved Tracks ---
274
        for (auto const & [group_id, path]: solved_paths) {
44✔
275
            GroupId write_group = group_id;
23✔
276
            if (output_group_ids) {
23✔
277
                auto it = output_group_ids->find(group_id);
18✔
278
                if (it != output_group_ids->end()) write_group = it->second;
18✔
279
            }
280
            for (auto const & node: path) {
3,301✔
281
                // Never overwrite anchors or any labeled entity: only add unlabeled entities
282
                // Additionally, skip any frame that already has ground truth for this group to avoid double assignment
283
                auto gt_frame_it = ground_truth.find(node.frame);
1,740✔
284
                if (gt_frame_it != ground_truth.end()) {
1,740✔
285
                    auto const & gt_map = gt_frame_it->second;
202✔
286
                    if (gt_map.find(group_id) != gt_map.end()) continue;
202✔
287
                }
288
                auto groups = group_manager.getGroupsContainingEntity(node.entity_id);
1,538✔
289
                if (!groups.empty()) continue;
1,538✔
290
                group_manager.addEntityToGroup(write_group, node.entity_id);
1,360✔
291
            }
292
        }
293

294
        // 3. --- Final Forward/Backward Smoothing Pass ---
295
        // Now that we have the globally optimal assignments, run a final KF pass to get the smoothed states.
296
        return generate_smoothed_results(solved_paths, frame_lookup, start_frame, end_frame);
21✔
297
    }
21✔
298

299
    void enableDebugLogging(std::string const & file_path) {
4✔
300
        _logger = std::make_shared<spdlog::logger>("MinCostFlowTracker", std::make_shared<spdlog::sinks::basic_file_sink_mt>(file_path, true));
4✔
301
        _logger->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%l] %v");
12✔
302
        _logger->set_level(spdlog::level::debug);
4✔
303
        _logger->flush_on(spdlog::level::debug);
4✔
304
    }
4✔
305

306
    [[nodiscard]] TrackerDiagnostics getDiagnostics() const { return _diagnostics; }
307

308
private:
309
    // Structure to track active chains being built
310
    struct ActiveChain {
311
        size_t meta_node_idx;// Index in meta_nodes vector
312
        TimeFrameIndex curr_frame;
313
        EntityId curr_entity;
314
        DataType const * curr_data;
315
        std::unique_ptr<IFilter> filter;// Cloned filter for this chain
316
        FilterState predicted;          // Cached prediction for next frame
317
        std::vector<NodeInfo> members;  // Collected nodes for this chain
318
        FilterState start_state;        // Initial state at chain start (for meta-node)
319

320
        // Constructor to properly initialize TimeFrameIndex
321
        ActiveChain()
42✔
322
            : meta_node_idx(0),
42✔
323
              curr_frame(TimeFrameIndex(0)),
42✔
324
              curr_entity(0),
42✔
325
              curr_data(nullptr) {}
42✔
326
    };
327

328

329
    /**
330
     * @brief Solve a single ground-truth segment over trimmed meta-nodes.
331
     *
332
     * Expects the anchors to be present in the provided meta-nodes (ideally after slicing).
333
     * Builds a min-cost single-unit path through the meta-graph and expands it to a full path.
334
     *
335
     * @param meta_nodes_trimmed Meta-nodes restricted to the segment range
336
     * @param frame_lookup Frame data lookup for cost evaluation
337
     * @param group_id Group identifier (for logging/diagnostics)
338
     * @param segment Ground truth segment describing anchors
339
     * @return Expanded path (sequence of NodeInfo) for the segment; empty on failure
340
     */
341
    Path solve_single_segment_flow_over_meta(
165✔
342
            std::vector<MetaNode> const & meta_nodes_trimmed,
343
            std::map<TimeFrameIndex, FrameBucket<DataType>> const & frame_lookup,
344
            GroupId group_id,
345
            GroundTruthSegment const & segment) {
346

347
        if (_logger) {
165✔
348
            _logger->debug("Solving single segment flow over meta: group={} start=({}, {}) end=({}, {})",
39✔
349
                           static_cast<unsigned long long>(group_id),
39✔
350
                           segment.start_frame.getValue(), segment.start_entity,
78✔
351
                           segment.end_frame.getValue(), segment.end_entity);
78✔
352
        }
353

354
        // Fast path: check if a single meta-node spans the segment exactly
355
        for (auto const & mn: meta_nodes_trimmed) {
167✔
356
            if (!mn.members.empty() &&
332✔
357
                mn.members.front().frame == segment.start_frame &&
166✔
358
                mn.members.front().entity_id == segment.start_entity &&
330✔
359
                mn.members.back().frame == segment.end_frame &&
497✔
360
                mn.members.back().entity_id == segment.end_entity) {
164✔
361
                return mn.members;
164✔
362
            }
363
        }
364

365
        // Find anchor positions within trimmed set
366
        auto const pos_opt = findAnchorPositions(meta_nodes_trimmed, segment);
1✔
367
        if (!pos_opt.has_value()) {
1✔
368
            if (_logger) {
×
UNCOV
369
                _logger->error("Segment anchors not found in trimmed meta-nodes: group={} start=({}, {}) end=({}, {})",
×
UNCOV
370
                               static_cast<unsigned long long>(group_id),
×
UNCOV
371
                               segment.start_frame.getValue(), segment.start_entity,
×
UNCOV
372
                               segment.end_frame.getValue(), segment.end_entity);
×
373
            }
UNCOV
374
            return {};
×
375
        }
376
        int const start_meta_index = pos_opt->start_meta_index;
1✔
377
        size_t const start_member_index = pos_opt->start_member_index;
1✔
378
        int const end_meta_index = pos_opt->end_meta_index;
1✔
379
        size_t const end_member_index = pos_opt->end_member_index;
1✔
380

381
        int const num_meta = static_cast<int>(meta_nodes_trimmed.size());
1✔
382
        int const source_node = num_meta;
1✔
383
        int const sink_node = num_meta + 1;
1✔
384

385
        std::vector<ArcSpec> arcs;
1✔
386
        arcs.reserve(static_cast<size_t>(num_meta * num_meta / 4 + 4));
1✔
387
        arcs.push_back({source_node, start_meta_index, 1, 0});
1✔
388
        arcs.push_back({end_meta_index, sink_node, 1, 0});
1✔
389

390
        // Build transition arcs (forward in time only)
391
        int num_transition_arcs = 0;
1✔
392
        constexpr int64_t kMaxHorizon = 50;
1✔
393
        for (int i = 0; i < num_meta; ++i) {
3✔
394
            MetaNode const & from = meta_nodes_trimmed[static_cast<size_t>(i)];
2✔
395
            for (int j = 0; j < num_meta; ++j) {
7✔
396
                MetaNode const & to = meta_nodes_trimmed[static_cast<size_t>(j)];
4✔
397
                if (to.start_frame <= from.end_frame) continue;
4✔
398
                int const steps = (to.start_frame - from.end_frame).getValue();
1✔
399
                if (steps <= 0 || steps > kMaxHorizon) continue;
1✔
400

401
                FilterState predicted_state;
1✔
402
                if (_filter_prototype) {
1✔
403
                    auto temp_filter = _filter_prototype->clone();
1✔
404
                    // Coerce end_state if needed
405
                    FilterState init_state = from.end_state;
1✔
406
                    int const target_dim = static_cast<int>(temp_filter->getState().state_mean.size());
1✔
407
                    if (static_cast<int>(init_state.state_mean.size()) != target_dim ||
1✔
408
                        init_state.state_covariance.rows() != target_dim ||
2✔
409
                        init_state.state_covariance.cols() != target_dim) {
1✔
UNCOV
410
                        FilterState coerced;
×
UNCOV
411
                        coerced.state_mean = Eigen::VectorXd::Zero(target_dim);
×
UNCOV
412
                        int const copy_dim = std::min<int>(target_dim, static_cast<int>(init_state.state_mean.size()));
×
UNCOV
413
                        if (copy_dim > 0) coerced.state_mean.head(copy_dim) = init_state.state_mean.head(copy_dim);
×
UNCOV
414
                        coerced.state_covariance = Eigen::MatrixXd::Zero(target_dim, target_dim);
×
UNCOV
415
                        int const cr = std::min<int>(target_dim, init_state.state_covariance.rows());
×
UNCOV
416
                        int const cc = std::min<int>(target_dim, init_state.state_covariance.cols());
×
UNCOV
417
                        if (cr > 0 && cc > 0) {
×
UNCOV
418
                            int const b = std::min(cr, cc);
×
UNCOV
419
                            coerced.state_covariance.topLeftCorner(b, b) = init_state.state_covariance.topLeftCorner(b, b);
×
420
                        }
UNCOV
421
                        constexpr double kPadVar = 1e6;
×
UNCOV
422
                        for (int d = 0; d < target_dim; ++d) {
×
UNCOV
423
                            if (coerced.state_covariance(d, d) <= 0.0) coerced.state_covariance(d, d) = kPadVar;
×
424
                        }
UNCOV
425
                        init_state = std::move(coerced);
×
426
                    }
×
427
                    temp_filter->initialize(init_state);
1✔
428
                    for (int s = 0; s < steps; ++s) {
2✔
429
                        predicted_state = temp_filter->predict();
1✔
430
                    }
431
                }
1✔
432

433
                DataType const * to_start_data = findEntity(frame_lookup.at(to.start_frame), to.start_entity);
1✔
434
                if (!to_start_data) continue;
1✔
435
                Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*to_start_data);
1✔
436
                double const dist = _transition_cost_function(predicted_state, obs, steps);
1✔
437
                int64_t const arc_cost = static_cast<int64_t>(dist * _cost_scale_factor);
1✔
438
                arcs.push_back({i, j, 1, arc_cost});
1✔
439
                num_transition_arcs++;
1✔
440
            }
441
        }
442

443
        auto const seq_opt = solveMinCostSingleUnitPath(num_meta + 2, source_node, sink_node, arcs);
1✔
444
        if (!seq_opt.has_value()) {
1✔
445
            _diagnostics.noOptimalPathCount += 1;
×
446
            if (_logger) {
×
UNCOV
447
                _logger->error("Min-cost flow failed for segment: group={} metaNodes={} arcs={} — falling back to anchors only",
×
448
                               static_cast<unsigned long long>(group_id), num_meta, arcs.size());
×
449
            }
450
            // If start/end nodes overlap, split them and return only the anchor-containing clipped nodes
UNCOV
451
            auto resolved = resolveOverlappingAnchorNodes(meta_nodes_trimmed, start_meta_index, end_meta_index);
×
452

453
            // Find nodes that contain the anchors in the resolved list
UNCOV
454
            int start_idx_res = -1;
×
UNCOV
455
            int end_idx_res = -1;
×
UNCOV
456
            for (int i = 0; i < static_cast<int>(resolved.size()); ++i) {
×
UNCOV
457
                auto const & mn = resolved[static_cast<size_t>(i)];
×
UNCOV
458
                for (auto const & m : mn.members) {
×
UNCOV
459
                    if (start_idx_res == -1 && m.frame == segment.start_frame && m.entity_id == segment.start_entity) {
×
UNCOV
460
                        start_idx_res = i;
×
461
                    }
UNCOV
462
                    if (end_idx_res == -1 && m.frame == segment.end_frame && m.entity_id == segment.end_entity) {
×
UNCOV
463
                        end_idx_res = i;
×
464
                    }
UNCOV
465
                    if (start_idx_res != -1 && end_idx_res != -1) break;
×
466
                }
UNCOV
467
                if (start_idx_res != -1 && end_idx_res != -1) break;
×
468
            }
469

UNCOV
470
            Path fallback_path;
×
UNCOV
471
            if (start_idx_res >= 0 && start_idx_res < static_cast<int>(resolved.size())) {
×
UNCOV
472
                auto const & mem = resolved[static_cast<size_t>(start_idx_res)].members;
×
UNCOV
473
                fallback_path.insert(fallback_path.end(), mem.begin(), mem.end());
×
474
            }
UNCOV
475
            if (end_idx_res >= 0 && end_idx_res < static_cast<int>(resolved.size()) && end_idx_res != start_idx_res) {
×
UNCOV
476
                auto const & mem = resolved[static_cast<size_t>(end_idx_res)].members;
×
UNCOV
477
                fallback_path.insert(fallback_path.end(), mem.begin(), mem.end());
×
478
            }
UNCOV
479
            if (!fallback_path.empty()) return fallback_path;
×
480
            // Fall back to simple concatenation if anchors not found (shouldn't happen)
UNCOV
481
            return buildFallbackPathFromTrimmed(meta_nodes_trimmed, start_meta_index, end_meta_index);
×
UNCOV
482
        }
×
483

484
        Path expanded_path;
1✔
485
        auto const & sequence = *seq_opt;
1✔
486
        for (size_t idx = 1; idx < sequence.size(); ++idx) {
4✔
487
            int const node_index = sequence[idx];
3✔
488
            if (node_index >= 0 && node_index < num_meta) {
3✔
489
                auto const & mem = meta_nodes_trimmed[static_cast<size_t>(node_index)].members;
2✔
490
                for (auto const & n: mem) expanded_path.push_back(n);
103✔
491
            }
492
        }
493
        return expanded_path;
1✔
494
    }
1✔
495

496
    /**
497
     * @brief Solve paths for all groups by iterating ground-truth segments and concatenating.
498
     *
499
     * For each consecutive labeled segment per group, slice meta-nodes to the segment,
500
     * run a per-segment min-cost path, and append the nodes to the group's output path.
501
     *
502
     * Deduplicates a single overlapping anchor node at segment boundaries.
503
     */
504
    std::map<GroupId, Path> solve_flow_over_segments(
21✔
505
            std::vector<MetaNode> const & meta_nodes,
506
            std::map<TimeFrameIndex, FrameBucket<DataType>> const & frame_lookup,
507
            GroundTruthMap const & ground_truth,
508
            TimeFrameIndex start_frame,
509
            TimeFrameIndex end_frame) {
510

511
        std::map<GroupId, Path> solved_paths;
21✔
512
        auto const segments = extractGroundTruthSegments(ground_truth);
21✔
513

514
        // Process segments in chronological order per group
515
        std::map<GroupId, std::vector<GroundTruthSegment>> by_group;
21✔
516
        for (auto const & seg: segments) {
186✔
517
            // Optionally filter to the requested range
518
            if (seg.end_frame < start_frame || seg.start_frame > end_frame) continue;
165✔
519
            by_group[seg.group_id].push_back(seg);
165✔
520
        }
521
        for (auto & [gid, segs]: by_group) {
67✔
522
            std::sort(segs.begin(), segs.end(), [](auto const & a, auto const & b) {
23✔
523
                return a.start_frame < b.start_frame;
284✔
524
            });
525
            Path out_path;
23✔
526
            for (auto const & seg: segs) {
353✔
527
                auto trimmed = sliceMetaNodesToSegment(meta_nodes, seg);
165✔
528
                if (trimmed.empty()) {
165✔
UNCOV
529
                    if (_logger) {
×
UNCOV
530
                        _logger->warn("No trimmed meta-nodes for segment: group={} start=({}, {}) end=({}, {})",
×
UNCOV
531
                                      static_cast<unsigned long long>(gid),
×
UNCOV
532
                                      seg.start_frame.getValue(), seg.start_entity,
×
UNCOV
533
                                      seg.end_frame.getValue(), seg.end_entity);
×
534
                    }
UNCOV
535
                    continue;
×
536
                }
537

538
                if (_logger) {
165✔
539
                    _logger->debug("Solving segment: group={} start=({}, {}) end=({}, {})",
39✔
540
                                   static_cast<unsigned long long>(gid),
39✔
541
                                   seg.start_frame.getValue(), seg.start_entity,
78✔
542
                                   seg.end_frame.getValue(), seg.end_entity);
78✔
543
                }
544

545
                Path segment_path = solve_single_segment_flow_over_meta(trimmed, frame_lookup, gid, seg);
165✔
546
                if (segment_path.empty()) continue;
165✔
547

548
                // Deduplicate overlapping anchor between consecutive segments
549
                if (!out_path.empty() && !segment_path.empty()) {
165✔
550
                    auto const & last = out_path.back();
142✔
551
                    auto const & first = segment_path.front();
142✔
552
                    if (last.frame == first.frame && last.entity_id == first.entity_id) {
142✔
553
                        segment_path.erase(segment_path.begin());
128✔
554
                    }
555
                }
556
                // Append
557
                out_path.insert(out_path.end(), segment_path.begin(), segment_path.end());
165✔
558
            }
559
            if (!out_path.empty()) {
23✔
560
                solved_paths.emplace(gid, std::move(out_path));
23✔
561
            }
562
        }
563

564
        return solved_paths;
21✔
565
    }
21✔
566

567
    // --- Main Graph Building and Solving Logic ---
568
    std::map<GroupId, Path> solve_flow_problem_new(
21✔
569
            std::map<TimeFrameIndex, FrameBucket<DataType>> const & frame_lookup,
570
            GroundTruthMap const & ground_truth,
571
            TimeFrameIndex start_frame,
572
            TimeFrameIndex end_frame,
573
            ProgressCallback progress) {
574

575

576
        auto start_anchors_it = ground_truth.find(start_frame);
21✔
577
        auto end_anchors_it = ground_truth.find(end_frame);
21✔
578

579
        if (start_anchors_it == ground_truth.end() || end_anchors_it == ground_truth.end()) {
21✔
UNCOV
580
            if (_logger) _logger->error("Min-cost flow requires anchors at both start and end frames.");
×
UNCOV
581
            return {};
×
582
        }
583

584
        auto start_anchors = start_anchors_it->second;
21✔
585
        auto end_anchors = end_anchors_it->second;
21✔
586

587
        // 1) Build greedy meta-nodes (cheap consecutive links) independent of groups
588
        auto meta_nodes = build_meta_nodes(frame_lookup, start_frame, end_frame, progress);
21✔
589

590
        // 2) Solve paths per group by iterating ground-truth segments and concatenating
591
        auto all_solved_paths = solve_flow_over_segments(meta_nodes,
21✔
592
                                                         frame_lookup,
593
                                                         ground_truth,
594
                                                         start_frame,
595
                                                         end_frame);
596

597
        return all_solved_paths;
21✔
598
    }
21✔
599

600
    /**
601
     * @brief Build meta-nodes using Hungarian algorithm for optimal chain extension.
602
     * 
603
     * Unlike greedy assignment, this uses Hungarian algorithm at each frame to ensure
604
     * global optimal assignment of chains to candidates, preventing "stealing" where
605
     * one chain takes another's best match.
606
     * 
607
     * @pre frame_lookup contains observations in [start_frame, end_frame]
608
     * @post Each observation belongs to at most one meta-node
609
     */
610
    std::vector<MetaNode> build_meta_nodes(
21✔
611
            std::map<TimeFrameIndex, FrameBucket<DataType>> const & frame_lookup,
612
            TimeFrameIndex start_frame,
613
            TimeFrameIndex end_frame,
614
            ProgressCallback progress) {
615

616

617
        progress(0);
21✔
618

619
        std::vector<MetaNode> meta_nodes;
21✔
620
        std::set<std::pair<long long, EntityId>> used;// key: (frame, entity)
21✔
621
        std::vector<ActiveChain> active_chains;
21✔
622

623

624
        // Process frame by frame, using Hungarian algorithm to extend chains optimally
625
        for (TimeFrameIndex f = start_frame; f <= end_frame; ++f) {
3,550✔
626
            if (!frame_lookup.count(f)) continue;
1,767✔
627

628
            if (_logger) {
1,762✔
629
                _logger->debug("Processing frame {}: {} active chains, {} observations",
391✔
630
                               f.getValue(), active_chains.size(), frame_lookup.at(f).size());
391✔
631
            }
632

633
            std::unordered_set<EntityId> this_frame_entities;
1,762✔
634
            for (size_t cand_idx = 0; cand_idx < frame_lookup.at(f).size(); ++cand_idx) {
5,266✔
635
                auto const & cand = frame_lookup.at(f)[cand_idx];
3,504✔
636
                this_frame_entities.insert(std::get<1>(cand));
3,504✔
637
            }
638

639
            // Step 1: Try to extend existing active chains to current frame (if any)
640
            // This must happen BEFORE creating new chains, so that chains can jump gaps
641
            if (!active_chains.empty() && f > start_frame) {
1,762✔
642

643
                // Predict all remaining active chains forward to current frame
644
                for (size_t chain_idx = 0; chain_idx < active_chains.size(); ++chain_idx) {
5,204✔
645
                    auto & chain = active_chains[chain_idx];
3,463✔
646
                    if (chain.filter) {
3,463✔
647
                        // Predict forward from last known frame to current frame
648
                        int gap_frames = static_cast<int>(f.getValue() - chain.curr_frame.getValue());
3,463✔
649
                        if (_logger && gap_frames > 0) {
3,463✔
650
                            auto initial_state = chain.filter->getState();
751✔
651
                            _logger->debug("Chain {} at frame {} before predictions: state=[{:.2f},{:.2f},{:.2f},{:.2f}], curr_entity={}",
751✔
652
                                           chain_idx, f.getValue(),
751✔
653
                                           initial_state.state_mean(0), initial_state.state_mean(1),
751✔
654
                                           initial_state.state_mean(2), initial_state.state_mean(3),
751✔
655
                                           chain.curr_entity);
751✔
656
                        }
751✔
657
                        for (int step = 0; step < gap_frames; ++step) {
6,936✔
658
                            chain.predicted = chain.filter->predict();
3,473✔
659
                            if (_logger && gap_frames > 1) {
3,473✔
660
                                _logger->debug("  After predict step {}/{}: state=[{:.2f},{:.2f},{:.2f},{:.2f}]",
12✔
661
                                               step + 1, gap_frames,
12✔
662
                                               chain.predicted.state_mean(0), chain.predicted.state_mean(1),
12✔
663
                                               chain.predicted.state_mean(2), chain.predicted.state_mean(3));
12✔
664
                            }
665
                        }
666
                        if (_logger && gap_frames > 0) {
3,463✔
667
                            _logger->debug("  Final predicted state: [{:.2f},{:.2f},{:.2f},{:.2f}]",
751✔
668
                                           chain.predicted.state_mean(0), chain.predicted.state_mean(1),
751✔
669
                                           chain.predicted.state_mean(2), chain.predicted.state_mean(3));
751✔
670
                        }
671
                    }
672
                }
673

674
                // Collect available candidates at current frame
675
                std::vector<std::tuple<EntityId, DataType const *, size_t>> candidates;
1,741✔
676
                for (size_t cand_idx = 0; cand_idx < frame_lookup.at(f).size(); ++cand_idx) {
5,204✔
677
                    auto const & cand = frame_lookup.at(f)[cand_idx];
3,463✔
678
                    EntityId cand_id = std::get<1>(cand);
3,463✔
679
                    auto key = std::make_pair(static_cast<long long>(f.getValue()), cand_id);
3,463✔
680
                    if (used.count(key)) continue;// How could this be true?
3,463✔
681

682
                    DataType const * cand_data = std::get<0>(cand);
3,463✔
683
                    candidates.emplace_back(cand_id, cand_data, cand_idx);
3,463✔
684
                }
685

686
                if (!candidates.empty() && !active_chains.empty()) {
1,741✔
687
                    // Build cost matrix for Hungarian algorithm
688
                    int const cost_scaling_factor = 1000;
1,741✔
689
                    int const max_cost = static_cast<int>(_cheap_assignment_threshold * cost_scaling_factor);
1,741✔
690
                    std::vector<std::vector<int>> cost_matrix(active_chains.size(),
6,964✔
691
                                                              std::vector<int>(candidates.size()));
6,964✔
692

693
                    for (size_t chain_idx = 0; chain_idx < active_chains.size(); ++chain_idx) {
5,204✔
694
                        auto const & chain = active_chains[chain_idx];
3,463✔
695
                        for (size_t cand_idx = 0; cand_idx < candidates.size(); ++cand_idx) {
10,370✔
696
                            DataType const * cand_data = std::get<1>(candidates[cand_idx]);
6,907✔
697
                            Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*cand_data);
6,907✔
698

699
                            double cost_double;
700
                            if (chain.filter) {
6,907✔
701
                                int gap_frames = static_cast<int>(f.getValue() - chain.curr_frame.getValue());
6,907✔
702
                                cost_double = _chain_cost_function(chain.predicted, obs, gap_frames);
6,907✔
703
                            } else {
UNCOV
704
                                Eigen::VectorXd curr_obs = _feature_extractor->getFilterFeatures(*chain.curr_data);
×
UNCOV
705
                                cost_double = (curr_obs - obs).norm();
×
UNCOV
706
                            }
×
707

708
                            int cost = static_cast<int>(cost_double * cost_scaling_factor);
6,907✔
709
                            if (cost >= std::numeric_limits<int>::max()) {
6,907✔
UNCOV
710
                                cost = std::numeric_limits<int>::max() - 1;
×
711
                            }
712
                            cost_matrix[chain_idx][cand_idx] = cost;
6,907✔
713
                        }
714
                    }
715

716
                    // Solve Hungarian assignment
717
                    std::vector<std::vector<int>> assignment_matrix;
1,741✔
718
                    Munkres::hungarian_with_assignment(cost_matrix, assignment_matrix);
1,741✔
719

720
                    // Check for ambiguity and trigger N-scan if enabled
721
                    std::unordered_set<size_t> ambiguous_chain_indices;
1,741✔
722
                    if (_enable_n_scan && _filter_prototype) {
1,741✔
723
                        Eigen::MatrixXd cost_matrix_eigen(active_chains.size(), candidates.size());
1,741✔
724
                        for (size_t i = 0; i < active_chains.size(); ++i) {
5,204✔
725
                            for (size_t j = 0; j < candidates.size(); ++j) {
10,370✔
726
                                cost_matrix_eigen(static_cast<Eigen::Index>(i), static_cast<Eigen::Index>(j)) =
6,907✔
727
                                        cost_matrix[i][j] / static_cast<double>(cost_scaling_factor);
6,907✔
728
                            }
729
                        }
730
                        ambiguous_chain_indices = detect_ambiguous_chains(cost_matrix_eigen, _ambiguity_threshold);
1,741✔
731
                        // Apply certainty margin: drop chains whose best is clearly better than next-best
732
                        if (_ambiguity_margin > 0.0) {
1,741✔
UNCOV
733
                            std::unordered_set<size_t> pruned;
×
UNCOV
734
                            for (size_t i = 0; i < active_chains.size(); ++i) {
×
UNCOV
735
                                if (ambiguous_chain_indices.find(i) == ambiguous_chain_indices.end()) continue;
×
736
                                // compute best and second best
UNCOV
737
                                double best = std::numeric_limits<double>::infinity();
×
UNCOV
738
                                double second = std::numeric_limits<double>::infinity();
×
UNCOV
739
                                for (size_t j = 0; j < candidates.size(); ++j) {
×
UNCOV
740
                                    double c = cost_matrix_eigen(static_cast<Eigen::Index>(i), static_cast<Eigen::Index>(j));
×
UNCOV
741
                                    if (c < best) {
×
UNCOV
742
                                        second = best;
×
UNCOV
743
                                        best = c;
×
UNCOV
744
                                    } else if (c < second) {
×
UNCOV
745
                                        second = c;
×
746
                                    }
747
                                }
UNCOV
748
                                if (best < _ambiguity_threshold && (second - best) >= _ambiguity_margin) {
×
UNCOV
749
                                    pruned.insert(i);
×
750
                                }
751
                            }
UNCOV
752
                            for (auto idx: pruned) ambiguous_chain_indices.erase(idx);
×
UNCOV
753
                        }
×
754

755
                        if (_logger && !ambiguous_chain_indices.empty()) {
1,741✔
756
                            _logger->debug("Frame {}: Detected {} ambiguous chains (threshold={:.3f})",
1✔
757
                                           f.getValue(), ambiguous_chain_indices.size(), _cheap_assignment_threshold);
1✔
758

759
                            // Check if we can run N-scan
760
                            int frames_ahead = (end_frame - f).getValue();
1✔
761
                            if (frames_ahead < _n_scan_depth) {
1✔
UNCOV
762
                                _logger->debug("  N-scan SKIPPED: need {} frames ahead, only have {} (end_frame={})",
×
UNCOV
763
                                               _n_scan_depth, frames_ahead, end_frame.getValue());
×
764
                            }
765
                        }
766
                    }
1,741✔
767

768
                    // If there are ambiguous chains, run N-scan for ALL of them FIRST, then assign globally
769
                    std::map<size_t, std::pair<std::vector<NodeInfo>, double>> n_scan_results;// chain_idx -> (path, total_cost)
1,741✔
770
                    // Variable-depth lookahead: allow shorter depth near tail
771
                    int frames_ahead_var = (end_frame - f).getValue();
1,741✔
772
                    int allowable_depth = std::min(_n_scan_depth, frames_ahead_var + 1);
1,741✔
773
                    if (!ambiguous_chain_indices.empty() && allowable_depth >= 1) {
1,741✔
774
                        if (_logger) {
1✔
775
                            _logger->debug("  Running N-scan with depth={} (need {} future frames, have {})",
1✔
776
                                           allowable_depth, allowable_depth - 1, (end_frame - f).getValue());
1✔
777
                        }
778
                        // Step 1: Run N-scan for each ambiguous chain independently
779
                        std::map<size_t, std::vector<std::pair<std::vector<NodeInfo>, double>>> all_paths;// chain_idx -> [(path, cost), ...]
1✔
780

781
                        int const saved_depth = _n_scan_depth;
1✔
782
                        _n_scan_depth = allowable_depth;// temporary override for lookahead
1✔
783
                        for (size_t chain_idx: ambiguous_chain_indices) {
5✔
784
                            auto & chain = active_chains[chain_idx];
2✔
785

786
                            if (_logger) {
2✔
787
                                _logger->debug("  N-scan for chain {} (curr_entity={}, curr_frame={})",
2✔
788
                                               chain_idx, chain.curr_entity, chain.curr_frame.getValue());
2✔
789
                            }
790

791
                            // Collect viable candidates with their costs using lookahead cost
792
                            std::vector<std::tuple<EntityId, DataType const *, double>> viable_candidates;
2✔
793
                            int const gap_frames = static_cast<int>(f.getValue() - chain.curr_frame.getValue());
2✔
794
                            for (size_t cand_idx = 0; cand_idx < candidates.size(); ++cand_idx) {
10✔
795
                                DataType const * cand_data = std::get<1>(candidates[cand_idx]);
4✔
796
                                Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*cand_data);
4✔
797
                                double cost_double = _lookahead_cost_function(chain.predicted, obs, std::max(1, gap_frames));
4✔
798
                                if (cost_double < _lookahead_threshold || !std::isfinite(_lookahead_threshold)) {
4✔
799
                                    viable_candidates.emplace_back(
4✔
800
                                            std::get<0>(candidates[cand_idx]),
4✔
801
                                            cand_data,
802
                                            cost_double);
803
                                }
804
                            }
805

806
                            // Each chain gets its own copy of 'used' to explore independently
807
                            std::set<std::pair<long long, EntityId>> chain_used = used;
2✔
808
                            auto [n_scan_path, path_cost] = run_n_scan_lookahead(chain, viable_candidates, f, end_frame,
2✔
809
                                                                                 frame_lookup, chain_used);
810

811
                            if (!n_scan_path.empty()) {
2✔
812
                                n_scan_results[chain_idx] = {n_scan_path, path_cost};
2✔
UNCOV
813
                            } else if (allowable_depth == 1) {
×
814
                                // One-step fallback: pick best single candidate not used
UNCOV
815
                                double best_c = std::numeric_limits<double>::infinity();
×
UNCOV
816
                                EntityId best_eid = 0;
×
UNCOV
817
                                for (auto const & tup: viable_candidates) {
×
UNCOV
818
                                    EntityId eid = std::get<0>(tup);
×
UNCOV
819
                                    auto key = std::make_pair(static_cast<long long>(f.getValue()), eid);
×
UNCOV
820
                                    if (used.count(key)) continue;
×
UNCOV
821
                                    double c = std::get<2>(tup);
×
UNCOV
822
                                    if (c < best_c) {
×
UNCOV
823
                                        best_c = c;
×
UNCOV
824
                                        best_eid = eid;
×
825
                                    }
826
                                }
UNCOV
827
                                if (best_eid != 0 && (best_c < _lookahead_threshold || !std::isfinite(_lookahead_threshold))) {
×
UNCOV
828
                                    std::vector<NodeInfo> single{{f, best_eid}};
×
UNCOV
829
                                    n_scan_results[chain_idx] = {single, best_c};
×
UNCOV
830
                                }
×
831
                            }
832
                        }
833
                        _n_scan_depth = saved_depth;// restore
1✔
834

835
                        // Step 2: Detect conflicts - check if multiple chains want the same observations
836
                        if (_logger && !n_scan_results.empty()) {
1✔
837
                            _logger->debug("N-scan completed for {} chains at frame {}", n_scan_results.size(), f.getValue());
1✔
838
                            for (auto const & [chain_idx, path_and_cost]: n_scan_results) {
3✔
839
                                _logger->debug("  Chain {}: cost={:.2f}, path length={}",
2✔
840
                                               chain_idx, path_and_cost.second, path_and_cost.first.size());
2✔
841
                            }
842
                        }
843

844
                        std::map<std::pair<long long, EntityId>, std::vector<size_t>> obs_to_chains;
1✔
845
                        for (auto const & [chain_idx, path_and_cost]: n_scan_results) {
3✔
846
                            // Only the current frame decision participates in conflicts
847
                            if (!path_and_cost.first.empty()) {
2✔
848
                                auto const & first_node = path_and_cost.first.front();
2✔
849
                                auto key = std::make_pair(static_cast<long long>(first_node.frame.getValue()), first_node.entity_id);
2✔
850
                                obs_to_chains[key].push_back(chain_idx);
2✔
851
                            }
852
                        }
853

854
                        if (_logger && !obs_to_chains.empty()) {
1✔
855
                            _logger->debug("Observation assignment: {} unique observations claimed", obs_to_chains.size());
1✔
856
                            for (auto const & [obs_key, claiming_chains]: obs_to_chains) {
3✔
857
                                if (claiming_chains.size() > 1) {
2✔
858
                                    _logger->debug("  Frame {}, entity {}: {} chains want it",
×
UNCOV
859
                                                   obs_key.first, obs_key.second, claiming_chains.size());
×
860
                                }
861
                            }
862
                        }
863

864
                        // Step 3: Resolve conflicts - if multiple chains want same observation, keep lowest cost
865
                        std::set<size_t> rejected_chains;
1✔
866
                        for (auto const & [obs_key, claiming_chains]: obs_to_chains) {
3✔
867
                            if (claiming_chains.size() > 1) {
2✔
868
                                // Conflict! Keep chain with lowest cost, reject others
UNCOV
869
                                if (_logger) {
×
UNCOV
870
                                    _logger->debug("N-scan conflict at frame {}, entity {}: {} chains competing",
×
UNCOV
871
                                                   obs_key.first, obs_key.second, claiming_chains.size());
×
UNCOV
872
                                    for (size_t chain_idx: claiming_chains) {
×
UNCOV
873
                                        _logger->debug("  Chain {} has cost {:.2f}", chain_idx, n_scan_results[chain_idx].second);
×
874
                                    }
875
                                }
876

877
                                size_t best_chain = claiming_chains[0];
×
UNCOV
878
                                double best_cost = n_scan_results[best_chain].second;
×
UNCOV
879
                                for (size_t chain_idx: claiming_chains) {
×
UNCOV
880
                                    if (n_scan_results[chain_idx].second < best_cost) {
×
UNCOV
881
                                        best_chain = chain_idx;
×
UNCOV
882
                                        best_cost = n_scan_results[chain_idx].second;
×
883
                                    }
884
                                }
885

UNCOV
886
                                if (_logger) {
×
UNCOV
887
                                    _logger->debug("  Keeping chain {} (cost {:.2f}), rejecting others", best_chain, best_cost);
×
888
                                }
889

890
                                // Reject all other chains
UNCOV
891
                                for (size_t chain_idx: claiming_chains) {
×
UNCOV
892
                                    if (chain_idx != best_chain) {
×
UNCOV
893
                                        rejected_chains.insert(chain_idx);
×
894
                                        if (_logger) {
×
UNCOV
895
                                            _logger->debug("  Rejected chain {}", chain_idx);
×
896
                                        }
897
                                    }
898
                                }
899
                            }
900
                        }
901

902
                        // Step 4: Remove rejected chains from results
903
                        for (size_t rejected: rejected_chains) {
1✔
UNCOV
904
                            n_scan_results.erase(rejected);
×
905
                        }
906

907
                        // Step 5: Mark accepted N-scan selections (current frame only) as used
908
                        for (auto const & [chain_idx, path_and_cost]: n_scan_results) {
3✔
909
                            if (!path_and_cost.first.empty()) {
2✔
910
                                auto const & node = path_and_cost.first.front();
2✔
911
                                used.insert(std::make_pair(static_cast<long long>(node.frame.getValue()), node.entity_id));
2✔
912
                            }
913
                        }
914

915
                        // Step 5b: Attempt fallback N-scan for rejected/failed ambiguous chains
916
                        // Re-run N-scan for chains that were ambiguous but have no accepted result,
917
                        // now honoring the updated 'used' set (to avoid prior conflicts).
918
                        for (size_t chain_idx: ambiguous_chain_indices) {
3✔
919
                            if (n_scan_results.count(chain_idx) > 0) continue;// already accepted
2✔
UNCOV
920
                            auto & chain = active_chains[chain_idx];
×
921

922
                            // Rebuild viable candidates using lookahead cost and current 'used'
UNCOV
923
                            std::vector<std::tuple<EntityId, DataType const *, double>> viable_candidates_alt;
×
UNCOV
924
                            int const gap_frames_alt = static_cast<int>(f.getValue() - chain.curr_frame.getValue());
×
UNCOV
925
                            for (size_t cand_idx = 0; cand_idx < candidates.size(); ++cand_idx) {
×
UNCOV
926
                                EntityId eid = std::get<0>(candidates[cand_idx]);
×
UNCOV
927
                                auto key = std::make_pair(static_cast<long long>(f.getValue()), eid);
×
UNCOV
928
                                if (used.count(key)) continue;// avoid already claimed obs
×
UNCOV
929
                                DataType const * cand_data = std::get<1>(candidates[cand_idx]);
×
UNCOV
930
                                Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*cand_data);
×
UNCOV
931
                                double cost_double = _lookahead_cost_function(chain.predicted, obs, std::max(1, gap_frames_alt));
×
UNCOV
932
                                if (cost_double < _lookahead_threshold || !std::isfinite(_lookahead_threshold)) {
×
UNCOV
933
                                    viable_candidates_alt.emplace_back(eid, cand_data, cost_double);
×
934
                                }
935
                            }
936

UNCOV
937
                            if (!viable_candidates_alt.empty()) {
×
UNCOV
938
                                int const saved_depth2 = _n_scan_depth;
×
UNCOV
939
                                _n_scan_depth = allowable_depth;// use same allowable depth
×
940
                                // Use the updated 'used' set so we avoid previous conflicts
UNCOV
941
                                auto alt_used = used;// COPY used set
×
UNCOV
942
                                auto [alt_path, alt_cost] = run_n_scan_lookahead(chain, viable_candidates_alt, f, end_frame,
×
943
                                                                                 frame_lookup, alt_used);
UNCOV
944
                                _n_scan_depth = saved_depth2;
×
UNCOV
945
                                if (!alt_path.empty()) {
×
946
                                    // Accept alternate but commit only current frame
UNCOV
947
                                    std::vector<NodeInfo> single{alt_path.front()};
×
UNCOV
948
                                    n_scan_results[chain_idx] = {single, alt_cost};
×
UNCOV
949
                                    used.insert(std::make_pair(static_cast<long long>(single.front().frame.getValue()), single.front().entity_id));
×
UNCOV
950
                                    if (_logger) {
×
UNCOV
951
                                        _logger->debug("  Fallback N-scan accepted for chain {}: cost={:.2f}, eid={}",
×
UNCOV
952
                                                       chain_idx, alt_cost, single.front().entity_id);
×
953
                                    }
UNCOV
954
                                } else if (allowable_depth == 1) {
×
955
                                    // One-step fallback here too
UNCOV
956
                                    double best_c = std::numeric_limits<double>::infinity();
×
UNCOV
957
                                    EntityId best_eid = 0;
×
UNCOV
958
                                    for (auto const & tup: viable_candidates_alt) {
×
UNCOV
959
                                        EntityId eid = std::get<0>(tup);
×
UNCOV
960
                                        auto key = std::make_pair(static_cast<long long>(f.getValue()), eid);
×
UNCOV
961
                                        if (used.count(key)) continue;
×
UNCOV
962
                                        double c = std::get<2>(tup);
×
UNCOV
963
                                        if (c < best_c) {
×
UNCOV
964
                                            best_c = c;
×
UNCOV
965
                                            best_eid = eid;
×
966
                                        }
967
                                    }
UNCOV
968
                                    if (best_eid != 0 && (best_c < _lookahead_threshold || !std::isfinite(_lookahead_threshold))) {
×
UNCOV
969
                                        std::vector<NodeInfo> single{{f, best_eid}};
×
UNCOV
970
                                        n_scan_results[chain_idx] = {single, best_c};
×
UNCOV
971
                                        used.insert(std::make_pair(static_cast<long long>(f.getValue()), best_eid));
×
UNCOV
972
                                        if (_logger) {
×
UNCOV
973
                                            _logger->debug("  Fallback single-step accepted for chain {}: eid={}, cost={:.2f}",
×
974
                                                           chain_idx, best_eid, best_c);
975
                                        }
UNCOV
976
                                    }
×
977
                                }
UNCOV
978
                            }
×
979
                        }
980
                    }// if (!ambiguous_chain_indices.empty() && allowable_depth >= 1)
1✔
981

982
                    // Process assignments
983
                    std::vector<ActiveChain> remaining_chains;
1,741✔
984

985
                    for (size_t chain_idx = 0; chain_idx < assignment_matrix.size(); ++chain_idx) {
5,204✔
986
                        // Check if this chain has N-scan results
987
                        if (n_scan_results.count(chain_idx) > 0) {
3,463✔
988
                            auto & chain = active_chains[chain_idx];
2✔
989
                            auto const & n_scan_path = n_scan_results[chain_idx].first;// current-frame selection
2✔
990

991
                            // Extend chain with the single current-frame decision
992
                            if (!n_scan_path.empty()) {
2✔
993
                                auto const & sel = n_scan_path.front();
2✔
994
                                chain.members.push_back(sel);
2✔
995
                                //this_frame_entities.erase(sel.entity_id);
996
                            }
997

998
                            // Update chain to the selected node at current frame
999
                            auto const & last_node = n_scan_path.front();
2✔
1000
                            chain.curr_frame = last_node.frame;
2✔
1001
                            chain.curr_entity = last_node.entity_id;
2✔
1002
                            this_frame_entities.erase(last_node.entity_id);
2✔
1003
                            chain.curr_data = findEntity(frame_lookup.at(last_node.frame), last_node.entity_id);
2✔
1004

1005
                            // Re-sync filter
1006
                            if (chain.filter && chain.curr_data) {
2✔
1007
                                Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*chain.curr_data);
2✔
1008
                                // At current frame, update with the current predicted
1009
                                chain.filter->update(chain.predicted, Measurement{obs});
4✔
1010
                            }
2✔
1011

1012
                            remaining_chains.push_back(std::move(chain));
2✔
1013
                            continue;
2✔
1014
                        }
2✔
1015

1016
                        // Check if this chain was ambiguous but N-scan failed
1017
                        if (ambiguous_chain_indices.count(chain_idx) > 0) {
3,461✔
UNCOV
1018
                            auto & chain = active_chains[chain_idx];
×
UNCOV
1019
                            auto & node = meta_nodes[chain.meta_node_idx];
×
1020

1021
                            // N-scan failed - terminate chain
1022
                            // Finalize chain into a meta-node upon termination
UNCOV
1023
                            MetaNode term;
×
UNCOV
1024
                            term.start_frame = chain.members.front().frame;
×
UNCOV
1025
                            term.start_entity = chain.members.front().entity_id;
×
UNCOV
1026
                            term.members = chain.members;
×
UNCOV
1027
                            term.end_frame = chain.members.back().frame;
×
UNCOV
1028
                            term.end_entity = chain.members.back().entity_id;
×
UNCOV
1029
                            term.start_state = chain.start_state;
×
UNCOV
1030
                            if (chain.filter) term.end_state = chain.filter->getState();
×
UNCOV
1031
                            meta_nodes.push_back(std::move(term));
×
UNCOV
1032
                            this_frame_entities.erase(chain.members.back().entity_id);
×
UNCOV
1033
                            continue;
×
UNCOV
1034
                        }
×
1035

1036
                        // Normal assignment processing
1037
                        bool found_assignment = false;
3,461✔
1038
                        int assigned_cand_idx = -1;
3,461✔
1039

1040
                        for (size_t cand_idx = 0; cand_idx < assignment_matrix[chain_idx].size(); ++cand_idx) {
5,182✔
1041
                            if (assignment_matrix[chain_idx][cand_idx] == 1) {
5,182✔
1042
                                if (cost_matrix[chain_idx][cand_idx] <= max_cost) {
3,461✔
1043
                                    found_assignment = true;
3,460✔
1044
                                    assigned_cand_idx = static_cast<int>(cand_idx);
3,460✔
1045
                                }
1046
                                break;
3,461✔
1047
                            }
1048
                        }
1049

1050
                        auto & chain = active_chains[chain_idx];
3,461✔
1051

1052
                        if (found_assignment) {
3,461✔
1053
                            // Extend chain
1054
                            EntityId best_entity = std::get<0>(candidates[assigned_cand_idx]);
3,460✔
1055
                            DataType const * best_data = std::get<1>(candidates[assigned_cand_idx]);
3,460✔
1056

1057
                            if (_logger) {
3,460✔
1058
                                double cost_unscaled = static_cast<double>(cost_matrix[chain_idx][assigned_cand_idx]) / cost_scaling_factor;
749✔
1059
                                _logger->debug("  Chain {} (entity {}) → entity {} (cost={:.3f}, threshold={:.3f})",
749✔
1060
                                               chain_idx, chain.curr_entity, best_entity, cost_unscaled, _cheap_assignment_threshold);
749✔
1061
                            }
1062

1063
                            // Guard: if N-scan (or another chain) already committed this obs at current frame,
1064
                            // do not double-claim it. Treat as no assignment and let fallback/termination handle it.
1065
                            auto used_key_current = std::make_pair(static_cast<long long>(f.getValue()), best_entity);
3,460✔
1066
                            if (used.count(used_key_current)) {
3,460✔
1067
                                // I think this might be a false positive
1068
                                //
UNCOV
1069
                                found_assignment = false;
×
1070
                            }
1071

1072
                            Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*best_data);
3,460✔
1073
                            if (chain.filter) {
3,460✔
1074
                                chain.filter->update(chain.predicted, {obs});
6,920✔
1075

1076
                                // Check covariance health
1077
                                auto updated_state = chain.filter->getState();
3,460✔
1078
                                double determinant = updated_state.state_covariance.determinant();
3,460✔
1079

1080
                                if (std::abs(determinant) < 1e-10 && _logger) {
3,460✔
UNCOV
1081
                                    Eigen::JacobiSVD<Eigen::MatrixXd> svd(updated_state.state_covariance);
×
UNCOV
1082
                                    double condition_number = svd.singularValues()(0) /
×
UNCOV
1083
                                                              (svd.singularValues()(svd.singularValues().size() - 1) + 1e-20);
×
1084

UNCOV
1085
                                    _logger->warn("State covariance singular: det={:.2e}, cond={:.2e}",
×
1086
                                                  determinant, condition_number);
1087

UNCOV
1088
                                    if (condition_number > 1e12) {
×
UNCOV
1089
                                        _logger->warn("  Terminating chain due to ill-conditioned covariance");
×
UNCOV
1090
                                        found_assignment = false;
×
1091
                                    }
UNCOV
1092
                                }
×
1093
                            }
3,460✔
1094

1095
                            if (found_assignment) {
3,460✔
1096
                                NodeInfo next_info{f, best_entity};
3,460✔
1097
                                chain.members.push_back(next_info);
3,460✔
1098
                                used.insert(std::make_pair(static_cast<long long>(f.getValue()), best_entity));
3,460✔
1099

1100
                                chain.curr_frame = f;
3,460✔
1101
                                chain.curr_entity = best_entity;
3,460✔
1102
                                chain.curr_data = best_data;
3,460✔
1103
                                this_frame_entities.erase(best_entity);
3,460✔
1104
                                remaining_chains.push_back(std::move(chain));
3,460✔
1105
                            }
1106
                        }
3,460✔
1107

1108
                        if (!found_assignment) {
3,461✔
1109
                            // Chain terminates -> emit meta-node
1110
                            if (_logger) {
1✔
UNCOV
1111
                                _logger->debug("  Chain {} (entity {}) terminated at frame {} - emit meta-node",
×
UNCOV
1112
                                               chain_idx, chain.curr_entity, chain.curr_frame.getValue());
×
1113
                            }
1114
                            MetaNode term;
1✔
1115
                            term.start_frame = chain.members.front().frame;
1✔
1116
                            term.start_entity = chain.members.front().entity_id;
1✔
1117
                            term.members = chain.members;
1✔
1118
                            term.end_frame = chain.members.back().frame;
1✔
1119
                            term.end_entity = chain.members.back().entity_id;
1✔
1120
                            term.start_state = chain.start_state;
1✔
1121
                            if (chain.filter) term.end_state = chain.filter->getState();
1✔
1122
                            meta_nodes.push_back(std::move(term));
1✔
1123
                            this_frame_entities.erase(chain.members.back().entity_id);
1✔
1124
                        }
1✔
1125
                    }
1126

1127
                    active_chains = std::move(remaining_chains);
1,741✔
1128
                }
1,741✔
1129
            }
1,741✔
1130

1131
            // Step 2: Start new chains for any remaining unused observations in current frame
1132
            for (auto const & item: frame_lookup.at(f)) {
5,308✔
1133
                EntityId entity_id = std::get<1>(item);
3,504✔
1134
                auto used_key = std::make_pair(static_cast<long long>(f.getValue()), entity_id);
3,504✔
1135
                if (used.count(used_key)) continue;
3,504✔
1136

1137
                DataType const * start_data = std::get<0>(item);
42✔
1138

1139
                // Initialize filter for this new chain
1140
                FilterState start_state;
42✔
1141
                FilterState updated_state;
42✔
1142
                std::unique_ptr<IFilter> chain_filter;
42✔
1143
                if (_filter_prototype) {
42✔
1144
                    chain_filter = _filter_prototype->clone();
42✔
1145
                    FilterState initial_state = _feature_extractor->getInitialState(*start_data);
42✔
1146
                    chain_filter->initialize(initial_state);
42✔
1147
                    start_state = chain_filter->getState();
42✔
1148

1149
                    // Immediately update filter with the first observation
1150
                    // This ensures single-frame meta-nodes have correct end_state
1151
                    Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*start_data);
42✔
1152
                    updated_state = chain_filter->update(start_state, Measurement{obs});
84✔
1153
                }
42✔
1154

1155
                // Start new active chain (defer meta-node until termination)
1156
                this_frame_entities.erase(entity_id);
42✔
1157
                used.insert(used_key);
42✔
1158

1159
                ActiveChain chain;
42✔
1160
                chain.meta_node_idx = static_cast<size_t>(-1);
42✔
1161
                chain.curr_frame = f;
42✔
1162
                chain.curr_entity = entity_id;
42✔
1163
                chain.curr_data = start_data;
42✔
1164
                chain.filter = std::move(chain_filter);
42✔
1165
                chain.members.push_back(NodeInfo{f, entity_id});
42✔
1166
                chain.start_state = start_state;
42✔
1167
                active_chains.push_back(std::move(chain));
42✔
1168
            }
1169

1170
            if (this_frame_entities.size() != 0) {
1,762✔
1171
                // ERROR WE LEFT A MAN BEHIND
1172
                // Error out
UNCOV
1173
                if (_logger) {
×
UNCOV
1174
                    _logger->error("We left a man behind at frame {}", f.getValue(),
×
UNCOV
1175
                                   " with entities: ", this_frame_entities.size());
×
1176

UNCOV
1177
                    for (auto const & entity: this_frame_entities) {
×
UNCOV
1178
                        _logger->error("  Entity {}", entity);
×
1179
                    }
1180

1181
                    // Is it in used?
UNCOV
1182
                    for (auto const & entity: this_frame_entities) {
×
UNCOV
1183
                        if (used.count(std::make_pair(static_cast<long long>(f.getValue()), entity))) {
×
UNCOV
1184
                            _logger->error("  Entity {} is in used", entity);
×
1185
                        }
1186
                    }
1187

1188
                    // Is it in active_chains?
UNCOV
1189
                    for (auto const & entity: this_frame_entities) {
×
UNCOV
1190
                        for (auto const & chain: active_chains) {
×
UNCOV
1191
                            if (chain.curr_entity == entity) {
×
UNCOV
1192
                                _logger->error("  Entity {} is in active_chains", entity);
×
UNCOV
1193
                                break;
×
1194
                            }
1195
                        }
1196
                    }
1197
                }
UNCOV
1198
                throw std::runtime_error(
×
1199
                        "We left a man behind at frame " + std::to_string(f.getValue()) +
1200
                        " with entities: " + std::to_string(this_frame_entities.size()));
1201
            }
1202

1203
            //Update progress every 1000 frames
1204
            if (f.getValue() % 1000 == 0) {
1,762✔
1205
                progress(static_cast<int>(f.getValue()) / (end_frame.getValue() - start_frame.getValue() + 1) * 100);
19✔
1206
            }
1207
        }
1208

1209
        // Finalize any remaining active chains at end of range
1210
        for (auto & chain: active_chains) {
103✔
1211
            MetaNode node;
41✔
1212
            node.start_frame = chain.members.front().frame;
41✔
1213
            node.start_entity = chain.members.front().entity_id;
41✔
1214
            node.members = chain.members;
41✔
1215
            node.end_frame = chain.members.back().frame;
41✔
1216
            node.end_entity = chain.members.back().entity_id;
41✔
1217
            node.start_state = chain.start_state;
41✔
1218
            if (chain.filter) node.end_state = chain.filter->getState();
41✔
1219
            if (_logger) {
41✔
1220
                _logger->debug("Meta-node (finalized): frames {} to {} ({} frames), entities {} to {}, {} members - reached end",
11✔
1221
                               node.start_frame.getValue(),
11✔
1222
                               node.end_frame.getValue(),
22✔
1223
                               node.end_frame.getValue() - node.start_frame.getValue() + 1,
22✔
1224
                               node.start_entity,
1225
                               node.end_entity,
1226
                               node.members.size());
22✔
1227
            }
1228
            meta_nodes.push_back(std::move(node));
41✔
1229
        }
1230

1231
        if (_logger) {
21✔
1232
            _logger->debug("Built {} meta-nodes using Hungarian assignment", meta_nodes.size());
6✔
1233

1234
            // Compute statistics on meta-node lengths
1235
            if (!meta_nodes.empty()) {
6✔
1236
                std::vector<int> lengths;
6✔
1237
                for (auto const & mn: meta_nodes) {
17✔
1238
                    lengths.push_back(static_cast<int>(mn.members.size()));
11✔
1239
                }
1240
                std::sort(lengths.begin(), lengths.end());
6✔
1241

1242
                int total_length = 0;
6✔
1243
                for (int len: lengths) total_length += len;
17✔
1244
                double mean_length = static_cast<double>(total_length) / lengths.size();
6✔
1245

1246
                int median_length = lengths[lengths.size() / 2];
6✔
1247
                int min_length = lengths.front();
6✔
1248
                int max_length = lengths.back();
6✔
1249

1250
                _logger->debug("Meta-node length statistics: min={}, median={}, mean={:.1f}, max={}",
6✔
1251
                               min_length, median_length, mean_length, max_length);
1252

1253
                // Count single-frame meta-nodes
1254
                int single_frame_count = 0;
6✔
1255
                for (int len: lengths) {
17✔
1256
                    if (len == 1) single_frame_count++;
11✔
1257
                }
1258
                if (single_frame_count > 0) {
6✔
UNCOV
1259
                    _logger->debug("  {} single-frame meta-nodes ({:.1f}%)",
×
1260
                                   single_frame_count,
UNCOV
1261
                                   100.0 * single_frame_count / meta_nodes.size());
×
1262
                }
1263
            }
6✔
1264
        }
1265
        return meta_nodes;
42✔
1266
    }
3,525✔
1267

1268
    // Remove all the old greedy code that follows (everything from "double best_cost" to the old "return meta_nodes")
1269
    /*
1270
    OLD GREEDY CODE REMOVED - replaced with Hungarian-based approach above
1271
    The new algorithm:
1272
    1. Starts new chains for all unused observations at each frame
1273
    2. Predicts all active chains forward one frame
1274
    3. Builds cost matrix (chains x candidates)
1275
    4. Uses Hungarian algorithm for optimal assignment
1276
    5. Only accepts assignments below threshold
1277
    6. Chains that don't get assigned (or exceed threshold) terminate
1278
    
1279
    This prevents "stealing" where long chains take candidates that would be better matches for other chains.
1280
    */
1281

1282
    // --- N-Scan Lookahead Functions ---
1283

1284

1285
    /**
1286
     * @brief Detect if assignments are ambiguous at current frame.
1287
     * Ambiguous if: (1) A chain has ≥2 candidates below threshold, OR
1288
     *               (2) Multiple chains compete for the same candidate.
1289
     * 
1290
     * @param cost_matrix Cost matrix (chains x candidates)
1291
     * @param threshold Assignment threshold
1292
     * @return Indices of chains involved in ambiguous assignments
1293
     */
1294
    std::unordered_set<size_t> detect_ambiguous_chains(
1,741✔
1295
            Eigen::MatrixXd const & cost_matrix,
1296
            double threshold) const {
1297

1298
        std::unordered_set<size_t> ambiguous_chains;
1,741✔
1299
        size_t num_chains = static_cast<size_t>(cost_matrix.rows());
1,741✔
1300
        size_t num_candidates = static_cast<size_t>(cost_matrix.cols());
1,741✔
1301

1302
        // Check condition 1: chain has ≥2 candidates below threshold
1303
        for (size_t chain_idx = 0; chain_idx < num_chains; ++chain_idx) {
5,204✔
1304
            int count_below_threshold = 0;
3,463✔
1305
            for (size_t cand_idx = 0; cand_idx < num_candidates; ++cand_idx) {
10,370✔
1306
                if (cost_matrix(static_cast<Eigen::Index>(chain_idx),
6,907✔
1307
                                static_cast<Eigen::Index>(cand_idx)) < threshold) {
6,907✔
1308
                    ++count_below_threshold;
3,427✔
1309
                }
1310
            }
1311
            if (count_below_threshold >= 2) {
3,463✔
1312
                ambiguous_chains.insert(chain_idx);
2✔
1313
            }
1314
        }
1315

1316
        // Check condition 2: multiple chains compete for same candidate
1317
        for (size_t cand_idx = 0; cand_idx < num_candidates; ++cand_idx) {
5,204✔
1318
            int count_below_threshold = 0;
3,463✔
1319
            size_t competing_chain = 0;
3,463✔
1320
            for (size_t chain_idx = 0; chain_idx < num_chains; ++chain_idx) {
10,370✔
1321
                if (cost_matrix(static_cast<Eigen::Index>(chain_idx),
6,907✔
1322
                                static_cast<Eigen::Index>(cand_idx)) < threshold) {
6,907✔
1323
                    ++count_below_threshold;
3,427✔
1324
                    competing_chain = chain_idx;
3,427✔
1325
                }
1326
            }
1327
            if (count_below_threshold >= 2) {
3,463✔
1328
                // Mark all competing chains as ambiguous
1329
                for (size_t chain_idx = 0; chain_idx < num_chains; ++chain_idx) {
6✔
1330
                    if (cost_matrix(static_cast<Eigen::Index>(chain_idx),
4✔
1331
                                    static_cast<Eigen::Index>(cand_idx)) < threshold) {
4✔
1332
                        ambiguous_chains.insert(chain_idx);
4✔
1333
                    }
1334
                }
1335
            }
1336
        }
1337

1338
        return ambiguous_chains;
1,741✔
UNCOV
1339
    }
×
1340

1341
    /**
1342
     * @brief Expand hypotheses by one frame: predict, compute costs, and branch.
1343
     * 
1344
     * @param hypotheses Current hypotheses for a chain
1345
     * @param candidates Available observations at next frame
1346
     * @param next_frame Frame index of candidates
1347
     * @param frame_lookup Frame data lookup
1348
     * @param scoring_fn Function to compute total cost from frame costs
1349
     * @return Updated hypotheses (terminated branches are marked)
1350
     */
1351
    std::vector<Hypothesis> expand_hypotheses(
4✔
1352
            std::vector<Hypothesis> && hypotheses,
1353
            std::vector<std::pair<EntityId, DataType const *>> const & candidates,
1354
            TimeFrameIndex next_frame,
1355
            std::map<TimeFrameIndex, FrameBucket<DataType>> const & frame_lookup,
1356
            HypothesisScoringFunction const & scoring_fn) {
1357

1358
        std::vector<Hypothesis> expanded;
4✔
1359

1360
        for (auto & hyp: hypotheses) {
28✔
1361
            if (hyp.terminated) {
12✔
UNCOV
1362
                expanded.push_back(std::move(hyp));
×
UNCOV
1363
                continue;
×
1364
            }
1365

1366
            // Predict forward one frame
1367
            FilterState predicted_state = hyp.filter->predict();
12✔
1368

1369
            if (_logger) {
12✔
1370
                _logger->debug("      Expanding hyp (current_path_length={}): predicted_mean=[{:.2f},{:.2f}]",
12✔
1371
                               hyp.path.size(), predicted_state.state_mean(0), predicted_state.state_mean(1));
12✔
1372
            }
1373

1374
            // Try each candidate
1375
            bool found_valid_branch = false;
12✔
1376
            for (auto const & [cand_entity_id, cand_data]: candidates) {
60✔
1377
                // Extract features
1378
                Eigen::VectorXd measurement = _feature_extractor->getFilterFeatures(*cand_data);
24✔
1379

1380
                // Compute cost
1381
                double cost = _lookahead_cost_function(predicted_state, measurement, 1);
24✔
1382

1383
                if (_logger && cost < _cheap_assignment_threshold) {
24✔
1384
                    _logger->debug("        → entity {}: obs=[{:.2f},{:.2f}], cost={:.3f}",
23✔
1385
                                   cand_entity_id, measurement(0), measurement(1), cost);
23✔
1386
                }
1387

1388
                // Prune if exceeds lookahead threshold
1389
                if (std::isfinite(_lookahead_threshold) && cost >= _lookahead_threshold) {
24✔
UNCOV
1390
                    continue;
×
1391
                }
1392

1393
                // Clone hypothesis and extend
1394
                Hypothesis new_hyp;
24✔
1395
                new_hyp.filter = hyp.filter->clone();
24✔
1396
                new_hyp.current_state = new_hyp.filter->update(predicted_state, Measurement{measurement});
48✔
1397
                new_hyp.path = hyp.path;
24✔
1398
                new_hyp.path.push_back({next_frame, cand_entity_id});
24✔
1399
                new_hyp.frame_costs = hyp.frame_costs;
24✔
1400
                new_hyp.frame_costs.push_back(cost);
24✔
1401
                new_hyp.total_cost = scoring_fn(new_hyp.frame_costs);
24✔
1402
                new_hyp.terminated = false;
24✔
1403

1404
                expanded.push_back(std::move(new_hyp));
24✔
1405
                found_valid_branch = true;
24✔
1406
            }
1407

1408
            // If no valid branches, terminate this hypothesis
1409
            if (!found_valid_branch) {
12✔
UNCOV
1410
                hyp.terminated = true;
×
UNCOV
1411
                expanded.push_back(std::move(hyp));
×
1412
            }
1413
        }
1414

1415
        return expanded;
4✔
1416
    }
24✔
1417

1418
    /**
1419
     * @brief Run N-scan lookahead for an ambiguous chain.
1420
     * Explores multiple hypothesis paths over the next N frames and selects the best.
1421
     * 
1422
     * @param chain The active chain to run N-scan on
1423
     * @param candidates_at_next_frame Initial candidates at the first lookahead frame
1424
     * @param start_scan_frame The frame where N-scan starts
1425
     * @param frame_lookup All frame data
1426
     * @param used Set of already-used (frame, entity) pairs (will be updated)
1427
     * @return Pair of (path, total_cost), or ({}, 0.0) if chain should terminate
1428
     */
1429
    std::pair<std::vector<NodeInfo>, double> run_n_scan_lookahead(
2✔
1430
            ActiveChain const & chain,
1431
            std::vector<std::tuple<EntityId, DataType const *, double>> const & candidates_with_costs,
1432
            TimeFrameIndex start_scan_frame,
1433
            TimeFrameIndex end_frame,
1434
            std::map<TimeFrameIndex, FrameBucket<DataType>> const & frame_lookup,
1435
            std::set<std::pair<long long, EntityId>> used// pass by value to make sure we don't modify the original set
1436
    ) {
1437

1438
        // Early return if we can't scan ahead (at or near end frame)
1439
        if (start_scan_frame + TimeFrameIndex(1) > end_frame) {
2✔
UNCOV
1440
            if (_logger) {
×
UNCOV
1441
                _logger->debug("N-scan skipped at frame {}: no future frames to scan", start_scan_frame.getValue());
×
1442
            }
UNCOV
1443
            return {{}, 0.0};
×
1444
        }
1445

1446
        // Initialize hypotheses for each viable candidate
1447
        std::vector<Hypothesis> hypotheses;
2✔
1448
        for (auto const & [cand_entity, cand_data, cost_double]: candidates_with_costs) {
10✔
1449
            if (cost_double >= _cheap_assignment_threshold) continue;
4✔
1450

1451
            Hypothesis hyp;
4✔
1452
            hyp.filter = chain.filter ? chain.filter->clone() : nullptr;
4✔
1453
            if (hyp.filter) {
4✔
1454
                Eigen::VectorXd obs = _feature_extractor->getFilterFeatures(*cand_data);
4✔
1455

1456
                // Check if cloned filter state matches chain.predicted
1457
                auto cloned_state = hyp.filter->getState();
4✔
1458

1459
                hyp.current_state = hyp.filter->update(chain.predicted, Measurement{obs});
8✔
1460

1461
                if (_logger) {
4✔
1462
                    _logger->debug("    Init hyp for entity {}: chain.predicted=[{:.2f},{:.2f},{:.2f},{:.2f}], cloned_filter=[{:.2f},{:.2f},{:.2f},{:.2f}], obs=[{:.2f},{:.2f}], cost={:.3f}",
4✔
1463
                                   cand_entity,
1464
                                   chain.predicted.state_mean(0), chain.predicted.state_mean(1),
4✔
1465
                                   chain.predicted.state_mean(2), chain.predicted.state_mean(3),
4✔
1466
                                   cloned_state.state_mean(0), cloned_state.state_mean(1),
4✔
1467
                                   cloned_state.state_mean(2), cloned_state.state_mean(3),
4✔
1468
                                   obs(0), obs(1), cost_double);
4✔
1469
                    _logger->debug("       After update: state=[{:.2f},{:.2f},{:.2f},{:.2f}]",
4✔
1470
                                   hyp.current_state.state_mean(0), hyp.current_state.state_mean(1),
4✔
1471
                                   hyp.current_state.state_mean(2), hyp.current_state.state_mean(3));
4✔
1472
                }
1473
            }
4✔
1474
            hyp.path.push_back({start_scan_frame, cand_entity});
4✔
1475
            hyp.frame_costs.push_back(cost_double);
4✔
1476
            hyp.total_cost = score_hypothesis_simple_sum(hyp.frame_costs);
4✔
1477
            hyp.terminated = false;
4✔
1478

1479
            hypotheses.push_back(std::move(hyp));
4✔
1480
        }
1481

1482
        if (hypotheses.empty()) {
2✔
UNCOV
1483
            return {{}, 0.0};// No viable paths
×
1484
        }
1485

1486
        if (_logger) {
2✔
1487
            _logger->debug("Starting N-scan at frame {} with {} initial hypotheses",
2✔
1488
                           start_scan_frame.getValue(), hypotheses.size());
2✔
1489
        }
1490

1491
        // Expand hypotheses over N frames
1492
        for (int depth = 1; depth < _n_scan_depth; ++depth) {
10✔
1493
            TimeFrameIndex scan_frame = start_scan_frame + TimeFrameIndex(depth);
4✔
1494
            if (scan_frame > end_frame || !frame_lookup.count(scan_frame)) {
4✔
UNCOV
1495
                break;// Reached end of available frames
×
1496
            }
1497

1498
            // Collect available candidates
1499
            std::vector<std::pair<EntityId, DataType const *>> scan_candidates;
4✔
1500
            for (auto const & item: frame_lookup.at(scan_frame)) {
12✔
1501
                EntityId cand_id = std::get<1>(item);
8✔
1502
                auto key = std::make_pair(static_cast<long long>(scan_frame.getValue()), cand_id);
8✔
1503
                if (used.count(key)) continue;
8✔
1504

1505
                scan_candidates.emplace_back(cand_id, std::get<0>(item));
8✔
1506
            }
1507

1508
            if (scan_candidates.empty()) {
4✔
UNCOV
1509
                break;// No candidates available
×
1510
            }
1511

1512
            // Expand all hypotheses
1513
            hypotheses = expand_hypotheses(std::move(hypotheses), scan_candidates, scan_frame,
4✔
1514
                                           frame_lookup, score_hypothesis_simple_sum);
1515

1516
            // Check for early termination
1517
            int viable_count = 0;
4✔
1518
            for (auto const & hyp: hypotheses) {
28✔
1519
                if (!hyp.terminated) viable_count++;
24✔
1520
            }
1521

1522
            if (_logger) {
4✔
1523
                _logger->debug("N-scan depth {}: {} viable hypotheses at frame {}",
4✔
1524
                               depth, viable_count, scan_frame.getValue());
4✔
1525
            }
1526

1527
            if (viable_count <= 1) {
4✔
UNCOV
1528
                break;// Only one path remains, can commit early
×
1529
            }
1530
        }
1531

1532
        // Select best hypothesis
1533
        bool reached_n = (hypotheses.empty() ? false : (hypotheses[0].path.size() >= static_cast<size_t>(_n_scan_depth)));
2✔
1534
        auto best_hyp_opt = select_best_hypothesis(hypotheses, reached_n);
2✔
1535

1536
        if (!best_hyp_opt.has_value()) {
2✔
UNCOV
1537
            if (_logger) {
×
UNCOV
1538
                _logger->debug("N-scan terminated: ambiguity persists or no viable paths");
×
1539
            }
UNCOV
1540
            return {{}, 0.0};
×
1541
        }
1542

1543
        // Mark used entities from the selected path
1544
        auto const & best_path = best_hyp_opt->path;
2✔
1545
        double best_cost = best_hyp_opt->total_cost;
2✔
1546
        for (auto const & node: best_path) {
8✔
1547
            used.insert(std::make_pair(static_cast<long long>(node.frame.getValue()), node.entity_id));
6✔
1548
        }
1549

1550
        if (_logger) {
2✔
1551
            _logger->debug("N-scan committed path with {} nodes, total cost {:.2f}",
2✔
1552
                           best_path.size(), best_cost);
2✔
1553
        }
1554

1555
        return {best_path, best_cost};
2✔
1556
    }
6✔
1557

1558
    // --- Final Smoothing Step ---
1559
    SmoothedResults generate_smoothed_results(
21✔
1560
            std::map<GroupId, Path> const & solved_paths,
1561
            std::map<TimeFrameIndex, FrameBucket<DataType>> const & frame_lookup,
1562
            TimeFrameIndex start_frame,
1563
            TimeFrameIndex end_frame) {
1564

1565
        SmoothedResults final_results;
21✔
1566

1567
        // Skip smoothing if no filter is provided
1568
        if (!_filter_prototype) {
21✔
UNCOV
1569
            return final_results;
×
1570
        }
1571

1572
        for (auto const & [group_id, path]: solved_paths) {
67✔
1573
            if (path.empty()) continue;
23✔
1574

1575
            auto filter = _filter_prototype->clone();
23✔
1576
            std::vector<FilterState> forward_states;
23✔
1577

1578
            // Forward pass using the solved path
1579
            for (size_t i = 0; i < path.size(); ++i) {
1,763✔
1580
                auto const & node = path[i];
1,740✔
1581
                auto const * data = findEntity(frame_lookup.at(node.frame), node.entity_id);
1,740✔
1582
                if (!data) continue;
1,740✔
1583

1584
                if (i == 0) {
1,740✔
1585
                    filter->initialize(_feature_extractor->getInitialState(*data));
23✔
1586
                } else {
1587
                    TimeFrameIndex prev_frame = path[i - 1].frame;
1,717✔
1588
                    int num_steps = (node.frame - prev_frame).getValue();
1,717✔
1589

1590
                    if (num_steps <= 0) {
1,717✔
UNCOV
1591
                        if (_logger) _logger->error("Invalid num_steps in smoothing: {}", num_steps);
×
UNCOV
1592
                        continue;// Skip invalid steps
×
1593
                    }
1594

1595
                    // Multi-step prediction: call predict() for each frame step
1596
                    // The last predict() call will set the filter's internal state to the predicted state
1597
                    FilterState pred = filter->getState();// Initialize with current state
1,717✔
1598
                    for (int step = 0; step < num_steps; ++step) {
3,542✔
1599
                        pred = filter->predict();
1,825✔
1600
                    }
1601
                    // Now filter's internal state is at 'pred', and we update it with the measurement
1602
                    filter->update(pred, {_feature_extractor->getFilterFeatures(*data)});
3,434✔
1603
                }
1,717✔
1604
                forward_states.push_back(filter->getState());
1,740✔
1605
            }
1606

1607
            // Backward smoothing pass
1608
            if (forward_states.size() > 1) {
23✔
1609
                final_results[group_id] = filter->smooth(forward_states);
23✔
1610
            } else {
UNCOV
1611
                final_results[group_id] = forward_states;
×
1612
            }
1613
        }
1614
        return final_results;
1615
    }
1,717✔
1616

1617
private:
1618
    std::unique_ptr<IFilter> _filter_prototype;
1619
    std::unique_ptr<IFeatureExtractor<DataType>> _feature_extractor;
1620
    CostFunction _chain_cost_function;
1621
    CostFunction _transition_cost_function;
1622
    CostFunction _lookahead_cost_function;
1623
    double _cost_scale_factor;
1624
    double _cheap_assignment_threshold;
1625
    std::shared_ptr<spdlog::logger> _logger;
1626
    TrackerContractPolicy _policy = TrackerContractPolicy::Throw;
1627
    TrackerDiagnostics _diagnostics{};
1628
    int _n_scan_depth = 3;
1629
    bool _enable_n_scan = true;
1630
    int _max_gap_frames = 3;// Maximum frames to skip before terminating chain (-1 = unlimited)
1631
    double _lookahead_threshold = std::numeric_limits<double>::infinity();
1632
    double _ambiguity_threshold = 1.0;// default: stricter than cheap assignment
1633
    double _ambiguity_margin = 0.0;   // default off
1634
};
1635

1636
}// namespace StateEstimation
1637

1638
#endif// MIN_COST_FLOW_TRACKER_HPP
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc