• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

paulmthompson / WhiskerToolbox / 17737327548

15 Sep 2025 02:54PM UTC coverage: 72.602% (+0.5%) from 72.1%
17737327548

push

github

paulmthompson
Merge branch 'main' of https://github.com/paulmthompson/WhiskerToolbox

38155 of 52554 relevant lines covered (72.6%)

1349.01 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

53.33
/src/DataManager/transforms/Media/whisker_tracing.cpp
1
#include "whisker_tracing.hpp"
2

3
#include "Masks/Mask_Data.hpp"
4
#include "whiskertracker.hpp"
5

6
#include <algorithm>
7
#include <cmath>
8
#include <iostream>
9
#include <memory>
10
#include <vector>
11

12
namespace {
13
constexpr uint8_t MASK_TRUE_VALUE = 255;
14
constexpr int PROGRESS_COMPLETE = 100;
15
constexpr double PROGRESS_SCALE = 100.0;
16
}// namespace
17

18
// Convert whisker::Line2D to Line2D
19
Line2D WhiskerTracingOperation::convert_to_Line2D(whisker::Line2D const & whisker_line) {
2✔
20
    Line2D line;
2✔
21

22
    for (auto const & point: whisker_line) {
503✔
23
        line.push_back(Point2D<float>{point.x, point.y});
501✔
24
    }
25

26
    return line;
2✔
27
}
×
28

29
// Convert mask data to binary mask format for whisker tracker
30
std::vector<uint8_t> convert_mask_to_binary(MaskData const * mask_data,
2✔
31
                                                                     int time_index,
32
                                                                     ImageSize const & image_size) {
33
    std::vector<uint8_t> binary_mask(static_cast<size_t>(image_size.width * image_size.height), 0);
6✔
34

35
    if (!mask_data) {
2✔
36
        return binary_mask;// Return empty mask if no mask data
×
37
    }
38

39
    // Get mask at the specified time
40
    auto const & masks_at_time = mask_data->getAtTime(TimeFrameIndex(time_index));
2✔
41
    if (masks_at_time.empty()) {
2✔
42
        return binary_mask;// Return empty mask if no mask at this time
×
43
    }
44

45
    // Source mask image size (may differ from target media image size)
46
    auto const src_size = mask_data->getImageSize();
2✔
47

48
    // Fast path: identical sizes, just map points directly
49
    if (src_size.width == image_size.width && src_size.height == image_size.height) {
2✔
50
        for (auto const & mask: masks_at_time) {
2✔
51
            for (auto const & point: mask) {
77✔
52
                if (point.x < image_size.width && point.y < image_size.height) {
76✔
53
                    auto const index = static_cast<size_t>(point.y) * static_cast<size_t>(image_size.width)
76✔
54
                                     + static_cast<size_t>(point.x);
76✔
55
                    if (index < binary_mask.size()) {
76✔
56
                        binary_mask[index] = MASK_TRUE_VALUE;// Set to 255 for true pixels
76✔
57
                    }
58
                }
59
            }
60
        }
61
        return binary_mask;
62
    }
63

64
    // Build a source binary mask for nearest-neighbor scaling when sizes differ
65
    std::vector<uint8_t> src_binary(static_cast<size_t>(src_size.width * src_size.height), 0);
3✔
66
    for (auto const & mask: masks_at_time) {
2✔
67
        for (auto const & point: mask) {
60✔
68
            if (point.x < src_size.width && point.y < src_size.height) {
59✔
69
                auto const src_index = static_cast<size_t>(point.y) * static_cast<size_t>(src_size.width)
59✔
70
                                     + static_cast<size_t>(point.x);
59✔
71
                if (src_index < src_binary.size()) {
59✔
72
                    src_binary[src_index] = MASK_TRUE_VALUE;
59✔
73
                }
74
            }
75
        }
76
    }
77

78
    // Nearest-neighbor scale from src_binary (src_size) to binary_mask (image_size)
79
    auto const src_w = std::max(1, src_size.width);
1✔
80
    auto const src_h = std::max(1, src_size.height);
1✔
81
    auto const dst_w = std::max(1, image_size.width);
1✔
82
    auto const dst_h = std::max(1, image_size.height);
1✔
83

84
    // Precompute ratios; use (N-1) mapping to preserve endpoints
85
    auto const rx = (dst_w > 1 && src_w > 1)
1✔
86
                            ? (static_cast<double>(src_w - 1) / static_cast<double>(dst_w - 1))
2✔
87
                            : 0.0;
88
    auto const ry = (dst_h > 1 && src_h > 1)
1✔
89
                            ? (static_cast<double>(src_h - 1) / static_cast<double>(dst_h - 1))
2✔
90
                            : 0.0;
91

92
    for (int y = 0; y < dst_h; ++y) {
481✔
93
        int const ys = (dst_h > 1 && src_h > 1)
480✔
94
                               ? static_cast<int>(std::round(static_cast<double>(y) * ry))
960✔
95
                               : 0;
96
        for (int x = 0; x < dst_w; ++x) {
307,680✔
97
            int const xs = (dst_w > 1 && src_w > 1)
307,200✔
98
                                   ? static_cast<int>(std::round(static_cast<double>(x) * rx))
614,400✔
99
                                   : 0;
100

101
            auto const src_index = static_cast<size_t>(ys) * static_cast<size_t>(src_w)
307,200✔
102
                                 + static_cast<size_t>(xs);
307,200✔
103
            auto const dst_index = static_cast<size_t>(y) * static_cast<size_t>(dst_w)
307,200✔
104
                                 + static_cast<size_t>(x);
307,200✔
105

106
            if (src_index < src_binary.size() && dst_index < binary_mask.size()) {
307,200✔
107
                binary_mask[dst_index] = src_binary[src_index] ? MASK_TRUE_VALUE : 0;
307,200✔
108
            }
109
        }
110
    }
111

112
    return binary_mask;
1✔
113
}
1✔
114

115
// Clip whisker line by removing points from the end
116
void WhiskerTracingOperation::clip_whisker(Line2D & line, int clip_length) {
2✔
117
    if (line.size() <= static_cast<std::size_t>(clip_length)) {
2✔
118
        return;
×
119
    }
120

121
    line.erase(line.end() - clip_length, line.end());
2✔
122
}
123

124
// Trace whiskers in a single image
125
std::vector<Line2D> WhiskerTracingOperation::trace_single_image(
2✔
126
        whisker::WhiskerTracker & whisker_tracker,
127
        std::vector<uint8_t> const & image_data,
128
        ImageSize const & image_size,
129
        int clip_length,
130
        MaskData const * mask_data,
131
        int time_index) {
132

133
    std::vector<Line2D> whisker_lines;
2✔
134

135
    if (mask_data) {
2✔
136
        // Use mask-based tracing
137
        auto binary_mask = convert_mask_to_binary(mask_data, time_index, image_size);
2✔
138
        auto whiskers = whisker_tracker.trace_with_mask(image_data, binary_mask, image_size.height, image_size.width);
2✔
139

140
        whisker_lines.reserve(whiskers.size());
2✔
141
        for (auto const & whisker: whiskers) {
4✔
142
            Line2D line = convert_to_Line2D(whisker);
2✔
143
            clip_whisker(line, clip_length);
2✔
144
            whisker_lines.push_back(std::move(line));
2✔
145
        }
2✔
146
    } else {
2✔
147
        // Use standard tracing
148
        auto whiskers = whisker_tracker.trace(image_data, image_size.height, image_size.width);
×
149

150
        whisker_lines.reserve(whiskers.size());
×
151
        for (auto const & whisker: whiskers) {
×
152
            Line2D line = convert_to_Line2D(whisker);
×
153
            clip_whisker(line, clip_length);
×
154
            whisker_lines.push_back(std::move(line));
×
155
        }
×
156
    }
×
157

158
    return whisker_lines;
2✔
159
}
×
160

161
// Trace whiskers in multiple images in parallel
162
std::vector<std::vector<Line2D>> WhiskerTracingOperation::trace_multiple_images(
×
163
        whisker::WhiskerTracker & whisker_tracker,
164
        std::vector<std::vector<uint8_t>> const & images,
165
        ImageSize const & image_size,
166
        int clip_length,
167
        MaskData const * mask_data,
168
        std::vector<int> const & time_indices) {
169

170
    std::vector<std::vector<Line2D>> result;
×
171
    result.reserve(images.size());
×
172

173
    if (mask_data && !time_indices.empty()) {
×
174
        // Use mask-based parallel tracing
175
        std::vector<std::vector<uint8_t>> masks;
×
176
        masks.reserve(images.size());
×
177

178
        for (size_t i = 0; i < images.size(); ++i) {
×
179
            int const time_idx = (i < time_indices.size()) ? time_indices[i] : 0;
×
180
            auto binary_mask = convert_mask_to_binary(mask_data, time_idx, image_size);
×
181
            masks.push_back(std::move(binary_mask));
×
182
        }
×
183

184
        auto whiskers_batch = whisker_tracker.trace_multiple_images_with_masks(images, masks, image_size.height, image_size.width);
×
185

186
        for (auto const & whiskers: whiskers_batch) {
×
187
            std::vector<Line2D> whisker_lines;
×
188
            whisker_lines.reserve(whiskers.size());
×
189

190
            for (auto const & whisker: whiskers) {
×
191
                Line2D line = convert_to_Line2D(whisker);
×
192
                clip_whisker(line, clip_length);
×
193
                whisker_lines.push_back(std::move(line));
×
194
            }
×
195

196
            result.push_back(std::move(whisker_lines));
×
197
        }
×
198
    } else {
×
199
        // Use standard parallel tracing
200
        auto whiskers_batch = whisker_tracker.trace_multiple_images(images, image_size.height, image_size.width);
×
201

202
        for (auto const & whiskers: whiskers_batch) {
×
203
            std::vector<Line2D> whisker_lines;
×
204
            whisker_lines.reserve(whiskers.size());
×
205

206
            for (auto const & whisker: whiskers) {
×
207
                Line2D line = convert_to_Line2D(whisker);
×
208
                clip_whisker(line, clip_length);
×
209
                whisker_lines.push_back(std::move(line));
×
210
            }
×
211

212
            result.push_back(std::move(whisker_lines));
×
213
        }
×
214
    }
×
215

216
    return result;
×
217
}
×
218

219
std::string WhiskerTracingOperation::getName() const {
148✔
220
    return "Whisker Tracing";
444✔
221
}
222

223
std::type_index WhiskerTracingOperation::getTargetInputTypeIndex() const {
148✔
224
    return typeid(std::shared_ptr<MediaData>);
148✔
225
}
226

227
bool WhiskerTracingOperation::canApply(DataTypeVariant const & dataVariant) const {
×
228
    if (!std::holds_alternative<std::shared_ptr<MediaData>>(dataVariant)) {
×
229
        return false;
×
230
    }
231

232
    auto const * ptr_ptr = std::get_if<std::shared_ptr<MediaData>>(&dataVariant);
×
233
    return ptr_ptr && *ptr_ptr;
×
234
}
235

236
std::unique_ptr<TransformParametersBase> WhiskerTracingOperation::getDefaultParameters() const {
×
237
    return std::make_unique<WhiskerTracingParameters>();
×
238
}
239

240
DataTypeVariant WhiskerTracingOperation::execute(DataTypeVariant const & dataVariant,
2✔
241
                                                 TransformParametersBase const * transformParameters) {
242
    return execute(dataVariant, transformParameters, [](int) {});
2✔
243
}
244

245
DataTypeVariant WhiskerTracingOperation::execute(DataTypeVariant const & dataVariant,
2✔
246
                                                 TransformParametersBase const * transformParameters,
247
                                                 ProgressCallback progressCallback) {
248
    auto const * ptr_ptr = std::get_if<std::shared_ptr<MediaData>>(&dataVariant);
2✔
249
    if (!ptr_ptr || !(*ptr_ptr)) {
2✔
250
        std::cerr << "WhiskerTracingOperation::execute: Incompatible variant type or null data." << std::endl;
×
251
        if (progressCallback) progressCallback(PROGRESS_COMPLETE);
×
252
        return {};
×
253
    }
254

255
    auto media_data = *ptr_ptr;
2✔
256

257
    auto const * typed_params =
258
            transformParameters ? dynamic_cast<WhiskerTracingParameters const *>(transformParameters) : nullptr;
2✔
259

260
    if (!typed_params) {
2✔
261
        std::cerr << "WhiskerTracingOperation::execute: Invalid parameters." << std::endl;
×
262
        if (progressCallback) progressCallback(PROGRESS_COMPLETE);
×
263
        return {};
×
264
    }
265

266
    // Allow caller (tests) to pass an already-initialized tracker to avoid heavy setup
267
    std::shared_ptr<whisker::WhiskerTracker> tracker_ptr = typed_params->tracker;
2✔
268
    if (!tracker_ptr) {
2✔
269
        tracker_ptr = std::make_shared<whisker::WhiskerTracker>();
×
270
        std::cout << "Whisker Tracker Initialized" << std::endl;
×
271
    }
272
    tracker_ptr->setWhiskerLengthThreshold(typed_params->whisker_length_threshold);
2✔
273
    // Disable whisker pad exclusion by using a large radius by default
274
    tracker_ptr->setWhiskerPadRadius(1000.0f);
2✔
275

276
    if (progressCallback) progressCallback(0);
2✔
277

278
    // Create new LineData for the traced whiskers
279
    auto traced_whiskers = std::make_shared<LineData>();
2✔
280
    traced_whiskers->setImageSize(media_data->getImageSize());
2✔
281

282
    // Get times with data
283
    auto total_frame_count = media_data->getTotalFrameCount();
2✔
284
    if (total_frame_count <= 0) {
2✔
285
        std::cerr << "WhiskerTracingOperation::execute: No data available in media." << std::endl;
×
286
        if (progressCallback) progressCallback(PROGRESS_COMPLETE);
×
287
        return {};
×
288
    }
289

290
    auto total_time_points = static_cast<size_t>(total_frame_count);
2✔
291
    size_t processed_time_points = 0;
2✔
292

293
    // Process frames in batches for parallel processing
294
    if (typed_params->use_parallel_processing && typed_params->batch_size > 1) {
2✔
295
        for (size_t i = 0; i < total_time_points; i += static_cast<size_t>(typed_params->batch_size)) {
×
296
            std::vector<std::vector<uint8_t>> batch_images;
×
297
            std::vector<int> batch_times;
×
298

299
            // Collect images for this batch
300
            for (size_t j = 0; j < static_cast<size_t>(typed_params->batch_size) && (i + j) < total_time_points; ++j) {
×
301
                auto time = i + j;
×
302
                std::vector<uint8_t> image_data;
×
303

304
                if (typed_params->use_processed_data) {
×
305
                    // whisker tracking expects 8 bit
306
                    image_data = media_data->getProcessedData8(static_cast<int>(time));
×
307
                } else {
308
                    image_data = media_data->getRawData8(static_cast<int>(time));
×
309
                }
310

311
                if (!image_data.empty()) {
×
312
                    batch_images.push_back(std::move(image_data));
×
313
                    batch_times.push_back(static_cast<int>(time));
×
314
                }
315
            }
×
316

317
            if (!batch_images.empty()) {
×
318
                // Trace whiskers in parallel for this batch
319
                auto batch_results = trace_multiple_images(*tracker_ptr, batch_images, media_data->getImageSize(),
×
320
                                                           typed_params->clip_length,
×
321
                                                           typed_params->use_mask_data ? typed_params->mask_data.get() : nullptr,
×
322
                                                           batch_times);
×
323

324
                // Add results to LineData
325
                for (size_t j = 0; j < batch_results.size(); ++j) {
×
326
                    for (auto const & line: batch_results[j]) {
×
327
                        traced_whiskers->addAtTime(TimeFrameIndex(batch_times[j]), line, false);
×
328
                    }
329
                }
330

331
                processed_time_points += batch_images.size();
×
332
                std::cout << "Processed " << processed_time_points << " time points" << std::endl;
×
333
            }
×
334

335
            if (progressCallback) {
×
336
                int const current_progress = static_cast<int>(std::round(static_cast<double>(processed_time_points) / static_cast<double>(total_time_points) * PROGRESS_SCALE));
×
337
                progressCallback(current_progress);
×
338
            }
339
        }
×
340
    } else {
×
341
        // Process frames one by one
342
        for (size_t time = 0; time < total_time_points; ++time) {
4✔
343
            std::vector<uint8_t> image_data;
2✔
344

345
            if (typed_params->use_processed_data) {
2✔
346
                image_data = media_data->getProcessedData8(static_cast<int>(time));
×
347
            } else {
348
                image_data = media_data->getRawData8(static_cast<int>(time));
2✔
349
            }
350

351
            if (!image_data.empty()) {
2✔
352
                auto whisker_lines = trace_single_image(*tracker_ptr, image_data, media_data->getImageSize(),
6✔
353
                                                        typed_params->clip_length,
2✔
354
                                                        typed_params->use_mask_data ? typed_params->mask_data.get() : nullptr,
2✔
355
                                                        static_cast<int>(time));
6✔
356

357
                for (auto const & line: whisker_lines) {
4✔
358
                    traced_whiskers->addAtTime(TimeFrameIndex(static_cast<int64_t>(time)), line, false);
2✔
359
                }
360
            }
2✔
361

362
            processed_time_points++;
2✔
363
            if (progressCallback) {
2✔
364
                int const current_progress = static_cast<int>(std::round(static_cast<double>(processed_time_points) / static_cast<double>(total_time_points) * PROGRESS_SCALE));
2✔
365
                progressCallback(current_progress);
2✔
366
            }
367
        }
2✔
368
    }
369

370
    if (progressCallback) progressCallback(PROGRESS_COMPLETE);
2✔
371

372
    std::cout << "WhiskerTracingOperation executed successfully. Traced "
2✔
373
              << traced_whiskers->GetAllLinesAsRange().size() << " whiskers across "
4✔
374
              << total_frame_count << " time points." << std::endl;
4✔
375

376
    return traced_whiskers;
2✔
377
}
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc