• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / ibllib / 6161719581526678

28 Jun 2024 01:14PM UTC coverage: 64.584% (+0.03%) from 64.55%
6161719581526678

push

tests

web-flow
Stim on extraction (#788)

* Issue #775
* Handle no go trials
* Pre-6.2.5 trials extraction
* DeprecationWarning -> FutureWarning; extractor fixes; timeline trials extraction

49 of 60 new or added lines in 9 files covered. (81.67%)

2 existing lines in 1 file now uncovered.

13055 of 20214 relevant lines covered (64.58%)

0.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.34
/brainbox/behavior/wheel.py
1
"""
1✔
2
Set of functions to handle wheel data.
3
"""
4
import logging
1✔
5
import warnings
1✔
6
import traceback
1✔
7

8
import numpy as np
1✔
9
from numpy import pi
1✔
10
from iblutil.numerical import between_sorted
1✔
11
import scipy.interpolate as interpolate
1✔
12
import scipy.signal
1✔
13
from scipy.linalg import hankel
1✔
14
import matplotlib.pyplot as plt
1✔
15
from matplotlib.collections import LineCollection
1✔
16
# from ibllib.io.extractors.ephys_fpga import WHEEL_TICKS  # FIXME Circular dependencies
17

18
__all__ = ['cm_to_deg',
1✔
19
           'cm_to_rad',
20
           'interpolate_position',
21
           'get_movement_onset',
22
           'movements',
23
           'samples_to_cm',
24
           'traces_by_trial',
25
           'velocity_filtered']
26

27
# Define some constants
28
ENC_RES = 1024 * 4  # Rotary encoder resolution, assumes X4 encoding
1✔
29
WHEEL_DIAMETER = 3.1 * 2  # Wheel diameter in cm
1✔
30

31

32
def interpolate_position(re_ts, re_pos, freq=1000, kind='linear', fill_gaps=None):
1✔
33
    """
34
    Return linearly interpolated wheel position.
35

36
    Parameters
37
    ----------
38
    re_ts : array_like
39
        Array of timestamps
40
    re_pos: array_like
41
        Array of unwrapped wheel positions
42
    freq : float
43
        frequency in Hz of the interpolation
44
    kind : {'linear', 'cubic'}
45
        Type of interpolation. Defaults to linear interpolation.
46
    fill_gaps : float
47
        Minimum gap length to fill. For gaps over this time (seconds),
48
        forward fill values before interpolation
49
    Returns
50
    -------
51
    yinterp : array
52
        Interpolated position
53
    t : array
54
        Timestamps of interpolated positions
55
    """
56
    t = np.arange(re_ts[0], re_ts[-1], 1 / freq)  # Evenly resample at frequency
1✔
57
    if t[-1] > re_ts[-1]:
1✔
58
        t = t[:-1]  # Occasionally due to precision errors the last sample may be outside of range.
1✔
59
    yinterp = interpolate.interp1d(re_ts, re_pos, kind=kind)(t)
1✔
60

61
    if fill_gaps:
1✔
62
        #  Find large gaps and forward fill @fixme This is inefficient
63
        gaps, = np.where(np.diff(re_ts) >= fill_gaps)
×
64

65
        for i in gaps:
×
66
            yinterp[(t >= re_ts[i]) & (t < re_ts[i + 1])] = re_pos[i]
×
67

68
    return yinterp, t
1✔
69

70

71
def velocity(re_ts, re_pos):
1✔
72
    """
73
    (DEPRECATED) Compute wheel velocity from non-uniformly sampled wheel data. Returns the velocity
74
    at the same samples locations as the position through interpolation.
75

76
    Parameters
77
    ----------
78
    re_ts : array_like
79
        Array of timestamps
80
    re_pos: array_like
81
        Array of unwrapped wheel positions
82

83
    Returns
84
    -------
85
    np.ndarray
86
        numpy array of velocities
87
    """
88
    for line in traceback.format_stack():
×
89
        print(line.strip())
×
90

NEW
91
    msg = 'brainbox.behavior.wheel.velocity will soon be removed. Use velocity_filtered instead.'
×
NEW
92
    warnings.warn(msg, FutureWarning)
×
UNCOV
93
    logging.getLogger(__name__).warning(msg)
×
94

95
    dp = np.diff(re_pos)
×
96
    dt = np.diff(re_ts)
×
97
    # Compute raw velocity
98
    vel = dp / dt
×
99
    # Compute velocity time scale
100
    tv = re_ts[:-1] + dt / 2
×
101
    # interpolate over original time scale
102
    if tv.size > 1:
×
103
        ifcn = interpolate.interp1d(tv, vel, fill_value="extrapolate")
×
104
        return ifcn(re_ts)
×
105

106

107
def velocity_filtered(pos, fs, corner_frequency=20, order=8):
1✔
108
    """
109
    Compute wheel velocity from uniformly sampled wheel data.
110

111
    pos: array_like
112
        Vector of uniformly sampled wheel positions.
113
    fs : float
114
        Frequency in Hz of the sampling frequency.
115
    corner_frequency : float
116
       Corner frequency of low-pass filter.
117
    order : int
118
        Order of Butterworth filter.
119

120
    Returns
121
    -------
122
    vel : np.ndarray
123
        Array of velocity values.
124
    acc : np.ndarray
125
        Array of acceleration values.
126
    """
127
    sos = scipy.signal.butter(**{'N': order, 'Wn': corner_frequency / fs * 2, 'btype': 'lowpass'}, output='sos')
1✔
128
    vel = np.insert(np.diff(scipy.signal.sosfiltfilt(sos, pos)), 0, 0) * fs
1✔
129
    acc = np.insert(np.diff(vel), 0, 0) * fs
1✔
130
    return vel, acc
1✔
131

132

133
def velocity_smoothed(pos, freq, smooth_size=0.03):
1✔
134
    """
135
    (DEPRECATED) Compute wheel velocity from uniformly sampled wheel data.
136

137
    Parameters
138
    ----------
139
    pos : array_like
140
        Array of wheel positions
141
    smooth_size : float
142
        Size of Gaussian smoothing window in seconds
143
    freq : float
144
        Sampling frequency of the data
145

146
    Returns
147
    -------
148
    vel : np.ndarray
149
        Array of velocity values
150
    acc : np.ndarray
151
        Array of acceleration values
152
    """
153
    for line in traceback.format_stack():
×
154
        print(line.strip())
×
155

NEW
156
    msg = 'brainbox.behavior.wheel.velocity_smoothed will be removed. Use velocity_filtered instead.'
×
NEW
157
    warnings.warn(msg, FutureWarning)
×
UNCOV
158
    logging.getLogger(__name__).warning(msg)
×
159

160
    # Define our smoothing window with an area of 1 so the units won't be changed
161
    std_samps = np.round(smooth_size * freq)  # Standard deviation relative to sampling frequency
×
162
    N = std_samps * 6  # Number of points in the Gaussian covering +/-3 standard deviations
×
163
    gauss_std = (N - 1) / 6
×
164
    win = scipy.signal.windows.gaussian(N, gauss_std)
×
165
    win = win / win.sum()  # Normalize amplitude
×
166

167
    # Convolve and multiply by sampling frequency to restore original units
168
    vel = np.insert(scipy.signal.convolve(np.diff(pos), win, mode='same'), 0, 0) * freq
×
169
    acc = np.insert(scipy.signal.convolve(np.diff(vel), win, mode='same'), 0, 0) * freq
×
170

171
    return vel, acc
×
172

173

174
def last_movement_onset(t, vel, event_time):
1✔
175
    """
176
    (DEPRECATED) Find the time at which movement started, given an event timestamp that occurred during the
177
    movement.
178

179
    Movement start is defined as the first sample after the velocity has been zero for at least 50ms.
180
    Wheel inputs should be evenly sampled.
181

182
    :param t: numpy array of wheel timestamps in seconds
183
    :param vel: numpy array of wheel velocities
184
    :param event_time: timestamp anywhere during movement of interest, e.g. peak velocity
185
    :return: timestamp of movement onset
186
    """
187
    for line in traceback.format_stack():
×
188
        print(line.strip())
×
189

190
    msg = 'brainbox.behavior.wheel.last_movement_onset has been deprecated. Use get_movement_onset instead.'
×
NEW
191
    warnings.warn(msg, FutureWarning)
×
192
    logging.getLogger(__name__).warning(msg)
×
193

194
    # Look back from timestamp
195
    threshold = 50e-3
×
196
    mask = t < event_time
×
197
    times = t[mask]
×
198
    vel = vel[mask]
×
199
    t = None  # Initialize
×
200
    for i, t in enumerate(times[::-1]):
×
201
        i = times.size - i
×
202
        idx = np.min(np.where((t - times) < threshold))
×
203
        if np.max(np.abs(vel[idx:i])) < 0.5:
×
204
            break
×
205

206
    # Return timestamp
207
    return t
×
208

209

210
def get_movement_onset(intervals, event_times):
1✔
211
    """
212
    Find the time at which movement started, given an event timestamp that occurred during the
213
    movement.
214

215
    Parameters
216
    ----------
217
    intervals : numpy.array
218
        The wheel movement intervals.
219
    event_times : numpy.array
220
        Sorted event timestamps anywhere during movement of interest, e.g. peak velocity, feedback
221
        time.
222

223
    Returns
224
    -------
225
    numpy.array
226
        An array the length of event_time of intervals.
227

228
    Examples
229
    --------
230
    Find the last movement onset before each trial response time
231

232
    >>> trials = one.load_object(eid, 'trials')
233
    >>> wheelMoves = one.load_object(eid, 'wheelMoves')
234
    >>> onsets = last_movement_onset(wheelMoves.intervals, trials.response_times)
235
    """
236
    if not np.all(np.diff(event_times) > 0):
1✔
237
        raise ValueError('event_times must be in ascending order.')
1✔
238
    onsets = np.full(event_times.size, np.nan)
1✔
239
    for i in np.arange(intervals.shape[0]):
1✔
240
        onset = between_sorted(event_times, intervals[i, :])
1✔
241
        if np.any(onset):
1✔
242
            onsets[onset] = intervals[i, 0]
1✔
243
    return onsets
1✔
244

245

246
def movements(t, pos, freq=1000, pos_thresh=8, t_thresh=.2, min_gap=.1, pos_thresh_onset=1.5,
1✔
247
              min_dur=.05, make_plots=False):
248
    """
249
    Detect wheel movements.
250

251
    Parameters
252
    ----------
253
    t : array_like
254
        An array of evenly sampled wheel timestamps in absolute seconds
255
    pos : array_like
256
        An array of evenly sampled wheel positions
257
    freq : int
258
        The sampling rate of the wheel data
259
    pos_thresh : float
260
        The minimum required movement during the t_thresh window to be considered part of a
261
        movement
262
    t_thresh : float
263
        The time window over which to check whether the pos_thresh has been crossed
264
    min_gap : float
265
        The minimum time between one movement's offset and another movement's onset in order to be
266
        considered separate.  Movements with a gap smaller than this are 'stictched together'
267
    pos_thresh_onset : float
268
        A lower threshold for finding precise onset times.  The first position of each movement
269
        transition that is this much bigger than the starting position is considered the onset
270
    min_dur : float
271
        The minimum duration of a valid movement.  Detected movements shorter than this are ignored
272
    make_plots : boolean
273
        Plot trace of position and velocity, showing detected onsets and offsets
274

275
    Returns
276
    -------
277
    onsets : np.ndarray
278
        Timestamps of detected movement onsets
279
    offsets : np.ndarray
280
        Timestamps of detected movement offsets
281
    peak_amps : np.ndarray
282
        The absolute maximum amplitude of each detected movement, relative to onset position
283
    peak_vel_times : np.ndarray
284
        Timestamps of peak velocity for each detected movement
285
    """
286
    # Wheel position must be evenly sampled.
287
    dt = np.diff(t)
1✔
288
    assert np.all(np.abs(dt - dt.mean()) < 1e-10), 'Values not evenly sampled'
1✔
289

290
    # Convert the time threshold into number of samples given the sampling frequency
291
    t_thresh_samps = int(np.round(t_thresh * freq))
1✔
292
    max_disp = np.empty(t.size, dtype=float)  # initialize array of total wheel displacement
1✔
293

294
    # Calculate a Hankel matrix of size t_thresh_samps in batches.  This is effectively a
295
    # sliding window within which we look for changes in position greater than pos_thresh
296
    BATCH_SIZE = 10000  # do this in batches in order to keep memory usage reasonable
1✔
297
    c = 0  # index of 'window' position
1✔
298
    while True:
1✔
299
        i2proc = np.arange(BATCH_SIZE) + c
1✔
300
        i2proc = i2proc[i2proc < t.size]
1✔
301
        w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan))
1✔
302
        # Below is the total change in position for each window
303
        max_disp[i2proc] = np.nanmax(w2e, axis=1) - np.nanmin(w2e, axis=1)
1✔
304
        c += BATCH_SIZE - t_thresh_samps
1✔
305
        if i2proc[-1] == t.size - 1:
1✔
306
            break
1✔
307

308
    moving = max_disp > pos_thresh  # for each window is the change in position greater than our threshold?
1✔
309
    moving = np.insert(moving, 0, False)  # First sample should always be not moving to ensure we have an onset
1✔
310
    moving[-1] = False  # Likewise, ensure we always end on an offset
1✔
311

312
    onset_samps = np.where(~moving[:-1] & moving[1:])[0]
1✔
313
    offset_samps = np.where(moving[:-1] & ~moving[1:])[0]
1✔
314
    too_short = np.where((onset_samps[1:] - offset_samps[:-1]) / freq < min_gap)[0]
1✔
315
    for p in too_short:
1✔
316
        moving[offset_samps[p]:onset_samps[p + 1] + 1] = True
1✔
317

318
    onset_samps = np.where(~moving[:-1] & moving[1:])[0]
1✔
319
    onsets_disp_arr = np.empty((onset_samps.size, t_thresh_samps))
1✔
320
    c = 0
1✔
321
    cwt = 0
1✔
322
    while onset_samps.size != 0:
1✔
323
        i2proc = np.arange(BATCH_SIZE) + c
1✔
324
        icomm = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps, assume_unique=True)
1✔
325
        itpltz = np.intersect1d(i2proc[:-t_thresh_samps - 1], onset_samps,
1✔
326
                                return_indices=True, assume_unique=True)[1]
327
        i2proc = i2proc[i2proc < t.size]
1✔
328
        if icomm.size > 0:
1✔
329
            w2e = hankel(pos[i2proc], np.full(t_thresh_samps, np.nan))
1✔
330
            w2e = np.abs((w2e.T - w2e[:, 0]).T)
1✔
331
            onsets_disp_arr[cwt + np.arange(icomm.size), :] = w2e[itpltz, :]
1✔
332
            cwt += icomm.size
1✔
333
        c += BATCH_SIZE - t_thresh_samps
1✔
334
        if i2proc[-1] >= onset_samps[-1]:
1✔
335
            break
1✔
336

337
    has_onset = onsets_disp_arr > pos_thresh_onset
1✔
338
    A = np.argmin(np.fliplr(has_onset).T, axis=0)
1✔
339
    onset_lags = t_thresh_samps - A
1✔
340
    onset_samps = onset_samps + onset_lags - 1
1✔
341
    onsets = t[onset_samps]
1✔
342
    offset_samps = np.where(moving[:-1] & ~moving[1:])[0]
1✔
343
    offsets = t[offset_samps]
1✔
344

345
    durations = offsets - onsets
1✔
346
    too_short = durations < min_dur
1✔
347
    onset_samps = onset_samps[~too_short]
1✔
348
    onsets = onsets[~too_short]
1✔
349
    offset_samps = offset_samps[~too_short]
1✔
350
    offsets = offsets[~too_short]
1✔
351

352
    moveGaps = onsets[1:] - offsets[:-1]
1✔
353
    gap_too_small = moveGaps < min_gap
1✔
354
    if onsets.size > 0:
1✔
355
        onsets = onsets[np.insert(~gap_too_small, 0, True)]  # always keep first onset
1✔
356
        onset_samps = onset_samps[np.insert(~gap_too_small, 0, True)]
1✔
357
        offsets = offsets[np.append(~gap_too_small, True)]  # always keep last offset
1✔
358
        offset_samps = offset_samps[np.append(~gap_too_small, True)]
1✔
359

360
    # Calculate the peak amplitudes -
361
    # the maximum absolute value of the difference from the onset position
362
    peaks = (pos[m + np.abs(pos[m:n] - pos[m]).argmax()] - pos[m]
1✔
363
             for m, n in zip(onset_samps, offset_samps))
364
    peak_amps = np.fromiter(peaks, dtype=float, count=onsets.size)
1✔
365
    N = 10  # Number of points in the Gaussian
1✔
366
    STDEV = 1.8  # Equivalent to a width factor (alpha value) of 2.5
1✔
367
    gauss = scipy.signal.windows.gaussian(N, STDEV)  # A 10-point Gaussian window of a given s.d.
1✔
368
    vel = scipy.signal.convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same')
1✔
369
    # For each movement period, find the timestamp where the absolute velocity was greatest
370
    peaks = (t[m + np.abs(vel[m:n]).argmax()] for m, n in zip(onset_samps, offset_samps))
1✔
371
    peak_vel_times = np.fromiter(peaks, dtype=float, count=onsets.size)
1✔
372

373
    if make_plots:
1✔
374
        fig, axes = plt.subplots(nrows=2, sharex='all')
×
375
        indices = np.sort(np.hstack((onset_samps, offset_samps)))  # Points to split trace
×
376
        vel, acc = velocity_filtered(pos, freq)
×
377

378
        # Plot the wheel position and velocity
379
        for ax, y in zip(axes, (pos, vel)):
×
380
            ax.plot(onsets, y[onset_samps], 'go')
×
381
            ax.plot(offsets, y[offset_samps], 'bo')
×
382

383
            t_split = np.split(np.vstack((t, y)).T, indices, axis=0)
×
384
            ax.add_collection(LineCollection(t_split[1::2], colors='r'))  # Moving
×
385
            ax.add_collection(LineCollection(t_split[0::2], colors='k'))  # Not moving
×
386

387
        axes[1].autoscale()  # rescale after adding line collections
×
388
        axes[0].autoscale()
×
389
        axes[0].set_ylabel('position')
×
390
        axes[1].set_ylabel('velocity')
×
391
        axes[1].set_xlabel('time')
×
392
        axes[0].legend(['onsets', 'offsets', 'in movement'])
×
393
        plt.show()
×
394

395
    return onsets, offsets, peak_amps, peak_vel_times
1✔
396

397

398
def cm_to_deg(positions, wheel_diameter=WHEEL_DIAMETER):
1✔
399
    """
400
    Convert wheel position to degrees turned.  This may be useful for e.g. calculating velocity
401
    in revolutions per second
402
    :param positions: array of wheel positions in cm
403
    :param wheel_diameter: the diameter of the wheel in cm
404
    :return: array of wheel positions in degrees turned
405

406
    # Example: Convert linear cm to degrees
407
    >>> cm_to_deg(3.142 * WHEEL_DIAMETER)
408
    360.04667846020925
409

410
    # Example: Get positions in deg from cm for 5cm diameter wheel
411
    >>> import numpy as np
412
    >>> cm_to_deg(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5)
413
    array([0.61999992, 0.93000011, 1.24000007, 1.55000003])
414
    """
415
    return positions / (wheel_diameter * pi) * 360
×
416

417

418
def cm_to_rad(positions, wheel_diameter=WHEEL_DIAMETER):
1✔
419
    """
420
    Convert wheel position to radians.  This may be useful for e.g. calculating angular velocity.
421
    :param positions: array of wheel positions in cm
422
    :param wheel_diameter: the diameter of the wheel in cm
423
    :return: array of wheel angle in radians
424

425
    # Example: Convert linear cm to radians
426
    >>> cm_to_rad(1)
427
    0.3225806451612903
428

429
    # Example: Get positions in rad from cm for 5cm diameter wheel
430
    >>> import numpy as np
431
    >>> cm_to_rad(np.array([0.0270526 , 0.04057891, 0.05410521, 0.06763151]), wheel_diameter=5)
432
    array([0.01082104, 0.01623156, 0.02164208, 0.0270526 ])
433
    """
434
    return positions * (2 / wheel_diameter)
1✔
435

436

437
def samples_to_cm(positions, wheel_diameter=WHEEL_DIAMETER, resolution=ENC_RES):
1✔
438
    """
439
    Convert wheel position samples to cm linear displacement.  This may be useful for
440
    inter-converting threshold units
441
    :param positions: array of wheel positions in sample counts
442
    :param wheel_diameter: the diameter of the wheel in cm
443
    :param resolution: resolution of the rotary encoder
444
    :return: array of wheel angle in radians
445

446
    # Example: Get resolution in linear cm
447
    >>> samples_to_cm(1)
448
    0.004755340442445488
449

450
    # Example: Get positions in linear cm for 4X, 360 ppr encoder
451
    >>> import numpy as np
452
    >>> samples_to_cm(np.array([2, 3, 4, 5, 6, 7, 6, 5, 4]), resolution=360*4)
453
    array([0.0270526 , 0.04057891, 0.05410521, 0.06763151, 0.08115781,
454
           0.09468411, 0.08115781, 0.06763151, 0.05410521])
455
    """
456
    return positions / resolution * pi * wheel_diameter
1✔
457

458

459
def direction_changes(t, vel, intervals):
1✔
460
    """
461
    Find the direction changes for the given movement intervals.
462

463
    Parameters
464
    ----------
465
    t : array_like
466
        An array of evenly sampled wheel timestamps in absolute seconds
467
    vel : array_like
468
        An array of evenly sampled wheel positions
469
    intervals : array_like
470
        An n-by-2 array of wheel movement intervals
471

472
    Returns
473
    ----------
474
    times : iterable
475
        A list of numpy arrays of direction change timestamps, one array per interval
476
    indices : iterable
477
        A list of numpy arrays containing indices of direction changes; the size of times
478
    """
479
    indices = []
1✔
480
    times = []
1✔
481
    chg = np.insert(np.diff(np.sign(vel)) != 0, 0, 0)
1✔
482

483
    for on, off in intervals.reshape(-1, 2):
1✔
484
        mask = np.logical_and(t > on, t < off)
1✔
485
        ind, = np.where(np.logical_and(mask, chg))
1✔
486
        times.append(t[ind])
1✔
487
        indices.append(ind)
1✔
488

489
    return times, indices
1✔
490

491

492
def traces_by_trial(t, *args, start=None, end=None, separate=True):
1✔
493
    """
494
    Returns list of tuples of positions and velocity for samples between stimulus onset and
495
    feedback.
496
    :param t: numpy array of timestamps
497
    :param args: optional numpy arrays of the same length as timestamps, such as positions,
498
    velocities or accelerations
499
    :param start: start timestamp or array thereof
500
    :param end: end timestamp or array thereof
501
    :param separate: when True, the output is returned as tuples list of the form [(t, args[0],
502
    args[1]), ...], when False, the output is a list of n-by-m ndarrays where n = number of
503
    positional args and m = len(t)
504
    :return: list of sliced arrays where length == len(start)
505
    """
506
    if start is None:
1✔
507
        start = t[0]
×
508
    if end is None:
1✔
509
        end = t[-1]
×
510
    traces = np.stack((t, *args))
1✔
511
    assert len(start) == len(end), 'number of start timestamps must equal end timestamps'
1✔
512

513
    def to_mask(a, b):
1✔
514
        return np.logical_and(t > a, t < b)
1✔
515

516
    cuts = [traces[:, to_mask(s, e)] for s, e in zip(start, end)]
1✔
517
    return [(cuts[n][0, :], cuts[n][1, :]) for n in range(len(cuts))] if separate else cuts
1✔
518

519

520
if __name__ == '__main__':
1✔
521
    import doctest
×
522
    doctest.testmod()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc