• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / ibllib / 6161719581526678

28 Jun 2024 01:14PM UTC coverage: 64.584% (+0.03%) from 64.55%
6161719581526678

push

tests

web-flow
Stim on extraction (#788)

* Issue #775
* Handle no go trials
* Pre-6.2.5 trials extraction
* DeprecationWarning -> FutureWarning; extractor fixes; timeline trials extraction

49 of 60 new or added lines in 9 files covered. (81.67%)

2 existing lines in 1 file now uncovered.

13055 of 20214 relevant lines covered (64.58%)

0.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.14
/ibllib/io/extractors/ephys_fpga.py
1
"""Data extraction from raw FPGA output.
1✔
2

3
The behaviour extraction happens in the following stages:
4

5
    1. The NI DAQ events are extracted into a map of event times and TTL polarities.
6
    2. The Bpod trial events are extracted from the raw Bpod data, depending on the task protocol.
7
    3. As protocols may be chained together within a given recording, the period of a given task
8
       protocol is determined using the 'spacer' DAQ signal (see `get_protocol_period`).
9
    4. Physical behaviour events such as stim on and reward time are separated out by TTL length or
10
       sequence within the trial.
11
    5. The Bpod clock is sync'd with the FPGA using one of the extracted trial events.
12
    6. The Bpod software events are then converted to FPGA time.
13

14
Examples
15
--------
16
For simple extraction, use the FPGATrials class:
17

18
>>> extractor = FpgaTrials(session_path)
19
>>> trials, _ = extractor.extract(update=False, save=False)
20

21
Notes
22
-----
23
Sync extraction in this module only supports FPGA data acquired with an NI DAQ as part of a
24
Neuropixels recording system, however a sync and channel map extracted from a different DAQ format
25
can be passed to the FpgaTrials class.
26

27
See Also
28
--------
29
For dynamic pipeline sessions it is best to call the extractor via the BehaviorTask class.
30

31
TODO notes on subclassing various methods of FpgaTrials for custom hardware.
32
"""
33
import logging
1✔
34
from itertools import cycle
1✔
35
from pathlib import Path
1✔
36
import uuid
1✔
37
import re
1✔
38
import warnings
1✔
39
from functools import partial
1✔
40

41
import matplotlib.pyplot as plt
1✔
42
from matplotlib.colors import TABLEAU_COLORS
1✔
43
import numpy as np
1✔
44
from packaging import version
1✔
45

46
import spikeglx
1✔
47
import ibldsp.utils
1✔
48
import one.alf.io as alfio
1✔
49
from one.alf.files import filename_parts
1✔
50
from iblutil.util import Bunch
1✔
51
from iblutil.spacer import Spacer
1✔
52

53
import ibllib.exceptions as err
1✔
54
from ibllib.io import raw_data_loaders as raw, session_params
1✔
55
from ibllib.io.extractors.bpod_trials import extract_all as bpod_extract_all
1✔
56
import ibllib.io.extractors.base as extractors_base
1✔
57
from ibllib.io.extractors.training_wheel import extract_wheel_moves
1✔
58
from ibllib import plots
1✔
59
from ibllib.io.extractors.default_channel_maps import DEFAULT_MAPS
1✔
60

61
_logger = logging.getLogger(__name__)
1✔
62

63
SYNC_BATCH_SIZE_SECS = 100
1✔
64
"""int: Number of samples to read at once in bin file for sync."""
1✔
65

66
WHEEL_RADIUS_CM = 1  # stay in radians
1✔
67
"""float: The radius of the wheel used in the task. A value of 1 ensures units remain in radians."""
1✔
68

69
WHEEL_TICKS = 1024
1✔
70
"""int: The number of encoder pulses per channel for one complete rotation."""
1✔
71

72
BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150
1✔
73
"""int: Throws an error if Bpod to FPGA clock drift is higher than this value."""
1✔
74

75
CHMAPS = {'3A':
1✔
76
          {'ap':
77
           {'left_camera': 2,
78
            'right_camera': 3,
79
            'body_camera': 4,
80
            'bpod': 7,
81
            'frame2ttl': 12,
82
            'rotary_encoder_0': 13,
83
            'rotary_encoder_1': 14,
84
            'audio': 15
85
            }
86
           },
87
          '3B':
88
          {'nidq':
89
           {'left_camera': 0,
90
            'right_camera': 1,
91
            'body_camera': 2,
92
            'imec_sync': 3,
93
            'frame2ttl': 4,
94
            'rotary_encoder_0': 5,
95
            'rotary_encoder_1': 6,
96
            'audio': 7,
97
            'bpod': 16,
98
            'laser': 17,
99
            'laser_ttl': 18},
100
           'ap':
101
           {'imec_sync': 6}
102
           },
103
          }
104
"""dict: The default channel indices corresponding to various devices for different recording systems."""
1✔
105

106

107
def data_for_keys(keys, data):
1✔
108
    """Check keys exist in 'data' dict and contain values other than None."""
109
    return data is not None and all(k in data and data.get(k, None) is not None for k in keys)
1✔
110

111

112
def get_ibl_sync_map(ef, version):
1✔
113
    """
114
    Gets default channel map for the version/binary file type combination
115
    :param ef: ibllib.io.spikeglx.glob_ephys_file dictionary with field 'ap' or 'nidq'
116
    :return: channel map dictionary
117
    """
118
    # Determine default channel map
119
    if version == '3A':
1✔
120
        default_chmap = CHMAPS['3A']['ap']
1✔
121
    elif version == '3B':
1✔
122
        if ef.get('nidq', None):
1✔
123
            default_chmap = CHMAPS['3B']['nidq']
1✔
124
        elif ef.get('ap', None):
1✔
125
            default_chmap = CHMAPS['3B']['ap']
1✔
126
    # Try to load channel map from file
127
    chmap = spikeglx.get_sync_map(ef['path'])
1✔
128
    # If chmap provided but not with all keys, fill up with default values
129
    if not chmap:
1✔
130
        return default_chmap
1✔
131
    else:
132
        if data_for_keys(default_chmap.keys(), chmap):
1✔
133
            return chmap
1✔
134
        else:
135
            _logger.warning("Keys missing from provided channel map, "
1✔
136
                            "setting missing keys from default channel map")
137
            return {**default_chmap, **chmap}
1✔
138

139

140
def _sync_to_alf(raw_ephys_apfile, output_path=None, save=False, parts=''):
1✔
141
    """
142
    Extracts sync.times, sync.channels and sync.polarities from binary ephys dataset
143

144
    :param raw_ephys_apfile: bin file containing ephys data or spike
145
    :param output_path: output directory
146
    :param save: bool write to disk only if True
147
    :param parts: string or list of strings that will be appended to the filename before extension
148
    :return:
149
    """
150
    # handles input argument: support ibllib.io.spikeglx.Reader, str and pathlib.Path
151
    if isinstance(raw_ephys_apfile, spikeglx.Reader):
1✔
152
        sr = raw_ephys_apfile
1✔
153
    else:
154
        raw_ephys_apfile = Path(raw_ephys_apfile)
×
155
        sr = spikeglx.Reader(raw_ephys_apfile)
×
156
    if not (opened := sr.is_open):
1✔
157
        sr.open()
×
158
    # if no output, need a temp folder to swap for big files
159
    if not output_path:
1✔
160
        output_path = raw_ephys_apfile.parent
×
161
    file_ftcp = Path(output_path).joinpath(f'fronts_times_channel_polarity{uuid.uuid4()}.bin')
1✔
162

163
    # loop over chunks of the raw ephys file
164
    wg = ibldsp.utils.WindowGenerator(sr.ns, int(SYNC_BATCH_SIZE_SECS * sr.fs), overlap=1)
1✔
165
    fid_ftcp = open(file_ftcp, 'wb')
1✔
166
    for sl in wg.slice:
1✔
167
        ss = sr.read_sync(sl)
1✔
168
        ind, fronts = ibldsp.utils.fronts(ss, axis=0)
1✔
169
        # a = sr.read_sync_analog(sl)
170
        sav = np.c_[(ind[0, :] + sl.start) / sr.fs, ind[1, :], fronts.astype(np.double)]
1✔
171
        sav.tofile(fid_ftcp)
1✔
172
    # close temp file, read from it and delete
173
    fid_ftcp.close()
1✔
174
    tim_chan_pol = np.fromfile(str(file_ftcp))
1✔
175
    tim_chan_pol = tim_chan_pol.reshape((int(tim_chan_pol.size / 3), 3))
1✔
176
    file_ftcp.unlink()
1✔
177
    sync = {'times': tim_chan_pol[:, 0],
1✔
178
            'channels': tim_chan_pol[:, 1],
179
            'polarities': tim_chan_pol[:, 2]}
180
    # If opened Reader was passed into function, leave open
181
    if not opened:
1✔
182
        sr.close()
×
183
    if save:
1✔
184
        out_files = alfio.save_object_npy(output_path, sync, 'sync',
1✔
185
                                          namespace='spikeglx', parts=parts)
186
        return Bunch(sync), out_files
1✔
187
    else:
188
        return Bunch(sync)
×
189

190

191
def _rotary_encoder_positions_from_fronts(ta, pa, tb, pb, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4'):
1✔
192
    """
193
    Extracts the rotary encoder absolute position as function of time from fronts detected
194
    on the 2 channels. Outputs in units of radius parameters, by default radians
195
    Coding options detailed here: http://www.ni.com/tutorial/7109/pt/
196
    Here output is clockwise from subject perspective
197

198
    :param ta: time of fronts on channel A
199
    :param pa: polarity of fronts on channel A
200
    :param tb: time of fronts on channel B
201
    :param pb: polarity of fronts on channel B
202
    :param ticks: number of ticks corresponding to a full revolution (1024 for IBL rotary encoder)
203
    :param radius: radius of the wheel. Defaults to 1 for an output in radians
204
    :param coding: x1, x2 or x4 coding (IBL default is x4)
205
    :return: indices vector (ta) and position vector
206
    """
207
    if coding == 'x1':
1✔
208
        ia = np.searchsorted(tb, ta[pa == 1])
1✔
209
        ia = ia[ia < ta.size]
1✔
210
        ia = ia[pa[ia] == 1]
1✔
211
        ib = np.searchsorted(ta, tb[pb == 1])
1✔
212
        ib = ib[ib < tb.size]
1✔
213
        ib = ib[pb[ib] == 1]
1✔
214
        t = np.r_[ta[ia], tb[ib]]
1✔
215
        p = np.r_[ia * 0 + 1, ib * 0 - 1]
1✔
216
        ordre = np.argsort(t)
1✔
217
        t = t[ordre]
1✔
218
        p = p[ordre]
1✔
219
        p = np.cumsum(p) / ticks * np.pi * 2 * radius
1✔
220
        return t, p
1✔
221
    elif coding == 'x2':
1✔
222
        p = pb[np.searchsorted(tb, ta) - 1] * pa
1✔
223
        p = - np.cumsum(p) / ticks * np.pi * 2 * radius / 2
1✔
224
        return ta, p
1✔
225
    elif coding == 'x4':
1✔
226
        p = np.r_[pb[np.searchsorted(tb, ta) - 1] * pa, -pa[np.searchsorted(ta, tb) - 1] * pb]
1✔
227
        t = np.r_[ta, tb]
1✔
228
        ordre = np.argsort(t)
1✔
229
        t = t[ordre]
1✔
230
        p = p[ordre]
1✔
231
        p = - np.cumsum(p) / ticks * np.pi * 2 * radius / 4
1✔
232
        return t, p
1✔
233

234

235
def _assign_events_to_trial(t_trial_start, t_event, take='last', t_trial_end=None):
1✔
236
    """
237
    Assign events to a trial given trial start times and event times.
238

239
    Trials without an event result in nan value in output time vector.
240
    The output has a consistent size with t_trial_start and ready to output to alf.
241

242
    Parameters
243
    ----------
244
    t_trial_start : numpy.array
245
        An array of start times, used to bin edges for assigning values from `t_event`.
246
    t_event : numpy.array
247
        An array of event times to assign to trials.
248
    take : str {'first', 'last'}, int
249
        'first' takes first event > t_trial_start; 'last' takes last event < the next
250
        t_trial_start; an int defines the index to take for events within trial bounds. The index
251
        may be negative.
252
    t_trial_end : numpy.array
253
        Optional array of end times, used to bin edges for assigning values from `t_event`.
254

255
    Returns
256
    -------
257
    numpy.array
258
        An array the length of `t_trial_start` containing values from `t_event`. Unassigned values
259
        are replaced with np.nan.
260

261
    See Also
262
    --------
263
    FpgaTrials._assign_events - Assign trial events based on TTL length.
264
    """
265
    # make sure the events are sorted
266
    try:
1✔
267
        assert np.all(np.diff(t_trial_start) >= 0)
1✔
268
    except AssertionError:
1✔
269
        raise ValueError('Trial starts vector not sorted')
1✔
270
    try:
1✔
271
        assert np.all(np.diff(t_event) >= 0)
1✔
272
    except AssertionError:
1✔
273
        raise ValueError('Events vector is not sorted')
1✔
274

275
    # remove events that happened before the first trial start
276
    remove = t_event < t_trial_start[0]
1✔
277
    if t_trial_end is not None:
1✔
278
        if not np.all(np.diff(t_trial_end) >= 0):
1✔
NEW
279
            raise ValueError('Trial end vector not sorted')
×
280
        if not np.all(t_trial_end[:-1] < t_trial_start[1:]):
1✔
NEW
281
            raise ValueError('Trial end times must not overlap with trial start times')
×
282
        # remove events between end and next start, and after last end
283
        remove |= t_event > t_trial_end[-1]
1✔
284
        for e, s in zip(t_trial_end[:-1], t_trial_start[1:]):
1✔
285
            remove |= np.logical_and(s > t_event, t_event >= e)
1✔
286
    t_event = t_event[~remove]
1✔
287
    ind = np.searchsorted(t_trial_start, t_event) - 1
1✔
288
    t_event_nans = np.zeros_like(t_trial_start) * np.nan
1✔
289
    # select first or last element matching each trial start
290
    if take == 'last':
1✔
291
        iall, iu = np.unique(np.flip(ind), return_index=True)
1✔
292
        t_event_nans[iall] = t_event[- (iu - ind.size + 1)]
1✔
293
    elif take == 'first':
1✔
294
        iall, iu = np.unique(ind, return_index=True)
1✔
295
        t_event_nans[iall] = t_event[iu]
1✔
296
    else:  # if the index is arbitrary, needs to be numeric (could be negative if from the end)
297
        iall = np.unique(ind)
1✔
298
        minsize = take + 1 if take >= 0 else - take
1✔
299
        # for each trial, take the take nth element if there are enough values in trial
300
        for iu in iall:
1✔
301
            match = t_event[iu == ind]
1✔
302
            if len(match) >= minsize:
1✔
303
                t_event_nans[iu] = match[take]
1✔
304
    return t_event_nans
1✔
305

306

307
def get_sync_fronts(sync, channel_nb, tmin=None, tmax=None):
1✔
308
    """
309
    Return the sync front polarities and times for a given channel.
310

311
    Parameters
312
    ----------
313
    sync : dict
314
        'polarities' of fronts detected on sync trace for all 16 channels and their 'times'.
315
    channel_nb : int
316
        The integer corresponding to the desired sync channel.
317
    tmin : float
318
        The minimum time from which to extract the sync pulses.
319
    tmax : float
320
        The maximum time up to which we extract the sync pulses.
321

322
    Returns
323
    -------
324
    Bunch
325
        Channel times and polarities.
326
    """
327
    selection = sync['channels'] == channel_nb
1✔
328
    selection = np.logical_and(selection, sync['times'] <= tmax) if tmax else selection
1✔
329
    selection = np.logical_and(selection, sync['times'] >= tmin) if tmin else selection
1✔
330
    return Bunch({'times': sync['times'][selection],
1✔
331
                  'polarities': sync['polarities'][selection]})
332

333

334
def _clean_audio(audio, display=False):
1✔
335
    """
336
    one guy wired the 150 Hz camera output onto the soundcard. The effect is to get 150 Hz periodic
337
    square pulses, 2ms up and 4.666 ms down. When this happens we remove all of the intermediate
338
    pulses to repair the audio trace
339
    Here is some helper code
340
        dd = np.diff(audio['times'])
341
        1 / np.median(dd[::2]) # 2ms up
342
        1 / np.median(dd[1::2])  # 4.666 ms down
343
        1 / (np.median(dd[::2]) + np.median(dd[1::2])) # both sum to 150 Hz
344
    This only runs on sessions when the bug is detected and leaves others untouched
345
    """
346
    DISCARD_THRESHOLD = 0.01
1✔
347
    average_150_hz = np.mean(1 / np.diff(audio['times'][audio['polarities'] == 1]) > 140)
1✔
348
    naudio = audio['times'].size
1✔
349
    if average_150_hz > 0.7 and naudio > 100:
1✔
350
        _logger.warning('Soundcard signal on FPGA seems to have been mixed with 150Hz camera')
1✔
351
        keep_ind = np.r_[np.diff(audio['times']) > DISCARD_THRESHOLD, False]
1✔
352
        keep_ind = np.logical_and(keep_ind, audio['polarities'] == -1)
1✔
353
        keep_ind = np.where(keep_ind)[0]
1✔
354
        keep_ind = np.sort(np.r_[0, keep_ind, keep_ind + 1, naudio - 1])
1✔
355

356
        if display:  # pragma: no cover
357
            from ibllib.plots import squares
358
            squares(audio['times'], audio['polarities'], ax=None, yrange=[-1, 1])
359
            squares(audio['times'][keep_ind], audio['polarities'][keep_ind], yrange=[-1, 1])
360
        audio = {'times': audio['times'][keep_ind],
1✔
361
                 'polarities': audio['polarities'][keep_ind]}
362
    return audio
1✔
363

364

365
def _clean_frame2ttl(frame2ttl, threshold=0.01, display=False):
1✔
366
    """
367
    Clean the frame2ttl events.
368

369
    Frame 2ttl calibration can be unstable and the fronts may be flickering at an unrealistic
370
    pace. This removes the consecutive frame2ttl pulses happening too fast, below a threshold
371
    of F2TTL_THRESH.
372

373
    Parameters
374
    ----------
375
    frame2ttl : dict
376
        A dictionary of frame2TTL events, with keys {'times', 'polarities'}.
377
    threshold : float
378
        Consecutive pulses occurring with this many seconds ignored.
379
    display : bool
380
        If true, plots the input TTLs and the cleaned output.
381

382
    Returns
383
    -------
384

385
    """
386
    dt = np.diff(frame2ttl['times'])
1✔
387
    iko = np.where(np.logical_and(dt < threshold, frame2ttl['polarities'][:-1] == -1))[0]
1✔
388
    iko = np.unique(np.r_[iko, iko + 1])
1✔
389
    frame2ttl_ = {'times': np.delete(frame2ttl['times'], iko),
1✔
390
                  'polarities': np.delete(frame2ttl['polarities'], iko)}
391
    if iko.size > (0.1 * frame2ttl['times'].size):
1✔
392
        _logger.warning(f'{iko.size} ({iko.size / frame2ttl["times"].size:.2%}) '
1✔
393
                        f'frame to TTL polarity switches below {threshold} secs')
394
    if display:  # pragma: no cover
395
        fig, (ax0, ax1) = plt.subplots(2, sharex=True)
396
        plots.squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9], ax=ax0)
397
        plots.squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9], ax=ax1)
398
        import seaborn as sns
399
        sns.displot(dt[dt < 0.05], binwidth=0.0005)
400

401
    return frame2ttl_
1✔
402

403

404
def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None):
1✔
405
    """
406
    Extract wheel positions and times from sync fronts dictionary for all 16 channels.
407
    Output position is in radians, mathematical convention.
408

409
    Parameters
410
    ----------
411
    sync : dict
412
        'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
413
    chmap : dict
414
        Map of channel names and their corresponding index.  Default to constant.
415
    tmin : float
416
        The minimum time from which to extract the sync pulses.
417
    tmax : float
418
        The maximum time up to which we extract the sync pulses.
419

420
    Returns
421
    -------
422
    numpy.array
423
        Wheel timestamps in seconds.
424
    numpy.array
425
        Wheel positions in radians.
426
    """
427
    # Assume two separate edge count channels
428
    assert chmap.keys() >= {'rotary_encoder_0', 'rotary_encoder_1'}
1✔
429
    channela = get_sync_fronts(sync, chmap['rotary_encoder_0'], tmin=tmin, tmax=tmax)
1✔
430
    channelb = get_sync_fronts(sync, chmap['rotary_encoder_1'], tmin=tmin, tmax=tmax)
1✔
431
    re_ts, re_pos = _rotary_encoder_positions_from_fronts(
1✔
432
        channela['times'], channela['polarities'], channelb['times'], channelb['polarities'],
433
        ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4')
434
    return re_ts, re_pos
1✔
435

436

437
def extract_sync(session_path, overwrite=False, ephys_files=None, namespace='spikeglx'):
1✔
438
    """
439
    Reads ephys binary file (s) and extract sync within the binary file folder
440
    Assumes ephys data is within a `raw_ephys_data` folder
441

442
    :param session_path: '/path/to/subject/yyyy-mm-dd/001'
443
    :param overwrite: Bool on re-extraction, forces overwrite instead of loading existing files
444
    :return: list of sync dictionaries
445
    """
446
    session_path = Path(session_path)
1✔
447
    if not ephys_files:
1✔
448
        ephys_files = spikeglx.glob_ephys_files(session_path)
1✔
449
    syncs = []
1✔
450
    outputs = []
1✔
451
    for efi in ephys_files:
1✔
452
        bin_file = efi.get('ap', efi.get('nidq', None))
1✔
453
        if not bin_file:
1✔
454
            continue
×
455
        alfname = dict(object='sync', namespace=namespace)
1✔
456
        if efi.label:
1✔
457
            alfname['extra'] = efi.label
1✔
458
        file_exists = alfio.exists(bin_file.parent, **alfname)
1✔
459
        if not overwrite and file_exists:
1✔
460
            _logger.warning(f'Skipping raw sync: SGLX sync found for {efi.label}!')
1✔
461
            sync = alfio.load_object(bin_file.parent, **alfname)
1✔
462
            out_files, _ = alfio._ls(bin_file.parent, **alfname)
1✔
463
        else:
464
            sr = spikeglx.Reader(bin_file)
1✔
465
            sync, out_files = _sync_to_alf(sr, bin_file.parent, save=True, parts=efi.label)
1✔
466
            sr.close()
1✔
467
        outputs.extend(out_files)
1✔
468
        syncs.extend([sync])
1✔
469

470
    return syncs, outputs
1✔
471

472

473
def _get_all_probes_sync(session_path, bin_exists=True):
1✔
474
    # round-up of all bin ephys files in the session, infer revision and get sync map
475
    ephys_files = spikeglx.glob_ephys_files(session_path, bin_exists=bin_exists)
1✔
476
    version = spikeglx.get_neuropixel_version_from_files(ephys_files)
1✔
477
    # attach the sync information to each binary file found
478
    for ef in ephys_files:
1✔
479
        ef['sync'] = alfio.load_object(ef.path, 'sync', namespace='spikeglx', short_keys=True)
1✔
480
        ef['sync_map'] = get_ibl_sync_map(ef, version)
1✔
481
    return ephys_files
1✔
482

483

484
def get_wheel_positions(sync, chmap, tmin=None, tmax=None):
1✔
485
    """
486
    Gets the wheel position from synchronisation pulses
487

488
    Parameters
489
    ----------
490
    sync : dict
491
        A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
492
        the corresponding channel numbers.
493
    chmap : dict[str, int]
494
        A map of channel names and their corresponding indices.
495
    tmin : float
496
        The minimum time from which to extract the sync pulses.
497
    tmax : float
498
        The maximum time up to which we extract the sync pulses.
499

500
    Returns
501
    -------
502
    Bunch
503
        A dictionary with keys ('timestamps', 'position'), containing the wheel event timestamps and
504
        position in radians
505
    Bunch
506
        A dictionary of detected movement times with keys ('intervals', 'peakAmplitude', 'peakVelocity_times').
507
    """
508
    ts, pos = extract_wheel_sync(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax)
1✔
509
    moves = Bunch(extract_wheel_moves(ts, pos))
1✔
510
    wheel = Bunch({'timestamps': ts, 'position': pos})
1✔
511
    return wheel, moves
1✔
512

513

514
def get_main_probe_sync(session_path, bin_exists=False):
1✔
515
    """
516
    From 3A or 3B multiprobe session, returns the main probe (3A) or nidq sync pulses
517
    with the attached channel map (default chmap if none)
518

519
    Parameters
520
    ----------
521
    session_path : str, pathlib.Path
522
        The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
523
    bin_exists : bool
524
        Whether there is a .bin file present.
525

526
    Returns
527
    -------
528
    one.alf.io.AlfBunch
529
        A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
530
        the corresponding channel numbers.
531
    dict
532
        A map of channel names and their corresponding indices.
533
    """
534
    ephys_files = _get_all_probes_sync(session_path, bin_exists=bin_exists)
1✔
535
    if not ephys_files:
1✔
536
        raise FileNotFoundError(f"No ephys files found in {session_path}")
1✔
537
    version = spikeglx.get_neuropixel_version_from_files(ephys_files)
1✔
538
    if version == '3A':
1✔
539
        # the sync master is the probe with the most sync pulses
540
        sync_box_ind = np.argmax([ef.sync.times.size for ef in ephys_files])
1✔
541
    elif version == '3B':
1✔
542
        # the sync master is the nidq breakout box
543
        sync_box_ind = np.argmax([1 if ef.get('nidq') else 0 for ef in ephys_files])
1✔
544
    sync = ephys_files[sync_box_ind].sync
1✔
545
    sync_chmap = ephys_files[sync_box_ind].sync_map
1✔
546
    return sync, sync_chmap
1✔
547

548

549
def get_protocol_period(session_path, protocol_number, bpod_sync):
1✔
550
    """
551

552
    Parameters
553
    ----------
554
    session_path : str, pathlib.Path
555
        The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
556
    protocol_number : int
557
        The order that the protocol was run in.
558
    bpod_sync : dict
559
        The sync times and polarities for Bpod BNC1.
560

561
    Returns
562
    -------
563
    float
564
        The time of the detected spacer for the protocol number.
565
    float, None
566
        The time of the next detected spacer or None if this is the last protocol run.
567
    """
568
    # The spacers are TTLs generated by Bpod at the start of each protocol
569
    spacer_times = Spacer().find_spacers_from_fronts(bpod_sync)
1✔
570
    # Ensure that the number of detected spacers matched the number of expected tasks
571
    if acquisition_description := session_params.read_params(session_path):
1✔
572
        n_tasks = len(acquisition_description.get('tasks', []))
1✔
573
        assert n_tasks == len(spacer_times), f'expected {n_tasks} spacers, found {len(spacer_times)}'
1✔
574
        assert n_tasks > protocol_number >= 0, f'protocol number must be between 0 and {n_tasks}'
1✔
575
    else:
576
        assert protocol_number < len(spacer_times)
×
577
    start = spacer_times[int(protocol_number)]
1✔
578
    end = None if len(spacer_times) - 1 == protocol_number else spacer_times[int(protocol_number + 1)]
1✔
579
    return start, end
1✔
580

581

582
class FpgaTrials(extractors_base.BaseExtractor):
1✔
583
    save_names = ('_ibl_trials.goCueTrigger_times.npy', None, None, None, None, None, None, None,
1✔
584
                  '_ibl_trials.stimOff_times.npy', None, None, None, '_ibl_trials.quiescencePeriod.npy',
585
                  '_ibl_trials.table.pqt', '_ibl_wheel.timestamps.npy',
586
                  '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy',
587
                  '_ibl_wheelMoves.peakAmplitude.npy')
588
    var_names = ('goCueTrigger_times', 'stimOnTrigger_times',
1✔
589
                 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times',
590
                 'errorCue_times', 'itiIn_times', 'stimFreeze_times', 'stimOff_times',
591
                 'valveOpen_times', 'phase', 'position', 'quiescence', 'table',
592
                 'wheel_timestamps', 'wheel_position',
593
                 'wheelMoves_intervals', 'wheelMoves_peakAmplitude')
594

595
    bpod_rsync_fields = ('intervals', 'response_times', 'goCueTrigger_times',
1✔
596
                         'stimOnTrigger_times', 'stimOffTrigger_times',
597
                         'stimFreezeTrigger_times', 'errorCueTrigger_times')
598
    """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA."""
1✔
599

600
    bpod_fields = ('feedbackType', 'choice', 'rewardVolume', 'contrastLeft', 'contrastRight',
1✔
601
                   'probabilityLeft', 'phase', 'position', 'quiescence')
602
    """tuple of str: Fields from bpod extractor that we want to save."""
1✔
603

604
    sync_field = 'intervals_0'  # trial start events
1✔
605
    """str: The trial event to synchronize (must be present in extracted trials)."""
1✔
606

607
    bpod = None
1✔
608
    """dict of numpy.array: The Bpod out TTLs recorded on the DAQ. Used in the QC viewer plot."""
1✔
609

610
    def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs):
1✔
611
        """An extractor for ephysChoiceWorld trials data, in FPGA time.
612

613
        This class may be subclassed to handle moderate variations in hardware and task protocol,
614
        however there is flexible
615
        """
616
        super().__init__(*args, **kwargs)
1✔
617
        self.bpod2fpga = None
1✔
618
        self.bpod_trials = bpod_trials
1✔
619
        self.frame2ttl = self.audio = self.bpod = self.settings = None
1✔
620
        if bpod_extractor:
1✔
621
            self.bpod_extractor = bpod_extractor
1✔
622
            self._update_var_names()
1✔
623

624
    def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None):
1✔
625
        """
626
        Updates this object's attributes based on the Bpod trials extractor.
627

628
        Fields updated: bpod_fields, bpod_rsync_fields, save_names, and var_names.
629

630
        Parameters
631
        ----------
632
        bpod_fields : tuple
633
            A set of Bpod trials fields to keep.
634
        bpod_rsync_fields : tuple
635
            A set of Bpod trials fields to sync to the DAQ times.
636
        """
637
        if self.bpod_extractor:
1✔
638
            for var_name, save_name in zip(self.bpod_extractor.var_names, self.bpod_extractor.save_names):
1✔
639
                if var_name not in self.var_names:
1✔
640
                    self.var_names += (var_name,)
1✔
641
                    self.save_names += (save_name,)
1✔
642

643
            # self.var_names = self.bpod_extractor.var_names
644
            # self.save_names = self.bpod_extractor.save_names
645
            self.settings = self.bpod_extractor.settings  # This is used by the TaskQC
1✔
646
            self.bpod_rsync_fields = bpod_rsync_fields
1✔
647
            if self.bpod_rsync_fields is None:
1✔
648
                self.bpod_rsync_fields = tuple(self._time_fields(self.bpod_extractor.var_names))
1✔
649
                if 'table' in self.bpod_extractor.var_names:
1✔
650
                    if not self.bpod_trials:
1✔
651
                        self.bpod_trials = self.bpod_extractor.extract(save=False)
×
652
                    table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys()
1✔
653
                    self.bpod_rsync_fields += tuple(self._time_fields(table_keys))
1✔
654
        elif bpod_rsync_fields:
×
655
            self.bpod_rsync_fields = bpod_rsync_fields
×
656
        excluded = (*self.bpod_rsync_fields, 'table')
1✔
657
        if bpod_fields:
1✔
658
            assert not set(self.bpod_fields).intersection(excluded), 'bpod_fields must not also be bpod_rsync_fields'
×
659
            self.bpod_fields = bpod_fields
×
660
        elif self.bpod_extractor:
1✔
661
            self.bpod_fields = tuple(x for x in self.bpod_extractor.var_names if x not in excluded)
1✔
662
            if 'table' in self.bpod_extractor.var_names:
1✔
663
                if not self.bpod_trials:
1✔
664
                    self.bpod_trials = self.bpod_extractor.extract(save=False)
×
665
                table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys()
1✔
666
                self.bpod_fields += tuple([x for x in table_keys if x not in excluded])
1✔
667

668
    @staticmethod
1✔
669
    def _time_fields(trials_attr) -> set:
1✔
670
        """
671
        Iterates over Bpod trials attributes returning those that correspond to times for syncing.
672

673
        Parameters
674
        ----------
675
        trials_attr : iterable of str
676
            The Bpod field names.
677

678
        Returns
679
        -------
680
        set
681
            The field names that contain timestamps.
682
        """
683
        FIELDS = ('times', 'timestamps', 'intervals')
1✔
684
        pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$')
1✔
685
        return set(filter(pattern.match, trials_attr))
1✔
686

687
    def load_sync(self, sync_collection='raw_ephys_data', **kwargs):
1✔
688
        """Load the DAQ sync and channel map data.
689

690
        This method may be subclassed for novel DAQ systems. The sync must contain the following
691
        keys: 'times' - an array timestamps in seconds; 'polarities' - an array of {-1, 1}
692
        corresponding to TTL LOW and TTL HIGH, respectively; 'channels' - an array of ints
693
        corresponding to channel number.
694

695
        Parameters
696
        ----------
697
        sync_collection : str
698
            The session subdirectory where the sync data are located.
699
        kwargs
700
            Optional arguments used by subclass methods.
701

702
        Returns
703
        -------
704
        one.alf.io.AlfBunch
705
            A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
706
            and the corresponding channel numbers.
707
        dict
708
            A map of channel names and their corresponding indices.
709
        """
710
        return get_sync_and_chn_map(self.session_path, sync_collection)
×
711

712
    def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data',
1✔
713
                 task_collection='raw_behavior_data', **kwargs) -> dict:
714
        """Extracts ephys trials by combining Bpod and FPGA sync pulses.
715

716
        It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field`
717
        attributes are all correct for the bpod protocol used.
718

719
        Below are the steps involved:
720
          0. Load sync and bpod trials, if required.
721
          1. Determine protocol period and discard sync events outside the task.
722
          2. Classify multiplexed TTL events based on length (see :meth:`FpgaTrials.build_trials`).
723
          3. Sync the Bpod clock to the DAQ clock using one of the assigned trial events.
724
          4. Assign classified TTL events to trial events based on order within the trial.
725
          4. Convert Bpod software event times to DAQ clock.
726
          5. Extract the wheel from the DAQ rotary encoder signal, if required.
727

728
        Parameters
729
        ----------
730
        sync : dict
731
            A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
732
            and the corresponding channel numbers. If None, the sync is loaded using the
733
            `load_sync` method.
734
        chmap : dict
735
            A map of channel names and their corresponding indices. If None, the channel map is
736
            loaded using the :meth:`FpgaTrials.load_sync` method.
737
        sync_collection : str
738
            The session subdirectory where the sync data are located. This is only used if the
739
            sync or channel maps are not provided.
740
        task_collection : str
741
            The session subdirectory where the raw Bpod data are located. This is used for loading
742
            the task settings and extracting the bpod trials, if not already done.
743
        protocol_number : int
744
            The protocol number if multiple protocols were run during the session. If provided, a
745
            spacer signal must be present in order to determine the correct period.
746
        kwargs
747
            Optional arguments for subclass methods to use.
748

749
        Returns
750
        -------
751
        dict
752
            A dictionary of numpy arrays with `FpgaTrials.var_names` as keys.
753
        """
754
        if sync is None or chmap is None:
1✔
755
            _sync, _chmap = self.load_sync(sync_collection)
1✔
756
            sync = sync or _sync
1✔
757
            chmap = chmap or _chmap
1✔
758

759
        if not self.bpod_trials:  # extract the behaviour data from bpod
1✔
760
            self.bpod_trials, *_ = bpod_extract_all(
×
761
                session_path=self.session_path, task_collection=task_collection, save=False,
762
                extractor_type=kwargs.get('extractor_type'))
763

764
        # Explode trials table df
765
        if 'table' in self.var_names:
1✔
766
            trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table'))
1✔
767
            table_columns = trials_table.keys()
1✔
768
            self.bpod_trials.update(trials_table)
1✔
769
        else:
770
            if 'table' in self.bpod_trials:
1✔
771
                _logger.error(
×
772
                    '"table" found in Bpod trials but missing from `var_names` attribute and will'
773
                    'therefore not be extracted. This is likely in error.')
774
            table_columns = None
1✔
775

776
        bpod = get_sync_fronts(sync, chmap['bpod'])
1✔
777
        # Get the spacer times for this protocol
778
        if any(arg in kwargs for arg in ('tmin', 'tmax')):
1✔
779
            tmin, tmax = kwargs.get('tmin'), kwargs.get('tmax')
×
780
        elif (protocol_number := kwargs.get('protocol_number')) is not None:  # look for spacer
1✔
781
            # The spacers are TTLs generated by Bpod at the start of each protocol
782
            tmin, tmax = get_protocol_period(self.session_path, protocol_number, bpod)
×
783
        else:
784
            # Older sessions don't have protocol spacers so we sync the Bpod intervals here to
785
            # find the approximate end time of the protocol (this will exclude the passive signals
786
            # in ephysChoiceWorld that tend to ruin the final trial extraction).
787
            _, trial_ints = self.get_bpod_event_times(sync, chmap, **kwargs)
1✔
788
            t_trial_start = trial_ints.get('trial_start', np.array([[np.nan, np.nan]]))[:, 0]
1✔
789
            bpod_start = self.bpod_trials['intervals'][:, 0]
1✔
790
            if len(t_trial_start) > len(bpod_start) / 2:  # if least half the trial start TTLs detected
1✔
791
                _logger.warning('Attempting to get protocol period from aligning trial start TTLs')
1✔
792
                fcn, *_ = ibldsp.utils.sync_timestamps(bpod_start, t_trial_start)
1✔
793
                buffer = 2.5  # the number of seconds to include before/after task
1✔
794
                start, end = fcn(self.bpod_trials['intervals'].flat[[0, -1]])
1✔
795
                tmin = min(sync['times'][0], start - buffer)
1✔
796
                tmax = max(sync['times'][-1], end + buffer)
1✔
797
            else:  # This type of alignment fails for some sessions, e.g. mesoscope
798
                tmin = tmax = None
1✔
799

800
        # Remove unnecessary data from sync
801
        selection = np.logical_and(
1✔
802
            sync['times'] <= (tmax if tmax is not None else sync['times'][-1]),
803
            sync['times'] >= (tmin if tmin is not None else sync['times'][0]),
804
        )
805
        sync = alfio.AlfBunch({k: v[selection] for k, v in sync.items()})
1✔
806
        _logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)',
1✔
807
                      *sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60)
808

809
        # Get the trial events from the DAQ sync TTLs, sync clocks and build final trials datasets
810
        out = self.build_trials(sync=sync, chmap=chmap, **kwargs)
1✔
811

812
        # extract the wheel data
813
        if any(x.startswith('wheel') for x in self.var_names):
1✔
814
            wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax)
1✔
815
            from ibllib.io.extractors.training_wheel import extract_first_movement_times
1✔
816
            if not self.settings:
1✔
817
                self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection)
1✔
818
            min_qt = self.settings.get('QUIESCENT_PERIOD', None)
1✔
819
            first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt)
1✔
820
            out.update({'firstMovement_times': first_move_onsets})
1✔
821
            out.update({f'wheel_{k}': v for k, v in wheel.items()})
1✔
822
            out.update({f'wheelMoves_{k}': v for k, v in moves.items()})
1✔
823

824
        # Re-create trials table
825
        if table_columns:
1✔
826
            trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns})
1✔
827
            out['table'] = trials_table.to_df()
1✔
828

829
        out = alfio.AlfBunch({k: out[k] for k in self.var_names if k in out})  # Reorder output
1✔
830
        assert self.var_names == tuple(out.keys())
1✔
831
        return out
1✔
832

833
    def _is_trials_object_attribute(self, var_name, variable_length_vars=None):
1✔
834
        """
835
        Check if variable name is expected to have the same length as trials.intervals.
836

837
        Parameters
838
        ----------
839
        var_name : str
840
            The variable name to check.
841
        variable_length_vars : list
842
            Set of variable names that are not expected to have the same length as trials.intervals.
843
            This list may be passed by superclasses.
844

845
        Returns
846
        -------
847
        bool
848
            True if variable is a trials dataset.
849

850
        Examples
851
        --------
852
        >>> assert self._is_trials_object_attribute('stimOnTrigger_times') is True
853
        >>> assert self._is_trials_object_attribute('wheel_position') is False
854
        """
855
        save_name = self.save_names[self.var_names.index(var_name)] if var_name in self.var_names else None
1✔
856
        if save_name:
1✔
857
            return filename_parts(save_name)[1] == 'trials'
1✔
858
        else:
859
            return var_name not in (variable_length_vars or [])
1✔
860

861
    def build_trials(self, sync, chmap, display=False, **kwargs):
1✔
862
        """
863
        Extract task related event times from the sync.
864

865
        The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The
866
        first trial start TTL of the session is longer and must be handled differently. The trial
867
        start TTL is used to assign the other trial events to each trial.
868

869
        The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest
870
        of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio
871
        tones. The first of these after each trial start is taken to be the go cue time. Error
872
        tones are longer audio TTLs and assigned as the last of such occurrence after each trial
873
        start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial.
874
        The feedback times are times of either valve open or error tone as there should be only one
875
        such event per trial.
876

877
        The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs
878
        removed): the first TTL after each trial start is assumed to be the stim onset time; the
879
        second to last and last are taken as the stimulus freeze and offset times, respectively.
880

881
        Parameters
882
        ----------
883
        sync : dict
884
            'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
885
        chmap : dict
886
            Map of channel names and their corresponding index.  Default to constant.
887
        display : bool, matplotlib.pyplot.Axes
888
            Show the full session sync pulses display.
889

890
        Returns
891
        -------
892
        dict
893
            A map of trial event timestamps.
894
        """
895
        # Get the events from the sync.
896
        # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
897
        self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs)
1✔
898
        self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs)
1✔
899
        if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}:
1✔
900
            raise ValueError(
×
901
                'Expected at least "ready_tone" and "error_tone" audio events.'
902
                '`audio_event_ttls` kwarg may be incorrect.')
903
        self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs)
1✔
904
        if not set(bpod_event_intervals.keys()) >= {'trial_start', 'valve_open', 'trial_end'}:
1✔
905
            raise ValueError(
×
906
                'Expected at least "trial_start", "trial_end", and "valve_open" audio events. '
907
                '`bpod_event_ttls` kwarg may be incorrect.')
908

909
        t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T
1✔
910
        fpga_events = alfio.AlfBunch({
1✔
911
            'goCue_times': audio_event_intervals['ready_tone'][:, 0],
912
            'errorCue_times': audio_event_intervals['error_tone'][:, 0],
913
            'valveOpen_times': bpod_event_intervals['valve_open'][:, 0],
914
            'valveClose_times': bpod_event_intervals['valve_open'][:, 1],
915
            'itiIn_times': t_iti_in,
916
            'intervals_0': bpod_event_intervals['trial_start'][:, 0],
917
            'intervals_1': t_trial_end
918
        })
919

920
        # Sync the Bpod clock to the DAQ.
921
        # NB: The Bpod extractor typically drops the final, incomplete, trial. Hence there is
922
        # usually at least one extra FPGA event. This shouldn't affect the sync. The final trial is
923
        # dropped after assigning the FPGA events, using the `ifpga` index. Doing this after
924
        # assigning the FPGA trial events ensures the last trial has the correct timestamps.
925
        self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field)
1✔
926

927
        if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0':
1✔
928
            # One issue is that sometimes pulses may not have been detected, in this case
929
            # add the events that have not been detected and re-extract the behaviour sync.
930
            # This is only really relevant for the Bpod interval events as the other TTLs are
931
            # from devices where a missing TTL likely means the Bpod event was truly absent.
932
            _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times')
×
933
            bpod_start = self.bpod_trials['intervals'][:, 0]
×
934
            missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))])
×
935
            t_trial_start = np.sort(np.r_[fpga_events['intervals_0'][:, 0], missing_bpod])
×
936
        else:
937
            t_trial_start = fpga_events['intervals_0']
1✔
938

939
        out = alfio.AlfBunch()
1✔
940
        # Add the Bpod trial events, converting the timestamp fields to FPGA time.
941
        # NB: The trial intervals are by default a Bpod rsync field.
942
        out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields})
1✔
943
        for k in self.bpod_rsync_fields:
1✔
944
            # Some personal projects may extract non-trials object datasets that may not have 1 event per trial
945
            idx = ibpod if self._is_trials_object_attribute(k) else np.arange(len(self.bpod_trials[k]), dtype=int)
1✔
946
            out[k] = self.bpod2fpga(self.bpod_trials[k][idx])
1✔
947

948
        f2ttl_t = self.frame2ttl['times']
1✔
949
        # Assign the FPGA events to individual trials
950
        fpga_trials = {
1✔
951
            'goCue_times': _assign_events_to_trial(t_trial_start, fpga_events['goCue_times'], take='first'),
952
            'errorCue_times': _assign_events_to_trial(t_trial_start, fpga_events['errorCue_times']),
953
            'valveOpen_times': _assign_events_to_trial(t_trial_start, fpga_events['valveOpen_times']),
954
            'itiIn_times': _assign_events_to_trial(t_trial_start, fpga_events['itiIn_times']),
955
            'stimOn_times': np.full_like(t_trial_start, np.nan),
956
            'stimOff_times': np.full_like(t_trial_start, np.nan),
957
            'stimFreeze_times': np.full_like(t_trial_start, np.nan)
958
        }
959

960
        # f2ttl times are unreliable owing to calibration and Bonsai sync square update issues.
961
        # Take the first event after the FPGA aligned stimulus trigger time.
962
        fpga_trials['stimOn_times'][ibpod] = _assign_events_to_trial(
1✔
963
            out['stimOnTrigger_times'], f2ttl_t, take='first', t_trial_end=out['stimOffTrigger_times'])
964
        fpga_trials['stimOff_times'][ibpod] = _assign_events_to_trial(
1✔
965
            out['stimOffTrigger_times'], f2ttl_t, take='first', t_trial_end=out['intervals'][:, 1])
966
        # For stim freeze we take the last event before the stim off trigger time.
967
        # To avoid assigning early events (e.g. for sessions where there are few flips due to
968
        # mis-calibration), we discount events before stim freeze trigger times (or stim on trigger
969
        # times for versions below 6.2.5). We take the last event rather than the first after stim
970
        # freeze trigger because often there are multiple flips after the trigger, presumably
971
        # before the stim actually stops.
972
        stim_freeze = np.copy(out['stimFreezeTrigger_times'])
1✔
973
        go_trials = np.where(out['choice'] != 0)[0]
1✔
974
        # NB: versions below 6.2.5 have no trigger times so use stim on trigger times
975
        lims = np.copy(out['stimOnTrigger_times'])
1✔
976
        if not np.isnan(stim_freeze).all():
1✔
977
            # Stim freeze times are NaN for nogo trials, but for all others use stim freeze trigger
978
            # times. _assign_events_to_trial requires ascending timestamps so no NaNs allowed.
979
            lims[go_trials] = stim_freeze[go_trials]
1✔
980
        # take last event after freeze/stim on trigger, before stim off trigger
981
        stim_freeze = _assign_events_to_trial(lims, f2ttl_t, take='last', t_trial_end=out['stimOffTrigger_times'])
1✔
982
        fpga_trials['stimFreeze_times'][go_trials] = stim_freeze[go_trials]
1✔
983

984
        # Feedback times are valve open on correct trials and error tone in on incorrect trials
985
        fpga_trials['feedback_times'] = np.copy(fpga_trials['valveOpen_times'])
1✔
986
        ind_err = np.isnan(fpga_trials['valveOpen_times'])
1✔
987
        fpga_trials['feedback_times'][ind_err] = fpga_trials['errorCue_times'][ind_err]
1✔
988

989
        out.update({k: fpga_trials[k][ifpga] for k in fpga_trials.keys()})
1✔
990

991
        if display:  # pragma: no cover
992
            width = 0.5
993
            ymax = 5
994
            if isinstance(display, bool):
995
                plt.figure('Bpod FPGA Sync')
996
                ax = plt.gca()
997
            else:
998
                ax = display
999
            plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k')
1000
            plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k')
1001
            plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k')
1002
            color_map = TABLEAU_COLORS.keys()
1003
            for (event_name, event_times), c in zip(fpga_events.items(), cycle(color_map)):
1004
                plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width)
1005
            # Plot the stimulus events along with the trigger times
1006
            stim_events = filter(lambda t: 'stim' in t[0], fpga_trials.items())
1007
            for (event_name, event_times), c in zip(stim_events, cycle(color_map)):
1008
                plots.vertical_lines(
1009
                    event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width, linestyle='--')
1010
                nm = event_name.replace('_times', 'Trigger_times')
1011
                plots.vertical_lines(
1012
                    out[nm], ymin=0, ymax=ymax, ax=ax, color=c, label=nm, linewidth=width, linestyle=':')
1013
            ax.legend()
1014
            ax.set_yticks([0, 1, 2, 3])
1015
            ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio'])
1016
            ax.set_ylim([0, 5])
1017
        return out
1✔
1018

1019
    def get_wheel_positions(self, *args, **kwargs):
1✔
1020
        """Extract wheel and wheelMoves objects.
1021

1022
        This method is called by the main extract method and may be overloaded by subclasses.
1023
        """
1024
        return get_wheel_positions(*args, **kwargs)
1✔
1025

1026
    def get_stimulus_update_times(self, sync, chmap, display=False, **_):
1✔
1027
        """
1028
        Extract stimulus update times from sync.
1029

1030
        Gets the stimulus times from the frame2ttl channel and cleans the signal.
1031

1032
        Parameters
1033
        ----------
1034
        sync : dict
1035
            A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1036
            and the corresponding channel numbers.
1037
        chmap : dict
1038
            A map of channel names and their corresponding indices. Must contain a 'frame2ttl' key.
1039
        display : bool
1040
            If true, plots the input TTLs and the cleaned output.
1041

1042
        Returns
1043
        -------
1044
        dict
1045
            A dictionary with keys {'times', 'polarities'} containing stimulus TTL fronts.
1046
        """
1047
        frame2ttl = get_sync_fronts(sync, chmap['frame2ttl'])
1✔
1048
        frame2ttl = _clean_frame2ttl(frame2ttl, display=display)
1✔
1049
        return frame2ttl
1✔
1050

1051
    def get_audio_event_times(self, sync, chmap, audio_event_ttls=None, display=False, **_):
1✔
1052
        """
1053
        Extract audio times from sync.
1054

1055
        Gets the TTL times from the 'audio' channel, cleans the signal, and classifies each TTL
1056
        event by length.
1057

1058
        Parameters
1059
        ----------
1060
        sync : dict
1061
            A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1062
            and the corresponding channel numbers.
1063
        chmap : dict
1064
            A map of channel names and their corresponding indices. Must contain an 'audio' key.
1065
        audio_event_ttls : dict
1066
            A map of event names to (min, max) TTL length.
1067
        display : bool
1068
            If true, plots the input TTLs and the cleaned output.
1069

1070
        Returns
1071
        -------
1072
        dict
1073
            A dictionary with keys {'times', 'polarities'} containing audio TTL fronts.
1074
        dict
1075
            A dictionary of events (from `audio_event_ttls`) and their intervals as an Nx2 array.
1076
        """
1077
        audio = get_sync_fronts(sync, chmap['audio'])
1✔
1078
        audio = _clean_audio(audio)
1✔
1079

1080
        if audio['times'].size == 0:
1✔
1081
            _logger.error('No audio sync fronts found.')
×
1082

1083
        if audio_event_ttls is None:
1✔
1084
            # For training/biased/ephys protocols, the ready tone should be below 110 ms. The error
1085
            # tone should be between 400ms and 1200ms
1086
            audio_event_ttls = {'ready_tone': (0, 0.11), 'error_tone': (0.4, 1.2)}
1✔
1087
        audio_event_intervals = self._assign_events(audio['times'], audio['polarities'], audio_event_ttls, display=display)
1✔
1088

1089
        return audio, audio_event_intervals
1✔
1090

1091
    def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
1✔
1092
        """
1093
        Extract Bpod times from sync.
1094

1095
        Gets the Bpod TTL times from the sync 'bpod' channel and classifies each TTL event by
1096
        length. NB: The first trial has an abnormal trial_start TTL that is usually mis-assigned.
1097
        This method accounts for this.
1098

1099
        Parameters
1100
        ----------
1101
        sync : dict
1102
            A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1103
            and the corresponding channel numbers. Must contain a 'bpod' key.
1104
        chmap : dict
1105
            A map of channel names and their corresponding indices.
1106
        bpod_event_ttls : dict of tuple
1107
            A map of event names to (min, max) TTL length.
1108

1109
        Returns
1110
        -------
1111
        dict
1112
            A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
1113
        dict
1114
            A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
1115
        """
1116
        bpod = get_sync_fronts(sync, chmap['bpod'])
1✔
1117
        if bpod.times.size == 0:
1✔
1118
            raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. '
×
1119
                                            'Check channel maps.')
1120
        # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
1121
        # lengths are defined by the state machine of the task protocol and therefore vary.
1122
        if bpod_event_ttls is None:
1✔
1123
            # For training/biased/ephys protocols, the trial start TTL length is 0.1ms but this has
1124
            # proven to drift on some Bpods and this is the highest possible value that
1125
            # discriminates trial start from valve. Valve open events are between 50ms to 300 ms.
1126
            # ITI events are above 400 ms.
1127
            bpod_event_ttls = {
1✔
1128
                'trial_start': (0, 2.33e-4), 'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)}
1129
        bpod_event_intervals = self._assign_events(
1✔
1130
            bpod['times'], bpod['polarities'], bpod_event_ttls, display=display)
1131

1132
        if 'trial_start' not in bpod_event_intervals or bpod_event_intervals['trial_start'].size == 0:
1✔
1133
            return bpod, bpod_event_intervals
1✔
1134

1135
        # The first trial pulse is longer and often assigned to another event.
1136
        # Here we move the earliest non-trial_start event to the trial_start array.
1137
        t0 = bpod_event_intervals['trial_start'][0, 0]  # expect 1st event to be trial_start
1✔
1138
        pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0]
1✔
1139
        if pretrial:
1✔
1140
            (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0]  # take the earliest event
1✔
1141
            dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3  # record TTL length to log
1✔
1142
            _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt)
1✔
1143
            bpod_event_intervals['trial_start'] = np.r_[
1✔
1144
                bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start']
1145
            ]
1146
            bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :]
1✔
1147

1148
        return bpod, bpod_event_intervals
1✔
1149

1150
    @staticmethod
1✔
1151
    def _assign_events(ts, polarities, event_lengths, precedence='shortest', display=False):
1✔
1152
        """
1153
        Classify TTL events by length.
1154

1155
        Outputs the synchronisation events such as trial intervals, valve opening, and audio.
1156

1157
        Parameters
1158
        ----------
1159
        ts : numpy.array
1160
            Numpy vector containing times of TTL fronts.
1161
        polarities : numpy.array
1162
            Numpy vector containing polarity of TTL fronts (1 rise, -1 fall).
1163
        event_lengths : dict of tuple
1164
            A map of TTL events and the range of permissible lengths, where l0 < ttl <= l1.
1165
        precedence : str {'shortest', 'longest', 'dict order'}
1166
            In the case of overlapping event TTL lengths, assign shortest/longest first or go by
1167
            the `event_lengths` dict order.
1168
        display : bool
1169
            If true, plots the TTLs with coloured lines delineating the assigned events.
1170

1171
        Returns
1172
        -------
1173
        Dict[str, numpy.array]
1174
            A dictionary of events and their intervals as an Nx2 array.
1175

1176
        See Also
1177
        --------
1178
        _assign_events_to_trial - classify TTLs by event order within a given trial period.
1179
        """
1180
        event_intervals = dict.fromkeys(event_lengths)
1✔
1181
        assert 'unassigned' not in event_lengths.keys()
1✔
1182

1183
        if len(ts) == 0:
1✔
1184
            return {k: np.array([[], []]).T for k in (*event_lengths.keys(), 'unassigned')}
×
1185

1186
        # make sure that there are no 2 consecutive fall or consecutive rise events
1187
        assert np.all(np.abs(np.diff(polarities)) == 2)
1✔
1188
        if polarities[0] == -1:
1✔
1189
            ts = np.delete(ts, 0)
1✔
1190
        if polarities[-1] == 1:  # if the final TTL is left HIGH, insert a NaN
1✔
1191
            ts = np.r_[ts, np.nan]
1✔
1192
        # take only even time differences: i.e. from rising to falling fronts
1193
        dt = np.diff(ts)[::2]
1✔
1194

1195
        # Assign events from shortest TTL to largest
1196
        assigned = np.zeros(ts.shape, dtype=bool)
1✔
1197
        if precedence.lower() == 'shortest':
1✔
1198
            event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1]))
1✔
1199
        elif precedence.lower() == 'longest':
×
1200
            event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1]), reverse=True)
×
1201
        elif precedence.lower() == 'dict order':
×
1202
            event_items = event_lengths.items()
×
1203
        else:
1204
            raise ValueError(f'Precedence must be one of "shortest", "longest", "dict order", got "{precedence}".')
×
1205
        for event, (min_len, max_len) in event_items:
1✔
1206
            _logger.debug('%s: %.4G < ttl <= %.4G', event, min_len, max_len)
1✔
1207
            i_event = np.where(np.logical_and(dt > min_len, dt <= max_len))[0] * 2
1✔
1208
            i_event = i_event[np.where(~assigned[i_event])[0]]  # remove those already assigned
1✔
1209
            event_intervals[event] = np.c_[ts[i_event], ts[i_event + 1]]
1✔
1210
            assigned[np.r_[i_event, i_event + 1]] = True
1✔
1211

1212
        # Include the unassigned events for convenience and debugging
1213
        event_intervals['unassigned'] = ts[~assigned].reshape(-1, 2)
1✔
1214

1215
        # Assert that event TTLs mutually exclusive
1216
        all_assigned = np.concatenate(list(event_intervals.values())).flatten()
1✔
1217
        assert all_assigned.size == np.unique(all_assigned).size, 'TTLs assigned to multiple events'
1✔
1218

1219
        # some debug plots when needed
1220
        if display:  # pragma: no cover
1221
            plt.figure()
1222
            plots.squares(ts, polarities, label='raw fronts')
1223
            for event, intervals in event_intervals.items():
1224
                plots.vertical_lines(intervals[:, 0], ymin=-0.2, ymax=1.1, linewidth=0.5, label=event)
1225
            plt.legend()
1226

1227
        # Return map of event intervals in the same order as `event_lengths` dict
1228
        return {k: event_intervals[k] for k in (*event_lengths, 'unassigned')}
1✔
1229

1230
    @staticmethod
1✔
1231
    def sync_bpod_clock(bpod_trials, fpga_trials, sync_field):
1✔
1232
        """
1233
        Sync the Bpod clock to FPGA one using the provided trial event.
1234

1235
        It assumes that `sync_field` is in both `fpga_trials` and `bpod_trials`. Syncing on both
1236
        intervals is not supported so to sync on trial start times, `sync_field` should be
1237
        'intervals_0'.
1238

1239
        Parameters
1240
        ----------
1241
        bpod_trials : dict
1242
            A dictionary of extracted Bpod trial events.
1243
        fpga_trials : dict
1244
            A dictionary of TTL events extracted from FPGA sync (see `extract_behaviour_sync`
1245
            method).
1246
        sync_field : str
1247
            The trials key to use for syncing clocks. For intervals (i.e. Nx2 arrays) append the
1248
            column index, e.g. 'intervals_0'.
1249

1250
        Returns
1251
        -------
1252
        function
1253
            Interpolation function such that f(timestamps_bpod) = timestamps_fpga.
1254
        float
1255
            The clock drift in parts per million.
1256
        numpy.array of int
1257
            The indices of the Bpod trial events in the FPGA trial events array.
1258
        numpy.array of int
1259
            The indices of the FPGA trial events in the Bpod trial events array.
1260

1261
        Raises
1262
        ------
1263
        ValueError
1264
            The key `sync_field` was not found in either the `bpod_trials` or `fpga_trials` dicts.
1265
        """
1266
        _logger.info(f'Attempting to align Bpod clock to DAQ using trial event "{sync_field}"')
1✔
1267
        bpod_fpga_timestamps = [None, None]
1✔
1268
        for i, trials in enumerate((bpod_trials, fpga_trials)):
1✔
1269
            if sync_field not in trials:
1✔
1270
                # handle syncing on intervals
1271
                if not (m := re.match(r'(.*)_(\d)', sync_field)):
1✔
1272
                    # If missing from bpod trials, either the sync field is incorrect,
1273
                    # or the Bpod extractor is incorrect. If missing from the fpga events, check
1274
                    # the sync field and the `extract_behaviour_sync` method.
1275
                    raise ValueError(
×
1276
                        f'Sync field "{sync_field}" not in extracted {"fpga" if i else "bpod"} events')
1277
                _sync_field, n = m.groups()
1✔
1278
                bpod_fpga_timestamps[i] = trials[_sync_field][:, int(n)]
1✔
1279
            else:
1280
                bpod_fpga_timestamps[i] = trials[sync_field]
1✔
1281

1282
        # Sync the two timestamps
1283
        fcn, drift, ibpod, ifpga = ibldsp.utils.sync_timestamps(*bpod_fpga_timestamps, return_indices=True)
1✔
1284

1285
        # If it's drifting too much throw warning or error
1286
        _logger.info('N trials: %i bpod, %i FPGA, %i merged, sync %.5f ppm',
1✔
1287
                     *map(len, bpod_fpga_timestamps), len(ibpod), drift)
1288
        if drift > 200 and bpod_fpga_timestamps[0].size != bpod_fpga_timestamps[1].size:
1✔
1289
            raise err.SyncBpodFpgaException('sync cluster f*ck')
×
1290
        elif drift > BPOD_FPGA_DRIFT_THRESHOLD_PPM:
1✔
1291
            _logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm',
×
1292
                            BPOD_FPGA_DRIFT_THRESHOLD_PPM)
1293

1294
        return fcn, drift, ibpod, ifpga
1✔
1295

1296

1297
class FpgaTrialsHabituation(FpgaTrials):
1✔
1298
    """Extract habituationChoiceWorld trial events from an NI DAQ."""
1✔
1299

1300
    save_names = ('_ibl_trials.stimCenter_times.npy', '_ibl_trials.feedbackType.npy', '_ibl_trials.rewardVolume.npy',
1✔
1301
                  '_ibl_trials.stimOff_times.npy', '_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy',
1302
                  '_ibl_trials.feedback_times.npy', '_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOnTrigger_times.npy',
1303
                  '_ibl_trials.intervals.npy', '_ibl_trials.goCue_times.npy', '_ibl_trials.goCueTrigger_times.npy',
1304
                  None, None, None, None, None)
1305
    """tuple of str: The filenames of each extracted dataset, or None if array should not be saved."""
1✔
1306

1307
    var_names = ('stimCenter_times', 'feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft',
1✔
1308
                 'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
1309
                 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
1310
                 'stimCenterTrigger_times', 'position', 'phase')
1311
    """tuple of str: A list of names for the extracted variables. These become the returned output keys."""
1✔
1312

1313
    bpod_rsync_fields = ('intervals', 'stimOn_times', 'feedback_times', 'stimCenterTrigger_times',
1✔
1314
                         'goCue_times', 'itiIn_times', 'stimOffTrigger_times', 'stimOff_times',
1315
                         'stimCenter_times', 'stimOnTrigger_times', 'goCueTrigger_times')
1316
    """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA."""
1✔
1317

1318
    bpod_fields = ('feedbackType', 'rewardVolume', 'contrastLeft', 'contrastRight', 'position', 'phase')
1✔
1319
    """tuple of str: Fields from Bpod extractor that we want to save."""
1✔
1320

1321
    sync_field = 'feedback_times'  # valve open events
1✔
1322
    """str: The trial event to synchronize (must be present in extracted trials)."""
1✔
1323

1324
    def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data',
1✔
1325
                 task_collection='raw_behavior_data', **kwargs) -> dict:
1326
        """
1327
        Extract habituationChoiceWorld trial events from an NI DAQ.
1328

1329
        It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field`
1330
        attributes are all correct for the bpod protocol used.
1331

1332
        Unlike FpgaTrials, this class assumes different Bpod TTL events and syncs the Bpod clock
1333
        using the valve open times, instead of the trial start times.
1334

1335
        Parameters
1336
        ----------
1337
        sync : dict
1338
            A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1339
            and the corresponding channel numbers. If None, the sync is loaded using the
1340
            `load_sync` method.
1341
        dict
1342
            A map of channel names and their corresponding indices. If None, the channel map is
1343
            loaded using the `load_sync` method.
1344
        sync_collection : str
1345
            The session subdirectory where the sync data are located. This is only used if the
1346
            sync or channel maps are not provided.
1347
        task_collection : str
1348
            The session subdirectory where the raw Bpod data are located. This is used for loading
1349
            the task settings and extracting the bpod trials, if not already done.
1350
        protocol_number : int
1351
            The protocol number if multiple protocols were run during the session. If provided, a
1352
            spacer signal must be present in order to determine the correct period.
1353
        kwargs
1354
            Optional arguments for class methods, e.g. 'display', 'bpod_event_ttls'.
1355

1356
        Returns
1357
        -------
1358
        dict
1359
            A dictionary of numpy arrays with `FpgaTrialsHabituation.var_names` as keys.
1360
        """
1361
        # Version check: the ITI in TTL was added in a later version
1362
        if not self.settings:
1✔
1363
            self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection)
×
1364
        iblrig_version = version.parse(self.settings.get('IBL_VERSION', '0.0.0'))
1✔
1365
        if version.parse('8.9.3') <= iblrig_version < version.parse('8.12.6'):
1✔
1366
            """A second 1s TTL was added in this version during the 'iti' state, however this is
×
1367
            unrelated to the trial ITI and is unfortunately the same length as the trial start TTL."""
1368
            raise NotImplementedError('Ambiguous TTLs in 8.9.3 >= version < 8.12.6')
×
1369

1370
        trials = super()._extract(sync=sync, chmap=chmap, sync_collection=sync_collection,
1✔
1371
                                  task_collection=task_collection, **kwargs)
1372

1373
        return trials
1✔
1374

1375
    def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs):
1✔
1376
        """
1377
        Extract Bpod times from sync.
1378

1379
        Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse.
1380
        Also the first trial pulse is incorrectly assigned due to its abnormal length.
1381

1382
        Parameters
1383
        ----------
1384
        sync : dict
1385
            A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
1386
            and the corresponding channel numbers. Must contain a 'bpod' key.
1387
        chmap : dict
1388
            A map of channel names and their corresponding indices.
1389
        bpod_event_ttls : dict of tuple
1390
            A map of event names to (min, max) TTL length.
1391

1392
        Returns
1393
        -------
1394
        dict
1395
            A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts.
1396
        dict
1397
            A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array.
1398
        """
1399
        bpod = get_sync_fronts(sync, chmap['bpod'])
1✔
1400
        if bpod.times.size == 0:
1✔
1401
            raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. '
×
1402
                                            'Check channel maps.')
1403
        # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these
1404
        # lengths are defined by the state machine of the task protocol and therefore vary.
1405
        if bpod_event_ttls is None:
1✔
1406
            # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse
1407
            bpod_event_ttls = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)}
1✔
1408
        bpod_event_intervals = self._assign_events(
1✔
1409
            bpod['times'], bpod['polarities'], bpod_event_ttls, display=display)
1410

1411
        # The first trial pulse is shorter and assigned to valve_open. Here we remove the first
1412
        # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was
1413
        # incomplete in Bpod.
1414
        bpod_event_intervals['trial_iti'] = np.r_[bpod_event_intervals['valve_open'][0:1, :],
1✔
1415
                                                  bpod_event_intervals['trial_iti']]
1416
        bpod_event_intervals['valve_open'] = bpod_event_intervals['valve_open'][1:, :]
1✔
1417

1418
        return bpod, bpod_event_intervals
1✔
1419

1420
    def build_trials(self, sync, chmap, display=False, **kwargs):
1✔
1421
        """
1422
        Extract task related event times from the sync.
1423

1424
        This is called by the superclass `_extract` method.  The key difference here is that the
1425
        `trial_start` LOW->HIGH is the trial end, and HIGH->LOW is trial start.
1426

1427
        Parameters
1428
        ----------
1429
        sync : dict
1430
            'polarities' of fronts detected on sync trace for all 16 chans and their 'times'
1431
        chmap : dict
1432
            Map of channel names and their corresponding index.  Default to constant.
1433
        display : bool, matplotlib.pyplot.Axes
1434
            Show the full session sync pulses display.
1435

1436
        Returns
1437
        -------
1438
        dict
1439
            A map of trial event timestamps.
1440
        """
1441
        # Get the events from the sync.
1442
        # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC
1443
        self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs)
1✔
1444
        self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs)
1✔
1445
        self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs)
1✔
1446
        if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_iti'}:
1✔
1447
            raise ValueError(
×
1448
                'Expected at least "trial_iti" and "valve_open" Bpod events. `bpod_event_ttls` kwarg may be incorrect.')
1449

1450
        fpga_events = alfio.AlfBunch({
1✔
1451
            'feedback_times': bpod_event_intervals['valve_open'][:, 0],
1452
            'valveClose_times': bpod_event_intervals['valve_open'][:, 1],
1453
            'intervals_0': bpod_event_intervals['trial_iti'][:, 1],
1454
            'intervals_1': bpod_event_intervals['trial_iti'][:, 0],
1455
            'goCue_times': audio_event_intervals['ready_tone'][:, 0]
1456
        })
1457

1458
        # Sync the Bpod clock to the DAQ.
1459
        self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field)
1✔
1460

1461
        out = alfio.AlfBunch()
1✔
1462
        # Add the Bpod trial events, converting the timestamp fields to FPGA time.
1463
        # NB: The trial intervals are by default a Bpod rsync field.
1464
        out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields})
1✔
1465
        out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields})
1✔
1466

1467
        # Assigning each event to a trial ensures exactly one event per trial (missing events are NaN)
1468
        assign_to_trial = partial(_assign_events_to_trial, fpga_events['intervals_0'])
1✔
1469
        trials = alfio.AlfBunch({
1✔
1470
            'goCue_times': assign_to_trial(fpga_events['goCue_times'], take='first'),
1471
            'feedback_times': assign_to_trial(fpga_events['feedback_times']),
1472
            'stimCenter_times': assign_to_trial(self.frame2ttl['times'], take=-2),
1473
            'stimOn_times': assign_to_trial(self.frame2ttl['times'], take='first'),
1474
            'stimOff_times': assign_to_trial(self.frame2ttl['times']),
1475
        })
1476
        out.update({k: trials[k][ifpga] for k in trials.keys()})
1✔
1477

1478
        # If stim on occurs before trial end, use stim on time. Likewise for trial end and stim off
1479
        to_correct = ~np.isnan(out['stimOn_times']) & (out['stimOn_times'] < out['intervals'][:, 0])
1✔
1480
        if np.any(to_correct):
1✔
1481
            _logger.warning('%i/%i stim on events occurring outside trial intervals', sum(to_correct), len(to_correct))
×
1482
            out['intervals'][to_correct, 0] = out['stimOn_times'][to_correct]
×
1483
        to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1])
1✔
1484
        if np.any(to_correct):
1✔
1485
            _logger.debug(
1✔
1486
                '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end',
1487
                sum(to_correct), len(to_correct))
1488
            out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct]
1✔
1489

1490
        if display:  # pragma: no cover
1491
            width = 0.5
1492
            ymax = 5
1493
            if isinstance(display, bool):
1494
                plt.figure('Bpod FPGA Sync')
1495
                ax = plt.gca()
1496
            else:
1497
                ax = display
1498
            plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k')
1499
            plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k')
1500
            plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k')
1501
            color_map = TABLEAU_COLORS.keys()
1502
            for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)):
1503
                plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width)
1504
            ax.legend()
1505
            ax.set_yticks([0, 1, 2, 3])
1506
            ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio'])
1507
            ax.set_ylim([0, 4])
1508

1509
        return out
1✔
1510

1511

1512
def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_path=None,
1✔
1513
                task_collection='raw_behavior_data', protocol_number=None, **kwargs):
1514
    """
1515
    For the IBL ephys task, reads ephys binary file and extract:
1516
        -   sync
1517
        -   wheel
1518
        -   behaviour
1519

1520
    These `extract_all` functions should be deprecated as they make assumptions about hardware
1521
    parameters.  Additionally the FpgaTrials class now automatically loads DAQ sync files, extracts
1522
    the Bpod trials, and returns a dict instead of a tuple. Therefore this function is entirely
1523
    redundant. See the examples for the correct way to extract NI DAQ behaviour sessions.
1524

1525
    Parameters
1526
    ----------
1527
    session_path : str, pathlib.Path
1528
        The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
1529
    sync_collection : str
1530
        The session subdirectory where the sync data are located.
1531
    save : bool
1532
        If true, save the extracted files to save_path.
1533
    task_collection : str
1534
        The location of the behaviour protocol data.
1535
    save_path : str, pathlib.Path
1536
        The save location of the extracted files, defaults to the alf directory of the session path.
1537
    protocol_number : int
1538
        The order that the protocol was run in.
1539
    **kwargs
1540
        Optional extractor keyword arguments.
1541

1542
    Returns
1543
    -------
1544
    list
1545
        The extracted data.
1546
    list of pathlib.Path, None
1547
        If save is True, a list of file paths to the extracted data.
1548
    """
1549
    warnings.warn(
1✔
1550
        'ibllib.io.extractors.ephys_fpga.extract_all will be removed in future versions; '
1551
        'use FpgaTrials instead. For reliable extraction, use the dynamic pipeline behaviour tasks.',
1552
        FutureWarning)
1553
    return_extractor = kwargs.pop('return_extractor', False)
1✔
1554
    # Extract Bpod trials
1555
    bpod_raw = raw.load_data(session_path, task_collection=task_collection)
1✔
1556
    assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit'
1✔
1557
    bpod_trials, bpod_wheel, *_ = bpod_extract_all(
1✔
1558
        session_path=session_path, bpod_trials=bpod_raw, task_collection=task_collection,
1559
        save=False, extractor_type=kwargs.get('extractor_type'))
1560

1561
    # Sync Bpod trials to FPGA
1562
    sync, chmap = get_sync_and_chn_map(session_path, sync_collection)
1✔
1563
    # sync, chmap = get_main_probe_sync(session_path, bin_exists=bin_exists)
1564
    trials = FpgaTrials(session_path, bpod_trials={**bpod_trials, **bpod_wheel})  # py3.9 -> |
1✔
1565
    outputs, files = trials.extract(
1✔
1566
        save=save, sync=sync, chmap=chmap, path_out=save_path,
1567
        task_collection=task_collection, protocol_number=protocol_number, **kwargs)
1568
    if not isinstance(outputs, dict):
1✔
1569
        outputs = {k: v for k, v in zip(trials.var_names, outputs)}
×
1570
    if return_extractor:
1✔
1571
        return outputs, files, trials
1✔
1572
    else:
1573
        return outputs, files
×
1574

1575

1576
def get_sync_and_chn_map(session_path, sync_collection):
1✔
1577
    """
1578
    Return sync and channel map for session based on collection where main sync is stored.
1579

1580
    Parameters
1581
    ----------
1582
    session_path : str, pathlib.Path
1583
        The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
1584
    sync_collection : str
1585
        The session subdirectory where the sync data are located.
1586

1587
    Returns
1588
    -------
1589
    one.alf.io.AlfBunch
1590
        A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
1591
        the corresponding channel numbers.
1592
    dict
1593
        A map of channel names and their corresponding indices.
1594
    """
1595
    if sync_collection == 'raw_ephys_data':
1✔
1596
        # Check to see if we have nidq files, if we do just go with this otherwise go into other function that deals with
1597
        # 3A probes
1598
        nidq_meta = next(session_path.joinpath(sync_collection).glob('*nidq.meta'), None)
1✔
1599
        if not nidq_meta:
1✔
1600
            sync, chmap = get_main_probe_sync(session_path)
1✔
1601
        else:
1602
            sync = load_sync(session_path, sync_collection)
1✔
1603
            ef = Bunch()
1✔
1604
            ef['path'] = session_path.joinpath(sync_collection)
1✔
1605
            ef['nidq'] = nidq_meta
1✔
1606
            chmap = get_ibl_sync_map(ef, '3B')
1✔
1607

1608
    else:
1609
        sync = load_sync(session_path, sync_collection)
1✔
1610
        chmap = load_channel_map(session_path, sync_collection)
1✔
1611

1612
    return sync, chmap
1✔
1613

1614

1615
def load_channel_map(session_path, sync_collection):
1✔
1616
    """
1617
    Load syncing channel map for session path and collection
1618

1619
    Parameters
1620
    ----------
1621
    session_path : str, pathlib.Path
1622
        The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
1623
    sync_collection : str
1624
        The session subdirectory where the sync data are located.
1625

1626
    Returns
1627
    -------
1628
    dict
1629
        A map of channel names and their corresponding indices.
1630
    """
1631

1632
    device = sync_collection.split('_')[1]
1✔
1633
    default_chmap = DEFAULT_MAPS[device]['nidq']
1✔
1634

1635
    # Try to load channel map from file
1636
    chmap = spikeglx.get_sync_map(session_path.joinpath(sync_collection))
1✔
1637
    # If chmap provided but not with all keys, fill up with default values
1638
    if not chmap:
1✔
1639
        return default_chmap
1✔
1640
    else:
1641
        if data_for_keys(default_chmap.keys(), chmap):
1✔
1642
            return chmap
1✔
1643
        else:
1644
            _logger.warning('Keys missing from provided channel map, '
×
1645
                            'setting missing keys from default channel map')
1646
            return {**default_chmap, **chmap}
×
1647

1648

1649
def load_sync(session_path, sync_collection):
1✔
1650
    """
1651
    Load sync files from session path and collection.
1652

1653
    Parameters
1654
    ----------
1655
    session_path : str, pathlib.Path
1656
        The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'.
1657
    sync_collection : str
1658
        The session subdirectory where the sync data are located.
1659

1660
    Returns
1661
    -------
1662
    one.alf.io.AlfBunch
1663
        A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses and
1664
        the corresponding channel numbers.
1665
    """
1666
    sync = alfio.load_object(session_path.joinpath(sync_collection), 'sync', namespace='spikeglx', short_keys=True)
1✔
1667

1668
    return sync
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc