• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

paulmthompson / WhiskerToolbox / 17743018711

15 Sep 2025 06:33PM UTC coverage: 72.505% (-0.1%) from 72.602%
17743018711

push

github

paulmthompson
consumer producer whisker tracking enabled. much faster whisker processing on windows

0 of 98 new or added lines in 2 files covered. (0.0%)

2 existing lines in 1 file now uncovered.

38156 of 52625 relevant lines covered (72.51%)

1347.15 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.66
/src/DataManager/transforms/Media/whisker_tracing.cpp
1
#include "whisker_tracing.hpp"
2

3
#include "Masks/Mask_Data.hpp"
4
#include "whiskertracker.hpp"
5

6
#include <omp.h>
7

8
#include <algorithm>
9
#include <chrono>
10
#include <cmath>
11
#include <iostream>
12
#include <memory>
13
#include <vector>
14

15
namespace {
16
constexpr uint8_t MASK_TRUE_VALUE = 255;
17
constexpr int PROGRESS_COMPLETE = 100;
18
constexpr double PROGRESS_SCALE = 100.0;
19
}// namespace
20

21
// Convert whisker::Line2D to Line2D
22
Line2D WhiskerTracingOperation::convert_to_Line2D(whisker::Line2D const & whisker_line) {
2✔
23
    Line2D line;
2✔
24

25
    for (auto const & point: whisker_line) {
503✔
26
        line.push_back(Point2D<float>{point.x, point.y});
501✔
27
    }
28

29
    return line;
2✔
30
}
×
31

32
// Convert mask data to binary mask format for whisker tracker
33
std::vector<uint8_t> convert_mask_to_binary(MaskData const * mask_data,
2✔
34
                                                                     int time_index,
35
                                                                     ImageSize const & image_size) {
36
    std::vector<uint8_t> binary_mask(static_cast<size_t>(image_size.width * image_size.height), 0);
6✔
37

38
    if (!mask_data) {
2✔
39
        return binary_mask;// Return empty mask if no mask data
×
40
    }
41

42
    // Get mask at the specified time
43
    auto const & masks_at_time = mask_data->getAtTime(TimeFrameIndex(time_index));
2✔
44
    if (masks_at_time.empty()) {
2✔
45
        return binary_mask;// Return empty mask if no mask at this time
×
46
    }
47

48
    // Source mask image size (may differ from target media image size)
49
    auto const src_size = mask_data->getImageSize();
2✔
50

51
    // Fast path: identical sizes, just map points directly
52
    if (src_size.width == image_size.width && src_size.height == image_size.height) {
2✔
53
        for (auto const & mask: masks_at_time) {
2✔
54
            for (auto const & point: mask) {
77✔
55
                if (point.x < image_size.width && point.y < image_size.height) {
76✔
56
                    auto const index = static_cast<size_t>(point.y) * static_cast<size_t>(image_size.width)
76✔
57
                                     + static_cast<size_t>(point.x);
76✔
58
                    if (index < binary_mask.size()) {
76✔
59
                        binary_mask[index] = MASK_TRUE_VALUE;// Set to 255 for true pixels
76✔
60
                    }
61
                }
62
            }
63
        }
64
        return binary_mask;
65
    }
66

67
    // Build a source binary mask for nearest-neighbor scaling when sizes differ
68
    std::vector<uint8_t> src_binary(static_cast<size_t>(src_size.width * src_size.height), 0);
3✔
69
    for (auto const & mask: masks_at_time) {
2✔
70
        for (auto const & point: mask) {
60✔
71
            if (point.x < src_size.width && point.y < src_size.height) {
59✔
72
                auto const src_index = static_cast<size_t>(point.y) * static_cast<size_t>(src_size.width)
59✔
73
                                     + static_cast<size_t>(point.x);
59✔
74
                if (src_index < src_binary.size()) {
59✔
75
                    src_binary[src_index] = MASK_TRUE_VALUE;
59✔
76
                }
77
            }
78
        }
79
    }
80

81
    // Nearest-neighbor scale from src_binary (src_size) to binary_mask (image_size)
82
    auto const src_w = std::max(1, src_size.width);
1✔
83
    auto const src_h = std::max(1, src_size.height);
1✔
84
    auto const dst_w = std::max(1, image_size.width);
1✔
85
    auto const dst_h = std::max(1, image_size.height);
1✔
86

87
    // Precompute ratios; use (N-1) mapping to preserve endpoints
88
    auto const rx = (dst_w > 1 && src_w > 1)
1✔
89
                            ? (static_cast<double>(src_w - 1) / static_cast<double>(dst_w - 1))
2✔
90
                            : 0.0;
91
    auto const ry = (dst_h > 1 && src_h > 1)
1✔
92
                            ? (static_cast<double>(src_h - 1) / static_cast<double>(dst_h - 1))
2✔
93
                            : 0.0;
94

95
    for (int y = 0; y < dst_h; ++y) {
481✔
96
        int const ys = (dst_h > 1 && src_h > 1)
480✔
97
                               ? static_cast<int>(std::round(static_cast<double>(y) * ry))
960✔
98
                               : 0;
99
        for (int x = 0; x < dst_w; ++x) {
307,680✔
100
            int const xs = (dst_w > 1 && src_w > 1)
307,200✔
101
                                   ? static_cast<int>(std::round(static_cast<double>(x) * rx))
614,400✔
102
                                   : 0;
103

104
            auto const src_index = static_cast<size_t>(ys) * static_cast<size_t>(src_w)
307,200✔
105
                                 + static_cast<size_t>(xs);
307,200✔
106
            auto const dst_index = static_cast<size_t>(y) * static_cast<size_t>(dst_w)
307,200✔
107
                                 + static_cast<size_t>(x);
307,200✔
108

109
            if (src_index < src_binary.size() && dst_index < binary_mask.size()) {
307,200✔
110
                binary_mask[dst_index] = src_binary[src_index] ? MASK_TRUE_VALUE : 0;
307,200✔
111
            }
112
        }
113
    }
114

115
    return binary_mask;
1✔
116
}
1✔
117

118
// Clip whisker line by removing points from the end
119
void WhiskerTracingOperation::clip_whisker(Line2D & line, int clip_length) {
2✔
120
    if (line.size() <= static_cast<std::size_t>(clip_length)) {
2✔
121
        return;
×
122
    }
123

124
    line.erase(line.end() - clip_length, line.end());
2✔
125
}
126

127
// Trace whiskers in a single image
128
std::vector<Line2D> WhiskerTracingOperation::trace_single_image(
2✔
129
        whisker::WhiskerTracker & whisker_tracker,
130
        std::vector<uint8_t> const & image_data,
131
        ImageSize const & image_size,
132
        int clip_length,
133
        MaskData const * mask_data,
134
        int time_index) {
135

136
    std::vector<Line2D> whisker_lines;
2✔
137

138
    if (mask_data) {
2✔
139
        // Use mask-based tracing
140
        auto binary_mask = convert_mask_to_binary(mask_data, time_index, image_size);
2✔
141
        auto whiskers = whisker_tracker.trace_with_mask(image_data, binary_mask, image_size.height, image_size.width);
2✔
142

143
        whisker_lines.reserve(whiskers.size());
2✔
144
        for (auto const & whisker: whiskers) {
4✔
145
            Line2D line = convert_to_Line2D(whisker);
2✔
146
            clip_whisker(line, clip_length);
2✔
147
            whisker_lines.push_back(std::move(line));
2✔
148
        }
2✔
149
    } else {
2✔
150
        // Use standard tracing
151
        auto whiskers = whisker_tracker.trace(image_data, image_size.height, image_size.width);
×
152

153
        whisker_lines.reserve(whiskers.size());
×
154
        for (auto const & whisker: whiskers) {
×
155
            Line2D line = convert_to_Line2D(whisker);
×
156
            clip_whisker(line, clip_length);
×
157
            whisker_lines.push_back(std::move(line));
×
158
        }
×
159
    }
×
160

161
    return whisker_lines;
2✔
162
}
×
163

164
// Trace whiskers in multiple images in parallel
165
std::vector<std::vector<Line2D>> WhiskerTracingOperation::trace_multiple_images(
×
166
        whisker::WhiskerTracker & whisker_tracker,
167
        std::vector<std::vector<uint8_t>> const & images,
168
        ImageSize const & image_size,
169
        int clip_length,
170
        MaskData const * mask_data,
171
        std::vector<int> const & time_indices) {
172

173
    std::vector<std::vector<Line2D>> result;
×
174
    result.reserve(images.size());
×
175

176
    if (mask_data && !time_indices.empty()) {
×
177
        // Use mask-based parallel tracing
178

NEW
179
        auto t0 = std::chrono::high_resolution_clock::now();
×
180

181
        std::vector<std::vector<uint8_t>> masks;
×
182
        masks.reserve(images.size());
×
183

184
        for (size_t i = 0; i < images.size(); ++i) {
×
185
            int const time_idx = (i < time_indices.size()) ? time_indices[i] : 0;
×
186
            auto binary_mask = convert_mask_to_binary(mask_data, time_idx, image_size);
×
187
            masks.push_back(std::move(binary_mask));
×
188
        }
×
189

190
        //auto t1 = std::chrono::high_resolution_clock::now();
191

192
        //std::cout << "Mask Generation: " << std::chrono::duration_cast<std::chrono::milliseconds>(t1 - t0).count() << "ms" << std::endl;
193

UNCOV
194
        auto whiskers_batch = whisker_tracker.trace_multiple_images_with_masks(images, masks, image_size.height, image_size.width);
×
195

196
        //auto t2 = std::chrono::high_resolution_clock::now();
197

198
        //std::cout << "Mask Tracing: " << std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count() << "ms" << std::endl;
199

200

201
        for (auto const & whiskers: whiskers_batch) {
×
202
            std::vector<Line2D> whisker_lines;
×
203
            whisker_lines.reserve(whiskers.size());
×
204

205
            for (auto const & whisker: whiskers) {
×
206
                Line2D line = convert_to_Line2D(whisker);
×
207
                clip_whisker(line, clip_length);
×
208
                whisker_lines.push_back(std::move(line));
×
209
            }
×
210

211
            result.push_back(std::move(whisker_lines));
×
212
        }
×
213
    } else {
×
214
        // Use standard parallel tracing
215
        auto whiskers_batch = whisker_tracker.trace_multiple_images(images, image_size.height, image_size.width);
×
216

217
        for (auto const & whiskers: whiskers_batch) {
×
218
            std::vector<Line2D> whisker_lines;
×
219
            whisker_lines.reserve(whiskers.size());
×
220

221
            for (auto const & whisker: whiskers) {
×
222
                Line2D line = convert_to_Line2D(whisker);
×
223
                clip_whisker(line, clip_length);
×
224
                whisker_lines.push_back(std::move(line));
×
225
            }
×
226

227
            result.push_back(std::move(whisker_lines));
×
228
        }
×
229
    }
×
230

231
    return result;
×
232
}
×
233

234
// Producer thread function that loads frames from media_data
NEW
235
void WhiskerTracingOperation::producer_thread(std::shared_ptr<MediaData> media_data,
×
236
                                             FrameQueue& frame_queue,
237
                                             WhiskerTracingParameters const* params,
238
                                             int total_frames,
239
                                             std::atomic<int>& progress_atomic) {
240
    try {
NEW
241
        for (int frame_idx = 0; frame_idx < total_frames; ++frame_idx) {
×
NEW
242
            std::vector<uint8_t> image_data;
×
243
            
NEW
244
            if (params->use_processed_data) {
×
NEW
245
                image_data = media_data->getProcessedData8(frame_idx);
×
246
            } else {
NEW
247
                image_data = media_data->getRawData8(frame_idx);
×
248
            }
249
            
NEW
250
            if (!image_data.empty()) {
×
NEW
251
                FrameData frame(std::move(image_data), frame_idx);
×
NEW
252
                frame_queue.push(std::move(frame));
×
NEW
253
            }
×
254
            
NEW
255
            progress_atomic.fetch_add(1);
×
NEW
256
        }
×
257
        
258
        // Signal end of data
NEW
259
        frame_queue.push_end_marker();
×
NEW
260
    } catch (const std::exception& e) {
×
NEW
261
        std::cerr << "Producer thread error: " << e.what() << std::endl;
×
NEW
262
        frame_queue.push_end_marker();
×
NEW
263
    }
×
NEW
264
}
×
265

266
// Consumer function that processes frames from the queue
NEW
267
void WhiskerTracingOperation::consumer_processing(FrameQueue& frame_queue,
×
268
                                                 whisker::WhiskerTracker& tracker,
269
                                                 ImageSize const& image_size,
270
                                                 WhiskerTracingParameters const* params,
271
                                                 std::shared_ptr<LineData> traced_whiskers,
272
                                                 std::atomic<int>& progress_atomic,
273
                                                 int total_frames,
274
                                                 ProgressCallback progressCallback) {
NEW
275
    std::vector<std::vector<uint8_t>> batch_images;
×
NEW
276
    std::vector<int> batch_times;
×
NEW
277
    batch_images.reserve(params->batch_size);
×
NEW
278
    batch_times.reserve(params->batch_size);
×
279
    
NEW
280
    int processed_frames = 0;
×
281
    
NEW
282
    while (processed_frames < total_frames) {
×
NEW
283
        FrameData frame;
×
284
        
285
        // Try to get a frame with timeout
NEW
286
        bool got_frame = frame_queue.pop(frame, std::chrono::milliseconds(1000));
×
287
        
NEW
288
        if (!got_frame) {
×
NEW
289
            std::cerr << "Consumer timeout waiting for frame" << std::endl;
×
NEW
290
            break;
×
291
        }
292
        
293
        // Check for end marker
NEW
294
        if (frame.is_end_marker) {
×
NEW
295
            break;
×
296
        }
297
        
NEW
298
        batch_images.push_back(std::move(frame.image_data));
×
NEW
299
        batch_times.push_back(frame.time_index);
×
300
        
301
        // Process batch when we have enough frames or we've reached the end
302
        // Use adaptive batch sizing: smaller batches when queue is empty, larger when full
NEW
303
        int adaptive_batch_size = (frame_queue.size() > 10) ? params->batch_size : std::min(params->batch_size, 20);
×
NEW
304
        if (batch_images.size() >= static_cast<size_t>(adaptive_batch_size) || 
×
NEW
305
            processed_frames + batch_images.size() >= total_frames) {
×
306
            
NEW
307
            if (!batch_images.empty()) {
×
308
                // Trace whiskers in parallel for this batch
NEW
309
                auto batch_results = trace_multiple_images(tracker,
×
310
                                                         batch_images,
311
                                                         image_size,
NEW
312
                                                         params->clip_length,
×
NEW
313
                                                         params->use_mask_data ? params->mask_data.get() : nullptr,
×
NEW
314
                                                         batch_times);
×
315
                
316
                // Add results to LineData
NEW
317
                for (size_t j = 0; j < batch_results.size(); ++j) {
×
NEW
318
                    for (auto const & line: batch_results[j]) {
×
NEW
319
                        traced_whiskers->addAtTime(TimeFrameIndex(batch_times[j]), line, false);
×
320
                    }
321
                }
322
                
NEW
323
                processed_frames += batch_images.size();
×
324
                
325
                // Update progress from consumer thread
NEW
326
                if (progressCallback) {
×
NEW
327
                    int const current_progress = static_cast<int>(std::round(static_cast<double>(processed_frames) / static_cast<double>(total_frames) * PROGRESS_SCALE));
×
NEW
328
                    progressCallback(current_progress);
×
329
                }
330
                
331
                // Clear batch for next iteration
NEW
332
                batch_images.clear();
×
NEW
333
                batch_times.clear();
×
NEW
334
            }
×
335
        }
NEW
336
    }
×
NEW
337
}
×
338

339
std::string WhiskerTracingOperation::getName() const {
148✔
340
    return "Whisker Tracing";
444✔
341
}
342

343
std::type_index WhiskerTracingOperation::getTargetInputTypeIndex() const {
148✔
344
    return typeid(std::shared_ptr<MediaData>);
148✔
345
}
346

347
bool WhiskerTracingOperation::canApply(DataTypeVariant const & dataVariant) const {
×
348
    if (!std::holds_alternative<std::shared_ptr<MediaData>>(dataVariant)) {
×
349
        return false;
×
350
    }
351

352
    auto const * ptr_ptr = std::get_if<std::shared_ptr<MediaData>>(&dataVariant);
×
353
    return ptr_ptr && *ptr_ptr;
×
354
}
355

356
std::unique_ptr<TransformParametersBase> WhiskerTracingOperation::getDefaultParameters() const {
×
357
    return std::make_unique<WhiskerTracingParameters>();
×
358
}
359

360
DataTypeVariant WhiskerTracingOperation::execute(DataTypeVariant const & dataVariant,
2✔
361
                                                 TransformParametersBase const * transformParameters) {
362
    return execute(dataVariant, transformParameters, [](int) {});
2✔
363
}
364

365
DataTypeVariant WhiskerTracingOperation::execute(DataTypeVariant const & dataVariant,
2✔
366
                                                 TransformParametersBase const * transformParameters,
367
                                                 ProgressCallback progressCallback) {
368
    auto const * ptr_ptr = std::get_if<std::shared_ptr<MediaData>>(&dataVariant);
2✔
369
    if (!ptr_ptr || !(*ptr_ptr)) {
2✔
370
        std::cerr << "WhiskerTracingOperation::execute: Incompatible variant type or null data." << std::endl;
×
371
        if (progressCallback) progressCallback(PROGRESS_COMPLETE);
×
372
        return {};
×
373
    }
374

375
    auto media_data = *ptr_ptr;
2✔
376

377
    auto const * typed_params =
2✔
378
            transformParameters ? dynamic_cast<WhiskerTracingParameters const *>(transformParameters) : nullptr;
2✔
379

380
    if (!typed_params) {
2✔
381
        std::cerr << "WhiskerTracingOperation::execute: Invalid parameters." << std::endl;
×
382
        if (progressCallback) progressCallback(PROGRESS_COMPLETE);
×
383
        return {};
×
384
    }
385

386
    // Allow caller (tests) to pass an already-initialized tracker to avoid heavy setup
387
    std::shared_ptr<whisker::WhiskerTracker> tracker_ptr = typed_params->tracker;
2✔
388
    if (!tracker_ptr) {
2✔
389
        tracker_ptr = std::make_shared<whisker::WhiskerTracker>();
×
390
        std::cout << "Whisker Tracker Initialized" << std::endl;
×
391
    }
392
    tracker_ptr->setWhiskerLengthThreshold(typed_params->whisker_length_threshold);
2✔
393
    // Disable whisker pad exclusion by using a large radius by default
394
    tracker_ptr->setWhiskerPadRadius(1000.0f);
2✔
395

396
    if (progressCallback) progressCallback(0);
2✔
397

398
    // Create new LineData for the traced whiskers
399
    auto traced_whiskers = std::make_shared<LineData>();
2✔
400
    traced_whiskers->setImageSize(media_data->getImageSize());
2✔
401

402
    // Get times with data
403
    auto total_frame_count = media_data->getTotalFrameCount();
2✔
404
    if (total_frame_count <= 0) {
2✔
405
        std::cerr << "WhiskerTracingOperation::execute: No data available in media." << std::endl;
×
406
        if (progressCallback) progressCallback(PROGRESS_COMPLETE);
×
407
        return {};
×
408
    }
409

410
    auto total_time_points = static_cast<size_t>(total_frame_count);
2✔
411
    size_t processed_time_points = 0;
2✔
412

413
    // Process frames using producer-consumer pattern for parallel processing
414
    if (typed_params->use_parallel_processing && typed_params->batch_size > 1) {
2✔
415

NEW
416
        auto max_threads = omp_get_max_threads();
×
417
        // Reserve threads for producer, use rest for OpenMP processing
NEW
418
        int omp_threads = std::max(1, max_threads - typed_params->producer_threads);
×
NEW
419
        omp_set_num_threads(omp_threads);
×
NEW
420
        std::cout << "Total CPU cores: " << max_threads 
×
NEW
421
                  << ", OpenMP threads: " << omp_threads 
×
NEW
422
                  << ", Producer threads: " << typed_params->producer_threads << std::endl;
×
423

424
        // Create frame queue for producer-consumer pattern
NEW
425
        FrameQueue frame_queue(typed_params->queue_size);
×
426
        
427
        // Atomic counter for progress tracking
NEW
428
        std::atomic<int> progress_atomic{0};
×
429
        
430
        // Start producer thread
NEW
431
        std::thread producer([&]() {
×
NEW
432
            producer_thread(media_data, frame_queue, typed_params, 
×
NEW
433
                          static_cast<int>(total_time_points), progress_atomic);
×
NEW
434
        });
×
435
        
436
        // Consumer processing (runs in main thread)
NEW
437
        consumer_processing(frame_queue, *tracker_ptr, media_data->getImageSize(),
×
438
                          typed_params, traced_whiskers, progress_atomic,
439
                          static_cast<int>(total_time_points), progressCallback);
440
        
441
        // Wait for producer to finish
NEW
442
        producer.join();
×
443
        
NEW
444
        processed_time_points = total_time_points;
×
445
        
UNCOV
446
    } else {
×
447
        // Process frames one by one (original sequential approach)
448
        for (size_t time = 0; time < total_time_points; ++time) {
4✔
449
            std::vector<uint8_t> image_data;
2✔
450

451
            if (typed_params->use_processed_data) {
2✔
452
                image_data = media_data->getProcessedData8(static_cast<int>(time));
×
453
            } else {
454
                image_data = media_data->getRawData8(static_cast<int>(time));
2✔
455
            }
456

457
            if (!image_data.empty()) {
2✔
458
                auto whisker_lines = trace_single_image(*tracker_ptr, image_data, media_data->getImageSize(),
6✔
459
                                                        typed_params->clip_length,
2✔
460
                                                        typed_params->use_mask_data ? typed_params->mask_data.get() : nullptr,
2✔
461
                                                        static_cast<int>(time));
6✔
462

463
                for (auto const & line: whisker_lines) {
4✔
464
                    traced_whiskers->addAtTime(TimeFrameIndex(static_cast<int64_t>(time)), line, false);
2✔
465
                }
466
            }
2✔
467

468
            processed_time_points++;
2✔
469
            if (progressCallback) {
2✔
470
                int const current_progress = static_cast<int>(std::round(static_cast<double>(processed_time_points) / static_cast<double>(total_time_points) * PROGRESS_SCALE));
2✔
471
                progressCallback(current_progress);
2✔
472
            }
473
        }
2✔
474
    }
475

476
    if (progressCallback) progressCallback(PROGRESS_COMPLETE);
2✔
477

478
    std::cout << "WhiskerTracingOperation executed successfully. Traced "
2✔
479
              << traced_whiskers->GetAllLinesAsRange().size() << " whiskers across "
4✔
480
              << total_frame_count << " time points." << std::endl;
4✔
481

482
    return traced_whiskers;
2✔
483
}
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc