• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

paulmthompson / WhiskerToolbox / 18246927847

04 Oct 2025 04:44PM UTC coverage: 71.826% (+0.6%) from 71.188%
18246927847

push

github

paulmthompson
refactor out media producer consumer pipeline

0 of 120 new or added lines in 2 files covered. (0.0%)

646 existing lines in 14 files now uncovered.

48895 of 68074 relevant lines covered (71.83%)

1193.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.17
/src/StateEstimation/Tracker.hpp
1
#ifndef TRACKER_HPP
2
#define TRACKER_HPP
3

4
#include "Assignment/IAssigner.hpp"
5
#include "DataSource.hpp"
6
#include "Features/IFeatureExtractor.hpp"
7
#include "Filter/IFilter.hpp"
8
#include "IdentityConfidence.hpp"
9

10
#include "spdlog/sinks/basic_file_sink.h"
11
#include "spdlog/sinks/rotating_file_sink.h"
12
#include "spdlog/spdlog.h"
13

14
#include <algorithm>
15
#include <chrono>
16
#include <cmath>
17
#include <iostream>
18
#include <limits>
19
#include <map>
20
#include <memory>
21
#include <optional>
22
#include <ranges>
23
#include <set>
24
#include <unordered_map>
25
#include <unordered_set>
26
#include <vector>
27

28
namespace StateEstimation {
29

30
// The return type: a map from each GroupId to its series of smoothed states.
31
using SmoothedResults = std::map<GroupId, std::vector<FilterState>>;
32

33
// Progress callback: takes percentage (0-100) and current frame
34
using ProgressCallback = std::function<void(int)>;
35

36
// Forward declaration for the state structure
37
template<typename DataType>
38
struct TrackedGroupState;
39

40
/**
41
 * @brief Helper structure for batching group assignment updates.
42
 * 
43
 * Accumulates entity-to-group assignments during tracking and flushes them
44
 * to the EntityGroupManager at strategic points (anchor frames). This provides
45
 * significant performance benefits by:
46
 * - Avoiding O(G × E_g × log E) cost of rebuilding group membership set every frame
47
 * - Providing O(1) membership checks via hash set
48
 * - Batching updates for better cache locality
49
 */
50
struct PendingGroupUpdates {
51
    // Frame-aware pending additions per group
52
    std::unordered_map<GroupId, std::vector<std::pair<TimeFrameIndex, EntityId>>> pending_additions;
53

54
    // Fast O(1) lookup for entities assigned during this pass
55
    std::unordered_set<EntityId> entities_added_this_pass;
56

57
    void addPending(GroupId group_id, EntityId entity_id, TimeFrameIndex frame) {
1,466✔
58
        pending_additions[group_id].emplace_back(frame, entity_id);
1,466✔
59
        entities_added_this_pass.insert(entity_id);
1,466✔
60
    }
1,466✔
61

62
    // Replace the entity assigned for a given group and frame, if present
63
    void replaceForFrame(GroupId group_id, TimeFrameIndex frame, EntityId new_entity_id) {
1,296✔
64
        auto it = pending_additions.find(group_id);
1,296✔
65
        if (it == pending_additions.end()) return;
1,296✔
66
        for (auto & entry: it->second) {
12,960✔
67
            if (entry.first == frame) {
11,664✔
68
                entry.second = new_entity_id;
1,296✔
69
            }
70
        }
71
        entities_added_this_pass.insert(new_entity_id);
1,296✔
72
    }
73

74
    void flushToManager(EntityGroupManager & manager) {
86✔
75
        for (auto const & [group_id, entries]: pending_additions) {
252✔
76
            for (auto const & [/*frame*/ _, entity_id]: entries) {
1,632✔
77
                manager.addEntityToGroup(group_id, entity_id);
1,466✔
78
            }
79
        }
80
        pending_additions.clear();
86✔
81
        entities_added_this_pass.clear();
86✔
82
    }
86✔
83

84
    bool contains(EntityId entity_id) const {
1,666✔
85
        return entities_added_this_pass.find(entity_id) != entities_added_this_pass.end();
1,666✔
86
    }
87

88
    std::unordered_set<EntityId> const & getAddedEntities() const {
73✔
89
        return entities_added_this_pass;
73✔
90
    }
91
};
92

93
/**
94
 * @brief The central orchestrator for the tracking process.
95
 *
96
 * This class manages the lifecycle of tracked objects (groups) and coordinates
97
 * the filter, feature extraction, and assignment components to process data
98
 * across multiple time frames. It is templated on the raw data type it operates on.
99
 *
100
 * @tparam DataType The raw data type (e.g., Line2D, Point2D).
101
 */
102
template<typename DataType>
103
class Tracker {
104
public:
105
    using GroundTruthMap = std::map<TimeFrameIndex, std::map<GroupId, EntityId>>;
106

107
    /**
108
     * @brief Constructs a Tracker.
109
     *
110
     * @param filter_prototype A prototype of the filter to be used for each track. It will be cloned.
111
     * @param feature_extractor A unique_ptr to the feature extractor strategy.
112
     * @param assigner A unique_ptr to the assignment strategy. Can be nullptr for smoothing-only tasks.
113
     */
114
    Tracker(std::unique_ptr<IFilter> filter_prototype,
13✔
115
            std::unique_ptr<IFeatureExtractor<DataType>> feature_extractor,
116
            std::unique_ptr<IAssigner> assigner = nullptr)
117
        : filter_prototype_(std::move(filter_prototype)),
13✔
118
          feature_extractor_(std::move(feature_extractor)),
13✔
119
          assigner_(std::move(assigner)) {}
13✔
120

121
    /**
122
     * @brief Enable detailed debug logging to a file.
123
     * @pre File path is writable.
124
     * @post Subsequent calls to process will emit per-frame diagnostics.
125
     */
126
    void enableDebugLogging(std::string const & file_path) {
9✔
127
        try {
128
            // Use a rotating sink to avoid losing logs in long runs
129
            std::size_t const max_bytes = static_cast<std::size_t>(500) * 1024 * 1024;// 500 MB
9✔
130
            std::size_t const max_files = 3;
9✔
131
            auto sink = std::make_shared<spdlog::sinks::rotating_file_sink_mt>(file_path, max_bytes, max_files);
9✔
132
            auto existing = spdlog::get("TrackerLogger");
27✔
133
            if (existing) {
9✔
134
                logger_ = existing;
7✔
135
                logger_->sinks().clear();
7✔
136
                logger_->sinks().push_back(sink);
7✔
137
            } else {
138
                logger_ = std::make_shared<spdlog::logger>("TrackerLogger", sink);
2✔
139
                spdlog::register_logger(logger_);
2✔
140
            }
141
            logger_->set_pattern("[%Y-%m-%d %H:%M:%S.%e] [%l] %v");
27✔
142
            logger_->set_level(spdlog::level::debug);
9✔
143
            logger_->flush_on(spdlog::level::debug);
9✔
144
            spdlog::flush_every(std::chrono::seconds(1));
9✔
145
        } catch (spdlog::spdlog_ex const &) {
9✔
UNCOV
146
            logger_ = spdlog::get("TrackerLogger");
×
UNCOV
147
            if (logger_) {
×
UNCOV
148
                logger_->set_level(spdlog::level::debug);
×
UNCOV
149
                logger_->flush_on(spdlog::level::debug);
×
150
            }
151
        }
152
    }
9✔
153

154
    /**
155
     * @brief Main processing entry point. Runs the tracking algorithm using a zero-copy data source.
156
     *
157
     * @tparam Source A range type satisfying the DataSource concept
158
     * @param data_source A range providing tuples of (DataType, EntityId, TimeFrameIndex)
159
     * @param group_manager The EntityGroupManager containing group assignments. Will be modified for new assignments.
160
     * @param ground_truth A map indicating ground-truth labels for specific groups at specific frames.
161
     * @param start_frame The first frame to process.
162
     * @param end_frame The last frame to process.
163
     * @param progress_callback Optional callback for progress reporting (percentage 0-100).
164
     * @return A map from GroupId to a vector of smoothed filter states.
165
     */
166
    template<typename Source>
167
        requires DataSource<Source, DataType>
168
    SmoothedResults process(Source && data_source,
13✔
169
                            EntityGroupManager & group_manager,
170
                            GroundTruthMap const & ground_truth,
171
                            TimeFrameIndex start_frame,
172
                            TimeFrameIndex end_frame,
173
                            ProgressCallback progress_callback = nullptr) {
174

175
        // Build frame-indexed lookup for efficient access
176
        std::map<TimeFrameIndex, std::vector<std::tuple<DataType const *, EntityId, TimeFrameIndex>>> frame_data_lookup;
13✔
177

178
        for (auto const & item: data_source) {
1,876✔
179
            TimeFrameIndex time = getTimeFrameIndex(item);
1,854✔
180
            if (time >= start_frame && time <= end_frame) {
1,854✔
181
                frame_data_lookup[time].emplace_back(&getData(item), getEntityId(item), time);
1,854✔
182
            }
183
        }
184

185
        // Initialize tracks from EntityGroupManager
186
        for (auto group_id: group_manager.getAllGroupIds()) {
37✔
187
            if (active_tracks_.find(group_id) == active_tracks_.end()) {
24✔
188
                active_tracks_[group_id] = TrackedGroupState<DataType>{
48✔
189
                        .group_id = group_id,
190
                        .filter = filter_prototype_->clone(),
48✔
191
                        .is_active = false,
192
                        .frames_since_last_seen = 0,
193
                        .confidence = 1.0,
194
                        .identity_confidence = IdentityConfidence{},
195
                        .anchor_frames = {},
196
                        .forward_pass_history = {},
197
                        .forward_prediction_history = {},
198
                        .processed_frames_history = {},
199
                        .identity_confidence_history = {},
200
                        .assigned_entity_history = {}};
201
            }
202
        }
203

204
        // OPTIMIZATION 1: Build initial grouped entities set ONCE
205
        // This avoids O(G × E_g × log E) rebuild every frame
206
        std::unordered_set<EntityId> initially_grouped_entities;
13✔
207
        for (auto group_id: group_manager.getAllGroupIds()) {
61✔
208
            auto entities = group_manager.getEntitiesInGroup(group_id);
24✔
209
            initially_grouped_entities.insert(entities.begin(), entities.end());
24✔
210
        }
211

212
        // OPTIMIZATION 1: Deferred group updates for batch processing
213
        PendingGroupUpdates pending_updates;
13✔
214

215
        SmoothedResults all_smoothed_results;
13✔
216

217
        TimeFrameIndex const total_frames = end_frame - start_frame + TimeFrameIndex(1);
13✔
218
        int64_t frames_processed = 0;
13✔
219

220
        for (TimeFrameIndex current_frame = start_frame; current_frame <= end_frame; ++current_frame) {
1,867✔
221
            auto frame_data_it = frame_data_lookup.find(current_frame);
927✔
222
            auto const & all_frame_data = (frame_data_it != frame_data_lookup.end())
1,854✔
223
                                                  ? frame_data_it->second
927✔
224
                                                  : std::vector<std::tuple<DataType const *, EntityId, TimeFrameIndex>>{};
225

226
            // OPTIMIZATION 2: Build per-frame entity index for O(1) entity lookup
227
            // Eliminates O(E) linear searches repeated 3+ times per frame
228
            std::unordered_map<EntityId, size_t> entity_to_index;
927✔
229
            for (size_t i = 0; i < all_frame_data.size(); ++i) {
2,781✔
230
                EntityId eid = std::get<1>(all_frame_data[i]);
1,854✔
231
                entity_to_index[eid] = i;
1,854✔
232
            }
233

234
            // Report progress
235
            if (progress_callback) {
927✔
236
                ++frames_processed;
900✔
237
                int const percentage = static_cast<int>((frames_processed * 100) / total_frames.getValue());
900✔
238
                progress_callback(percentage);
900✔
239
            }
240

241
            if (logger_) {
927✔
242
                logger_->debug("frame={} entities={} active_groups={}",
900✔
243
                               current_frame.getValue(),
900✔
244
                               all_frame_data.size(),
1,800✔
245
                               active_tracks_.size());
1,800✔
246
            }
247

248
            // --- Predictions ---
249
            std::map<GroupId, FilterState> predictions;
927✔
250
            for (auto & [group_id, track]: active_tracks_) {
2,581✔
251
                if (track.is_active) {
1,654✔
252
                    predictions[group_id] = track.filter->predict();
1,630✔
253
                    track.frames_since_last_seen++;
1,630✔
254
                }
255
            }
256

257
            auto gt_frame_it = ground_truth.find(current_frame);
927✔
258

259
            std::set<GroupId> updated_groups_this_frame;
927✔
260
            std::set<EntityId> assigned_entities_this_frame;
927✔
261
            std::unordered_map<GroupId, std::optional<EntityId>> group_assigned_entity_in_frame;
927✔
262

263
            // --- Ground Truth Updates & Activation ---
264
            processGroundTruthUpdates(current_frame,
927✔
265
                                      ground_truth,
266
                                      all_frame_data,
267
                                      entity_to_index,
268
                                      predictions,
269
                                      updated_groups_this_frame,
270
                                      assigned_entities_this_frame);
271

272
            // **FIX 3: SYNCHRONIZE PREDICTION HISTORY**
273
            // For any group updated by GT, overwrite its prediction with the certain, GT-updated state.
274
            for (GroupId group_id: updated_groups_this_frame) {
1,097✔
275
                if (active_tracks_.count(group_id)) {
170✔
276
                    predictions[group_id] = active_tracks_.at(group_id).filter->getState();
170✔
277
                }
278
            }
279

280
            // --- Assignment for Ungrouped Data ---
281
            if (assigner_) {
927✔
282
                std::vector<Observation> observations;
916✔
283
                std::map<EntityId, FeatureCache> feature_cache;
916✔
284

285
                // OPTIMIZATION 1: O(1) membership checks using cached sets
286
                // Avoids O(G × E_g × log E) rebuild every frame
287
                for (auto const & [data_ptr, entity_id, time]: all_frame_data) {
2,748✔
288
                    if (assigned_entities_this_frame.find(entity_id) == assigned_entities_this_frame.end() &&
8,662✔
289
                        initially_grouped_entities.find(entity_id) == initially_grouped_entities.end() &&
8,662✔
290
                        !pending_updates.contains(entity_id)) {
1,666✔
291
                        observations.push_back({entity_id});
1,666✔
292
                        feature_cache[entity_id] = feature_extractor_->getAllFeatures(*data_ptr);
1,666✔
293
                    }
294
                }
295

296
                std::vector<Prediction> prediction_list;
916✔
297
                for (auto const & [group_id, pred_state]: predictions) {
2,548✔
298
                    if (active_tracks_.at(group_id).is_active &&
8,160✔
299
                        updated_groups_this_frame.find(group_id) == updated_groups_this_frame.end()) {
6,528✔
300
                        // Do not allow assignment to groups already updated by ground truth this frame
301
                        prediction_list.push_back({group_id, pred_state});
2,932✔
302
                    }
303
                }
304

305
                if (!observations.empty() && !prediction_list.empty()) {
916✔
306
                    Assignment assignment = assigner_->solve(prediction_list, observations, feature_cache);
733✔
307

308
                    for (auto const & [obs_idx, pred_idx]: assignment.observation_to_prediction) {
2,199✔
309
                        auto const & obs = observations[static_cast<size_t>(obs_idx)];
1,466✔
310
                        auto const & pred = prediction_list[static_cast<size_t>(pred_idx)];
1,466✔
311

312
                        auto & track = active_tracks_.at(pred.group_id);
1,466✔
313

314
                        // OPTIMIZATION 2: O(1) entity lookup instead of O(E) linear search
315
                        auto entity_it = entity_to_index.find(obs.entity_id);
1,466✔
316
                        if (entity_it == entity_to_index.end()) continue;
1,466✔
317

318
                        DataType const * obs_data = std::get<0>(all_frame_data[entity_it->second]);
1,466✔
319

320
                        // Update identity confidence based on assignment cost
321
                        auto cost_it = assignment.assignment_costs.find(obs_idx);
1,466✔
322
                        if (cost_it != assignment.assignment_costs.end()) {
1,466✔
323
                            track.identity_confidence.updateOnAssignment(cost_it->second, assignment.cost_threshold);
1,466✔
324

325
                            // Allow slow recovery for excellent assignments
326
                            double excellent_threshold = assignment.cost_threshold * 0.1;
1,466✔
327
                            track.identity_confidence.allowSlowRecovery(cost_it->second, excellent_threshold);
1,466✔
328
                            if (logger_) {
1,466✔
329
                                logger_->debug("assign f={} g={} obs={} cost={:.3f} conf={:.3f}",
1,440✔
330
                                               current_frame.getValue(),
1,440✔
331
                                               pred.group_id,
1,440✔
332
                                               obs.entity_id,
1,440✔
333
                                               cost_it->second,
1,440✔
334
                                               track.identity_confidence.getConfidence());
2,880✔
335
                            }
336
                        }
337

338
                        // Scale measurement noise based on identity confidence
339
                        double noise_scale = track.identity_confidence.getMeasurementNoiseScale();
1,466✔
340
                        Measurement measurement = {feature_extractor_->getFilterFeatures(*obs_data)};
1,466✔
341

342
                        // Apply noise scaling to the measurement
343
                        track.filter->update(pred.filter_state, measurement, noise_scale);
1,466✔
344

345
                        if (logger_) {
1,466✔
346
                            double cov_tr = track.filter->getState().state_covariance.trace();
1,440✔
347
                            logger_->debug("update f={} g={} obs={} noise_scale={:.3f} cov_tr={:.3f}",
1,440✔
348
                                           current_frame.getValue(),
1,440✔
349
                                           pred.group_id,
1,440✔
350
                                           obs.entity_id,
1,440✔
351
                                           noise_scale, cov_tr);
352
                        }
353

354
                        // OPTIMIZATION 1: Defer update to batch flush at anchor frames (track frame)
355
                        pending_updates.addPending(pred.group_id, obs.entity_id, current_frame);
1,466✔
356

357
                        updated_groups_this_frame.insert(pred.group_id);
1,466✔
358
                        assigned_entities_this_frame.insert(obs.entity_id);
1,466✔
359
                        group_assigned_entity_in_frame[pred.group_id] = obs.entity_id;
1,466✔
360
                        track.frames_since_last_seen = 0;
1,466✔
361
                    }
362
                }
733✔
363
            }// end of assignment for ungrouped data
916✔
364

365
            // --- Finalize Frame State, Log History, and Handle Smoothing ---
366
            bool any_smoothing_this_frame = false;
927✔
367
            for (auto & [group_id, track]: active_tracks_) {
2,581✔
368
                if (!track.is_active) continue;
1,654✔
369

370
                // If a track was NOT updated this frame, its state is the predicted state. Commit it.
371
                if (updated_groups_this_frame.find(group_id) == updated_groups_this_frame.end()) {
1,654✔
372
                    track.filter->initialize(predictions.at(group_id));
18✔
373
                }
374

375
                // Record histories aligned by frame
376
                track.forward_pass_history.push_back(track.filter->getState());
1,654✔
377
                auto pred_it_for_hist = predictions.find(group_id);
1,654✔
378
                if (pred_it_for_hist != predictions.end()) {
1,654✔
379
                    track.forward_prediction_history.push_back(pred_it_for_hist->second);
1,654✔
380
                } else {
381
                    // At activation frames we may not have a prediction; use current state as placeholder
UNCOV
382
                    track.forward_prediction_history.push_back(track.filter->getState());
×
383
                }
384
                track.processed_frames_history.push_back(current_frame);
1,654✔
385
                track.identity_confidence_history.push_back(track.identity_confidence.getConfidence());
1,654✔
386
                if (auto it_assigned = group_assigned_entity_in_frame.find(group_id); it_assigned != group_assigned_entity_in_frame.end()) {
1,654✔
387
                    track.assigned_entity_history.push_back(it_assigned->second);
1,466✔
388
                } else {
389
                    track.assigned_entity_history.push_back(std::nullopt);
188✔
390
                }
391

392
                // Check for smoothing trigger on new anchor frames
393
                bool is_anchor = (gt_frame_it != ground_truth.end() && gt_frame_it->second.count(group_id));
1,654✔
394
                if (is_anchor) {
1,654✔
395
                    if (std::find(track.anchor_frames.begin(), track.anchor_frames.end(), current_frame) == track.anchor_frames.end()) {
170✔
396
                        track.anchor_frames.push_back(current_frame);
170✔
397
                    }
398

399
                    if (track.anchor_frames.size() >= 2) {
170✔
400
                        /**************************************************************************
401
                         * RE-IMPLEMENTED SMOOTHING AND RE-ASSIGNMENT BLOCK
402
                         **************************************************************************/
403
                        any_smoothing_this_frame = true;
146✔
404
                        size_t const interval_size = track.processed_frames_history.size();
146✔
405

406
                        if (logger_) {
146✔
407
                            logger_->info("SMOOTH_BLOCK START g={} | interval=[{}, {}] | size={}",
144✔
408
                                          group_id, track.anchor_frames.front().getValue(),
144✔
409
                                          current_frame.getValue(), interval_size);
288✔
410
                        }
411

412
                        if (interval_size <= 1 || !assigner_ || !track.filter->supportsBackwardPrediction()) {
146✔
413
                            if (logger_) logger_->warn("SMOOTH_BLOCK SKIP g={} | interval too small or backward prediction not supported. Applying standard smoothing.", group_id);
2✔
414
                            auto smoothed = track.filter->smooth(track.forward_pass_history);
2✔
415
                            if (!smoothed.empty()) {
2✔
416
                                // Only skip the first element if we already have results (to avoid duplication)
417
                                auto start_it = all_smoothed_results[group_id].empty() ? smoothed.begin() : std::next(smoothed.begin());
2✔
418
                                all_smoothed_results[group_id].insert(all_smoothed_results[group_id].end(), start_it, smoothed.end());
2✔
419
                            }
420
                        } else {
2✔
421
                            // --- STEP 1A: GENERATE A TRUE BACKWARD-FILTERED HYPOTHESIS ---
422
                            std::vector<FilterState> bwd_predictions(interval_size);
432✔
423

424
                            auto bwd_filter = filter_prototype_->createBackwardFilter();
144✔
425
                            IdentityConfidence bwd_identity_confidence;// Backward filter gets its own confidence tracker
144✔
426

427
                            bwd_filter->initialize(track.forward_pass_history.back());
144✔
428
                            bwd_predictions[interval_size - 1] = track.forward_pass_history.back();
144✔
429
                            bwd_identity_confidence.resetOnGroundTruth();// Start with high confidence at the anchor
144✔
430

431
                            for (size_t i = interval_size - 1; i-- > 0;) {
3,024✔
432
                                // Predict state for frame i from the filter's state at i+1
433
                                auto pred_for_i = bwd_filter->predict();
1,440✔
434

435
                                bwd_predictions[i] = pred_for_i;// Store the prediction
1,440✔
436

437
                                // Now, perform a measurement update using data from frame i to constrain the backward filter's uncertainty
438
                                TimeFrameIndex frame_i = track.processed_frames_history[i];
1,440✔
439
                                auto fd_it = frame_data_lookup.find(frame_i);
1,440✔
440

441
                                // Find best assignment for this backward prediction
442
                                if (fd_it != frame_data_lookup.end() && !fd_it->second.empty()) {
1,440✔
443
                                    std::vector<Observation> observations;
1,440✔
444
                                    std::map<EntityId, FeatureCache> feature_cache;
1,440✔
445
                                    for (auto const & item: fd_it->second) {
4,320✔
446
                                        observations.push_back({std::get<1>(item)});
2,880✔
447
                                        feature_cache[std::get<1>(item)] = feature_extractor_->getAllFeatures(*std::get<0>(item));
2,880✔
448
                                    }
449

450
                                    auto bwd_assign = assigner_->solve({{group_id, pred_for_i}}, observations, feature_cache);
5,760✔
451
                                    if (!bwd_assign.observation_to_prediction.empty()) {
1,440✔
452
                                        auto const & [obs_idx, pred_idx] = *bwd_assign.observation_to_prediction.begin();
1,440✔
453
                                        EntityId entity_id = observations[obs_idx].entity_id;
1,440✔
454

455
                                        auto cost_it = bwd_assign.assignment_costs.find(obs_idx);
1,440✔
456
                                        if (cost_it != bwd_assign.assignment_costs.end()) {
1,440✔
457
                                            bwd_identity_confidence.updateOnAssignment(cost_it->second, bwd_assign.cost_threshold);
1,440✔
458
                                        }
459

460
                                        DataType const * data = nullptr;
1,440✔
461
                                        for (auto const & item: fd_it->second) {
2,250✔
462
                                            if (std::get<1>(item) == entity_id) {
2,250✔
463
                                                data = std::get<0>(item);
1,440✔
464
                                                break;
1,440✔
465
                                            }
466
                                        }
467

468
                                        if (data) {
1,440✔
469
                                            Measurement m = {feature_extractor_->getFilterFeatures(*data)};
1,440✔
470
                                            // USE THE BACKWARD CONFIDENCE to scale measurement noise
471
                                            double noise_scale = bwd_identity_confidence.getMeasurementNoiseScale();
1,440✔
472
                                            bwd_filter->update(pred_for_i, m, noise_scale);
1,440✔
473
                                        } else {
1,440✔
UNCOV
474
                                            bwd_filter->initialize(pred_for_i);
×
475
                                        }
476
                                    } else {
UNCOV
477
                                        bwd_filter->initialize(pred_for_i);
×
478
                                    }
479
                                } else {
1,440✔
UNCOV
480
                                    bwd_filter->initialize(pred_for_i);
×
481
                                }
482
                            }
483

484

485
                            // --- STEP 1B: RE-ASSIGNMENT BY RECONCILING FORWARD & BACKWARD HYPOTHESES ---
486
                            std::map<TimeFrameIndex, EntityId> revised_assignments;
144✔
487
                            std::map<TimeFrameIndex, double> revised_confidences;
144✔
488

489
                            double const interval_duration = static_cast<double>((track.processed_frames_history.back() - track.processed_frames_history.front()).getValue());
144✔
490

491

492
                            for (size_t i = 0; i < interval_size; ++i) {
3,024✔
493
                                TimeFrameIndex frame = track.processed_frames_history[i];
1,584✔
494

495
                                if (ground_truth.count(frame) && ground_truth.at(frame).count(group_id)) {
1,584✔
496
                                    revised_assignments[frame] = ground_truth.at(frame).at(group_id);
288✔
497
                                    revised_confidences[frame] = 1.0;
288✔
498
                                    continue;
288✔
499
                                }
500

501
                                auto fd_it = frame_data_lookup.find(frame);
1,296✔
502
                                if (fd_it == frame_data_lookup.end() || fd_it->second.empty()) continue;
1,296✔
503

504
                                std::vector<Observation> observations;
1,296✔
505
                                std::map<EntityId, FeatureCache> feature_cache;
1,296✔
506
                                for (auto const & item: fd_it->second) {
3,888✔
507
                                    observations.push_back({std::get<1>(item)});
2,592✔
508
                                    feature_cache[std::get<1>(item)] = feature_extractor_->getAllFeatures(*std::get<0>(item));
2,592✔
509
                                }
510

511
                                auto get_best_pick = [&](Assignment const & a, std::vector<Observation> const & obs) -> std::optional<std::pair<EntityId, double>> {
1,296✔
512
                                    if (a.observation_to_prediction.empty()) return std::nullopt;
2,592✔
513
                                    auto it = a.observation_to_prediction.begin();
2,592✔
514
                                    auto cost_it = a.assignment_costs.find(it->first);
2,592✔
515
                                    if (cost_it == a.assignment_costs.end()) return std::nullopt;
2,592✔
516
                                    return std::make_pair(obs.at(it->first).entity_id, cost_it->second);
2,592✔
517
                                };
518

519
                                auto fwd_assign = assigner_->solve({{group_id, track.forward_prediction_history[i]}}, observations, feature_cache);
5,184✔
520
                                auto fwd_pick = get_best_pick(fwd_assign, observations);
1,296✔
521

522
                                std::optional<std::pair<EntityId, double>> bwd_pick;
1,296✔
523
                                auto bwd_assign = assigner_->solve({{group_id, bwd_predictions[i]}}, observations, feature_cache);
5,184✔
524
                                bwd_pick = get_best_pick(bwd_assign, observations);
1,296✔
525

526

527
                                double fwd_cov_tr = track.forward_prediction_history[i].state_covariance.trace();
1,296✔
528
                                double bwd_cov_tr = bwd_predictions[i].state_covariance.trace();
1,296✔
529

530
                                if (logger_) {
1,296✔
531
                                    EntityId fwd_entity = fwd_pick ? fwd_pick->first : -1;
1,296✔
532
                                    EntityId bwd_entity = bwd_pick ? bwd_pick->first : -1;
1,296✔
533
                                    logger_->debug("RECONCILE f={} g={} | FWD: entity={}, cost={:.4f}, cov_tr={:.4f} | BWD: entity={}, cost={:.4f}, cov_tr={:.4f}",
1,296✔
534
                                                   frame.getValue(), group_id, fwd_entity, fwd_pick->second, fwd_cov_tr, bwd_entity, bwd_pick->second, bwd_cov_tr);
1,296✔
535
                                }
536

537
                                // --- NEW WEIGHTED DECISION LOGIC ---
538
                                double forward_weight = 1.0;
1,296✔
539
                                if (interval_duration > 0) {
1,296✔
540
                                    double frame_pos = static_cast<double>((frame - track.processed_frames_history.front()).getValue());
1,296✔
541
                                    forward_weight = 1.0 - (frame_pos / interval_duration);
1,296✔
542
                                }
543

544
                                // Add a small epsilon to avoid division by zero
545
                                double const epsilon = 1e-9;
1,296✔
546
                                double fwd_trust = forward_weight + epsilon;
1,296✔
547
                                double bwd_trust = (1.0 - forward_weight) + epsilon;
1,296✔
548

549
                                // Score is uncertainty divided by our trust in the hypothesis. Lower score is better.
550
                                double fwd_score = fwd_cov_tr / fwd_trust;
1,296✔
551
                                double bwd_score = bwd_cov_tr / bwd_trust;
1,296✔
552

553
                                bool use_bwd = false;
1,296✔
554
                                if (fwd_pick && bwd_pick) {
1,296✔
555
                                    if (bwd_score < fwd_score) {
1,296✔
556
                                        use_bwd = true;
566✔
557
                                    }
UNCOV
558
                                } else if (bwd_pick) {
×
UNCOV
559
                                    use_bwd = true;
×
560
                                }
561

562

563
                                auto const & winner_pick = use_bwd ? bwd_pick : fwd_pick;
1,296✔
564
                                if (!winner_pick) continue;
1,296✔
565

566
                                revised_assignments[frame] = winner_pick->first;
1,296✔
567
                                pending_updates.replaceForFrame(group_id, frame, winner_pick->first);
1,296✔
568

569
                                if (logger_) {
1,296✔
570
                                    EntityId original_entity = track.assigned_entity_history[i].has_value() ? track.assigned_entity_history[i].value() : -1;
1,296✔
571
                                    if (original_entity != winner_pick->first) {
1,296✔
572
                                         logger_->info("RECONCILE_WINNER f={} g={} | winner={} chosen_entity={} (original={}) | Decision: BWD_score={:.2f} < FWD_score={:.2f}",
160✔
573
                                                      frame.getValue(), group_id, (use_bwd ? "BWD" : "FWD"), winner_pick->first, original_entity, bwd_score, fwd_score);
160✔
574
                                    }
575
                                }
576

577
                                IdentityConfidence temp_conf;
1,296✔
578
                                temp_conf.updateOnAssignment(winner_pick->second, assigner_->getCostThreshold());
1,296✔
579
                                revised_confidences[frame] = temp_conf.getConfidence();
1,296✔
580
                            }
581

582
                            // --- STEP 2: RE-FILTER THE INTERVAL WITH THE CORRECTED ASSIGNMENTS ---
583
                            // This creates a single, consistent history based on the best assignments.
584
                            std::vector<FilterState> corrected_history;
144✔
585
                            auto temp_filter = filter_prototype_->clone();
144✔
586
                            temp_filter->initialize(track.forward_pass_history.front());// Start from the first anchor's state
144✔
587
                            corrected_history.push_back(temp_filter->getState());
144✔
588

589
                            for (size_t i = 1; i < interval_size; ++i) {
3,024✔
590
                                TimeFrameIndex frame = track.processed_frames_history[i];
1,440✔
591
                                FilterState pred = temp_filter->predict();
1,440✔
592

593
                                if (revised_assignments.count(frame)) {
1,440✔
594
                                    EntityId entity_id = revised_assignments.at(frame);
1,440✔
595

596
                                    // Find the data for the revised entity in its corresponding historical frame
597
                                    DataType const * data = nullptr;
1,440✔
598
                                    auto const & past_frame_data_it = frame_data_lookup.find(frame);
1,440✔
599
                                    if (past_frame_data_it != frame_data_lookup.end()) {
1,440✔
600
                                        for (auto const & item: past_frame_data_it->second) {
2,240✔
601
                                            if (std::get<1>(item) == entity_id) {
2,240✔
602
                                                data = std::get<0>(item);
1,440✔
603
                                                break;
1,440✔
604
                                            }
605
                                        }
606
                                    }
607

608
                                    if (data) {
1,440✔
609
                                        Measurement m = {feature_extractor_->getFilterFeatures(*data)};
1,440✔
610
                                        // The noise scale is based on the confidence of the *winning* assignment
611
                                        double confidence = revised_confidences.count(frame) ? revised_confidences.at(frame) : 0.5;
1,440✔
612
                                        double noise_scale = std::pow(10.0, 2.0 * (1.0 - confidence));
1,440✔
613
                                        temp_filter->update(pred, m, noise_scale);
1,440✔
614
                                        if (logger_) {
1,440✔
615
                                            logger_->debug("RE-FILTER f={} g={} | entity={} noise_scale={:.3f} new_cov_tr={:.4f}",
1,440✔
616
                                                           frame.getValue(), group_id, entity_id, noise_scale, temp_filter->getState().state_covariance.trace());
1,440✔
617
                                        }
618
                                    } else {
1,440✔
UNCOV
619
                                        temp_filter->initialize(pred);// Coast if data not found
×
UNCOV
620
                                        if (logger_) logger_->warn("RE-FILTER f={} g={} | entity {} not found in frame data, coasting.", frame.getValue(), group_id, entity_id);
×
621
                                    }
622
                                } else {
UNCOV
623
                                    temp_filter->initialize(pred);// Coast if no assignment was made
×
UNCOV
624
                                    if (logger_) logger_->debug("RE-FILTER f={} g={} | no revised assignment, coasting.", frame.getValue(), group_id);
×
625
                                }
626
                                corrected_history.push_back(temp_filter->getState());
1,440✔
627
                            }
628

629
                            // --- STEP 3: (REGRESSION) SMOOTH THE CORRECTED HISTORY ---
630
                            // Now that assignments are fixed, run a standard RTS smoother to get the best state estimates.
631
                            auto smoothed = track.filter->smooth(corrected_history);
144✔
632

633
                            // --- STEP 4: APPLY ASSIGNMENT-AWARE COVARIANCE INFLATION ---
634
                            // The smoother often produces over-confident covariances. We inflate them based on the
635
                            // confidence of our assignment choice to better reflect the true uncertainty.
636
                            if (smoothed.size() == corrected_history.size()) {
144✔
637
                                for (size_t i = 0; i < smoothed.size(); ++i) {
1,728✔
638
                                    TimeFrameIndex frame = track.processed_frames_history[i];
1,584✔
639
                                    if (revised_confidences.count(frame)) {
1,584✔
640
                                        double conf = revised_confidences.at(frame);
1,584✔
641
                                        // Inflate covariance more for low-confidence assignments.
642
                                        // Example: 1.0 confidence -> 1.0x inflation. 0.5 confidence -> 1.5x inflation.
643
                                        double inflation_factor = 1.0 + (1.0 - conf);
1,584✔
644
                                        smoothed[i].state_covariance *= inflation_factor;
1,584✔
645
                                    }
646
                                }
647
                            }
648

649
                            // Store results, excluding the first element only if we already have results (to avoid duplication)
650
                            if (!smoothed.empty()) {
144✔
651
                                auto start_it = all_smoothed_results[group_id].empty() ? smoothed.begin() : std::next(smoothed.begin());
272✔
652
                                all_smoothed_results[group_id].insert(all_smoothed_results[group_id].end(), start_it, smoothed.end());
144✔
653
                            }
654
                        }
144✔
655

656

657
                        /**************************************************************************
658
                         * END OF RE-IMPLEMENTED BLOCK
659
                         **************************************************************************/
660

661
                        // Collapse histories to keep only the last element for continuity
662
                        track.forward_pass_history = {track.forward_pass_history.back()};
438✔
663
                        track.forward_prediction_history = {track.forward_prediction_history.back()};
438✔
664
                        track.processed_frames_history = {track.processed_frames_history.back()};
146✔
665
                        track.assigned_entity_history = {track.assigned_entity_history.back()};
146✔
666
                        track.anchor_frames = {current_frame};
146✔
667
                    }
668
                }// end of anchor frame backward pass
669
            }
670

671
            // OPTIMIZATION 1: Flush pending updates at anchor frames (smoothing boundaries)
672
            // This is the natural synchronization point between tracking intervals
673
            if (any_smoothing_this_frame) {
927✔
674
                pending_updates.flushToManager(group_manager);
73✔
675

676
                // Update the cached grouped entities set for the next interval
677
                auto const & newly_added = pending_updates.getAddedEntities();
73✔
678
                initially_grouped_entities.insert(newly_added.begin(), newly_added.end());
73✔
679
            }
680
        }
681

682
        // Final flush of any remaining pending updates
683
        pending_updates.flushToManager(group_manager);
13✔
684

685
        return all_smoothed_results;
26✔
686
    }
5,827✔
687

688
    /**
689
     * @brief Gets the current identity confidence for a specific group.
690
     * @param group_id The group to query
691
     * @return The current identity confidence [0.1, 1.0], or -1.0 if group not found
692
     */
693
    double
694
    getIdentityConfidence(GroupId group_id) const {
4✔
695
        auto it = active_tracks_.find(group_id);
4✔
696
        if (it == active_tracks_.end()) return -1.0;
4✔
697
        return it->second.identity_confidence.getConfidence();
4✔
698
    }
699

700
    /**
701
     * @brief Gets the measurement noise scale factor for a specific group.
702
     * @param group_id The group to query
703
     * @return The current noise scale factor, or -1.0 if group not found
704
     */
705
    double getMeasurementNoiseScale(GroupId group_id) const {
4✔
706
        auto it = active_tracks_.find(group_id);
4✔
707
        if (it == active_tracks_.end()) return -1.0;
4✔
708
        return it->second.identity_confidence.getMeasurementNoiseScale();
4✔
709
    }
710

711
    /**
712
     * @brief Gets the minimum confidence since last anchor for a specific group.
713
     * @param group_id The group to query
714
     * @return The minimum confidence since last anchor, or -1.0 if group not found
715
     */
716
    double getMinConfidenceSinceAnchor(GroupId group_id) const {
2✔
717
        auto it = active_tracks_.find(group_id);
2✔
718
        if (it == active_tracks_.end()) return -1.0;
2✔
719
        return it->second.identity_confidence.getMinConfidenceSinceAnchor();
2✔
720
    }
721

722
private:
723
    /**
724
     * @brief Processes ground truth updates and activates tracks for the current frame.
725
     * 
726
     * @param current_frame The current frame being processed
727
     * @param ground_truth The ground truth map for all frames
728
     * @param all_frame_data The data for the current frame
729
     * @param entity_to_index Fast lookup map from EntityId to data index
730
     * @param predictions The predicted states for all active tracks
731
     * @param updated_groups_this_frame Set of groups updated this frame (modified)
732
     * @param assigned_entities_this_frame Set of entities assigned this frame (modified)
733
     */
734
    void processGroundTruthUpdates(
927✔
735
            TimeFrameIndex current_frame,
736
            GroundTruthMap const & ground_truth,
737
            std::vector<std::tuple<DataType const *, EntityId, TimeFrameIndex>> const & all_frame_data,
738
            std::unordered_map<EntityId, size_t> const & entity_to_index,
739
            std::map<GroupId, FilterState> const & predictions,
740
            std::set<GroupId> & updated_groups_this_frame,
741
            std::set<EntityId> & assigned_entities_this_frame) {
742

743
        auto gt_frame_it = ground_truth.find(current_frame);
927✔
744
        if (gt_frame_it != ground_truth.end()) {
927✔
745
            for (auto const & [group_id, entity_id]: gt_frame_it->second) {
255✔
746
                auto track_it = active_tracks_.find(group_id);
170✔
747
                if (track_it == active_tracks_.end()) continue;
170✔
748
                auto & track = track_it->second;
170✔
749

750
                // OPTIMIZATION 2: O(1) entity lookup instead of O(E) linear search
751
                auto entity_it = entity_to_index.find(entity_id);
170✔
752
                if (entity_it == entity_to_index.end()) continue;
170✔
753

754
                DataType const * gt_item = std::get<0>(all_frame_data[entity_it->second]);
170✔
755

756
                if (!track.is_active) {
170✔
757
                    track.filter->initialize(feature_extractor_->getInitialState(*gt_item));
24✔
758
                    track.is_active = true;
24✔
759
                } else {
760
                    Measurement measurement = {feature_extractor_->getFilterFeatures(*gt_item)};
146✔
761
                    // Strengthen anchor certainty by reducing measurement noise at GT frames
762
                    double const anchor_noise_scale = 0.25;
146✔
763
                    track.filter->update(predictions.at(group_id), measurement, anchor_noise_scale);
146✔
764
                }
146✔
765

766
                // Reset identity confidence on ground truth updates
767
                track.identity_confidence.resetOnGroundTruth();
170✔
768

769
                track.frames_since_last_seen = 0;
170✔
770
                updated_groups_this_frame.insert(group_id);
170✔
771
                assigned_entities_this_frame.insert(entity_id);
170✔
772
            }
773
        }
774
    }
927✔
775

776
    std::unique_ptr<IFilter> filter_prototype_;
777
    std::unique_ptr<IFeatureExtractor<DataType>> feature_extractor_;
778
    std::unique_ptr<IAssigner> assigner_;
779
    std::unordered_map<GroupId, TrackedGroupState<DataType>> active_tracks_;
780
    std::shared_ptr<spdlog::logger> logger_;
781
};
782

783

784
/**
785
 * @brief Holds the state for a single tracked group.
786
 * @tparam DataType The raw data type (e.g., Line2D, Point2D).
787
 */
788
template<typename DataType>
789
struct TrackedGroupState {
790
    GroupId group_id;
791
    std::unique_ptr<IFilter> filter;
792
    bool is_active = false;
793
    int frames_since_last_seen = 0;
794
    double confidence = 1.0;
795

796
    // Identity confidence tracking for assignment uncertainty
797
    IdentityConfidence identity_confidence;
798

799
    // History for smoothing between anchors
800
    std::vector<TimeFrameIndex> anchor_frames = {};
801
    std::vector<FilterState> forward_pass_history = {};
802
    // Auxiliary histories aligned with forward_pass_history indices
803
    std::vector<FilterState> forward_prediction_history = {};
804
    std::vector<TimeFrameIndex> processed_frames_history = {};
805
    std::vector<double> identity_confidence_history = {};
806
    std::vector<std::optional<EntityId>> assigned_entity_history = {};
807
};
808

809
}// namespace StateEstimation
810

811

812
#endif// TRACKER_HPP
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc