• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / ibllib / 6161719581526678

28 Jun 2024 01:14PM UTC coverage: 64.584% (+0.03%) from 64.55%
6161719581526678

push

tests

web-flow
Stim on extraction (#788)

* Issue #775
* Handle no go trials
* Pre-6.2.5 trials extraction
* DeprecationWarning -> FutureWarning; extractor fixes; timeline trials extraction

49 of 60 new or added lines in 9 files covered. (81.67%)

2 existing lines in 1 file now uncovered.

13055 of 20214 relevant lines covered (64.58%)

0.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

44.34
/brainbox/processing.py
1
'''
1✔
2
Processes data from one form into another, e.g. taking spike times and binning them into
3
non-overlapping bins and convolving spike times with a gaussian kernel.
4
'''
5

6
import numpy as np
1✔
7
import pandas as pd
1✔
8
from scipy import interpolate, sparse
1✔
9
from brainbox import core
1✔
10
from iblutil.numerical import bincount2D as _bincount2D
1✔
11
from iblutil.util import Bunch
1✔
12
import logging
1✔
13
import warnings
1✔
14
import traceback
1✔
15

16
_logger = logging.getLogger(__name__)
1✔
17

18

19
def sync(dt, times=None, values=None, timeseries=None, offsets=None, interp='zero',
1✔
20
         fillval=np.nan):
21
    """
22
    Function for resampling a single or multiple time series to a single, evenly-spaced, delta t
23
    between observations. Uses interpolation to find values.
24

25
    Can be used on raw numpy arrays of timestamps and values using the 'times' and 'values' kwargs
26
    and/or on brainbox.core.TimeSeries objects passed to the 'timeseries' kwarg. If passing both
27
    TimeSeries objects and numpy arrays, the offsets passed should be for the TS objects first and
28
    then the numpy arrays.
29

30
    Uses scipy's interpolation library to perform interpolation.
31
    See scipy.interp1d for more information regarding interp and fillval parameters.
32

33
    :param dt: Separation of points which the output timeseries will be sampled at
34
    :type dt: float
35
    :param timeseries: A group of time series to perform alignment or a single time series.
36
        Must have time stamps.
37
    :type timeseries: tuple of TimeSeries objects, or a single TimeSeries object.
38
    :param times: time stamps for the observations in 'values']
39
    :type times: np.ndarray or list of np.ndarrays
40
    :param values: observations corresponding to the timestamps in 'times'
41
    :type values: np.ndarray or list of np.ndarrays
42
    :param offsets: tuple of offsets for time stamps of each time series. Offsets for passed
43
        TimeSeries objects first, then offsets for passed numpy arrays. defaults to None
44
    :type offsets: tuple of floats, optional
45
    :param interp: Type of interpolation to use. Refer to scipy.interpolate.interp1d for possible
46
        values, defaults to np.nan
47
    :type interp: str
48
    :param fillval: Fill values to use when interpolating outside of range of data. See interp1d
49
        for possible values, defaults to np.nan
50
    :return: TimeSeries object with each row representing synchronized values of all
51
        input TimeSeries. Will carry column names from input time series if all of them have column
52
        names.
53
    """
54
    #########################################
55
    # Checks on inputs and input processing #
56
    #########################################
57

58
    # Initialize a list to contain times/values pairs if no TS objs are passed
59
    if timeseries is None:
1✔
60
        timeseries = []
1✔
61
    # If a single time series is passed for resampling, wrap it in an iterable
62
    elif isinstance(timeseries, core.TimeSeries):
1✔
63
        timeseries = [timeseries]
1✔
64
    # Yell at the user if they try to pass stuff to timeseries that isn't a TimeSeries object
65
    elif not all([isinstance(ts, core.TimeSeries) for ts in timeseries]):
1✔
66
        raise TypeError('All elements of \'timeseries\' argument must be brainbox.core.TimeSeries '
×
67
                        'objects. Please uses \'times\' and \'values\' for np.ndarray args.')
68
    # Check that if something is passed to times or values, there is a corresponding equal-length
69
    # argument for the other element.
70
    if (times is not None) or (values is not None):
1✔
71
        if len(times) != len(values):
1✔
72
            raise ValueError('\'times\' and \'values\' must have the same number of elements.')
×
73
        if type(times[0]) is np.ndarray:
1✔
74
            if not all([t.shape == v.shape for t, v in zip(times, values)]):
1✔
75
                raise ValueError('All arrays in \'times\' must match the shape of the'
×
76
                                 ' corresponding entry in \'values\'.')
77
            # If all checks are passed, convert all times and values args into TimeSeries objects
78
            timeseries.extend([core.TimeSeries(t, v) for t, v in zip(times, values)])
1✔
79
        else:
80
            # If times and values are only numpy arrays and lists of arrays, pair them and add
81
            timeseries.append(core.TimeSeries(times, values))
1✔
82

83
    # Adjust each timeseries by the associated offset if necessary then load into a list
84
    if offsets is not None:
1✔
85
        tstamps = [ts.times + os for ts, os in zip(timeseries, offsets)]
×
86
    else:
87
        tstamps = [ts.times for ts in timeseries]
1✔
88
    # If all input timeseries have column names, put them together for the output TS
89
    if all([ts.columns is not None for ts in timeseries]):
1✔
90
        colnames = []
1✔
91
        for ts in timeseries:
1✔
92
            colnames.extend(ts.columns)
1✔
93
    else:
94
        colnames = None
1✔
95

96
    #################
97
    # Main function #
98
    #################
99

100
    # Get the min and max values for all timeseries combined after offsetting
101
    tbounds = np.array([(np.amin(ts), np.amax(ts)) for ts in tstamps])
1✔
102
    if not np.all(np.isfinite(tbounds)):
1✔
103
        # If there is a np.inf or np.nan in the time stamps for any of the timeseries this will
104
        # break any further code so we check for all finite values and throw an informative error.
105
        raise ValueError('NaN or inf encountered in passed timeseries.\
×
106
                          Please either drop or fill these values.')
107
    tmin, tmax = np.amin(tbounds[:, 0]), np.amax(tbounds[:, 1])
1✔
108
    if fillval == 'extrapolate':
1✔
109
        # If extrapolation is enabled we can ensure we have a full coverage of the data by
110
        # extending the t max to be an whole integer multiple of dt above tmin.
111
        # The 0.01% fudge factor is to account for floating point arithmetic errors.
112
        newt = np.arange(tmin, tmax + 1.0001 * (dt - (tmax - tmin) % dt), dt)
1✔
113
    else:
114
        newt = np.arange(tmin, tmax, dt)
×
115
    tsinterps = [interpolate.interp1d(ts.times, ts.values, kind=interp, fill_value=fillval, axis=0)
1✔
116
                 for ts in timeseries]
117
    syncd = core.TimeSeries(newt, np.hstack([tsi(newt) for tsi in tsinterps]), columns=colnames)
1✔
118
    return syncd
1✔
119

120

121
def bincount2D(x, y, xbin=0, ybin=0, xlim=None, ylim=None, weights=None):
1✔
122
    """
123
    Computes a 2D histogram by aggregating values in a 2D array.
124

125
    :param x: values to bin along the 2nd dimension (c-contiguous)
126
    :param y: values to bin along the 1st dimension
127
    :param xbin:
128
        scalar: bin size along 2nd dimension
129
        0: aggregate according to unique values
130
        array: aggregate according to exact values (count reduce operation)
131
    :param ybin:
132
        scalar: bin size along 1st dimension
133
        0: aggregate according to unique values
134
        array: aggregate according to exact values (count reduce operation)
135
    :param xlim: (optional) 2 values (array or list) that restrict range along 2nd dimension
136
    :param ylim: (optional) 2 values (array or list) that restrict range along 1st dimension
137
    :param weights: (optional) defaults to None, weights to apply to each value for aggregation
138
    :return: 3 numpy arrays MAP [ny,nx] image, xscale [nx], yscale [ny]
139
    """
140
    for line in traceback.format_stack():
×
141
        print(line.strip())
×
NEW
142
    warning_text = """Future warning: bincount2D() is now a part of iblutil.
×
143
                    brainbox.processing.bincount2D will be removed in future versions.
144
                    Please replace imports with iblutil.numerical.bincount2D."""
145
    _logger.warning(warning_text)
×
NEW
146
    warnings.warn(warning_text, FutureWarning)
×
147
    return _bincount2D(x, y, xbin, ybin, xlim, ylim, weights)
×
148

149

150
def compute_cluster_average(spike_clusters, spike_var):
1✔
151
    """
152
    Quickish way to compute the average of some quantity across spikes in each cluster given
153
    quantity for each spike
154

155
    :param spike_clusters: cluster idx of each spike
156
    :param spike_var: variable of each spike (e.g spike amps or spike depths)
157
    :return: cluster id, average of quantity for each cluster, no. of spikes per cluster
158
    """
159
    clust, inverse, counts = np.unique(spike_clusters, return_inverse=True, return_counts=True)
1✔
160
    _spike_var = sparse.csr_matrix((spike_var, (inverse, np.zeros(inverse.size, dtype=int))))
1✔
161
    spike_var_avg = np.ravel(_spike_var.toarray()) / counts
1✔
162

163
    return clust, spike_var_avg, counts
1✔
164

165

166
def bin_spikes(spikes, binsize, interval_indices=False):
1✔
167
    """
168
    Wrapper for bincount2D which is intended to take in a TimeSeries object of spike times
169
    and cluster identities and spit out spike counts in bins of a specified width binsize, also in
170
    another TimeSeries object. Can either return a TS object with each row labeled with the
171
    corresponding interval or the value of the left edge of the bin.
172

173
    :param spikes: Spike times and cluster identities of sorted spikes
174
    :type spikes: TimeSeries object with \'clusters\' column and timestamps
175
    :param binsize: Width of the non-overlapping bins in which to bin spikes
176
    :type binsize: float
177
    :param interval_indices: Whether to use intervals as the time stamps for binned spikes, rather
178
        than the left edge value of the bins, defaults to False
179
    :type interval_indices: bool, optional
180
    :return: Object with 2D array of shape T x N, for T timesteps and N clusters, and the
181
        associated time stamps.
182
    :rtype: TimeSeries object
183
    """
184
    if type(spikes) is not core.TimeSeries:
×
185
        raise TypeError('Input spikes need to be in TimeSeries object format')
×
186

187
    if not hasattr(spikes, 'clusters'):
×
188
        raise AttributeError('Input spikes need to have a clusters attribute. Make sure you set '
×
189
                             'columns=(\'clusters\',)) when constructing spikes.')
190

191
    rates, tbins, clusters = bincount2D(spikes.times, spikes.clusters, binsize)
×
192
    if interval_indices:
×
193
        intervals = pd.interval_range(tbins[0], tbins[-1], freq=binsize, closed='left')
×
194
        return core.TimeSeries(times=intervals, values=rates.T[:-1], columns=clusters)
×
195
    else:
196
        return core.TimeSeries(times=tbins, values=rates.T, columns=clusters)
×
197

198

199
def get_units_bunch(spks_b, *args):
1✔
200
    '''
201
    Returns a bunch, where the bunch keys are keys from `spks` with labels of spike information
202
    (e.g. unit IDs, times, features, etc.), and the values for each key are arrays with values for
203
    each unit: these arrays are ordered and can be indexed by unit id.
204

205
    Parameters
206
    ----------
207
    spks_b : bunch
208
        A spikes bunch containing fields with spike information (e.g. unit IDs, times, features,
209
        etc.) for all spikes.
210
    features : list of strings (optional positional arg)
211
        A list of names of labels of spike information (which must be keys in `spks`) that specify
212
        which labels to return as keys in `units`. If not provided, all keys in `spks` are returned
213
        as keys in `units`.
214

215
    Returns
216
    -------
217
    units_b : bunch
218
        A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.)
219
        whose values are arrays that hold values for each unit. The arrays for each key are ordered
220
        by unit ID.
221

222
    Examples
223
    --------
224
    1) Create a units bunch given a spikes bunch, and get the amps for unit #4 from the units
225
    bunch.
226
        >>> import brainbox as bb
227
        >>> import alf.io as aio
228
        >>> import ibllib.ephys.spikes as e_spks
229
        (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
230
        >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
231
        >>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
232
        >>> units_b = bb.processing.get_units_bunch(spks_b)
233
        # Get amplitudes for unit 4.
234
        >>> amps = units_b['amps']['4']
235

236
    TODO add computation time estimate?
237
    '''
238

239
    # Initialize `units`
240
    units_b = Bunch()
×
241
    # Get the keys to return for `units`:
242
    if not args:
×
243
        feat_keys = list(spks_b.keys())
×
244
    else:
245
        feat_keys = args[0]
×
246
    # Get unit id for each spike and number of units. *Note: `n_units` might not equal `len(units)`
247
    # because some clusters may be empty (due to a "wontfix" bug in ks2).
248
    spks_unit_id = spks_b['clusters']
×
249
    n_units = np.max(spks_unit_id)
×
250
    units = np.unique(spks_b['clusters'])
×
251
    # For each key in `units`, iteratively get each unit's values and add as a key to a bunch,
252
    # `feat_bunch`. After iterating through all units, add `feat_bunch` as a key to `units`:
253
    for feat in feat_keys:
×
254
        # Initialize `feat_bunch` with a key for each unit.
255
        feat_bunch = Bunch((str(unit), np.array([])) for unit in np.arange(n_units))
×
256
        for unit in units:
×
257
            unit_idxs = np.where(spks_unit_id == unit)[0]
×
258
            feat_bunch[str(unit)] = spks_b[feat][unit_idxs]
×
259
        units_b[feat] = feat_bunch
×
260
    return units_b
×
261

262

263
def filter_units(units_b, t, **kwargs):
1✔
264
    '''
265
    Filters units according to some parameters. **kwargs are the keyword parameters used to filter
266
    the units.
267

268
    Parameters
269
    ----------
270
    units_b : bunch
271
        A bunch with keys of labels of spike information (e.g. cluster IDs, times, features, etc.)
272
        whose values are arrays that hold values for each unit. The arrays for each key are ordered
273
        by unit ID.
274
    t : float
275
        Duration of time over which to calculate the firing rate and false positive rate.
276

277
    Keyword Parameters
278
    ------------------
279
    min_amp : float
280
        The minimum mean amplitude (in V) of the spikes in the unit. Default value is 50e-6.
281
    min_fr : float
282
        The minimum firing rate (in Hz) of the unit. Default value is 0.5.
283
    max_fpr : float
284
        The maximum false positive rate of the unit (using the fp formula in Hill et al. (2011)
285
        J Neurosci 31: 8699-8705). Default value is 0.2.
286
    rp : float
287
        The refractory period (in s) of the unit. Used to calculate `max_fp`. Default value is
288
        0.002.
289

290
    Returns
291
    -------
292
    filt_units : ndarray
293
        The ids of the filtered units.
294

295
    See Also
296
    --------
297
    get_units_bunch
298

299
    Examples
300
    --------
301
    1) Filter units according to the default parameters.
302
        >>> import brainbox as bb
303
        >>> import alf.io as aio
304
        >>> import ibllib.ephys.spikes as e_spks
305
        (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
306
        >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
307
        # Get a spikes bunch, units bunch, and filter the units.
308
        >>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
309
        >>> units_b = bb.processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters'])
310
        >>> T = spks_b['times'][-1] - spks_b['times'][0]
311
        >>> filtered_units = bb.processing.filter_units(units_b, T)
312

313
    2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false
314
    positive rate of 0.2, given a refractory period of 2 ms.
315
        >>> filtered_units  = bb.processing.filter_units(units_b, T, min_amp=0, min_fr=1)
316

317
    TODO: `units_b` input arg could eventually be replaced by `clstrs_b` if the required metrics
318
          are in `clstrs_b['metrics']`
319
    '''
320

321
    # Set params
322
    params = {'min_amp': 50e-6, 'min_fr': 0.5, 'max_fpr': 0.2, 'rp': 0.002}  # defaults
×
323
    params.update(kwargs)  # update from **kwargs
×
324

325
    # Iteratively filter the units for each filter param #
326
    # -------------------------------------------------- #
327
    units = np.asarray(list(units_b.amps.keys()))
×
328
    # Remove empty clusters
329
    empty_cl = np.where([len(units_b.amps[unit]) == 0 for unit in units])[0]
×
330
    filt_units = np.delete(units, empty_cl)
×
331
    for param in params.keys():
×
332
        if param == 'min_amp':  # return units above with amp > `'min_amp'`
×
333
            mean_amps = np.asarray([np.mean(units_b.amps[unit]) for unit in filt_units])
×
334
            filt_idxs = np.where(mean_amps > params['min_amp'])[0]
×
335
            filt_units = filt_units[filt_idxs]
×
336
        elif param == 'min_fr':  # return units with fr > `'min_fr'`
×
337
            fr = np.asarray([len(units_b.amps[unit]) /
×
338
                            (units_b.times[unit][-1] - units_b.times[unit][0])
339
                            for unit in filt_units])
340
            filt_idxs = np.where(fr > params['min_fr'])[0]
×
341
            filt_units = filt_units[filt_idxs]
×
342
        elif param == 'max_fpr':  # return units with fpr < `'max_fpr'`
×
343
            fpr = np.zeros_like(filt_units, dtype='float')
×
344
            for i, unit in enumerate(filt_units):
×
345
                n_spks = len(units_b.amps[unit])
×
346
                n_isi_viol = len(np.where(np.diff(units_b.times[unit]) < params['rp'])[0])
×
347
                # fpr is min of roots of solved quadratic equation (Hill, et al. 2011).
348
                c = (t * n_isi_viol) / (2 * params['rp'] * n_spks**2)  # 3rd term in quadratic
×
349
                fpr[i] = np.min(np.abs(np.roots([-1, 1, c])))  # solve quadratic
×
350
            filt_idxs = np.where(fpr < params['max_fpr'])[0]
×
351
            filt_units = filt_units[filt_idxs]
×
352
    return filt_units.astype(int)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc