• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

paulmthompson / WhiskerToolbox / 18477247352

13 Oct 2025 08:18PM UTC coverage: 72.391% (+0.4%) from 71.943%
18477247352

push

github

web-flow
Merge pull request #140 from paulmthompson/kdtree

Jules PR

164 of 287 new or added lines in 3 files covered. (57.14%)

350 existing lines in 9 files now uncovered.

51889 of 71679 relevant lines covered (72.39%)

63071.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.39
/src/DataManager/transforms/Lines/Line_Kalman_Grouping/line_kalman_grouping.cpp
1
#include "line_kalman_grouping.hpp"
2

3
#include "Entity/EntityGroupManager.hpp"
4
#include "Entity/EntityTypes.hpp"
5
#include "Lines/Line_Data.hpp"
6
#include "StateEstimation/DataAdapter.hpp"
7
#include "StateEstimation/Features/CompositeFeatureExtractor.hpp"
8
#include "StateEstimation/Features/LineBasePointExtractor.hpp"
9
#include "StateEstimation/Features/LineCentroidExtractor.hpp"
10
#include "StateEstimation/Features/LineLengthExtractor.hpp"
11
#include "StateEstimation/Filter/Kalman/KalmanMatrixBuilder.hpp"
12
#include "StateEstimation/Features/IFeatureExtractor.hpp"
13
#include "StateEstimation/Filter/Kalman/KalmanFilter.hpp"
14
#include "StateEstimation/MinCostFlowTracker.hpp"
15

16
#include <algorithm>
17
#include <cmath>
18
#include <iostream>
19
#include <numeric>
20
#include <ranges>
21

22
namespace {
23

24
/**
25
 * @brief Statistics for a feature extracted from ground truth data
26
 */
27
struct FeatureStatistics {
28
    double mean = 0.0;
29
    double variance = 0.0;
30
    double std_dev = 0.0;
31
    double mean_frame_to_frame_change = 0.0;
32
    double variance_frame_to_frame_change = 0.0;
33
    int num_samples = 0;
34
    int num_transitions = 0;
35
};
36

37
/**
38
 * @brief Cross-correlation statistics between two features
39
 */
40
struct CrossCorrelationStatistics {
41
    double pearson_correlation = 0.0;  // Pearson correlation coefficient (-1 to 1)
42
    int num_paired_samples = 0;
43
    bool is_valid = false;
44
};
45

46
/**
47
 * @brief Analyze ground truth data to estimate realistic noise parameters
48
 * 
49
 * For static features (like length), computes:
50
 * - Mean and variance of the feature across all ground truth
51
 * - Mean and variance of frame-to-frame changes (process noise estimate)
52
 * 
53
 * @param line_data The LineData containing lines
54
 * @param ground_truth Map of ground truth assignments
55
 * @param feature_extractor The feature extractor to analyze
56
 * @param feature_name Name of the feature for logging
57
 * @return Statistics computed from ground truth data
58
 */
59
template<typename FeatureExtractor>
60
FeatureStatistics analyzeGroundTruthFeatureStatistics(
1✔
61
        std::shared_ptr<LineData> const & line_data,
62
        std::map<TimeFrameIndex, std::map<GroupId, EntityId>> const & ground_truth,
63
        FeatureExtractor const & feature_extractor,
64
        std::string const & feature_name) {
65
    (void)feature_name;
66

67
    FeatureStatistics stats;
1✔
68

69
    // Collect all feature values and frame-to-frame changes per group
70
    std::map<GroupId, std::vector<double>> group_feature_values;
1✔
71

72
    for (auto const & [time, group_assignments]: ground_truth) {
11✔
73
        for (auto const & [group_id, entity_id]: group_assignments) {
50✔
74
            // Get the line for this entity
75
            auto line = line_data->getLineByEntityId(entity_id);
20✔
76
            if (!line.has_value()) continue;// Entity not found
20✔
77

78
            // Extract feature
79
            Eigen::VectorXd features = feature_extractor.getFilterFeatures(line.value());
20✔
80

81
            // For now, assume 1D features (extend later if needed)
82
            if (features.size() == 1) {
20✔
83
                group_feature_values[group_id].push_back(features(0));
20✔
84
            }
85
        }
86
    }
87

88
    // Compute statistics
89
    std::vector<double> all_values;
1✔
90
    std::vector<double> all_changes;
1✔
91

92
    for (auto const & [group_id, values]: group_feature_values) {
3✔
93
        if (values.empty()) continue;
2✔
94

95
        // Collect all values
96
        all_values.insert(all_values.end(), values.begin(), values.end());
2✔
97

98
        // Compute frame-to-frame changes
99
        for (size_t i = 1; i < values.size(); ++i) {
20✔
100
            double change = std::abs(values[i] - values[i - 1]);
18✔
101
            all_changes.push_back(change);
18✔
102
        }
103
    }
104

105
    if (all_values.empty()) {
1✔
UNCOV
106
        return stats;// No data
×
107
    }
108

109
    // Compute mean
110
    stats.mean = std::accumulate(all_values.begin(), all_values.end(), 0.0) / all_values.size();
1✔
111
    stats.num_samples = static_cast<int>(all_values.size());
1✔
112

113
    // Compute variance
114
    double sum_sq_diff = 0.0;
1✔
115
    for (double val: all_values) {
21✔
116
        double diff = val - stats.mean;
20✔
117
        sum_sq_diff += diff * diff;
20✔
118
    }
119
    stats.variance = sum_sq_diff / all_values.size();
1✔
120
    stats.std_dev = std::sqrt(stats.variance);
1✔
121

122
    // Compute frame-to-frame change statistics (process noise estimate)
123
    if (!all_changes.empty()) {
1✔
124
        stats.mean_frame_to_frame_change = std::accumulate(all_changes.begin(), all_changes.end(), 0.0) / all_changes.size();
1✔
125
        stats.num_transitions = static_cast<int>(all_changes.size());
1✔
126

127
        double sum_sq_change_diff = 0.0;
1✔
128
        for (double change: all_changes) {
19✔
129
            double diff = change - stats.mean_frame_to_frame_change;
18✔
130
            sum_sq_change_diff += diff * diff;
18✔
131
        }
132
        stats.variance_frame_to_frame_change = sum_sq_change_diff / all_changes.size();
1✔
133
    }
134

135
    return stats;
1✔
136
}
1✔
137

138
/**
139
 * @brief Compute empirical correlation between two features from ground truth data
140
 * 
141
 * Uses Pearson correlation coefficient to measure linear relationship between features.
142
 * This is computed from actual observed data, not assumptions.
143
 * 
144
 * @param line_data The LineData containing lines
145
 * @param ground_truth Map of ground truth assignments
146
 * @param extractor_a First feature extractor
147
 * @param extractor_b Second feature extractor
148
 * @param feature_a_name Name of first feature (for logging)
149
 * @param feature_b_name Name of second feature (for logging)
150
 * @return Cross-correlation statistics
151
 */
152
template<typename ExtractorA, typename ExtractorB>
UNCOV
153
CrossCorrelationStatistics computeFeatureCrossCorrelation(
×
154
        std::shared_ptr<LineData> const & line_data,
155
        std::map<TimeFrameIndex, std::map<GroupId, EntityId>> const & ground_truth,
156
        ExtractorA const & extractor_a,
157
        ExtractorB const & extractor_b,
158
        std::string const & feature_a_name,
159
        std::string const & feature_b_name) {
160
    (void)feature_a_name;
161
    (void)feature_b_name;
162

163
    CrossCorrelationStatistics stats;
×
164
    
165
    // Collect paired feature values across all groups and times
166
    std::vector<double> values_a, values_b;
×
167
    
168
    for (auto const & [time, group_assignments]: ground_truth) {
×
UNCOV
169
        for (auto const & [group_id, entity_id]: group_assignments) {
×
UNCOV
170
            auto line = line_data->getLineByEntityId(entity_id);
×
171
            if (!line.has_value()) continue;
×
172
            
173
            // Extract both features
UNCOV
174
            Eigen::VectorXd feat_a = extractor_a.getFilterFeatures(line.value());
×
175
            Eigen::VectorXd feat_b = extractor_b.getFilterFeatures(line.value());
×
176
            
177
            // For multi-dimensional features, use first component or magnitude
178
            double val_a = (feat_a.size() == 1) ? feat_a(0) : feat_a.norm();
×
179
            double val_b = (feat_b.size() == 1) ? feat_b(0) : feat_b.norm();
×
180
            
UNCOV
181
            values_a.push_back(val_a);
×
UNCOV
182
            values_b.push_back(val_b);
×
183
        }
184
    }
185
    
UNCOV
186
    if (values_a.size() < 3) {
×
187
        return stats;  // Not enough data for meaningful correlation
×
188
    }
189
    
190
    stats.num_paired_samples = static_cast<int>(values_a.size());
×
191
    
192
    // Compute means
UNCOV
193
    double mean_a = std::accumulate(values_a.begin(), values_a.end(), 0.0) / values_a.size();
×
194
    double mean_b = std::accumulate(values_b.begin(), values_b.end(), 0.0) / values_b.size();
×
195
    
196
    // Compute covariance and standard deviations
UNCOV
197
    double cov_ab = 0.0;
×
198
    double var_a = 0.0;
×
199
    double var_b = 0.0;
×
200
    
UNCOV
201
    for (size_t i = 0; i < values_a.size(); ++i) {
×
202
        double diff_a = values_a[i] - mean_a;
×
203
        double diff_b = values_b[i] - mean_b;
×
204
        
UNCOV
205
        cov_ab += diff_a * diff_b;
×
UNCOV
206
        var_a += diff_a * diff_a;
×
207
        var_b += diff_b * diff_b;
×
208
    }
209
    
UNCOV
210
    cov_ab /= values_a.size();
×
UNCOV
211
    var_a /= values_a.size();
×
212
    var_b /= values_b.size();
×
213
    
214
    // Compute Pearson correlation: ρ = cov(A,B) / (σ_A × σ_B)
215
    double std_a = std::sqrt(var_a);
×
216
    double std_b = std::sqrt(var_b);
×
217
    
UNCOV
218
    if (std_a > 1e-10 && std_b > 1e-10) {
×
UNCOV
219
        stats.pearson_correlation = cov_ab / (std_a * std_b);
×
220
        stats.is_valid = true;
×
221
    }
222
    
UNCOV
223
    return stats;
×
UNCOV
224
}
×
225

226
}// anonymous namespace
227

228

229
std::shared_ptr<LineData> lineKalmanGrouping(std::shared_ptr<LineData> line_data,
13✔
230
                                             LineKalmanGroupingParameters const * params) {
231
    // No-op progress callback
232
    return ::lineKalmanGrouping(std::move(line_data), params, [](int) { /* no progress reporting */ });
13✔
233
}
234

235
std::shared_ptr<LineData> lineKalmanGrouping(std::shared_ptr<LineData> line_data,
13✔
236
                                             LineKalmanGroupingParameters const * params,
237
                                             ProgressCallback const & progressCallback) {
238
    if (!line_data || !params) {
13✔
239
        return line_data;
2✔
240
    }
241

242
    // Check if group manager is valid (required for grouping operations)
243
    if (!params->hasValidGroupManager()) {
11✔
244
        std::cerr << "lineKalmanGrouping: EntityGroupManager is required but not set. Call setGroupManager() on parameters before execution." << std::endl;
1✔
245
        return line_data;
1✔
246
    }
247

248
    using namespace StateEstimation;
249

250
    auto group_manager = params->getGroupManager();
10✔
251

252
    // Get all time frames with data
253
    auto times_view = line_data->getTimesWithData();
10✔
254
    std::vector<TimeFrameIndex> all_times(times_view.begin(), times_view.end());
30✔
255
    if (all_times.empty()) {
10✔
256
        progressCallback(100);
1✔
257
        return line_data;
1✔
258
    }
259

260
    std::sort(all_times.begin(), all_times.end());
9✔
261
    TimeFrameIndex start_frame = all_times.front();
9✔
262
    TimeFrameIndex end_frame = all_times.back();
9✔
263

264
    if (params->verbose_output) {
9✔
265
        std::cout << "Processing " << all_times.size() << " frames from "
2✔
266
                  << start_frame.getValue() << " to " << end_frame.getValue() << std::endl;
2✔
267
    }
268

269
    // Get natural iterator from LineData and flatten to individual items
270
    // This provides zero-copy access to Line2D objects
271
    auto line_entries_range = line_data->GetAllLineEntriesAsRange();
9✔
272
    auto data_source = StateEstimation::flattenLineData(line_entries_range);
9✔
273

274
    if (params->verbose_output) {
9✔
275
        std::cout << "Created zero-copy data source from LineData" << std::endl;
2✔
276
    }
277

278
    // Build GroundTruthMap: frames where entities are already grouped
279
    std::map<TimeFrameIndex, std::map<GroupId, EntityId>> ground_truth;
9✔
280
    auto all_group_ids = group_manager->getAllGroupIds();
9✔
281

282
    for (auto group_id: all_group_ids) {
27✔
283
        auto const & entities_in_group = group_manager->getEntitiesInGroup(group_id);
18✔
284

285
        for (auto entity_id: entities_in_group) {
340✔
286
            // Find which frame this entity belongs to
287
            auto time_info = line_data->getTimeAndIndexByEntityId(entity_id);
322✔
288
            if (time_info.has_value()) {
322✔
289
                ground_truth[time_info->first][group_id] = entity_id;
322✔
290
            }
291
        }
292
    }
18✔
293

294
    if (params->verbose_output) {
9✔
295
        std::cout << "Found " << all_group_ids.size() << " existing groups with "
2✔
296
                  << ground_truth.size() << " ground truth frames" << std::endl;
2✔
297
    }
298

299
    // Create composite feature extractor with centroid + base point + length
300
    // Uses metadata-driven approach to handle different feature types
301
    // - Centroid & base point: KINEMATIC_2D (position + velocity)
302
    // - Length: STATIC (no velocity tracking)
303
    auto composite_extractor = std::make_unique<StateEstimation::CompositeFeatureExtractor<Line2D>>();
9✔
304
    composite_extractor->addExtractor(std::make_unique<StateEstimation::LineCentroidExtractor>());
9✔
305
    composite_extractor->addExtractor(std::make_unique<StateEstimation::LineBasePointExtractor>());
9✔
306
    composite_extractor->addExtractor(std::make_unique<StateEstimation::LineLengthExtractor>());
9✔
307

308
    // Auto-estimate cross-feature correlations from ground truth data if requested
309
    std::map<std::pair<int, int>, double> estimated_correlations;
9✔
310
    
311
    if (params->enable_cross_feature_covariance && !ground_truth.empty()) {
9✔
UNCOV
312
        if (params->verbose_output) {
×
313
            std::cout << "\n=== Auto-Estimating Cross-Feature Correlations ===" << std::endl;
×
314
        }
315
        
316
        // Create extractors for correlation analysis
UNCOV
317
        StateEstimation::LineCentroidExtractor centroid_extractor;
×
318
        StateEstimation::LineBasePointExtractor base_point_extractor;
×
UNCOV
319
        StateEstimation::LineLengthExtractor length_extractor;
×
320
        
321
        // Compute centroid-length correlation
322
        auto centroid_length_corr = computeFeatureCrossCorrelation(
×
323
            line_data, ground_truth, centroid_extractor, length_extractor,
324
            "centroid", "length");
×
325
        
326
        // Compute base_point-length correlation
UNCOV
327
        auto base_point_length_corr = computeFeatureCrossCorrelation(
×
328
            line_data, ground_truth, base_point_extractor, length_extractor,
UNCOV
329
            "base_point", "length");
×
330
        
331
        if (params->verbose_output) {
×
332
            std::cout << "Centroid-Length correlation: " << centroid_length_corr.pearson_correlation
×
333
                     << " (n=" << centroid_length_corr.num_paired_samples << ")" << std::endl;
×
334
            std::cout << "BasePoint-Length correlation: " << base_point_length_corr.pearson_correlation
×
335
                     << " (n=" << base_point_length_corr.num_paired_samples << ")" << std::endl;
×
336
        }
337
        
338
        // Apply correlations above threshold
339
        // Feature indices: 0 = centroid, 1 = base_point, 2 = length
340
        if (centroid_length_corr.is_valid && 
×
341
            std::abs(centroid_length_corr.pearson_correlation) >= params->min_correlation_threshold) {
×
342
            estimated_correlations[{0, 2}] = centroid_length_corr.pearson_correlation;
×
343
            if (params->verbose_output) {
×
344
                std::cout << "  → Using centroid-length correlation: " 
×
345
                         << centroid_length_corr.pearson_correlation << std::endl;
×
346
            }
347
        }
348
        
349
        if (base_point_length_corr.is_valid && 
×
350
            std::abs(base_point_length_corr.pearson_correlation) >= params->min_correlation_threshold) {
×
351
            estimated_correlations[{1, 2}] = base_point_length_corr.pearson_correlation;
×
UNCOV
352
            if (params->verbose_output) {
×
353
                std::cout << "  → Using base_point-length correlation: " 
×
UNCOV
354
                         << base_point_length_corr.pearson_correlation << std::endl;
×
355
            }
356
        }
357
        
358
        if (estimated_correlations.empty() && params->verbose_output) {
×
359
            std::cout << "  → No significant correlations found (all below threshold "
×
UNCOV
360
                     << params->min_correlation_threshold << ")" << std::endl;
×
361
        }
362
    }
×
363
    
364
    // Configure cross-feature covariance in composite extractor
365
    if (!estimated_correlations.empty()) {
9✔
UNCOV
366
        StateEstimation::CompositeFeatureExtractor<Line2D>::CrossCovarianceConfig cross_cov_config;
×
UNCOV
367
        cross_cov_config.feature_correlations = estimated_correlations;
×
UNCOV
368
        composite_extractor->setCrossCovarianceConfig(std::move(cross_cov_config));
×
369
        
UNCOV
370
        if (params->verbose_output) {
×
UNCOV
371
            std::cout << "Configured initial cross-feature covariance from empirical correlations" << std::endl;
×
372
        }
UNCOV
373
    }
×
374

375
    // Get metadata from all child extractors
376
    // This automatically handles different temporal behaviors (kinematic, static, etc.)
377
    auto metadata_list = composite_extractor->getChildMetadata();
9✔
378

379
    if (params->verbose_output) {
9✔
380
        std::cout << "Building Kalman filter for " << metadata_list.size() << " features:" << std::endl;
2✔
381
        int total_meas = 0, total_state = 0;
2✔
382
        for (auto const & meta: metadata_list) {
8✔
383
            std::cout << "  - " << meta.name << ": "
6✔
384
                      << meta.measurement_size << "D measurement → "
6✔
385
                      << meta.state_size << "D state";
6✔
386
            if (meta.hasDerivatives()) {
6✔
387
                std::cout << " (with derivatives)";
4✔
388
            }
389
            std::cout << std::endl;
6✔
390
            total_meas += meta.measurement_size;
6✔
391
            total_state += meta.state_size;
6✔
392
        }
393
        std::cout << "Total measurement space: " << total_meas << "D" << std::endl;
2✔
394
        std::cout << "Total state space: " << total_state << "D" << std::endl;
2✔
395
    }
396

397
    // Auto-estimate noise parameters from ground truth data if requested
398
    double estimated_length_process_noise_scale = params->static_feature_process_noise_scale;
9✔
399
    double estimated_length_measurement_noise = params->measurement_noise_length;
9✔
400

401
    if (params->auto_estimate_static_noise || params->auto_estimate_measurement_noise) {
9✔
402
        StateEstimation::LineLengthExtractor length_extractor;
1✔
403
        auto length_stats = analyzeGroundTruthFeatureStatistics(
3✔
404
                line_data, ground_truth, length_extractor, "line_length");
1✔
405

406
        if (length_stats.num_samples > 0) {
1✔
407
            if (params->verbose_output) {
1✔
408
                std::cout << "\n=== Ground Truth Length Statistics ===" << std::endl;
1✔
409
                std::cout << "Samples: " << length_stats.num_samples << std::endl;
1✔
410
                std::cout << "Mean length: " << length_stats.mean << " pixels" << std::endl;
1✔
411
                std::cout << "Std dev: " << length_stats.std_dev << " pixels" << std::endl;
1✔
412
                std::cout << "Frame-to-frame changes: " << length_stats.num_transitions << " transitions" << std::endl;
1✔
413
                std::cout << "Mean absolute change: " << length_stats.mean_frame_to_frame_change << " pixels/frame" << std::endl;
1✔
414
                std::cout << "Std dev of changes: " << std::sqrt(length_stats.variance_frame_to_frame_change) << " pixels/frame" << std::endl;
1✔
415
            }
416

417
            if (params->auto_estimate_static_noise && length_stats.num_transitions > 0) {
1✔
418
                // Use the observed frame-to-frame variance as basis for process noise
419
                // Apply the percentile scaling (e.g., 10% of observed variation)
420
                double observed_change_variance = length_stats.variance_frame_to_frame_change;
1✔
421

422
                // Scale: we want Q = (percentile × change_std_dev)²
423
                // But Q is scaled by static_noise_scale × position_var
424
                // So: static_noise_scale × position_var² = (percentile × change_std_dev)²
425
                // Therefore: static_noise_scale = (percentile × change_std_dev)² / position_var²
426

427
                double change_std_dev = std::sqrt(observed_change_variance);
1✔
428
                double target_process_std = params->static_noise_percentile * change_std_dev;
1✔
429

430
                estimated_length_process_noise_scale =
1✔
431
                        (target_process_std * target_process_std) /
1✔
432
                        (params->process_noise_position * params->process_noise_position);
1✔
433

434
                if (params->verbose_output) {
1✔
435
                    std::cout << "\nAuto-estimated static noise:" << std::endl;
1✔
436
                    std::cout << "  Target process std dev: " << target_process_std << " pixels/frame" << std::endl;
1✔
437
                    std::cout << "  Computed scale factor: " << estimated_length_process_noise_scale << std::endl;
1✔
438
                    std::cout << "  (was: " << params->static_feature_process_noise_scale << ")" << std::endl;
1✔
439
                }
440
            }
441

442
            if (params->auto_estimate_measurement_noise) {
1✔
443
                // Use the percentile of the overall standard deviation as measurement noise
444
                estimated_length_measurement_noise = params->static_noise_percentile * length_stats.std_dev;
1✔
445
                // Clamp to a small positive floor for numerical stability
446
                double constexpr kMinMeasNoise = 1.0; // pixels
1✔
447
                if (estimated_length_measurement_noise < kMinMeasNoise) {
1✔
448
                    estimated_length_measurement_noise = kMinMeasNoise;
1✔
449
                }
450

451
                if (params->verbose_output) {
1✔
452
                    std::cout << "\nAuto-estimated measurement noise:" << std::endl;
1✔
453
                    std::cout << "  Estimated: " << estimated_length_measurement_noise << " pixels" << std::endl;
1✔
454
                    std::cout << "  (was: " << params->measurement_noise_length << ")" << std::endl;
1✔
455
                }
456
            }
UNCOV
457
        } else if (params->verbose_output) {
×
UNCOV
458
            std::cout << "\nWarning: No ground truth data found for noise estimation" << std::endl;
×
459
        }
460
    }
1✔
461

462
    // Build Kalman matrices from metadata with per-feature noise configuration
463
    // This automatically creates correct block-diagonal structure
464
    StateEstimation::KalmanMatrixBuilder::PerFeatureConfig config;
9✔
465
    config.dt = params->dt;
9✔
466
    config.process_noise_position = params->process_noise_position;
9✔
467
    config.process_noise_velocity = params->process_noise_velocity;
9✔
468
    config.static_noise_scale = estimated_length_process_noise_scale;// Use estimated or default
9✔
469
    config.measurement_noise = params->measurement_noise_position;   // Default for position features
9✔
470

471
    // Set feature-specific measurement noise
472
    config.feature_measurement_noise["line_centroid"] = params->measurement_noise_position;
27✔
473
    config.feature_measurement_noise["line_base_point"] = params->measurement_noise_position;
27✔
474
    config.feature_measurement_noise["line_length"] = estimated_length_measurement_noise;// Use estimated or default
27✔
475

476
    auto [F, H, Q, R] = StateEstimation::KalmanMatrixBuilder::buildAllMatricesFromMetadataPerFeature(
18✔
477
            metadata_list, config);
9✔
478

479
    // Add cross-feature process noise using estimated correlations
480
    if (!estimated_correlations.empty()) {
9✔
UNCOV
481
        Q = StateEstimation::KalmanMatrixBuilder::addCrossFeatureProcessNoise(
×
UNCOV
482
            Q, metadata_list, estimated_correlations);
×
483
        
UNCOV
484
        if (params->verbose_output) {
×
UNCOV
485
            std::cout << "\nAdded cross-feature process noise covariance based on empirical correlations" << std::endl;
×
486
        }
487
    }
488

489
    if (params->verbose_output) {
9✔
490
        std::cout << "\nNoise configuration:" << std::endl;
2✔
491
        std::cout << "  Process noise - position: " << params->process_noise_position << std::endl;
2✔
492
        std::cout << "  Process noise - velocity: " << params->process_noise_velocity << std::endl;
2✔
493
        std::cout << "  Process noise - static scale: " << estimated_length_process_noise_scale;
2✔
494
        if (params->auto_estimate_static_noise) {
2✔
495
            std::cout << " (auto-estimated, parameter was: " << params->static_feature_process_noise_scale << ")";
1✔
496
        }
497
        std::cout << std::endl;
2✔
498
        std::cout << "  Measurement noise - position: " << params->measurement_noise_position << std::endl;
2✔
499
        std::cout << "  Measurement noise - length: " << estimated_length_measurement_noise;
2✔
500
        if (params->auto_estimate_measurement_noise) {
2✔
501
            std::cout << " (auto-estimated, parameter was: " << params->measurement_noise_length << ")";
1✔
502
        }
503
        std::cout << std::endl;
2✔
504
        std::cout << "\nResulting Q (process noise covariance) diagonal:" << std::endl;
2✔
505
        for (int i = 0; i < Q.rows(); ++i) {
20✔
506
            std::cout << "    Q[" << i << "," << i << "] = " << Q(i, i) << std::endl;
18✔
507
        }
508
        std::cout << "\nResulting R (measurement noise covariance) diagonal:" << std::endl;
2✔
509
        for (int i = 0; i < R.rows(); ++i) {
12✔
510
            std::cout << "    R[" << i << "," << i << "] = " << R(i, i) << std::endl;
10✔
511
        }
512
    }
513

514
    auto kalman_filter = std::make_unique<KalmanFilter>(F, H, Q, R);
9✔
515

516
    // Build a state index map for dynamics-aware costs (order-independent)
517
    auto index_map = StateEstimation::KalmanMatrixBuilder::buildStateIndexMap(metadata_list);
9✔
518

519
    // Create dynamics-aware transition cost (measurement NLL + velocity + implied-acceleration)
520
    auto cost_fn = StateEstimation::createDynamicsAwareCostFunction(
9✔
521
            H,
522
            R,
523
            index_map,
524
            config.dt,
525
            /*beta=*/1.0,
526
            /*gamma=*/0.25,
527
            /*lambda_gap=*/0.0);
9✔
528

529
    // Use MinCostFlowTracker with the custom cost function
530
    // Relax greedy cheap-link threshold to account for added dynamics terms
531
    double const cheap_threshold = params->cheap_assignment_threshold * 5.0;
9✔
532
    // Use Mahalanobis for greedy chaining and dynamics-aware for transitions
533
    auto chain_cost = StateEstimation::createMahalanobisCostFunction(H, R);
9✔
534
    StateEstimation::MinCostFlowTracker<Line2D> tracker(
9✔
535
            std::move(kalman_filter),
9✔
536
            std::move(composite_extractor),
9✔
537
            chain_cost,
538
            cost_fn,
539
            params->cost_scale_factor,
9✔
540
            cheap_threshold);
27✔
541

542
    tracker.enableDebugLogging("tracker.log");
27✔
543

544
    // Process per-group consecutive anchor spans (safer when anchors are not coincident across groups)
545
    // Build group -> sorted anchor frames mapping
546
    std::map<GroupId, std::vector<TimeFrameIndex>> group_to_anchor_frames;
9✔
547
    for (auto const & [frame, assignments]: ground_truth) {
170✔
548
        for (auto const & [group_id, _]: assignments) {
483✔
549
            group_to_anchor_frames[group_id].push_back(frame);
322✔
550
        }
551
    }
552

553
    for (auto & [gid, frames]: group_to_anchor_frames) {
27✔
554
        std::sort(frames.begin(), frames.end());
18✔
555
        frames.erase(std::unique(frames.begin(), frames.end()), frames.end());
18✔
556
    }
557

558
    // Count total per-group intervals for progress
559
    size_t total_pairs = 0;
9✔
560
    for (auto const & [gid, frames]: group_to_anchor_frames) {
27✔
561
        if (frames.size() > 1) total_pairs += (frames.size() - 1);
18✔
562
    }
563
    size_t processed_pairs = 0;
9✔
564

565
    if (params->verbose_output) {
9✔
566
        std::cout << "\nProcessing per-group anchors across " << group_to_anchor_frames.size() << " groups" << std::endl;
2✔
567
    }
568

569
    for (auto const & [group_id, frames]: group_to_anchor_frames) {
27✔
570
        if (frames.size() < 2) continue;
18✔
571

572
        // Create putative output group for this anchor group if requested
573
        std::optional<GroupId> putative_group_id;
18✔
574
        if (params->write_to_putative_groups) {
18✔
575
            auto desc = group_manager->getGroupDescriptor(group_id);
18✔
576
            std::string base_name = desc ? desc->name : std::string("Group ") + std::to_string(group_id);
18✔
577
            std::string putative_name = params->putative_group_prefix + base_name;
18✔
578
            putative_group_id = group_manager->createGroup(putative_name, "Putative labels from Kalman grouping");
54✔
579
        }
18✔
580

581
        for (size_t i = 0; i + 1 < frames.size(); ++i) {
322✔
582
            TimeFrameIndex interval_start = frames[i];
304✔
583
            TimeFrameIndex interval_end = frames[i + 1];
304✔
584

585
            // Skip if consecutive frames - no gap to fill with MCF
586
            if (interval_end.getValue() - interval_start.getValue() <= 1) {
304✔
587
                continue;
144✔
588
            }
589

590
            // Build a minimal ground truth map for this group and interval only
591
            std::map<TimeFrameIndex, std::map<GroupId, EntityId>> gt_local;
160✔
592
            auto const & start_map = ground_truth.at(interval_start);
160✔
593
            auto const & end_map = ground_truth.at(interval_end);
160✔
594
            auto start_it = start_map.find(group_id);
160✔
595
            auto end_it = end_map.find(group_id);
160✔
596
            if (start_it == start_map.end() || end_it == end_map.end()) {
160✔
UNCOV
597
                continue;// safety: skip if either end missing
×
598
            }
599
            gt_local[interval_start][group_id] = start_it->second;
160✔
600
            gt_local[interval_end][group_id] = end_it->second;
160✔
601

602
            if (params->verbose_output) {
160✔
603
                std::cout << "\nProcessing group " << group_id << " interval: "
36✔
604
                          << interval_start.getValue() << " -> " << interval_end.getValue() << std::endl;
36✔
605
            }
606

607
            // Exclude already-labeled entities from matching; allow anchors explicitly
608
            std::unordered_set<EntityId> excluded_entities;
160✔
609
            for (auto gid: all_group_ids) {
548✔
610
                auto ents = group_manager->getEntitiesInGroup(gid);
388✔
611
                excluded_entities.insert(ents.begin(), ents.end());
388✔
612
            }
388✔
613
            std::unordered_set<EntityId> include_entities;// whitelist (anchors at ends)
160✔
614
            include_entities.insert(start_it->second);
160✔
615
            include_entities.insert(end_it->second);
160✔
616

617
            // Map write group: default write back to same anchor group; if putative, write to new group
618
            std::map<GroupId, GroupId> write_group_map;
160✔
619
            if (putative_group_id.has_value()) {
160✔
620
                write_group_map[group_id] = *putative_group_id;
160✔
621
            }
622

623
            [[maybe_unused]] auto smoothed_results = tracker.process(
160✔
624
                    data_source,
625
                    *group_manager,
626
                    gt_local,
627
                    interval_start,
628
                    interval_end,
629
                    progressCallback,
630
                    putative_group_id.has_value() ? &write_group_map : nullptr,
160✔
631
                    &excluded_entities,
632
                    &include_entities);
160✔
633

634
            // Report progress across all group-intervals
635
            processed_pairs++;
160✔
636
            int progress = total_pairs > 0 ? static_cast<int>(100.0 * processed_pairs / total_pairs) : 100;
160✔
637
            progressCallback(progress);
160✔
638
        }
160✔
639

640
        // After completing all intervals for this anchor group, emit one bulk notification
641
        group_manager->notifyGroupsChanged();
18✔
642
    }
643

644
    if (params->verbose_output) {
9✔
645
        std::cout << "Tracking complete. Groups updated in EntityGroupManager." << std::endl;
2✔
646
        for (auto group_id: all_group_ids) {
6✔
647
            auto entities = group_manager->getEntitiesInGroup(group_id);
4✔
648
            std::cout << "Group " << group_id << " now has " << entities.size() << " entities" << std::endl;
4✔
649
        }
4✔
650
    }
651

652
    progressCallback(100);
9✔
653
    return line_data;
9✔
654
}
10✔
655

656
// LineKalmanGroupingOperation implementation
657

658
std::string LineKalmanGroupingOperation::getName() const {
149✔
659
    return "Group Lines using Kalman Filtering";
447✔
660
}
661

662
std::type_index LineKalmanGroupingOperation::getTargetInputTypeIndex() const {
149✔
663
    return std::type_index(typeid(std::shared_ptr<LineData>));
149✔
664
}
665

UNCOV
666
bool LineKalmanGroupingOperation::canApply(DataTypeVariant const & dataVariant) const {
×
667
    return std::holds_alternative<std::shared_ptr<LineData>>(dataVariant) &&
×
668
           std::get<std::shared_ptr<LineData>>(dataVariant) != nullptr;
×
669
}
670

671
std::unique_ptr<TransformParametersBase> LineKalmanGroupingOperation::getDefaultParameters() const {
1✔
672
    // Create default parameters with null group manager
673
    // The EntityGroupManager must be set via setGroupManager() before execution
674
    return std::make_unique<LineKalmanGroupingParameters>();
1✔
675
}
676

UNCOV
677
DataTypeVariant LineKalmanGroupingOperation::execute(DataTypeVariant const & dataVariant,
×
678
                                                     TransformParametersBase const * transformParameters) {
679
    // No-op progress callback
UNCOV
680
    return execute(dataVariant, transformParameters, [](int) { /* no progress reporting */ });
×
681
}
682

UNCOV
683
DataTypeVariant LineKalmanGroupingOperation::execute(DataTypeVariant const & dataVariant,
×
684
                                                     TransformParametersBase const * transformParameters,
685
                                                     ProgressCallback progressCallback) {
UNCOV
686
    if (!canApply(dataVariant)) {
×
UNCOV
687
        return DataTypeVariant{};
×
688
    }
689

UNCOV
690
    auto line_data = std::get<std::shared_ptr<LineData>>(dataVariant);
×
UNCOV
691
    auto params = dynamic_cast<LineKalmanGroupingParameters const *>(transformParameters);
×
692

UNCOV
693
    if (!params) {
×
UNCOV
694
        return DataTypeVariant{};
×
695
    }
696

697
    // Check if group manager is valid (required for grouping operations)
UNCOV
698
    if (!params->hasValidGroupManager()) {
×
UNCOV
699
        std::cerr << "LineKalmanGroupingOperation::execute: EntityGroupManager is required but not set. Call setGroupManager() on parameters before execution." << std::endl;
×
UNCOV
700
        return DataTypeVariant{};
×
701
    }
702

UNCOV
703
    auto result = ::lineKalmanGrouping(line_data, params, progressCallback);
×
UNCOV
704
    return DataTypeVariant{result};
×
UNCOV
705
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc