• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / ibllib / 1761696499260742

05 Oct 2023 09:46AM UTC coverage: 55.27% (-1.4%) from 56.628%
1761696499260742

Pull #655

continuous-integration/UCL

bimac
add @sleepless decorator
Pull Request #655: add @sleepless decorator

21 of 21 new or added lines in 1 file covered. (100.0%)

10330 of 18690 relevant lines covered (55.27%)

0.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

53.38
/ibllib/qc/camera.py
1
"""Video quality control.
1✔
2

3
This module runs a list of quality control metrics on the camera and extracted video data.
4

5
Examples
6
--------
7
Run right camera QC, downloading all but video file
8

9
>>> qc = CameraQC(eid, 'right', download_data=True, stream=True)
10
>>> qc.run()
11

12
Run left camera QC with session path, update QC field in Alyx
13

14
>>> qc = CameraQC(session_path, 'left')
15
>>> outcome, extended = qc.run(update=True)  # Returns outcome of videoQC only
16
>>> print(f'video QC = {outcome}; overall session QC = {qc.outcome}')  # NB difference outcomes
17

18
Run only video QC (no timestamp/alignment checks) on 20 frames for the body camera
19

20
>>> qc = CameraQC(eid, 'body', n_samples=20)
21
>>> qc.load_video_data()  # Quicker than loading all data
22
>>> qc.run()
23

24
Run specific video QC check and display the plots
25

26
>>> qc = CameraQC(eid, 'left')
27
>>> qc.load_data(download_data=True)
28
>>> qc.check_position(display=True)  # NB: Not all checks make plots
29

30
Run the QC for all cameras
31

32
>>> qcs = run_all_qc(eid)
33
>>> qcs['left'].metrics  # Dict of checks and outcomes for left camera
34
"""
35
import logging
1✔
36
from inspect import getmembers, isfunction
1✔
37
from pathlib import Path
1✔
38
from itertools import chain
1✔
39
import copy
1✔
40

41
import cv2
1✔
42
import numpy as np
1✔
43
import matplotlib.pyplot as plt
1✔
44
import pandas as pd
1✔
45
from matplotlib.patches import Rectangle
1✔
46
from labcams import parse_cam_log
1✔
47

48
import one.alf.io as alfio
1✔
49
from one.util import filter_datasets
1✔
50
from one.alf.spec import is_session_path
1✔
51
from one.alf.exceptions import ALFObjectNotFound
1✔
52
from iblutil.util import Bunch
1✔
53
from iblutil.numerical import within_ranges
1✔
54

55
from ibllib.io.extractors.camera import extract_camera_sync, extract_all
1✔
56
from ibllib.io.extractors import ephys_fpga, training_wheel, mesoscope
1✔
57
from ibllib.io.extractors.video_motion import MotionAlignment
1✔
58
from ibllib.io.extractors.base import get_session_extractor_type
1✔
59
from ibllib.io import raw_data_loaders as raw
1✔
60
from ibllib.io.raw_daq_loaders import load_timeline_sync_and_chmap
1✔
61
from ibllib.io.session_params import read_params, get_sync, get_sync_namespace
1✔
62
import brainbox.behavior.wheel as wh
1✔
63
from ibllib.io.video import get_video_meta, get_video_frames_preload, assert_valid_label
1✔
64
from . import base
1✔
65

66
_log = logging.getLogger(__name__)
1✔
67

68

69
class CameraQC(base.QC):
1✔
70
    """A class for computing camera QC metrics"""
1✔
71
    dstypes = [
1✔
72
        '_ibl_experiment.description',
73
        '_iblrig_Camera.frameData',  # Replaces the next 3 datasets
74
        '_iblrig_Camera.frame_counter',
75
        '_iblrig_Camera.GPIO',
76
        '_iblrig_Camera.timestamps',
77
        '_iblrig_taskData.raw',
78
        '_iblrig_taskSettings.raw',
79
        '_iblrig_Camera.raw',
80
        'camera.times',
81
        'wheel.position',
82
        'wheel.timestamps'
83
    ]
84
    dstypes_fpga = [
1✔
85
        '_spikeglx_sync.channels',
86
        '_spikeglx_sync.polarities',
87
        '_spikeglx_sync.times',
88
        'ephysData.raw.meta'
89
    ]
90
    """Recall that for the training rig there is only one side camera at 30 Hz and 1280 x 1024 px.
1✔
91
    For the recording rig there are two label cameras (left: 60 Hz, 1280 x 1024 px;
92
    right: 150 Hz, 640 x 512 px) and one body camera (30 Hz, 640 x 512 px). """
93
    video_meta = {
1✔
94
        'training': {
95
            'left': {
96
                'fps': 30,
97
                'width': 1280,
98
                'height': 1024
99
            }
100
        },
101
        'ephys': {
102
            'left': {
103
                'fps': 60,
104
                'width': 1280,
105
                'height': 1024
106
            },
107
            'right': {
108
                'fps': 150,
109
                'width': 640,
110
                'height': 512
111
            },
112
            'body': {
113
                'fps': 30,
114
                'width': 640,
115
                'height': 512
116
            },
117
        }
118
    }
119

120
    def __init__(self, session_path_or_eid, camera, **kwargs):
1✔
121
        """
122
        :param session_path_or_eid: A session id or path
123
        :param camera: The camera to run QC on, if None QC is run for all three cameras
124
        :param n_samples: The number of frames to sample for the position and brightness QC
125
        :param stream: If true and local video files not available, the data are streamed from
126
        the remote source.
127
        :param log: A logging.Logger instance, if None the 'ibllib' logger is used
128
        :param one: An ONE instance for fetching and setting the QC on Alyx
129
        """
130
        # When an eid is provided, we will download the required data by default (if necessary)
131
        download_data = not is_session_path(session_path_or_eid)
1✔
132
        self.download_data = kwargs.pop('download_data', download_data)
1✔
133
        self.stream = kwargs.pop('stream', None)
1✔
134
        self.n_samples = kwargs.pop('n_samples', 100)
1✔
135
        self.sync_collection = kwargs.pop('sync_collection', None)
1✔
136
        self.sync = kwargs.pop('sync_type', None)
1✔
137
        super().__init__(session_path_or_eid, **kwargs)
1✔
138

139
        # Data
140
        self.label = assert_valid_label(camera)
1✔
141
        filename = f'_iblrig_{self.label}Camera.raw*.mp4'
1✔
142
        raw_video_path = self.session_path.joinpath('raw_video_data')
1✔
143
        self.video_path = next(raw_video_path.glob(filename), None)
1✔
144

145
        # If local video doesn't exist, change video path to URL
146
        if not self.video_path and self.stream is not False and self.one is not None:
1✔
147
            try:
×
148
                self.stream = True
×
149
                self.video_path = self.one.path2url(raw_video_path / filename.replace('*', ''))
×
150
            except (StopIteration, ALFObjectNotFound):
×
151
                _log.error('No remote or local video file found')
×
152
                self.video_path = None
×
153

154
        logging.disable(logging.NOTSET)
1✔
155
        keys = ('count', 'pin_state', 'audio', 'fpga_times', 'wheel', 'video',
1✔
156
                'frame_samples', 'timestamps', 'camera_times', 'bonsai_times')
157
        self.data = Bunch.fromkeys(keys)
1✔
158
        self.frame_samples_idx = None
1✔
159

160
        # QC outcomes map
161
        self.metrics = None
1✔
162
        self.outcome = 'NOT_SET'
1✔
163

164
        # Specify any checks to remove
165
        self.checks_to_remove = []
1✔
166
        self._type = None
1✔
167

168
    @property
1✔
169
    def type(self):
1✔
170
        """
171
        Returns the camera type based on the protocol.
172
        :return: Returns either None, 'ephys' or 'training'
173
        """
174
        if not self._type:
1✔
175
            return
1✔
176
        else:
177
            return 'ephys' if 'ephys' in self._type else 'training'
1✔
178

179
    def load_data(self, download_data: bool = None, extract_times: bool = False, load_video: bool = True) -> None:
1✔
180
        """Extract the data from raw data files
181
        Extracts all the required task data from the raw data files.
182

183
        Data keys:
184
            - count (int array): the sequential frame number (n, n+1, n+2...)
185
            - pin_state (): the camera GPIO pin; records the audio TTLs; should be one per frame
186
            - audio (float array): timestamps of audio TTL fronts
187
            - fpga_times (float array): timestamps of camera TTLs recorded by FPGA
188
            - timestamps (float array): extracted video timestamps (the camera.times ALF)
189
            - bonsai_times (datetime array): system timestamps of video PC; should be one per frame
190
            - camera_times (float array): camera frame timestamps extracted from frame headers
191
            - wheel (Bunch): rotary encoder timestamps, position and period used for wheel motion
192
            - video (Bunch): video meta data, including dimensions and FPS
193
            - frame_samples (h x w x n array): array of evenly sampled frames (1 colour channel)
194

195
        :param download_data: if True, any missing raw data is downloaded via ONE.
196
        Missing data will raise an AssertionError
197
        :param extract_times: if True, the camera.times are re-extracted from the raw data
198
        :param load_video: if True, calls the load_video_data method
199
        """
200
        assert self.session_path, 'no session path set'
1✔
201
        if download_data is not None:
1✔
202
            self.download_data = download_data
1✔
203
        if self.download_data and self.eid and self.one and not self.one.offline:
1✔
204
            self.ensure_required_data()
×
205
        _log.info('Gathering data for QC')
1✔
206

207
        # Get frame count and pin state
208
        self.data['count'], self.data['pin_state'] = \
1✔
209
            raw.load_embedded_frame_data(self.session_path, self.label, raw=True)
210

211
        # If there is an experiment description and there are video parameters
212
        sess_params = read_params(self.session_path) or {}
1✔
213
        task_collection = get_task_collection(sess_params)
1✔
214
        ns = get_sync_namespace(sess_params)
1✔
215
        self._set_sync(sess_params)
1✔
216
        if not self.sync:
1✔
217
            if not self.type:
1✔
218
                self._type = get_session_extractor_type(self.session_path, task_collection)
1✔
219
            self.sync = 'nidq' if 'ephys' in self.type else 'bpod'
1✔
220
        self._update_meta_from_session_params(sess_params)
1✔
221

222
        # Load the audio and raw FPGA times
223
        if self.sync != 'bpod' and self.sync is not None:
1✔
224
            self.sync_collection = self.sync_collection or 'raw_ephys_data'
1✔
225
            ns = ns or 'spikeglx'
1✔
226
            if ns == 'spikeglx':
1✔
227
                sync, chmap = ephys_fpga.get_sync_and_chn_map(self.session_path, self.sync_collection)
1✔
228
            elif ns == 'timeline':
1✔
229
                sync, chmap = load_timeline_sync_and_chmap(self.session_path / self.sync_collection)
×
230
            else:
231
                raise NotImplementedError(f'Unknown namespace "{ns}"')
1✔
232
            audio_ttls = ephys_fpga.get_sync_fronts(sync, chmap['audio'])
1✔
233
            self.data['audio'] = audio_ttls['times']  # Get rises
1✔
234
            # Load raw FPGA times
235
            cam_ts = extract_camera_sync(sync, chmap)
1✔
236
            self.data['fpga_times'] = cam_ts[self.label]
1✔
237
        else:
238
            self.sync_collection = self.sync_collection or task_collection
1✔
239
            bpod_data = raw.load_data(self.session_path, task_collection)
1✔
240
            _, audio_ttls = raw.load_bpod_fronts(
1✔
241
                self.session_path, data=bpod_data, task_collection=task_collection)
242
            self.data['audio'] = audio_ttls['times']
1✔
243

244
        # Load extracted frame times
245
        alf_path = self.session_path / 'alf'
1✔
246
        try:
1✔
247
            assert not extract_times
1✔
248
            self.data['timestamps'] = alfio.load_object(
1✔
249
                alf_path, f'{self.label}Camera', short_keys=True)['times']
250
        except AssertionError:  # Re-extract
1✔
251
            kwargs = dict(video_path=self.video_path, labels=self.label)
1✔
252
            if self.sync != 'bpod' and self.sync is not None:
1✔
253
                kwargs = {**kwargs, 'sync': sync, 'chmap': chmap}  # noqa
1✔
254
            outputs, _ = extract_all(self.session_path, self.sync, save=False,
1✔
255
                                     sync_collection=self.sync_collection, **kwargs)
256
            self.data['timestamps'] = outputs[f'{self.label}_camera_timestamps']
1✔
257
        except ALFObjectNotFound:
1✔
258
            _log.warning('no camera.times ALF found for session')
1✔
259

260
        # Get audio and wheel data
261
        wheel_keys = ('timestamps', 'position')
1✔
262
        try:
1✔
263
            # glob in case wheel data are in sub-collections
264
            alf_path = next(alf_path.rglob('*wheel.timestamps*')).parent
1✔
265
            self.data['wheel'] = alfio.load_object(alf_path, 'wheel', short_keys=True)
1✔
266
        except (StopIteration, ALFObjectNotFound):
1✔
267
            # Extract from raw data
268
            if self.sync != 'bpod' and self.sync is not None:
1✔
269
                if ns == 'spikeglx':
×
270
                    wheel_data = ephys_fpga.extract_wheel_sync(sync, chmap)
×
271
                elif ns == 'timeline':
×
272
                    extractor = mesoscope.TimelineTrials(self.session_path, sync_collection=self.sync_collection)
×
273
                    wheel_data = extractor.extract_wheel_sync()
×
274
                else:
275
                    raise NotImplementedError(f'Unknown namespace "{ns}"')
×
276
            else:
277
                wheel_data = training_wheel.get_wheel_position(
1✔
278
                    self.session_path, task_collection=task_collection)
279
            self.data['wheel'] = Bunch(zip(wheel_keys, wheel_data))
1✔
280

281
        # Find short period of wheel motion for motion correlation.
282
        if data_for_keys(wheel_keys, self.data['wheel']) and self.data['timestamps'] is not None:
1✔
283
            self.data['wheel'].period = self.get_active_wheel_period(self.data['wheel'])
1✔
284

285
        # Load Bonsai frame timestamps
286
        try:
1✔
287
            ssv_times = raw.load_camera_ssv_times(self.session_path, self.label)
1✔
288
            self.data['bonsai_times'], self.data['camera_times'] = ssv_times
1✔
289
        except AssertionError:
×
290
            _log.warning('No Bonsai video timestamps file found')
×
291

292
        # Gather information from video file
293
        if load_video:
1✔
294
            _log.info('Inspecting video file...')
1✔
295
            self.load_video_data()
1✔
296

297
    def load_video_data(self):
1✔
298
        # Get basic properties of video
299
        try:
1✔
300
            self.data['video'] = get_video_meta(self.video_path, one=self.one)
1✔
301
            # Sample some frames from the video file
302
            indices = np.linspace(100, self.data['video'].length - 100, self.n_samples).astype(int)
1✔
303
            self.frame_samples_idx = indices
1✔
304
            self.data['frame_samples'] = get_video_frames_preload(self.video_path, indices,
1✔
305
                                                                  mask=np.s_[:, :, 0])
306
        except AssertionError:
×
307
            _log.error('Failed to read video file; setting outcome to CRITICAL')
×
308
            self._outcome = 'CRITICAL'
×
309

310
    @staticmethod
1✔
311
    def get_active_wheel_period(wheel, duration_range=(3., 20.), display=False):
1✔
312
        """
313
        Attempts to find a period of movement where the wheel accelerates and decelerates for
314
        the wheel motion alignment QC.
315
        :param wheel: A Bunch of wheel timestamps and position data
316
        :param duration_range: The candidates must be within min/max duration range
317
        :param display: If true, plot the selected wheel movement
318
        :return: 2-element array comprising the start and end times of the active period
319
        """
320
        pos, ts = wh.interpolate_position(wheel.timestamps, wheel.position)
1✔
321
        v, acc = wh.velocity_filtered(pos, 1000)
1✔
322
        on, off, *_ = wh.movements(ts, acc, pos_thresh=.1, make_plots=False)
1✔
323
        edges = np.c_[on, off]
1✔
324
        indices, _ = np.where(np.logical_and(
1✔
325
            np.diff(edges) > duration_range[0], np.diff(edges) < duration_range[1]))
326
        if len(indices) == 0:
1✔
327
            _log.warning('No period of wheel movement found for motion alignment.')
×
328
            return None
×
329
        # Pick movement somewhere in the middle
330
        i = indices[int(indices.size / 2)]
1✔
331
        if display:
1✔
332
            _, (ax0, ax1) = plt.subplots(2, 1, sharex='all')
×
333
            mask = np.logical_and(ts > edges[i][0], ts < edges[i][1])
×
334
            ax0.plot(ts[mask], pos[mask])
×
335
            ax1.plot(ts[mask], acc[mask])
×
336
        return edges[i]
1✔
337

338
    def ensure_required_data(self):
1✔
339
        """
340
        Ensures the datasets required for QC are local.  If the download_data attribute is True,
341
        any missing data are downloaded.  If all the data are not present locally at the end of
342
        it an exception is raised.  If the stream attribute is True, the video file is not
343
        required to be local, however it must be remotely accessible.
344
        NB: Requires a valid instance of ONE and a valid session eid.
345
        :return:
346
        """
347
        assert self.one is not None, 'ONE required to download data'
×
348

349
        sess_params = {}
×
350
        if self.download_data:
×
351
            dset = self.one.list_datasets(self.session_path, '*experiment.description*', details=True)
×
352
            if self.one._check_filesystem(dset):
×
353
                sess_params = read_params(self.session_path) or {}
×
354
        else:
355
            sess_params = read_params(self.session_path) or {}
×
356
        self._set_sync(sess_params)
×
357

358
        # Get extractor type
359
        is_ephys = 'ephys' in (self.type or self.one.get_details(self.eid)['task_protocol'])
×
360
        self.sync = self.sync or ('nidq' if is_ephys else 'bpod')
×
361

362
        is_fpga = 'bpod' not in self.sync
×
363

364
        # dataset collections outside this list are ignored (e.g. probe00, raw_passive_data)
365
        collections = (
×
366
            'alf', '', get_task_collection(sess_params), get_video_collection(sess_params, self.label)
367
        )
368
        dtypes = self.dstypes + self.dstypes_fpga if is_fpga else self.dstypes
×
369
        assert_unique = True
×
370
        # Check we have raw ephys data for session
371
        if is_ephys:
×
372
            if len(self.one.list_datasets(self.eid, collection='raw_ephys_data')) == 0:
×
373
                # Assert 3A probe model; if so download all probe data
374
                det = self.one.get_details(self.eid, full=True)
×
375
                probe_model = next(x['model'] for x in det['probe_insertion'])
×
376
                assert probe_model == '3A', 'raw ephys data missing'
×
377
                collections += (self.sync_collection or 'raw_ephys_data',)
×
378
                if sess_params:
×
379
                    probes = sess_params.get('devices', {}).get('neuropixel', {})
×
380
                    probes = set(x.get('collection') for x in chain(*map(dict.values, probes)))
×
381
                    collections += tuple(probes)
×
382
                else:
383
                    collections += ('raw_ephys_data/probe00', 'raw_ephys_data/probe01')
×
384
                assert_unique = False
×
385
            else:
386
                # 3B probes have data in root collection
387
                collections += ('raw_ephys_data',)
×
388
        for dstype in dtypes:
×
389
            datasets = self.one.type2datasets(self.eid, dstype, details=True)
×
390
            if 'camera' in dstype.lower():  # Download individual camera file
×
391
                datasets = filter_datasets(datasets, filename=f'.*{self.label}.*')
×
392
            else:  # Ignore probe datasets, etc.
393
                _datasets = filter_datasets(datasets, collection=collections, assert_unique=assert_unique)
×
394
                if '' in collections:  # Must be handled as a separate query
×
395
                    datasets = filter_datasets(datasets, collection='', assert_unique=assert_unique)
×
396
                    datasets = pd.concat([datasets, _datasets]).sort_index()
×
397
                else:
398
                    datasets = _datasets
×
399

400
            if any(x.endswith('.mp4') for x in datasets.rel_path) and self.stream:
×
401
                names = [x.split('/')[-1] for x in self.one.list_datasets(self.eid, details=False)]
×
402
                assert f'_iblrig_{self.label}Camera.raw.mp4' in names, 'No remote video file found'
×
403
                continue
×
404
            optional = ('camera.times', '_iblrig_Camera.raw', 'wheel.position', 'wheel.timestamps',
×
405
                        '_iblrig_Camera.timestamps', '_iblrig_Camera.frame_counter', '_iblrig_Camera.GPIO',
406
                        '_iblrig_Camera.frameData', '_ibl_experiment.description')
407
            present = (
×
408
                self.one._check_filesystem(datasets)
409
                if self.download_data
410
                else (next(self.session_path.rglob(d), None) for d in datasets['rel_path'])
411
            )
412

413
            required = (dstype not in optional)
×
414
            all_present = not datasets.empty and all(present)
×
415
            assert all_present or not required, f'Dataset {dstype} not found'
×
416

417
        if not self.type and self.sync != 'nidq':
×
418
            self._type = get_session_extractor_type(self.session_path)
×
419

420
    def _set_sync(self, session_params=False):
1✔
421
        """Set the sync and sync_collection attributes if not already set.
422

423
        Also set the type attribute based on the sync. NB The type attribute is for legacy sessions.
424

425
        Parameters
426
        ----------
427
        session_params : dict, bool
428
            The loaded experiment description file.  If False, attempts to load it from the session_path.
429
        """
430
        if session_params is False:
1✔
431
            if not self.session_path:
×
432
                raise ValueError('No session path set')
×
433
            session_params = read_params(self.session_path)
×
434
        sync, sync_dict = get_sync(session_params)
1✔
435
        self.sync = self.sync or sync
1✔
436
        self.sync_collection = self.sync_collection or sync_dict.get('collection')
1✔
437
        if self.sync:
1✔
438
            self._type = 'ephys' if self.sync == 'nidq' else 'training'
1✔
439

440
    def _update_meta_from_session_params(self, sess_params):
1✔
441
        """
442
        Update the default expected video properties with those defined in the experiment
443
        description file (if any).  This updates the `video_meta` property with the fps, width and
444
        height for the type and camera label.
445

446
        Parameters
447
        ----------
448
        sess_params : dict
449
            The loaded experiment.description file.
450
        """
451
        try:
1✔
452
            assert sess_params
1✔
453
            video_pars = sess_params.get('devices', {}).get('cameras', {}).get(self.label)
1✔
454
        except (AssertionError, KeyError):
1✔
455
            return
1✔
456
        PROPERTIES = ('width', 'height', 'fps')
1✔
457
        video_meta = copy.deepcopy(self.video_meta)  # must re-assign as it's a class attribute
1✔
458
        if self.type not in video_meta:
1✔
459
            video_meta[self.type] = {}
×
460
        if self.label not in video_meta[self.type]:
1✔
461
            video_meta[self.type][self.label] = {}
×
462
        video_meta[self.type][self.label].update(
1✔
463
            **{k: v for k, v in video_pars.items() if k in PROPERTIES}
464
        )
465
        self.video_meta = video_meta
1✔
466

467
    def run(self, update: bool = False, **kwargs) -> (str, dict):
1✔
468
        """
469
        Run video QC checks and return outcome
470
        :param update: if True, updates the session QC fields on Alyx
471
        :param download_data: if True, downloads any missing data if required
472
        :param extract_times: if True, re-extracts the camera timestamps from the raw data
473
        :returns: overall outcome as a str, a dict of checks and their outcomes
474
        """
475
        _log.info(f'Computing QC outcome for {self.label} camera, session {self.eid}')
1✔
476
        namespace = f'video{self.label.capitalize()}'
1✔
477
        if all(x is None for x in self.data.values()):
1✔
478
            self.load_data(**kwargs)
1✔
479
        if self.data['frame_samples'] is None or self.data['timestamps'] is None:
1✔
480
            return 'NOT_SET', {}
1✔
481
        if self.data['timestamps'].shape[0] == 0:
1✔
482
            _log.error(f'No timestamps for {self.label} camera; setting outcome to CRITICAL')
×
483
            return 'CRITICAL', {}
×
484

485
        def is_metric(x):
1✔
486
            return isfunction(x) and x.__name__.startswith('check_')
1✔
487
        # import importlib
488
        # classe = getattr(importlib.import_module(self.__module__), self.__name__)
489
        # print(classe)
490

491
        checks = getmembers(self.__class__, is_metric)
1✔
492
        checks = self.remove_check(checks)
1✔
493
        self.metrics = {f'_{namespace}_' + k[6:]: fn(self) for k, fn in checks}
1✔
494

495
        values = [x if isinstance(x, str) else x[0] for x in self.metrics.values()]
1✔
496
        code = max(base.CRITERIA[x] for x in values)
1✔
497
        outcome = next(k for k, v in base.CRITERIA.items() if v == code)
1✔
498

499
        if update:
1✔
500
            extended = {
×
501
                k: 'NOT_SET' if v is None else v
502
                for k, v in self.metrics.items()
503
            }
504
            self.update_extended_qc(extended)
×
505
            self.update(outcome, namespace)
×
506
        return outcome, self.metrics
1✔
507

508
    def remove_check(self, checks):
1✔
509
        if len(self.checks_to_remove) == 0:
1✔
510
            return checks
1✔
511
        else:
512
            for check in self.checks_to_remove:
×
513
                check_names = [ch[0] for ch in checks]
×
514
                idx = check_names.index(check)
×
515
                checks.pop(idx)
×
516
            return checks
×
517

518
    def check_brightness(self, bounds=(40, 200), max_std=20, roi=True, display=False):
1✔
519
        """Check that the video brightness is within a given range
520
        The mean brightness of each frame must be with the bounds provided, and the standard
521
        deviation across samples frames should be less then the given value.  Assumes that the
522
        frame samples are 2D (no colour channels).
523

524
        :param bounds: For each frame, check that: bounds[0] < M < bounds[1],
525
        where M = mean(frame). If less than 75% of sample frames outside of these bounds, the
526
        outcome is WARNING. If <75% of frames within twice the bounds, the outcome is FAIL.
527
        :param max_std: The standard deviation of the frame luminance means must be less than this
528
        :param roi: If True, check brightness on ROI of frame
529
        :param display: When True the mean frame luminance is plotted against sample frames.
530
        The sample frames with the lowest and highest mean luminance are shown.
531
        """
532
        if self.data['frame_samples'] is None:
1✔
533
            return 'NOT_SET'
×
534
        if roi is True:
1✔
535
            _, h, w = self.data['frame_samples'].shape
1✔
536
            if self.label == 'body':  # Latter half
1✔
537
                roi = (slice(None), slice(None), slice(int(w / 2), None, None))
×
538
            elif self.label == 'left':  # Top left quadrant (~2/3, 1/2 height)
1✔
539
                roi = (slice(None), slice(None, int(h / 2), None), slice(None, int(w * .66), None))
1✔
540
            else:  # Top right quadrant (~2/3 width, 1/2 height)
541
                roi = (slice(None), slice(None, int(h / 2), None), slice(int(w * .66), None, None))
×
542
        else:
543
            roi = (slice(None), slice(None), slice(None))
×
544
        brightness = self.data['frame_samples'][roi].mean(axis=(1, 2))
1✔
545
        # dims = self.data['frame_samples'].shape
546
        # brightness /= np.array((*dims[1:], 255)).prod()  # Normalize
547

548
        if display:
1✔
549
            f = plt.figure()
×
550
            gs = f.add_gridspec(2, 3)
×
551
            indices = self.frame_samples_idx
×
552
            # Plot mean frame luminance
553
            ax = f.add_subplot(gs[:2, :2])
×
554
            plt.plot(indices, brightness, label='brightness')
×
555
            ax.set(
×
556
                xlabel='frame #',
557
                ylabel='brightness (mean pixel)',
558
                title='Brightness')
559
            ax.hlines(bounds, 0, indices[-1],
×
560
                      colors='tab:orange', linestyles=':', label='warning bounds')
561
            ax.hlines((bounds[0] / 2, bounds[1] * 2), 0, indices[-1],
×
562
                      colors='r', linestyles=':', label='failure bounds')
563
            ax.legend()
×
564
            # Plot min-max frames
565
            for i, idx in enumerate((np.argmax(brightness), np.argmin(brightness))):
×
566
                a = f.add_subplot(gs[i, 2])
×
567
                ax.annotate('*', (indices[idx], brightness[idx]),  # this is the point to label
×
568
                            textcoords='offset points', xytext=(0, 1), ha='center')
569
                frame = self.data['frame_samples'][idx][roi[1:]]
×
570
                title = ('min' if i else 'max') + ' mean luminance = %.2f' % brightness[idx]
×
571
                self.imshow(frame, ax=a, title=title)
×
572

573
        PCT_PASS = .75  # Proportion of sample frames that must pass
1✔
574
        # Warning if brightness not within range (3/4 of frames must be between bounds)
575
        warn_range = np.logical_and(brightness > bounds[0], brightness < bounds[1])
1✔
576
        warn_range = 'PASS' if sum(warn_range) / self.n_samples > PCT_PASS else 'WARNING'
1✔
577
        # Fail if brightness not within twice the range or std less than threshold
578
        fail_range = np.logical_and(brightness > bounds[0] / 2, brightness < bounds[1] * 2)
1✔
579
        within_range = sum(fail_range) / self.n_samples > PCT_PASS
1✔
580
        fail_range = 'PASS' if within_range and np.std(brightness) < max_std else 'FAIL'
1✔
581
        return self.overall_outcome([warn_range, fail_range])
1✔
582

583
    def check_file_headers(self):
1✔
584
        """Check reported frame rate matches FPGA frame rate"""
585
        if None in (self.data['video'], self.video_meta):
1✔
586
            return 'NOT_SET'
×
587
        expected = self.video_meta[self.type][self.label]
1✔
588
        return 'PASS' if self.data['video']['fps'] == expected['fps'] else 'FAIL'
1✔
589

590
    def check_framerate(self, threshold=1.):
1✔
591
        """Check camera times match specified frame rate for camera
592

593
        :param threshold: The maximum absolute difference between timestamp sample rate and video
594
        frame rate.  NB: Does not take into account dropped frames.
595
        """
596
        if any(x is None for x in (self.data['timestamps'], self.video_meta)):
1✔
597
            return 'NOT_SET'
×
598
        fps = self.video_meta[self.type][self.label]['fps']
1✔
599
        Fs = 1 / np.median(np.diff(self.data['timestamps']))  # Approx. frequency of camera
1✔
600
        return 'PASS' if abs(Fs - fps) < threshold else 'FAIL', float(round(Fs, 3))
1✔
601

602
    def check_pin_state(self, display=False):
1✔
603
        """Check the pin state reflects Bpod TTLs"""
604
        if not data_for_keys(('video', 'pin_state', 'audio'), self.data):
1✔
605
            return 'NOT_SET'
×
606
        size_diff = int(self.data['pin_state'].shape[0] - self.data['video']['length'])
1✔
607
        # NB: The pin state can be high for 2 consecutive frames
608
        low2high = np.insert(np.diff(self.data['pin_state'][:, -1].astype(int)) == 1, 0, False)
1✔
609
        # NB: Time between two consecutive TTLs can be sub-frame, so this will fail
610
        ndiff_low2high = int(self.data['audio'][::2].size - sum(low2high))
1✔
611
        # state_ttl_matches = ndiff_low2high == 0
612
        # Check within ms of audio times
613
        if display:
1✔
614
            plt.Figure()
×
615
            plt.plot(self.data['timestamps'][low2high], np.zeros(sum(low2high)), 'o',
×
616
                     label='GPIO Low -> High')
617
            plt.plot(self.data['audio'], np.zeros(self.data['audio'].size), 'rx',
×
618
                     label='Audio TTL High')
619
            plt.xlabel('FPGA frame times / s')
×
620
            plt.gca().set(yticklabels=[])
×
621
            plt.gca().tick_params(left=False)
×
622
            plt.legend()
×
623

624
        outcome = self.overall_outcome(
1✔
625
            ('PASS' if size_diff == 0 else 'WARNING' if np.abs(size_diff) < 5 else 'FAIL',
626
             'PASS' if np.abs(ndiff_low2high) < 5 else 'WARNING')
627
        )
628
        return outcome, ndiff_low2high, size_diff
1✔
629

630
    def check_dropped_frames(self, threshold=.1):
1✔
631
        """Check how many frames were reported missing
632

633
        :param threshold: The maximum allowable percentage of dropped frames
634
        """
635
        if not data_for_keys(('video', 'count'), self.data):
1✔
636
            return 'NOT_SET'
×
637
        size_diff = int(self.data['count'].size - self.data['video']['length'])
1✔
638
        strict_increase = np.all(np.diff(self.data['count']) > 0)
1✔
639
        if not strict_increase:
1✔
640
            n_effected = np.sum(np.invert(strict_increase))
×
641
            _log.info(f'frame count not strictly increasing: '
×
642
                      f'{n_effected} frames effected ({n_effected / strict_increase.size:.2%})')
643
            return 'CRITICAL'
×
644
        dropped = np.diff(self.data['count']).astype(int) - 1
1✔
645
        pct_dropped = (sum(dropped) / len(dropped) * 100)
1✔
646
        # Calculate overall outcome for this check
647
        outcome = self.overall_outcome(
1✔
648
            ('PASS' if size_diff == 0 else 'WARNING' if np.abs(size_diff) < 5 else 'FAIL',
649
             'PASS' if pct_dropped < threshold else 'FAIL')
650
        )
651
        return outcome, int(sum(dropped)), size_diff
1✔
652

653
    def check_timestamps(self):
1✔
654
        """Check that the camera.times array is reasonable"""
655
        if not data_for_keys(('timestamps', 'video'), self.data):
1✔
656
            return 'NOT_SET'
×
657
        # Check number of timestamps matches video
658
        length_matches = self.data['timestamps'].size == self.data['video'].length
1✔
659
        # Check times are strictly increasing
660
        increasing = all(np.diff(self.data['timestamps']) > 0)
1✔
661
        # Check times do not contain nans
662
        nanless = not np.isnan(self.data['timestamps']).any()
1✔
663
        return 'PASS' if increasing and length_matches and nanless else 'FAIL'
1✔
664

665
    def check_camera_times(self):
1✔
666
        """Check that the number of raw camera timestamps matches the number of video frames"""
667
        if not data_for_keys(('bonsai_times', 'video'), self.data):
1✔
668
            return 'NOT_SET'
×
669
        length_match = len(self.data['camera_times']) == self.data['video'].length
1✔
670
        outcome = 'PASS' if length_match else 'WARNING'
1✔
671
        # 1 / np.median(np.diff(self.data.camera_times))
672
        return outcome, len(self.data['camera_times']) - self.data['video'].length
1✔
673

674
    def check_resolution(self):
1✔
675
        """Check that the timestamps and video file resolution match what we expect"""
676
        if self.data['video'] is None:
1✔
677
            return 'NOT_SET'
×
678
        actual = self.data['video']
1✔
679
        expected = self.video_meta[self.type][self.label]
1✔
680
        match = actual['width'] == expected['width'] and actual['height'] == expected['height']
1✔
681
        return 'PASS' if match else 'FAIL'
1✔
682

683
    def check_wheel_alignment(self, tolerance=(1, 2), display=False):
1✔
684
        """Check wheel motion in video correlates with the rotary encoder signal
685

686
        Check is skipped for body camera videos as the wheel is often obstructed
687

688
        Parameters
689
        ----------
690
        tolerance : int, (int, int)
691
            Maximum absolute offset in frames.  If two values, the maximum value is taken as the
692
            warning threshold.
693
        display : bool
694
            If true, the wheel motion energy is plotted against the rotary encoder.
695

696
        Returns
697
        -------
698
        str
699
            The outcome string, one of {'NOT_SET', 'FAIL', 'WARNING', 'PASS'}.
700
        int
701
            Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder
702
            signal.  Negative values mean the video was shifted backwards with respect to the wheel
703
            timestamps.
704

705
        Notes
706
        -----
707
        - A negative frame offset typically means that there were frame TTLs at the beginning that
708
        do not correspond to any video frames (sometimes the first few frames aren't saved to
709
        disk).  Since 2021-09-15 the extractor should compensate for this.
710
        """
711
        wheel_present = data_for_keys(('position', 'timestamps', 'period'), self.data['wheel'])
1✔
712
        if not wheel_present or self.label == 'body':
1✔
713
            return 'NOT_SET'
×
714

715
        # Check the selected wheel movement period occurred within camera timestamp time
716
        camera_times = self.data['timestamps']
1✔
717
        in_range = within_ranges(camera_times, self.data['wheel']['period'].reshape(-1, 2))
1✔
718
        if not in_range.any():
1✔
719
            # Check if any camera timestamps overlap with the wheel times
720
            if np.any(np.logical_and(
×
721
                camera_times > self.data['wheel']['timestamps'][0],
722
                camera_times < self.data['wheel']['timestamps'][-1])
723
            ):
724
                _log.warning('Unable to check wheel alignment: '
×
725
                             'chosen movement is not during video')
726
                return 'NOT_SET'
×
727
            else:
728
                # No overlap, return fail
729
                return 'FAIL'
×
730
        aln = MotionAlignment(self.eid, self.one, self.log, session_path=self.session_path)
1✔
731
        aln.data = self.data.copy()
1✔
732
        aln.data['camera_times'] = {self.label: camera_times}
1✔
733
        aln.video_paths = {self.label: self.video_path}
1✔
734
        offset, *_ = aln.align_motion(period=self.data['wheel'].period,
1✔
735
                                      display=display, side=self.label)
736
        if offset is None:
1✔
737
            return 'NOT_SET'
×
738
        if display:
1✔
739
            aln.plot_alignment()
×
740

741
        # Determine the outcome.  If there are two values for the tolerance, one is taken to be
742
        # a warning threshold, the other a failure threshold.
743
        out_map = {0: 'WARNING', 1: 'WARNING', 2: 'PASS'}  # 0: FAIL -> WARNING Aug 2022
1✔
744
        passed = np.abs(offset) <= np.sort(np.array(tolerance))
1✔
745
        return out_map[sum(passed)], int(offset)
1✔
746

747
    def check_position(self, hist_thresh=(75, 80), pos_thresh=(10, 15),
1✔
748
                       metric=cv2.TM_CCOEFF_NORMED,
749
                       display=False, test=False, roi=None, pct_thresh=True):
750
        """Check camera is positioned correctly
751
        For the template matching zero-normalized cross-correlation (default) should be more
752
        robust to exposure (which we're not checking here).  The L2 norm (TM_SQDIFF) should
753
        also work.
754

755
        If display is True, the template ROI (pick hashed) is plotted over a video frame,
756
        along with the threshold regions (green solid).  The histogram correlations are plotted
757
        and the full histogram is plotted for one of the sample frames and the reference frame.
758

759
        :param hist_thresh: The minimum histogram cross-correlation threshold to pass (0-1).
760
        :param pos_thresh: The maximum number of pixels off that the template matcher may be off
761
         by. If two values are provided, the lower threshold is treated as a warning boundary.
762
        :param metric: The metric to use for template matching.
763
        :param display: If true, the results are plotted
764
        :param test: If true a reference frame instead of the frames in frame_samples.
765
        :param roi: A tuple of indices for the face template in the for ((y1, y2), (x1, x2))
766
        :param pct_thresh: If true, the thresholds are treated as percentages
767
        """
768
        if not test and self.data['frame_samples'] is None:
1✔
769
            return 'NOT_SET'
×
770
        refs = self.load_reference_frames(self.label)
1✔
771
        # ensure iterable
772
        pos_thresh = np.sort(np.array(pos_thresh))
1✔
773
        hist_thresh = np.sort(np.array(hist_thresh))
1✔
774

775
        # Method 1: compareHist
776
        #### Mean hist comparison
777
        # ref_h = [cv2.calcHist([x], [0], None, [256], [0, 256]) for x in refs]
778
        # ref_h = np.array(ref_h).mean(axis=0)
779
        # frames = refs if test else self.data['frame_samples']
780
        # hists = [cv2.calcHist([x], [0], None, [256], [0, 256]) for x in frames]
781
        # test_h = np.array(hists).mean(axis=0)
782
        # corr = cv2.compareHist(test_h, ref_h, cv2.HISTCMP_CORREL)
783
        # if pct_thresh:
784
        #     corr *= 100
785
        # hist_passed = corr > hist_thresh
786
        ####
787
        ref_h = cv2.calcHist([refs[0]], [0], None, [256], [0, 256])
1✔
788
        frames = refs if test else self.data['frame_samples']
1✔
789
        hists = [cv2.calcHist([x], [0], None, [256], [0, 256]) for x in frames]
1✔
790
        corr = np.array([cv2.compareHist(test_h, ref_h, cv2.HISTCMP_CORREL) for test_h in hists])
1✔
791
        if pct_thresh:
1✔
792
            corr *= 100
1✔
793
        hist_passed = [np.all(corr > x) for x in hist_thresh]
1✔
794

795
        # Method 2:
796
        top_left, roi, template = self.find_face(roi=roi, test=test, metric=metric, refs=refs)
1✔
797
        (y1, y2), (x1, x2) = roi
1✔
798
        err = (x1, y1) - np.median(np.array(top_left), axis=0)
1✔
799
        h, w = frames[0].shape[:2]
1✔
800

801
        if pct_thresh:  # Threshold as percent
1✔
802
            # t_x, t_y = pct_thresh
803
            err_pct = [(abs(x) / y) * 100 for x, y in zip(err, (h, w))]
1✔
804
            face_passed = [all(err_pct < x) for x in pos_thresh]
1✔
805
        else:
806
            face_passed = [np.all(np.abs(err) < x) for x in pos_thresh]
×
807

808
        if display:
1✔
809
            plt.figure()
×
810
            # Plot frame with template overlay
811
            img = frames[0]
×
812
            ax0 = plt.subplot(221)
×
813
            ax0.imshow(img, cmap='gray', vmin=0, vmax=255)
×
814
            bounds = (x1 - err[0], x2 - err[0], y2 - err[1], y1 - err[1])
×
815
            ax0.imshow(template, cmap='gray', alpha=0.5, extent=bounds)
×
816
            if pct_thresh:
×
817
                for c, thresh in zip(('green', 'yellow'), pos_thresh):
×
818
                    t_y = (h / 100) * thresh
×
819
                    t_x = (w / 100) * thresh
×
820
                    xy = (x1 - t_x, y1 - t_y)
×
821
                    ax0.add_patch(Rectangle(xy, x2 - x1 + (t_x * 2), y2 - y1 + (t_y * 2),
×
822
                                            fill=True, facecolor=c, lw=0, alpha=0.05))
823
            else:
824
                for c, thresh in zip(('green', 'yellow'), pos_thresh):
×
825
                    xy = (x1 - thresh, y1 - thresh)
×
826
                    ax0.add_patch(Rectangle(xy, x2 - x1 + (thresh * 2), y2 - y1 + (thresh * 2),
×
827
                                            fill=True, facecolor=c, lw=0, alpha=0.05))
828
            xy = (x1 - err[0], y1 - err[1])
×
829
            ax0.add_patch(Rectangle(xy, x2 - x1, y2 - y1,
×
830
                                    edgecolor='pink', fill=False, hatch='//', lw=1))
831
            ax0.set(xlim=(0, img.shape[1]), ylim=(img.shape[0], 0))
×
832
            ax0.set_axis_off()
×
833
            # Plot the image histograms
834
            ax1 = plt.subplot(212)
×
835
            ax1.plot(ref_h[5:-1], label='reference frame')
×
836
            ax1.plot(np.array(hists).mean(axis=0)[5:-1], label='mean frame')
×
837
            ax1.set_xlim([0, 256])
×
838
            plt.legend()
×
839
            # Plot the correlations for each sample frame
840
            ax2 = plt.subplot(222)
×
841
            ax2.plot(corr, label='hist correlation')
×
842
            ax2.axhline(hist_thresh[0], 0, self.n_samples,
×
843
                        linestyle=':', color='r', label='fail threshold')
844
            ax2.axhline(hist_thresh[1], 0, self.n_samples,
×
845
                        linestyle=':', color='g', label='pass threshold')
846
            ax2.set(xlabel='Sample Frame #', ylabel='Hist correlation')
×
847
            plt.legend()
×
848
            plt.suptitle('Check position')
×
849
            plt.show()
×
850

851
        pass_map = {i: s for i, s in enumerate(('FAIL', 'WARNING', 'PASS'))}
1✔
852
        face_aligned = pass_map[sum(face_passed)]
1✔
853
        hist_correlates = pass_map[sum(hist_passed)]
1✔
854

855
        return self.overall_outcome([face_aligned, hist_correlates], agg=min)
1✔
856

857
    def check_focus(self, n=20, threshold=(100, 6),
1✔
858
                    roi=False, display=False, test=False, equalize=True):
859
        """Check video is in focus
860
        Two methods are used here: Looking at the high frequencies with a DFT and
861
        applying a Laplacian HPF and looking at the variance.
862

863
        Note:
864
            - Both methods are sensitive to noise (Laplacian is 2nd order filter).
865
            - The thresholds for the fft may need to be different for the left/right vs body as
866
              the distribution of frequencies in the image is different (e.g. the holder
867
              comprises mostly very high frequencies).
868
            - The image may be overall in focus but the places we care about can still be out of
869
              focus (namely the face).  For this we'll take an ROI around the face.
870
            - Focus check thrown off by brightness.  This may be fixed by equalizing the histogram
871
              (set equalize=True)
872

873
        :param n: number of frames from frame_samples data to use in check.
874
        :param threshold: the lower boundary for Laplacian variance and mean FFT filtered
875
         brightness, respectively
876
        :param roi: if False, the roi is determined via template matching for the face or body.
877
        If None, some set ROIs for face and paws are used.  A list of slices may also be passed.
878
        :param display: if true, the results are displayed
879
        :param test: if true, a set of artificially blurred reference frames are used as the
880
        input.  This can be used to selecting reasonable thresholds.
881
        :param equalize: if true, the histograms of the frames are equalized, resulting in an
882
        increased the global contrast and linear CDF.  This makes check robust to low light
883
        conditions.
884
        """
885
        no_frames = self.data['frame_samples'] is None or len(self.data['frame_samples']) == 0
1✔
886
        if not test and no_frames:
1✔
887
            return 'NOT_SET'
×
888

889
        if roi is False:
1✔
890
            top_left, roi, _ = self.find_face(test=test)  # (y1, y2), (x1, x2)
1✔
891
            h, w = map(lambda x: np.diff(x).item(), roi)
1✔
892
            x, y = np.median(np.array(top_left), axis=0).round().astype(int)
1✔
893
            roi = (np.s_[y: y + h, x: x + w],)
1✔
894
        else:
895
            ROI = {
×
896
                'left': (np.s_[:400, :561], np.s_[500:, 100:800]),  # (face, wheel)
897
                'right': (np.s_[:196, 397:], np.s_[221:, 255:]),
898
                'body': (np.s_[143:274, 84:433],)  # body holder
899
            }
900
            roi = roi or ROI[self.label]
×
901

902
        if test:
1✔
903
            """In test mode load a reference frame and run it through a normalized box filter with
904
            increasing kernel size.
905
            """
906
            idx = (0,)
×
907
            ref = self.load_reference_frames(self.label)[idx]
×
908
            kernal_sz = np.unique(np.linspace(0, 15, n, dtype=int))
×
909
            n = kernal_sz.size  # Size excluding repeated kernels
×
910
            img = np.empty((n, *ref.shape), dtype=np.uint8)
×
911
            for i, k in enumerate(kernal_sz):
×
912
                img[i] = ref.copy() if k == 0 else cv2.blur(ref, (k, k))
×
913
            if equalize:
×
914
                [cv2.equalizeHist(x, x) for x in img]
×
915
            if display:
×
916
                # Plot blurred images
917
                f, axes = plt.subplots(1, len(kernal_sz))
×
918
                for ax, ig, k in zip(axes, img, kernal_sz):
×
919
                    self.imshow(ig, ax=ax, title='Kernal ({0}, {0})'.format(k or 'None'))
×
920
                f.suptitle('Reference frame with box filter')
×
921
        else:
922
            # Sub-sample the frame samples
923
            idx = np.unique(np.linspace(0, len(self.data['frame_samples']) - 1, n, dtype=int))
1✔
924
            img = self.data['frame_samples'][idx]
1✔
925
            if equalize:
1✔
926
                [cv2.equalizeHist(x, x) for x in img]
1✔
927

928
        # A measure of the sharpness effectively taking the second derivative of the image
929
        lpc_var = np.empty((min(n, len(img)), len(roi)))
1✔
930
        for i, frame in enumerate(img[::-1]):
1✔
931
            lpc = cv2.Laplacian(frame, cv2.CV_16S, ksize=1)
1✔
932
            lpc_var[i] = [lpc[mask].var() for mask in roi]
1✔
933

934
        if display:
1✔
935
            # Plot the first sample image
936
            f = plt.figure()
×
937
            gs = f.add_gridspec(len(roi) + 1, 3)
×
938
            f.add_subplot(gs[0:len(roi), 0])
×
939
            frame = img[0]
×
940
            self.imshow(frame, title=f'Frame #{self.frame_samples_idx[idx[0]]}')
×
941
            # Plot the ROIs with and without filter
942
            lpc = cv2.Laplacian(frame, cv2.CV_16S, ksize=1)
×
943
            abs_lpc = cv2.convertScaleAbs(lpc)
×
944
            for i, r in enumerate(roi):
×
945
                f.add_subplot(gs[i, 1])
×
946
                self.imshow(frame[r], title=f'ROI #{i + 1}')
×
947
                f.add_subplot(gs[i, 2])
×
948
                self.imshow(abs_lpc[r], title=f'ROI #{i + 1} - Lapacian filter')
×
949
            f.suptitle('Laplacian blur detection')
×
950
            # Plot variance over frames
951
            ax = f.add_subplot(gs[len(roi), :])
×
952
            ln = plt.plot(lpc_var)
×
953
            [l.set_label(f'ROI #{i + 1}') for i, l in enumerate(ln)]
×
954
            ax.axhline(threshold[0], 0, n, linestyle=':', color='r', label='lower threshold')
×
955
            ax.set(xlabel='Frame sample', ylabel='Variance of the Laplacian')
×
956
            plt.tight_layout()
×
957
            plt.legend()
×
958

959
        # Second test is to highpass with dft
960
        h, w = img.shape[1:]
1✔
961
        cX, cY = w // 2, h // 2
1✔
962
        sz = 60  # Seems to be the magic number for high pass
1✔
963
        mask = np.ones((h, w, 2), bool)
1✔
964
        mask[cY - sz:cY + sz, cX - sz:cX + sz] = False
1✔
965
        filt_mean = np.empty(len(img))
1✔
966
        for i, frame in enumerate(img[::-1]):
1✔
967
            dft = cv2.dft(np.float32(frame), flags=cv2.DFT_COMPLEX_OUTPUT)
1✔
968
            f_shift = np.fft.fftshift(dft) * mask  # Shift & remove low frequencies
1✔
969
            f_ishift = np.fft.ifftshift(f_shift)  # Shift back
1✔
970
            filt_frame = cv2.idft(f_ishift)  # Reconstruct
1✔
971
            filt_frame = cv2.magnitude(filt_frame[..., 0], filt_frame[..., 1])
1✔
972
            # Re-normalize to 8-bits to make threshold simpler
973
            img_back = cv2.normalize(filt_frame, None, alpha=0, beta=256,
1✔
974
                                     norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
975
            filt_mean[i] = np.mean(img_back)
1✔
976
            if i == len(img) - 1 and display:
1✔
977
                # Plot Fourier transforms
978
                f = plt.figure()
×
979
                gs = f.add_gridspec(2, 3)
×
980
                self.imshow(img[0], ax=f.add_subplot(gs[0, 0]), title='Original frame')
×
981
                dft_shift = np.fft.fftshift(dft)
×
982
                magnitude = 20 * np.log(cv2.magnitude(dft_shift[..., 0], dft_shift[..., 1]))
×
983
                self.imshow(magnitude, ax=f.add_subplot(gs[0, 1]), title='Magnitude spectrum')
×
984
                self.imshow(img_back, ax=f.add_subplot(gs[0, 2]), title='Filtered frame')
×
985
                ax = f.add_subplot(gs[1, :])
×
986
                ax.plot(filt_mean)
×
987
                ax.axhline(threshold[1], 0, n, linestyle=':', color='r', label='lower threshold')
×
988
                ax.set(xlabel='Frame sample', ylabel='Mean of filtered frame')
×
989
                f.suptitle('Discrete Fourier Transform')
×
990
                plt.show()
×
991
        passes = np.all(lpc_var > threshold[0]) or np.all(filt_mean > threshold[1])
1✔
992
        return 'PASS' if passes else 'FAIL'
1✔
993

994
    def find_face(self, roi=None, test=False, metric=cv2.TM_CCOEFF_NORMED, refs=None):
1✔
995
        """Use template matching to find face location in frame
996
        For the template matching zero-normalized cross-correlation (default) should be more
997
        robust to exposure (which we're not checking here).  The L2 norm (TM_SQDIFF) should
998
        also work.  That said, normalizing the histograms works best.
999

1000
        :param roi: A tuple of indices for the face template in the for ((y1, y2), (x1, x2))
1001
        :param test: If True the template is matched against frames that come from the same session
1002
        :param metric: The metric to use for template matching
1003
        :param refs: An array of frames to match the template to
1004

1005
        :returns: (y1, y2), (x1, x2)
1006
        """
1007
        ROI = {
1✔
1008
            'left': ((45, 346), (138, 501)),
1009
            'right': ((14, 174), (430, 618)),
1010
            'body': ((141, 272), (90, 339))
1011
        }
1012
        roi = roi or ROI[self.label]
1✔
1013
        refs = self.load_reference_frames(self.label) if refs is None else refs
1✔
1014

1015
        frames = refs if test else self.data['frame_samples']
1✔
1016
        template = refs[0][tuple(slice(*r) for r in roi)]
1✔
1017
        top_left = []  # [(x1, y1), ...]
1✔
1018
        for frame in frames:
1✔
1019
            res = cv2.matchTemplate(frame, template, metric)
1✔
1020
            min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(res)
1✔
1021
            # If the method is TM_SQDIFF or TM_SQDIFF_NORMED, take minimum
1022
            top_left.append(min_loc if metric < 2 else max_loc)
1✔
1023
            # bottom_right = (top_left[0] + w, top_left[1] + h)
1024
        return top_left, roi, template
1✔
1025

1026
    @staticmethod
1✔
1027
    def load_reference_frames(side):
1✔
1028
        """Load some reference frames for a given video
1029

1030
        The reference frames are from sessions where the camera was well positioned. The
1031
        frames are in qc/reference, one file per camera, only one channel per frame.  The
1032
        session eids can be found in qc/reference/frame_src.json
1033

1034
        :param side: Video label, e.g. 'left'
1035
        :return: numpy array of frames with the shape (n, h, w)
1036
        """
1037
        file = next(Path(__file__).parent.joinpath('reference').glob(f'frames_{side}.npy'))
1✔
1038
        refs = np.load(file)
1✔
1039
        return refs
1✔
1040

1041
    @staticmethod
1✔
1042
    def imshow(frame, ax=None, title=None, **kwargs):
1✔
1043
        """plt.imshow with some convenient defaults for greyscale frames"""
1044
        h = ax or plt.gca()
×
1045
        defaults = {
×
1046
            'cmap': kwargs.pop('cmap', 'gray'),
1047
            'vmin': kwargs.pop('vmin', 0),
1048
            'vmax': kwargs.pop('vmax', 255)
1049
        }
1050
        h.imshow(frame, **defaults, **kwargs)
×
1051
        h.set(title=title)
×
1052
        h.set_axis_off()
×
1053
        return ax
×
1054

1055

1056
class CameraQCCamlog(CameraQC):
1✔
1057
    """A class for computing camera QC metrics from camlog data. For this QC we expect the check_pin_state to be NOT_SET as we are
1✔
1058
    not using the GPIO for timestamp alignment"""
1059
    dstypes = [
1✔
1060
        '_iblrig_taskData.raw',
1061
        '_iblrig_taskSettings.raw',
1062
        '_iblrig_Camera.raw',
1063
        'camera.times',
1064
        'wheel.position',
1065
        'wheel.timestamps'
1066
    ]
1067
    dstypes_fpga = [
1✔
1068
        '_spikeglx_sync.channels',
1069
        '_spikeglx_sync.polarities',
1070
        '_spikeglx_sync.times',
1071
        'DAQData.raw.meta',
1072
        'DAQData.wiring'
1073
    ]
1074

1075
    def __init__(self, session_path_or_eid, camera, sync_collection='raw_sync_data', sync_type='nidq', **kwargs):
1✔
1076
        super().__init__(session_path_or_eid, camera, sync_collection=sync_collection, sync_type=sync_type, **kwargs)
×
1077
        self._type = 'ephys'
×
1078
        self.checks_to_remove = ['check_pin_state']
×
1079

1080
    def load_data(self, download_data: bool = None,
1✔
1081
                  extract_times: bool = False, load_video: bool = True, **kwargs) -> None:
1082
        """Extract the data from raw data files
1083
        Extracts all the required task data from the raw data files.
1084

1085
        Data keys:
1086
            - count (int array): the sequential frame number (n, n+1, n+2...)
1087
            - pin_state (): the camera GPIO pin; records the audio TTLs; should be one per frame
1088
            - audio (float array): timestamps of audio TTL fronts
1089
            - fpga_times (float array): timestamps of camera TTLs recorded by FPGA
1090
            - timestamps (float array): extracted video timestamps (the camera.times ALF)
1091
            - bonsai_times (datetime array): system timestamps of video PC; should be one per frame
1092
            - camera_times (float array): camera frame timestamps extracted from frame headers
1093
            - wheel (Bunch): rotary encoder timestamps, position and period used for wheel motion
1094
            - video (Bunch): video meta data, including dimensions and FPS
1095
            - frame_samples (h x w x n array): array of evenly sampled frames (1 colour channel)
1096

1097
        :param download_data: if True, any missing raw data is downloaded via ONE.
1098
        Missing data will raise an AssertionError
1099
        :param extract_times: if True, the camera.times are re-extracted from the raw data
1100
        :param load_video: if True, calls the load_video_data method
1101
        """
1102
        assert self.session_path, 'no session path set'
×
1103
        if download_data is not None:
×
1104
            self.download_data = download_data
×
1105
        if self.download_data and self.eid and self.one and not self.one.offline:
×
1106
            self.ensure_required_data()
×
1107
        _log.info('Gathering data for QC')
×
1108

1109
        # If there is an experiment description and there are video parameters
1110
        sess_params = read_params(self.session_path) or {}
×
1111
        video_collection = get_video_collection(sess_params, self.label)
×
1112
        task_collection = get_task_collection(sess_params)
×
1113
        self._set_sync(sess_params)
×
1114
        self._update_meta_from_session_params(sess_params)
×
1115

1116
        # Get frame count
1117
        log, _ = parse_cam_log(self.session_path.joinpath(video_collection, f'_iblrig_{self.label}Camera.raw.camlog'))
×
1118
        self.data['count'] = log.frame_id.values
×
1119

1120
        # Load the audio and raw FPGA times
1121
        if self.sync != 'bpod' and self.sync is not None:
×
1122
            sync, chmap = ephys_fpga.get_sync_and_chn_map(self.session_path, self.sync_collection)
×
1123
            audio_ttls = ephys_fpga.get_sync_fronts(sync, chmap['audio'])
×
1124
            self.data['audio'] = audio_ttls['times']  # Get rises
×
1125
            # Load raw FPGA times
1126
            cam_ts = extract_camera_sync(sync, chmap)
×
1127
            self.data['fpga_times'] = cam_ts[self.label]
×
1128
        else:
1129
            bpod_data = raw.load_data(self.session_path, task_collection=task_collection)
×
1130
            _, audio_ttls = raw.load_bpod_fronts(self.session_path, data=bpod_data, task_collection=task_collection)
×
1131
            self.data['audio'] = audio_ttls['times']
×
1132

1133
        # Load extracted frame times
1134
        alf_path = self.session_path / 'alf'
×
1135
        try:
×
1136
            assert not extract_times
×
1137
            self.data['timestamps'] = alfio.load_object(
×
1138
                alf_path, f'{self.label}Camera', short_keys=True)['times']
1139
        except AssertionError:  # Re-extract
×
1140
            kwargs = dict(video_path=self.video_path, labels=self.label)
×
1141
            if self.sync == 'bpod':
×
1142
                kwargs = {**kwargs, 'task_collection': task_collection}
×
1143
            else:
1144
                kwargs = {**kwargs, 'sync': sync, 'chmap': chmap}  # noqa
×
1145
            outputs, _ = extract_all(self.session_path, self.sync, save=False, camlog=True, **kwargs)
×
1146
            self.data['timestamps'] = outputs[f'{self.label}_camera_timestamps']
×
1147
        except ALFObjectNotFound:
×
1148
            _log.warning('no camera.times ALF found for session')
×
1149

1150
        # Get audio and wheel data
1151
        wheel_keys = ('timestamps', 'position')
×
1152
        try:
×
1153
            # glob in case wheel data are in sub-collections
1154
            alf_path = next(alf_path.rglob('*wheel.timestamps*')).parent
×
1155
            self.data['wheel'] = alfio.load_object(alf_path, 'wheel', short_keys=True)
×
1156
        except ALFObjectNotFound:
×
1157
            # Extract from raw data
1158
            if self.sync != 'bpod':
×
1159
                wheel_data = ephys_fpga.extract_wheel_sync(sync, chmap)
×
1160
            else:
1161
                wheel_data = training_wheel.get_wheel_position(self.session_path, task_collection=task_collection)
×
1162
            self.data['wheel'] = Bunch(zip(wheel_keys, wheel_data))
×
1163

1164
        # Find short period of wheel motion for motion correlation.
1165
        if data_for_keys(wheel_keys, self.data['wheel']) and self.data['timestamps'] is not None:
×
1166
            self.data['wheel'].period = self.get_active_wheel_period(self.data['wheel'])
×
1167

1168
        # load in camera times
1169
        self.data['camera_times'] = log.timestamp.values
×
1170

1171
        # Gather information from video file
1172
        if load_video:
×
1173
            _log.info('Inspecting video file...')
×
1174
            self.load_video_data()
×
1175

1176
    def ensure_required_data(self):
1✔
1177
        """
1178
        Ensures the datasets required for QC are local.  If the download_data attribute is True,
1179
        any missing data are downloaded.  If all the data are not present locally at the end of
1180
        it an exception is raised.  If the stream attribute is True, the video file is not
1181
        required to be local, however it must be remotely accessible.
1182
        NB: Requires a valid instance of ONE and a valid session eid.
1183
        :return:
1184
        """
1185
        assert self.one is not None, 'ONE required to download data'
×
1186

1187
        sess_params = {}
×
1188
        if self.download_data:
×
1189
            dset = self.one.list_datasets(self.session_path, '*experiment.description*', details=True)
×
1190
            if self.one._check_filesystem(dset):
×
1191
                sess_params = read_params(self.session_path) or {}
×
1192
        else:
1193
            sess_params = read_params(self.session_path) or {}
×
1194
        self._set_sync(sess_params)
×
1195

1196
        # dataset collections outside this list are ignored (e.g. probe00, raw_passive_data)
1197
        collections = (
×
1198
            'alf', self.sync_collection, get_task_collection(sess_params),
1199
            get_video_collection(sess_params, self.label))
1200

1201
        # Get extractor type
1202
        dtypes = self.dstypes + self.dstypes_fpga
×
1203
        assert_unique = True
×
1204

1205
        for dstype in dtypes:
×
1206
            datasets = self.one.type2datasets(self.eid, dstype, details=True)
×
1207
            if 'camera' in dstype.lower():  # Download individual camera file
×
1208
                datasets = filter_datasets(datasets, filename=f'.*{self.label}.*')
×
1209
            else:  # Ignore probe datasets, etc.
1210
                datasets = filter_datasets(datasets, collection=collections,
×
1211
                                           assert_unique=assert_unique)
1212
            if any(x.endswith('.mp4') for x in datasets.rel_path) and self.stream:
×
1213
                names = [x.split('/')[-1] for x in self.one.list_datasets(self.eid, details=False)]
×
1214
                assert f'_iblrig_{self.label}Camera.raw.mp4' in names, 'No remote video file found'
×
1215
                continue
×
1216
            optional = ('camera.times', '_iblrig_Camera.raw', 'wheel.position', 'wheel.timestamps')
×
1217
            present = (
×
1218
                self.one._check_filesystem(datasets)
1219
                if self.download_data
1220
                else (next(self.session_path.rglob(d), None) for d in datasets['rel_path'])
1221
            )
1222

1223
            required = (dstype not in optional)
×
1224
            all_present = not datasets.empty and all(present)
×
1225
            assert all_present or not required, f'Dataset {dstype} not found'
×
1226

1227
    def check_camera_times(self):
1✔
1228
        """Check that the number of raw camera timestamps matches the number of video frames"""
1229
        if not data_for_keys(('camera_times', 'video'), self.data):
×
1230
            return 'NOT_SET'
×
1231
        length_match = len(self.data['camera_times']) == self.data['video'].length
×
1232
        outcome = 'PASS' if length_match else 'WARNING'
×
1233
        # 1 / np.median(np.diff(self.data.camera_times))
1234
        return outcome, len(self.data['camera_times']) - self.data['video'].length
×
1235

1236

1237
def data_for_keys(keys, data):
1✔
1238
    """Check keys exist in 'data' dict and contain values other than None"""
1239
    return data is not None and all(k in data and data.get(k, None) is not None for k in keys)
1✔
1240

1241

1242
def get_task_collection(sess_params):
1✔
1243
    """
1244
    Returns the first task collection from the experiment description whose task name does not
1245
    contain 'passive', otherwise returns 'raw_behavior_data'.
1246

1247
    Parameters
1248
    ----------
1249
    sess_params : dict
1250
        The loaded experiment description file.
1251

1252
    Returns
1253
    -------
1254
    str:
1255
        The collection presumed to contain wheel data.
1256
    """
1257
    sess_params = sess_params or {}
1✔
1258
    tasks = (chain(*map(dict.items, sess_params.get('tasks', []))))
1✔
1259
    return next((v['collection'] for k, v in tasks if 'passive' not in k), 'raw_behavior_data')
1✔
1260

1261

1262
def get_video_collection(sess_params, label):
1✔
1263
    """
1264
    Returns the collection containing the raw video data for a given camera.
1265

1266
    Parameters
1267
    ----------
1268
    sess_params : dict
1269
        The loaded experiment description file.
1270
    label : str
1271
        The camera label.
1272

1273
    Returns
1274
    -------
1275
    str:
1276
        The collection presumed to contain the video data.
1277
    """
1278
    DEFAULT = 'raw_video_data'
×
1279
    value = sess_params or {}
×
1280
    for key in ('devices', 'cameras', label, 'collection'):
×
1281
        value = value.get(key)
×
1282
        if not value:
×
1283
            return DEFAULT
×
1284
    return value
×
1285

1286

1287
def run_all_qc(session, cameras=('left', 'right', 'body'), **kwargs):
1✔
1288
    """Run QC for all cameras
1289
    Run the camera QC for left, right and body cameras.
1290
    :param session: A session path or eid.
1291
    :param update: If True, QC fields are updated on Alyx.
1292
    :param cameras: A list of camera names to perform QC on.
1293
    :param stream: If true and local video files not available, the data are streamed from
1294
    the remote source.
1295
    :return: dict of CameraCQ objects
1296
    """
1297
    qc = {}
1✔
1298
    camlog = kwargs.pop('camlog', False)
1✔
1299
    CamQC = CameraQCCamlog if camlog else CameraQC
1✔
1300

1301
    run_args = {k: kwargs.pop(k) for k in ('download_data', 'extract_times', 'update')
1✔
1302
                if k in kwargs.keys()}
1303
    for camera in cameras:
1✔
1304
        qc[camera] = CamQC(session, camera, **kwargs)
1✔
1305
        qc[camera].run(**run_args)
1✔
1306
    return qc
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc