• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

paulmthompson / WhiskerToolbox / 15029140720

14 May 2025 07:15PM UTC coverage: 23.816% (+3.3%) from 20.55%
15029140720

push

github

paulmthompson
fix failing tests with notify method

558 of 2343 relevant lines covered (23.82%)

2.68 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

22.37
/src/WhiskerToolbox/DataManager/DataManager.cpp
1

2
#include "DataManager.hpp"
3
#include "AnalogTimeSeries/Analog_Time_Series.hpp"
4
#include "DigitalTimeSeries/Digital_Event_Series.hpp"
5
#include "DigitalTimeSeries/Digital_Interval_Series.hpp"
6
#include "Lines/Line_Data.hpp"
7
#include "Masks/Mask_Data.hpp"
8
#include "Media/Media_Data.hpp"
9
#include "Media/Video_Data.hpp"
10
#include "Points/Point_Data.hpp"
11
#include "Tensors/Tensor_Data.hpp"
12

13
#include "AnalogTimeSeries/Analog_Time_Series_Loader.hpp"
14
#include "DigitalTimeSeries/Digital_Event_Series_Loader.hpp"
15
#include "DigitalTimeSeries/Digital_Interval_Series_Loader.hpp"
16
#include "Lines/IO/CSV/Line_Data_CSV.hpp"
17
#include "Masks/IO/HDF5/Mask_Data_HDF5.hpp"
18
#include "Media/Video_Data_Loader.hpp"
19
#include "Points/IO/CSV/Point_Data_CSV.hpp"
20

21
#include "loaders/binary_loaders.hpp"
22
#include "transforms/data_transforms.hpp"
23
#include "transforms/Masks/mask_area.hpp"
24

25
#include "TimeFrame.hpp"
26

27
#include "nlohmann/json.hpp"
28
#include "utils/string_manip.hpp"
29

30
#include <filesystem>
31
#include <fstream>
32
#include <iostream>
33
#include <optional>
34
#include <regex>
35

36
using namespace nlohmann;
37

38
DataManager::DataManager() {
31✔
39
    _times["time"] = std::make_shared<TimeFrame>();
93✔
40
    _data["media"] = std::make_shared<MediaData>();
93✔
41

42
    setTimeFrame("media", "time");
155✔
43
    _output_path = std::filesystem::current_path();
31✔
44
}
31✔
45

46
bool DataManager::setTime(std::string const & key, std::shared_ptr<TimeFrame> timeframe) {
16✔
47

48
    if (!timeframe) {
16✔
49
        std::cerr << "Error: Cannot register a nullptr TimeFrame for key: " << key << std::endl;
1✔
50
        return false;
1✔
51
    }
52

53
    if (_times.find(key) != _times.end()) {
15✔
54
        std::cerr << "Error: Time key already exists in DataManager: " << key << std::endl;
2✔
55
        return false;
2✔
56
    }
57

58
    _times[key] = std::move(timeframe);
13✔
59
    return true;
13✔
60
}
61

62
std::shared_ptr<TimeFrame>  DataManager::getTime() {
1✔
63
    return _times["time"];
3✔
64
};
65

66
std::shared_ptr<TimeFrame> DataManager::getTime(std::string const & key) {
7✔
67
    if (_times.find(key) != _times.end()) {
7✔
68
        return _times[key];
5✔
69
    }
70
    return nullptr;
2✔
71
};
72

73
bool DataManager::setTimeFrame(std::string const & data_key, std::string const & time_key) {
58✔
74
    if (_data.find(data_key) == _data.end()) {
58✔
75
        std::cerr << "Error: Data key not found in DataManager: " << data_key << std::endl;
1✔
76
        return false;
1✔
77
    }
78

79
    if (_times.find(time_key) == _times.end()) {
57✔
80
        std::cerr << "Error: Time key not found in DataManager: " << time_key << std::endl;
1✔
81
        return false;
1✔
82
    }
83

84
    _time_frames[data_key] = time_key;
56✔
85
    return true;
56✔
86
}
87

88
std::string DataManager::getTimeFrame(std::string const & data_key) {
6✔
89
    // check if data_key exists
90
    if (_data.find(data_key) == _data.end()) {
6✔
91
        std::cerr << "Error: Data key not found in DataManager: " << data_key << std::endl;
1✔
92
        return "";
3✔
93
    }
94

95
    // check if data key has time frame
96
    if (_time_frames.find(data_key) == _time_frames.end()) {
5✔
97
        std::cerr << "Error: Data key "
98
                  << data_key
99
                  << " exists, but not assigned to a TimeFrame" <<  std::endl;
×
100
        return "";
×
101
    }
102

103
    return _time_frames[data_key];
5✔
104
}
105

106
std::vector<std::string> DataManager::getTimeFrameKeys() {
8✔
107
    std::vector<std::string> keys;
8✔
108
    keys.reserve(_times.size());
8✔
109
    for (auto const & [key, value]: _times) {
24✔
110

111
        keys.push_back(key);
16✔
112
    }
113
    return keys;
8✔
114
}
×
115

116
int DataManager::addCallbackToData(std::string const & key, ObserverCallback callback) {
7✔
117

118
    int id = -1;
7✔
119

120
    if (_data.find(key) != _data.end()) {
7✔
121
        auto data = _data[key];
6✔
122

123
        id = std::visit([callback](auto & x) {
12✔
124
            return x.get()->addObserver(callback);
6✔
125
        }, data);
126
    }
6✔
127

128
    return id;
7✔
129
}
130

131
bool DataManager::removeCallbackFromData(std::string const & key, int callback_id) {
4✔
132
    if (_data.find(key) != _data.end()) {
4✔
133
        auto data = _data[key];
3✔
134

135
        std::visit([callback_id](auto & x) {
6✔
136
            x.get()->removeObserver(callback_id);
3✔
137
        }, data);
3✔
138

139
        return true;
3✔
140
    }
3✔
141

142
    return false;
1✔
143
}
144

145
void DataManager::addObserver(ObserverCallback callback) {
5✔
146
    _observers.push_back(std::move(callback));
5✔
147
}
5✔
148

149
void DataManager::_notifyObservers() {
21✔
150
    for (auto & observer: _observers) {
29✔
151
        observer();
8✔
152
    }
153
}
21✔
154

155
std::optional<std::string> processFilePath(
×
156
        std::string const & file_path,
157
        std::filesystem::path const & base_path) {
158
    std::filesystem::path full_path = file_path;
×
159

160
    // Check for wildcard character
161
    if (file_path.find('*') != std::string::npos) {
×
162
        // Convert wildcard pattern to regex
163
        std::string const pattern = std::regex_replace(full_path.string(), std::regex("\\*"), ".*");
×
164
        std::regex const regex_pattern(pattern);
×
165

166
        // Iterate through the directory to find matching files
167
        for (auto const & entry: std::filesystem::directory_iterator(base_path)) {
×
168
            std::cout << "Checking " << entry.path().string() << " with full path " << full_path << std::endl;
×
169
            if (std::regex_match(entry.path().string(), regex_pattern)) {
×
170
                std::cout << "Loading file " << entry.path().string() << std::endl;
×
171
                return entry.path().string();
×
172
            }
173
        }
×
174
        return std::nullopt;
×
175
    } else {
×
176
        // Check if the file path is relative
177
        if (!std::filesystem::path(file_path).is_absolute()) {
×
178
            full_path = base_path / file_path;
×
179
        }
180
        // Check for the presence of the file
181
        if (std::filesystem::exists(full_path)) {
×
182
            std::cout << "Loading file " << full_path.string() << std::endl;
×
183
            return full_path.string();
×
184
        } else {
185
            return std::nullopt;
×
186
        }
187
    }
188
}
×
189

190
bool checkRequiredFields(json const & item, std::vector<std::string> const & requiredFields) {
×
191
    for (auto const & field: requiredFields) {
×
192
        if (!item.contains(field)) {
×
193
            std::cerr << "Error: Missing required field \"" << field << "\" in JSON item." << std::endl;
×
194
            return false;
×
195
        }
196
    }
197
    return true;
×
198
}
199

200
void checkOptionalFields(json const & item, std::vector<std::string> const & optionalFields) {
×
201
    for (auto const & field: optionalFields) {
×
202
        if (!item.contains(field)) {
×
203
            std::cout << "Warning: Optional field \"" << field << "\" is missing in JSON item." << std::endl;
×
204
        }
205
    }
206
}
×
207

208
DM_DataType stringToDataType(std::string const & data_type_str) {
×
209
    if (data_type_str == "video") return DM_DataType::Video;
×
210
    if (data_type_str == "points") return DM_DataType::Points;
×
211
    if (data_type_str == "mask") return DM_DataType::Mask;
×
212
    if (data_type_str == "line") return DM_DataType::Line;
×
213
    if (data_type_str == "analog") return DM_DataType::Analog;
×
214
    if (data_type_str == "digital_event") return DM_DataType::DigitalEvent;
×
215
    if (data_type_str == "digital_interval") return DM_DataType::DigitalInterval;
×
216
    if (data_type_str == "tensor") return DM_DataType::Tensor;
×
217
    if (data_type_str == "time") return DM_DataType::Time;
×
218
    return DM_DataType::Unknown;
×
219
}
220

221
std::vector<DataInfo> load_data_from_json_config(DataManager * dm, std::string const & json_filepath) {
×
222
    std::vector<DataInfo> data_info_list;
×
223
    // Open JSON file
224
    std::ifstream ifs(json_filepath);
×
225
    if (!ifs.is_open()) {
×
226
        std::cerr << "Failed to open JSON file: " << json_filepath << std::endl;
×
227
        return data_info_list;
×
228
    }
229

230
    // Parse JSON
231
    json j;
×
232
    ifs >> j;
×
233

234
    // get base path of filepath
235
    std::filesystem::path const base_path = std::filesystem::path(json_filepath).parent_path();
×
236

237
    // Iterate through JSON array
238
    for (auto const & item: j) {
×
239

240
        if (!checkRequiredFields(item, {"data_type", "name", "filepath"})) {
×
241
            continue;// Exit if any required field is missing
×
242
        }
243

244
        std::string const data_type_str = item["data_type"];
×
245
        auto const data_type = stringToDataType(data_type_str);
×
246
        if (data_type == DM_DataType::Unknown) {
×
247
            std::cout << "Unknown data type: " << data_type_str << std::endl;
×
248
            continue;
×
249
        }
250

251
        std::string const name = item["name"];
×
252

253
        auto file_exists = processFilePath(item["filepath"], base_path);
×
254
        if (!file_exists) {
×
255
            std::cerr << "File does not exist: " << item["filepath"] << std::endl;
×
256
            continue;
×
257
        }
258

259
        std::string const file_path = file_exists.value();
×
260

261
        switch (data_type) {
×
262
            case DM_DataType::Video: {
×
263

264
                auto video_data = load_video_into_VideoData(file_path);
×
265
                dm->setData<VideoData>("media", video_data);
×
266

267
                data_info_list.push_back({name, "VideoData", ""});
×
268
                break;
×
269
            }
×
270
            case DM_DataType::Points: {
×
271

272
                auto point_data = load_into_PointData(file_path, item);
×
273

274
                dm->setData<PointData>(name, point_data);
×
275

276
                std::string const color = item.value("color", "#0000FF");
×
277
                data_info_list.push_back({name, "PointData", color});
×
278
                break;
×
279
            }
×
280
            case DM_DataType::Mask: {
×
281

282
                auto mask_data = load_into_MaskData(file_path, item);
×
283

284
                std::string const color = item.value("color", "0000FF");
×
285
                dm->setData<MaskData>(name, mask_data);
×
286

287
                data_info_list.push_back({name, "MaskData", color});
×
288

289
                if (item.contains("operations")) {
×
290

291
                    for (auto const & operation: item["operations"]) {
×
292

293
                        std::string const operation_type = operation["type"];
×
294

295
                        if (operation_type == "area") {
×
296
                            std::cout << "Calculating area for mask: " << name << std::endl;
×
297
                            auto area_data = area(dm->getData<MaskData>(name).get());
×
298
                            std::string const output_name = name + "_area";
×
299
                            dm->setData<AnalogTimeSeries>(output_name, area_data);
×
300
                        }
×
301
                    }
×
302
                }
303
                break;
×
304
            }
×
305
            case DM_DataType::Line: {
×
306

307
                auto line_map = load_line_csv(file_path);
×
308

309
                //Get the whisker name from the filename using filesystem
310
                auto whisker_filename = std::filesystem::path(file_path).filename().string();
×
311

312
                //Remove .csv suffix from filename
313
                auto whisker_name = remove_extension(whisker_filename);
×
314

315
                dm->setData<LineData>(whisker_name, std::make_shared<LineData>(line_map));
×
316

317
                std::string const color = item.value("color", "0000FF");
×
318

319
                data_info_list.push_back({name, "LineData", color});
×
320

321
                break;
×
322
            }
×
323
            case DM_DataType::Analog: {
×
324

325
                auto analog_time_series = load_into_AnalogTimeSeries(file_path, item);
×
326

327
                for (int channel = 0; channel < analog_time_series.size(); channel++) {
×
328
                    std::string const channel_name = name + "_" + std::to_string(channel);
×
329

330
                    dm->setData<AnalogTimeSeries>(channel_name, analog_time_series[channel]);
×
331

332
                    if (item.contains("clock")) {
×
333
                        std::string const clock = item["clock"];
×
334
                        dm->setTimeFrame(channel_name, clock);
×
335
                    }
×
336
                }
×
337
                break;
×
338
            }
×
339
            case DM_DataType::DigitalEvent: {
×
340

341
                auto digital_event_series = load_into_DigitalEventSeries(file_path, item);
×
342

343
                for (int channel = 0; channel < digital_event_series.size(); channel++) {
×
344
                    std::string const channel_name = name + "_" + std::to_string(channel);
×
345

346
                    dm->setData<DigitalEventSeries>(channel_name, digital_event_series[channel]);
×
347

348
                    if (item.contains("clock")) {
×
349
                        std::string const clock = item["clock"];
×
350
                        dm->setTimeFrame(channel_name, clock);
×
351
                    }
×
352
                }
×
353
                break;
×
354
            }
×
355
            case DM_DataType::DigitalInterval: {
×
356

357
                auto digital_interval_series = load_into_DigitalIntervalSeries(file_path, item);
×
358
                dm->setData<DigitalIntervalSeries>(name, digital_interval_series);
×
359

360
                break;
×
361
            }
×
362
            case DM_DataType::Tensor: {
×
363

364
                if (item["format"] == "numpy") {
×
365

366
                    TensorData tensor_data;
×
367
                    loadNpyToTensorData(file_path, tensor_data);
×
368

369
                    dm->setData<TensorData>(name, std::make_shared<TensorData>(tensor_data));
×
370

371
                } else {
×
372
                    std::cout << "Format " << item["format"] << " not found for " << name << std::endl;
×
373
                }
374
                break;
×
375
            }
376
            case DM_DataType::Time: {
×
377

378
                if (item["format"] == "uint16") {
×
379

380
                    int const channel = item["channel"];
×
381
                    std::string const transition = item["transition"];
×
382

383
                    int const header_size = item.value("header_size", 0);
×
384

385
                    auto opts = Loader::BinaryAnalogOptions{.file_path = file_path,
×
386
                                                            .header_size_bytes = static_cast<size_t>(header_size)};
×
387
                    auto data = readBinaryFile<uint16_t>(opts);
×
388

389
                    auto digital_data = Loader::extractDigitalData(data, channel);
×
390
                    auto events = Loader::extractEvents(digital_data, transition);
×
391

392
                    // convert to int with std::transform
393
                    std::vector<int> events_int;
×
394
                    events_int.reserve(events.size());
×
395
                    for (auto e: events) {
×
396
                        events_int.push_back(static_cast<int>(e));
×
397
                    }
398
                    std::cout << "Loaded " << events_int.size() << " events for " << name << std::endl;
×
399

400
                    auto timeframe = std::make_shared<TimeFrame>(events_int);
×
401
                    dm->setTime(name, timeframe);
×
402
                }
×
403

404
                if (item["format"] == "uint16_length") {
×
405

406
                    int const header_size = item.value("header_size", 0);
×
407

408
                    auto opts = Loader::BinaryAnalogOptions{.file_path = file_path,
×
409
                                                            .header_size_bytes = static_cast<size_t>(header_size)};
×
410
                    auto data = readBinaryFile<uint16_t>(opts);
×
411

412
                    std::vector<int> t(data.size());
×
413
                    std::iota(std::begin(t), std::end(t), 0);
×
414

415
                    std::cout << "Total of " << t.size() << " timestamps for " << name << std::endl;
×
416

417
                    auto timeframe = std::make_shared<TimeFrame>(t);
×
418
                    dm->setTime(name, timeframe);
×
419
                }
×
420
                break;
×
421
            }
422
            default:
×
423
                std::cout << "Unsupported data type: " << data_type_str << std::endl;
×
424
                continue;
×
425
        }
×
426
        if (item.contains("clock")) {
×
427
            std::string const clock = item["clock"];
×
428
            std::cout << "Setting time for " << name << " to " << clock << std::endl;
×
429
            dm->setTimeFrame(name, clock);
×
430
        }
×
431
    }
×
432

433
    return data_info_list;
434
}
×
435

436
DM_DataType DataManager::getType(std::string const & key) const {
×
437
    auto it = _data.find(key);
×
438
    if (it != _data.end()) {
×
439
        if (std::holds_alternative<std::shared_ptr<MediaData>>(it->second)) {
×
440
            return DM_DataType::Video;
×
441
        } else if (std::holds_alternative<std::shared_ptr<PointData>>(it->second)) {
×
442
            return DM_DataType::Points;
×
443
        } else if (std::holds_alternative<std::shared_ptr<LineData>>(it->second)) {
×
444
            return DM_DataType::Line;
×
445
        } else if (std::holds_alternative<std::shared_ptr<MaskData>>(it->second)) {
×
446
            return DM_DataType::Mask;
×
447
        } else if (std::holds_alternative<std::shared_ptr<AnalogTimeSeries>>(it->second)) {
×
448
            return DM_DataType::Analog;
×
449
        } else if (std::holds_alternative<std::shared_ptr<DigitalEventSeries>>(it->second)) {
×
450
            return DM_DataType::DigitalEvent;
×
451
        } else if (std::holds_alternative<std::shared_ptr<DigitalIntervalSeries>>(it->second)) {
×
452
            return DM_DataType::DigitalInterval;
×
453
        } else if (std::holds_alternative<std::shared_ptr<TensorData>>(it->second)) {
×
454
            return DM_DataType::Tensor;
×
455
        }
456
        return DM_DataType::Unknown;
×
457
    }
458
    return DM_DataType::Unknown;
×
459
}
460

461
std::string convert_data_type_to_string(DM_DataType type) {
×
462
    switch (type) {
×
463
        case DM_DataType::Video:
×
464
            return "video";
×
465
        case DM_DataType::Points:
×
466
            return "points";
×
467
        case DM_DataType::Mask:
×
468
            return "mask";
×
469
        case DM_DataType::Line:
×
470
            return "line";
×
471
        case DM_DataType::Analog:
×
472
            return "analog";
×
473
        case DM_DataType::DigitalEvent:
×
474
            return "digital_event";
×
475
        case DM_DataType::DigitalInterval:
×
476
            return "digital_interval";
×
477
        case DM_DataType::Tensor:
×
478
            return "tensor";
×
479
        case DM_DataType::Time:
×
480
            return "time";
×
481
        default:
×
482
            return "unknown";
×
483
    }
484
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc