• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

paulmthompson / WhiskerToolbox / 18246927847

04 Oct 2025 04:44PM UTC coverage: 71.826% (+0.6%) from 71.188%
18246927847

push

github

paulmthompson
refactor out media producer consumer pipeline

0 of 120 new or added lines in 2 files covered. (0.0%)

646 existing lines in 14 files now uncovered.

48895 of 68074 relevant lines covered (71.83%)

1193.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

1.83
/src/DataManager/transforms/Media/whisker_tracing.cpp
1
#include "whisker_tracing.hpp"
2

3
#include "Masks/Mask_Data.hpp"
4
#include "whiskertracker.hpp"
5
#include "producer_consumer_pipeline.hpp"
6

7
#include <omp.h>
8

9
#include <algorithm>
10
#include <chrono>
11
#include <cmath>
12
#include <iostream>
13
#include <memory>
14
#include <vector>
15

16
namespace {
17
constexpr uint8_t MASK_TRUE_VALUE = 255;
18
constexpr int PROGRESS_COMPLETE = 100;
19
constexpr double PROGRESS_SCALE = 100.0;
20
}// namespace
21

22
// Convert whisker::Line2D to Line2D
23
Line2D WhiskerTracingOperation::convert_to_Line2D(whisker::Line2D const & whisker_line) {
×
24
    Line2D line;
×
25

26
    for (auto const & point: whisker_line) {
×
27
        line.push_back(Point2D<float>{point.x, point.y});
×
28
    }
29

30
    return line;
×
31
}
×
32

33
// Convert mask data to binary mask format for whisker tracker
34
std::vector<uint8_t> convert_mask_to_binary(MaskData const * mask_data,
×
35
                                                                     int time_index,
36
                                                                     ImageSize const & image_size) {
37
    std::vector<uint8_t> binary_mask(static_cast<size_t>(image_size.width * image_size.height), 0);
×
38

39
    if (!mask_data) {
×
40
        return binary_mask;// Return empty mask if no mask data
×
41
    }
42

43
    // Get mask at the specified time
44
    auto const & masks_at_time = mask_data->getAtTime(TimeFrameIndex(time_index));
×
45
    if (masks_at_time.empty()) {
×
46
        return binary_mask;// Return empty mask if no mask at this time
×
47
    }
48

49
    // Source mask image size (may differ from target media image size)
50
    auto const src_size = mask_data->getImageSize();
×
51

52
    // Fast path: identical sizes, just map points directly
53
    if (src_size.width == image_size.width && src_size.height == image_size.height) {
×
54
        for (auto const & mask: masks_at_time) {
×
55
            for (auto const & point: mask) {
×
56
                if (point.x < image_size.width && point.y < image_size.height) {
×
57
                    auto const index = static_cast<size_t>(point.y) * static_cast<size_t>(image_size.width)
×
58
                                     + static_cast<size_t>(point.x);
×
59
                    if (index < binary_mask.size()) {
×
60
                        binary_mask[index] = MASK_TRUE_VALUE;// Set to 255 for true pixels
×
61
                    }
62
                }
63
            }
64
        }
65
        return binary_mask;
66
    }
67

68
    // Build a source binary mask for nearest-neighbor scaling when sizes differ
69
    std::vector<uint8_t> src_binary(static_cast<size_t>(src_size.width * src_size.height), 0);
×
70
    for (auto const & mask: masks_at_time) {
×
71
        for (auto const & point: mask) {
×
72
            if (point.x < src_size.width && point.y < src_size.height) {
×
73
                auto const src_index = static_cast<size_t>(point.y) * static_cast<size_t>(src_size.width)
×
74
                                     + static_cast<size_t>(point.x);
×
75
                if (src_index < src_binary.size()) {
×
76
                    src_binary[src_index] = MASK_TRUE_VALUE;
×
77
                }
78
            }
79
        }
80
    }
81

82
    // Nearest-neighbor scale from src_binary (src_size) to binary_mask (image_size)
83
    auto const src_w = std::max(1, src_size.width);
×
84
    auto const src_h = std::max(1, src_size.height);
×
85
    auto const dst_w = std::max(1, image_size.width);
×
86
    auto const dst_h = std::max(1, image_size.height);
×
87

88
    // Precompute ratios; use (N-1) mapping to preserve endpoints
89
    auto const rx = (dst_w > 1 && src_w > 1)
×
90
                            ? (static_cast<double>(src_w - 1) / static_cast<double>(dst_w - 1))
×
91
                            : 0.0;
92
    auto const ry = (dst_h > 1 && src_h > 1)
×
93
                            ? (static_cast<double>(src_h - 1) / static_cast<double>(dst_h - 1))
×
94
                            : 0.0;
95

96
    for (int y = 0; y < dst_h; ++y) {
×
97
        int const ys = (dst_h > 1 && src_h > 1)
×
98
                               ? static_cast<int>(std::round(static_cast<double>(y) * ry))
×
99
                               : 0;
100
        for (int x = 0; x < dst_w; ++x) {
×
101
            int const xs = (dst_w > 1 && src_w > 1)
×
102
                                   ? static_cast<int>(std::round(static_cast<double>(x) * rx))
×
103
                                   : 0;
104

105
            auto const src_index = static_cast<size_t>(ys) * static_cast<size_t>(src_w)
×
106
                                 + static_cast<size_t>(xs);
×
107
            auto const dst_index = static_cast<size_t>(y) * static_cast<size_t>(dst_w)
×
108
                                 + static_cast<size_t>(x);
×
109

110
            if (src_index < src_binary.size() && dst_index < binary_mask.size()) {
×
111
                binary_mask[dst_index] = src_binary[src_index] ? MASK_TRUE_VALUE : 0;
×
112
            }
113
        }
114
    }
115

116
    return binary_mask;
×
117
}
×
118

119
// Clip whisker line by removing points from the end
120
void WhiskerTracingOperation::clip_whisker(Line2D & line, int clip_length) {
×
121
    if (line.size() <= static_cast<std::size_t>(clip_length)) {
×
122
        return;
×
123
    }
124

125
    line.erase(line.end() - clip_length, line.end());
×
126
}
127

128
// Trace whiskers in a single image
129
std::vector<Line2D> WhiskerTracingOperation::trace_single_image(
×
130
        whisker::WhiskerTracker & whisker_tracker,
131
        std::vector<uint8_t> const & image_data,
132
        ImageSize const & image_size,
133
        int clip_length,
134
        MaskData const * mask_data,
135
        int time_index) {
136

137
    std::vector<Line2D> whisker_lines;
×
138

139
    if (mask_data) {
×
140
        // Use mask-based tracing
141
        auto binary_mask = convert_mask_to_binary(mask_data, time_index, image_size);
×
142
        auto whiskers = whisker_tracker.trace_with_mask(image_data, binary_mask, image_size.height, image_size.width);
×
143

144
        whisker_lines.reserve(whiskers.size());
×
145
        for (auto const & whisker: whiskers) {
×
146
            Line2D line = convert_to_Line2D(whisker);
×
147
            clip_whisker(line, clip_length);
×
148
            whisker_lines.push_back(std::move(line));
×
149
        }
×
150
    } else {
×
151
        // Use standard tracing
152
        auto whiskers = whisker_tracker.trace(image_data, image_size.height, image_size.width);
×
153

154
        whisker_lines.reserve(whiskers.size());
×
155
        for (auto const & whisker: whiskers) {
×
156
            Line2D line = convert_to_Line2D(whisker);
×
157
            clip_whisker(line, clip_length);
×
158
            whisker_lines.push_back(std::move(line));
×
159
        }
×
160
    }
×
161

162
    return whisker_lines;
×
163
}
×
164

165
// Trace whiskers in multiple images in parallel
166
std::vector<std::vector<Line2D>> WhiskerTracingOperation::trace_multiple_images(
×
167
        whisker::WhiskerTracker & whisker_tracker,
168
        std::vector<std::vector<uint8_t>> const & images,
169
        ImageSize const & image_size,
170
        int clip_length,
171
        MaskData const * mask_data,
172
        std::vector<int> const & time_indices) {
173

174
    std::vector<std::vector<Line2D>> result;
×
175
    result.reserve(images.size());
×
176

177
    if (mask_data && !time_indices.empty()) {
×
178
        // Use mask-based parallel tracing
179

180
        auto t0 = std::chrono::high_resolution_clock::now();
×
181

182
        std::vector<std::vector<uint8_t>> masks;
×
183
        masks.reserve(images.size());
×
184

185
        for (size_t i = 0; i < images.size(); ++i) {
×
186
            int const time_idx = (i < time_indices.size()) ? time_indices[i] : 0;
×
187
            auto binary_mask = convert_mask_to_binary(mask_data, time_idx, image_size);
×
188
            masks.push_back(std::move(binary_mask));
×
189
        }
×
190

191
        //auto t1 = std::chrono::high_resolution_clock::now();
192

193
        //std::cout << "Mask Generation: " << std::chrono::duration_cast<std::chrono::milliseconds>(t1 - t0).count() << "ms" << std::endl;
194

195
        auto whiskers_batch = whisker_tracker.trace_multiple_images_with_masks(images, masks, image_size.height, image_size.width);
×
196

197
        //auto t2 = std::chrono::high_resolution_clock::now();
198

199
        //std::cout << "Mask Tracing: " << std::chrono::duration_cast<std::chrono::milliseconds>(t2 - t1).count() << "ms" << std::endl;
200

201

202
        for (auto const & whiskers: whiskers_batch) {
×
203
            std::vector<Line2D> whisker_lines;
×
204
            whisker_lines.reserve(whiskers.size());
×
205

206
            for (auto const & whisker: whiskers) {
×
207
                Line2D line = convert_to_Line2D(whisker);
×
208
                clip_whisker(line, clip_length);
×
209
                whisker_lines.push_back(std::move(line));
×
210
            }
×
211

212
            result.push_back(std::move(whisker_lines));
×
213
        }
×
214
    } else {
×
215
        // Use standard parallel tracing
216
        auto whiskers_batch = whisker_tracker.trace_multiple_images(images, image_size.height, image_size.width);
×
217

218
        for (auto const & whiskers: whiskers_batch) {
×
219
            std::vector<Line2D> whisker_lines;
×
220
            whisker_lines.reserve(whiskers.size());
×
221

222
            for (auto const & whisker: whiskers) {
×
223
                Line2D line = convert_to_Line2D(whisker);
×
224
                clip_whisker(line, clip_length);
×
225
                whisker_lines.push_back(std::move(line));
×
226
            }
×
227

228
            result.push_back(std::move(whisker_lines));
×
229
        }
×
230
    }
×
231

232
    return result;
×
233
}
×
234

235
std::string WhiskerTracingOperation::getName() const {
148✔
236
    return "Whisker Tracing";
444✔
237
}
238

239
std::type_index WhiskerTracingOperation::getTargetInputTypeIndex() const {
148✔
240
    return typeid(std::shared_ptr<MediaData>);
148✔
241
}
242

243
bool WhiskerTracingOperation::canApply(DataTypeVariant const & dataVariant) const {
×
244
    if (!std::holds_alternative<std::shared_ptr<MediaData>>(dataVariant)) {
×
245
        return false;
×
246
    }
247

248
    auto const * ptr_ptr = std::get_if<std::shared_ptr<MediaData>>(&dataVariant);
×
249
    return ptr_ptr && *ptr_ptr;
×
250
}
251

252
std::unique_ptr<TransformParametersBase> WhiskerTracingOperation::getDefaultParameters() const {
×
253
    return std::make_unique<WhiskerTracingParameters>();
×
254
}
255

256
DataTypeVariant WhiskerTracingOperation::execute(DataTypeVariant const & dataVariant,
×
257
                                                 TransformParametersBase const * transformParameters) {
258
    return execute(dataVariant, transformParameters, [](int) {});
×
259
}
260

261
/**
262
 * @brief A simple data structure to hold a frame's image data and its timestamp.
263
 *
264
 * This replaces the original FrameData struct, removing the is_end_marker
265
 * which is no longer needed with the new BlockingQueue design.
266
 */
267
struct MediaFrame {
268
    std::vector<uint8_t> image_data;
269
    int time_index;
270
};
271

272

NEW
273
DataTypeVariant WhiskerTracingOperation::execute(DataTypeVariant const& dataVariant,
×
274
                                                 TransformParametersBase const* transformParameters,
275
                                                 ProgressCallback progressCallback) {
NEW
276
    auto const* ptr_ptr = std::get_if<std::shared_ptr<MediaData>>(&dataVariant);
×
277
    if (!ptr_ptr || !(*ptr_ptr)) {
×
278
        std::cerr << "WhiskerTracingOperation::execute: Incompatible variant type or null data." << std::endl;
×
NEW
279
        if (progressCallback) progressCallback(100);
×
280
        return {};
×
281
    }
282

283
    auto media_data = *ptr_ptr;
×
NEW
284
    auto const* params = dynamic_cast<WhiskerTracingParameters const*>(transformParameters);
×
285

NEW
286
    if (!params) {
×
287
        std::cerr << "WhiskerTracingOperation::execute: Invalid parameters." << std::endl;
×
NEW
288
        if (progressCallback) progressCallback(100);
×
289
        return {};
×
290
    }
291

NEW
292
    std::shared_ptr<whisker::WhiskerTracker> tracker = params->tracker;
×
NEW
293
    if (!tracker) {
×
NEW
294
        tracker = std::make_shared<whisker::WhiskerTracker>();
×
UNCOV
295
        std::cout << "Whisker Tracker Initialized" << std::endl;
×
296
    }
NEW
297
    tracker->setWhiskerLengthThreshold(params->whisker_length_threshold);
×
NEW
298
    tracker->setWhiskerPadRadius(1000.0f);
×
299

300
    if (progressCallback) progressCallback(0);
×
301

UNCOV
302
    auto traced_whiskers = std::make_shared<LineData>();
×
303
    traced_whiskers->setImageSize(media_data->getImageSize());
×
304

UNCOV
305
    auto total_frame_count = media_data->getTotalFrameCount();
×
306
    if (total_frame_count <= 0) {
×
NEW
307
        if (progressCallback) progressCallback(100);
×
308
        return {};
×
309
    }
310

NEW
311
    if (params->use_parallel_processing && params->batch_size > 1) {
×
312
        // --- Producer-Consumer Parallel Processing ---
313

314
        // BUG FIX: The original code was not thread-safe if MediaData is not.
315
        // This mutex protects access to media_data from the producer thread.
NEW
316
        std::mutex media_data_mutex;
×
317

318
        // The consumer will update the final LineData object. This mutex protects it.
NEW
319
        std::mutex results_mutex;
×
320

321
        // Define the producer logic as a lambda function.
NEW
322
        auto producer = [&](size_t frame_idx) -> std::optional<MediaFrame> {
×
NEW
323
            std::vector<uint8_t> image_data;
×
324
            try {
325
                // Lock the mutex before accessing media_data
NEW
326
                std::lock_guard<std::mutex> lock(media_data_mutex);
×
NEW
327
                if (params->use_processed_data) {
×
NEW
328
                    image_data = media_data->getProcessedData8(frame_idx);
×
329
                } else {
NEW
330
                    image_data = media_data->getRawData8(frame_idx);
×
331
                }
NEW
332
            } catch (const std::exception& e) {
×
NEW
333
                std::cerr << "Error producing frame " << frame_idx << ": " << e.what() << std::endl;
×
NEW
334
                return std::nullopt; // Signal failure for this item
×
NEW
335
            }
×
336

NEW
337
            if (image_data.empty()) {
×
NEW
338
                return std::nullopt; // Can happen if a frame is invalid
×
339
            }
NEW
340
            return MediaFrame{std::move(image_data), static_cast<int>(frame_idx)};
×
NEW
341
        };
×
342

343
        // Define the consumer logic as a lambda function.
NEW
344
        auto consumer = [&](std::vector<MediaFrame> batch) {
×
NEW
345
            std::vector<std::vector<uint8_t>> batch_images;
×
NEW
346
            std::vector<int> batch_times;
×
NEW
347
            batch_images.reserve(batch.size());
×
NEW
348
            batch_times.reserve(batch.size());
×
349

NEW
350
            for (auto& frame : batch) {
×
NEW
351
                batch_images.push_back(std::move(frame.image_data));
×
NEW
352
                batch_times.push_back(frame.time_index);
×
353
            }
354

NEW
355
            auto batch_results = trace_multiple_images(*tracker,
×
356
                                                       batch_images,
NEW
357
                                                       media_data->getImageSize(),
×
NEW
358
                                                       params->clip_length,
×
NEW
359
                                                       params->use_mask_data ? params->mask_data.get() : nullptr,
×
NEW
360
                                                       batch_times);
×
361
            
362
            // Lock the mutex to safely update the shared results container
NEW
363
            std::lock_guard<std::mutex> lock(results_mutex);
×
NEW
364
            for (size_t j = 0; j < batch_results.size(); ++j) {
×
NEW
365
                for (auto const& line : batch_results[j]) {
×
NEW
366
                    traced_whiskers->addAtTime(TimeFrameIndex(batch_times[j]), line, false);
×
367
                }
368
            }
NEW
369
        };
×
370

371
        // Set OpenMP threads. For an application-wide effect, this is okay.
372
        // For library code, using the 'num_threads' clause on the pragma is safer.
NEW
373
        int max_threads = omp_get_max_threads();
×
NEW
374
        int omp_threads = std::max(1, max_threads - 1); // Reserve 1 core for producer
×
375
        omp_set_num_threads(omp_threads);
×
NEW
376
        std::cout << "Using " << omp_threads << " OpenMP threads for processing." << std::endl;
×
377

378
        // Execute the pipeline
NEW
379
        run_pipeline<MediaFrame>(
×
NEW
380
            params->queue_size,
×
381
            total_frame_count,
382
            producer,
383
            consumer,
NEW
384
            params->batch_size,
×
385
            progressCallback);
386

387
    } else {
388
        // Process frames one by one (original sequential approach)
NEW
389
        for (size_t time = 0; time < total_frame_count; ++time) {
×
390
            std::vector<uint8_t> image_data;
×
391

NEW
392
            if (params->use_processed_data) {
×
393
                image_data = media_data->getProcessedData8(static_cast<int>(time));
×
394
            } else {
395
                image_data = media_data->getRawData8(static_cast<int>(time));
×
396
            }
397

398
            if (!image_data.empty()) {
×
NEW
399
                auto whisker_lines = trace_single_image(*tracker, image_data, media_data->getImageSize(),
×
NEW
400
                                                        params->clip_length,
×
NEW
401
                                                        params->use_mask_data ? params->mask_data.get() : nullptr,
×
UNCOV
402
                                                        static_cast<int>(time));
×
403

404
                for (auto const & line: whisker_lines) {
×
405
                    traced_whiskers->addAtTime(TimeFrameIndex(static_cast<int64_t>(time)), line, false);
×
406
                }
407
            }
×
408

409
            if (progressCallback) {
×
NEW
410
                int const current_progress = static_cast<int>(std::round(static_cast<double>(time) / static_cast<double>(total_frame_count) * PROGRESS_SCALE));
×
411
                progressCallback(current_progress);
×
412
            }
413
        }
×
414
    }
415

416
    if (progressCallback) progressCallback(PROGRESS_COMPLETE);
×
417

418
    std::cout << "WhiskerTracingOperation executed successfully. Traced "
×
419
              << traced_whiskers->GetAllLinesAsRange().size() << " whiskers across "
×
NEW
420
              << total_frame_count << " frames." << std::endl;
×
421

422
    return traced_whiskers;
×
423
}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc