• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / ibllib / 1761696499260742

05 Oct 2023 09:46AM UTC coverage: 55.27% (-1.4%) from 56.628%
1761696499260742

Pull #655

continuous-integration/UCL

bimac
add @sleepless decorator
Pull Request #655: add @sleepless decorator

21 of 21 new or added lines in 1 file covered. (100.0%)

10330 of 18690 relevant lines covered (55.27%)

0.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

46.42
/brainbox/behavior/training.py
1
import logging
1✔
2
import datetime
1✔
3
import re
1✔
4
from enum import IntFlag, auto, unique
1✔
5

6
import numpy as np
1✔
7
import matplotlib
1✔
8
import matplotlib.pyplot as plt
1✔
9
import seaborn as sns
1✔
10
import pandas as pd
1✔
11
from scipy.stats import bootstrap
1✔
12
from iblutil.util import Bunch
1✔
13
from one.api import ONE
1✔
14
from one.alf.exceptions import ALFObjectNotFound
1✔
15

16
import psychofit as psy
1✔
17

18
_logger = logging.getLogger('ibllib')
1✔
19

20
TRIALS_KEYS = ['contrastLeft',
1✔
21
               'contrastRight',
22
               'feedbackType',
23
               'probabilityLeft',
24
               'choice',
25
               'response_times',
26
               'stimOn_times']
27

28

29
@unique
1✔
30
class TrainingStatus(IntFlag):
1✔
31
    """Standard IBL training criteria.
1✔
32

33
    Enumeration allows for comparisons between training status.
34

35
    Examples
36
    --------
37
    >>> status = 'ready4delay'
38
    ... assert TrainingStatus[status.upper()] is TrainingStatus.READY4DELAY
39
    ... assert TrainingStatus[status.upper()] not in TrainingStatus.FAILED, 'Subject failed training'
40
    ... assert TrainingStatus[status.upper()] >= TrainingStatus.TRAINED, 'Subject untrained'
41
    ... assert TrainingStatus[status.upper()] > TrainingStatus.IN_TRAINING, 'Subject untrained'
42
    ... assert TrainingStatus[status.upper()] in ~TrainingStatus.FAILED, 'Subject untrained'
43
    ... assert TrainingStatus[status.upper()] in TrainingStatus.TRAINED ^ TrainingStatus.READY
44

45
    # Get the next training status
46
    >>> next(member for member in sorted(TrainingStatus) if member > TrainingStatus[status.upper()])
47
    <TrainingStatus.READY4RECORDING: 128>
48

49
    Notes
50
    -----
51
    - ~TrainingStatus.TRAINED means any status but trained 1a or trained 1b.
52
    - A subject may acheive both TRAINED_1A and TRAINED_1B within a single session, therefore it
53
     is possible to have skipped the TRAINED_1A session status.
54
    """
55
    UNTRAINABLE = auto()
1✔
56
    UNBIASABLE = auto()
1✔
57
    IN_TRAINING = auto()
1✔
58
    TRAINED_1A = auto()
1✔
59
    TRAINED_1B = auto()
1✔
60
    READY4EPHYSRIG = auto()
1✔
61
    READY4DELAY = auto()
1✔
62
    READY4RECORDING = auto()
1✔
63
    # Compound training statuses for convenience
64
    FAILED = UNTRAINABLE | UNBIASABLE
1✔
65
    READY = READY4EPHYSRIG | READY4DELAY | READY4RECORDING
1✔
66
    TRAINED = TRAINED_1A | TRAINED_1B
1✔
67

68

69
def get_lab_training_status(lab, date=None, details=True, one=None):
1✔
70
    """
71
    Computes the training status of all alive and water restricted subjects in a specified lab
72

73
    :param lab: lab name (must match the name registered on Alyx)
74
    :type lab: string
75
    :param date: the date from which to compute training status from. If not specified will compute
76
    from the latest date with available data
77
    :type date: string of format 'YYYY-MM-DD'
78
    :param details: whether to display all information about training status computation e.g
79
    performance, number of trials, psychometric fit parameters
80
    :type details: bool
81
    :param one: instantiation of ONE class
82
    """
83
    one = one or ONE()
×
84
    subj_lab = one.alyx.rest('subjects', 'list', lab=lab, alive=True, water_restricted=True)
×
85
    subjects = [subj['nickname'] for subj in subj_lab]
×
86
    for subj in subjects:
×
87
        get_subject_training_status(subj, date=date, details=details, one=one)
×
88

89

90
def get_subject_training_status(subj, date=None, details=True, one=None):
1✔
91
    """
92
    Computes the training status of specified subject
93

94
    :param subj: subject nickname (must match the name registered on Alyx)
95
    :type subj: string
96
    :param date: the date from which to compute training status from. If not specified will compute
97
    from the latest date with available data
98
    :type date: string of format 'YYYY-MM-DD'
99
    :param details: whether to display all information about training status computation e.g
100
    performance, number of trials, psychometric fit parameters
101
    :type details: bool
102
    :param one: instantiation of ONE class
103
    """
104
    one = one or ONE()
×
105

106
    trials, task_protocol, ephys_sess, n_delay = get_sessions(subj, date=date, one=one)
×
107
    if not trials:
×
108
        return
×
109
    sess_dates = list(trials.keys())
×
110
    status, info = get_training_status(trials, task_protocol, ephys_sess, n_delay)
×
111

112
    if details:
×
113
        if np.any(info.get('psych')):
×
114
            display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
×
115
                           n_trials=info.n_trials, psych=info.psych, rt=info.rt)
116
        elif np.any(info.get('psych_20')):
×
117
            display_status(subj, sess_dates, status, perf_easy=info.perf_easy,
×
118
                           n_trials=info.n_trials, psych_20=info.psych_20, psych_80=info.psych_80,
119
                           rt=info.rt)
120
    else:
121
        display_status(subj, sess_dates, status)
×
122

123

124
def get_sessions(subj, date=None, one=None):
1✔
125
    """
126
    Download and load in training data for a specified subject. If a date is given it will load
127
    data from the three (or as many as are available) previous sessions up to the specified date.
128
    If not it will load data from the last three training sessions that have data available.
129

130
    :param subj: subject nickname (must match the name registered on Alyx)
131
    :type subj: string
132
    :param date: the date from which to compute training status from. If not specified will compute
133
    from the latest date with available data
134
    :type date: string of format 'YYYY-MM-DD'
135
    :param one: instantiation of ONE class
136
    :returns:
137
        - trials - dict of trials objects where each key is the session date
138
        - task_protocol - list of the task protocol used for each of the sessions
139
        - ephys_sess_data - list of dates where training was conducted on ephys rig. Empty list if
140
                            all sessions on training rig
141
        - n_delay - number of sessions on ephys rig that had delay prior to starting session
142
                    > 15min. Returns 0 is no sessions detected
143
    """
144
    one = one or ONE()
×
145

146
    if date is None:
×
147
        # compute from yesterday
148
        specified_date = (datetime.date.today() - datetime.timedelta(days=1))
×
149
        latest_sess = specified_date.strftime("%Y-%m-%d")
×
150
        latest_minus_week = (datetime.date.today() -
×
151
                             datetime.timedelta(days=8)).strftime("%Y-%m-%d")
152
    else:
153
        # compute from the date specified
154
        specified_date = datetime.datetime.strptime(date, '%Y-%m-%d')
×
155
        latest_minus_week = (specified_date - datetime.timedelta(days=7)).strftime("%Y-%m-%d")
×
156
        latest_sess = date
×
157

158
    sessions = one.alyx.rest('sessions', 'list', subject=subj, date_range=[latest_minus_week,
×
159
                             latest_sess], dataset_types='trials.goCueTrigger_times')
160

161
    # If not enough sessions in the last week, then just fetch them all
162
    if len(sessions) < 3:
×
163
        specified_date_plus = (specified_date + datetime.timedelta(days=1)).strftime("%Y-%m-%d")
×
164
        django_query = 'start_time__lte,' + specified_date_plus
×
165
        sessions = one.alyx.rest('sessions', 'list', subject=subj,
×
166
                                 dataset_types='trials.goCueTrigger_times', django=django_query)
167

168
        # If still 0 sessions then return with warning
169
        if len(sessions) == 0:
×
170
            _logger.warning(f"No training sessions detected for {subj}")
×
171
            return [None] * 4
×
172

173
    trials = Bunch()
×
174
    task_protocol = []
×
175
    sess_dates = []
×
176
    if len(sessions) < 3:
×
177
        for n, _ in enumerate(sessions):
×
178
            try:
×
179
                trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
×
180
            except ALFObjectNotFound:
×
181
                trials_ = None
×
182

183
            if trials_:
×
184
                task_protocol.append(re.search('tasks_(.*)Choice',
×
185
                                     sessions[n]['task_protocol']).group(1))
186
                sess_dates.append(sessions[n]['start_time'][:10])
×
187
                trials[sessions[n]['start_time'][:10]] = trials_
×
188

189
    else:
190
        n = 0
×
191
        while len(trials) < 3:
×
192
            print(sessions[n]['url'].split('/')[-1])
×
193
            try:
×
194
                trials_ = one.load_object(sessions[n]['url'].split('/')[-1], 'trials')
×
195
            except ALFObjectNotFound:
×
196
                trials_ = None
×
197

198
            if trials_:
×
199
                task_protocol.append(re.search('tasks_(.*)Choice',
×
200
                                     sessions[n]['task_protocol']).group(1))
201
                sess_dates.append(sessions[n]['start_time'][:10])
×
202
                trials[sessions[n]['start_time'][:10]] = trials_
×
203

204
            n += 1
×
205

206
    if not np.any(np.array(task_protocol) == 'training'):
×
207
        ephys_sess = one.alyx.rest('sessions', 'list', subject=subj,
×
208
                                   date_range=[sess_dates[-1], sess_dates[0]],
209
                                   django='json__PYBPOD_BOARD__icontains,ephys')
210
        if len(ephys_sess) > 0:
×
211
            ephys_sess_dates = [sess['start_time'][:10] for sess in ephys_sess]
×
212

213
            n_delay = len(one.alyx.rest('sessions', 'list', subject=subj,
×
214
                                        date_range=[sess_dates[-1], sess_dates[0]],
215
                                        django='json__SESSION_START_DELAY_SEC__gte,900'))
216
        else:
217
            ephys_sess_dates = []
×
218
            n_delay = 0
×
219
    else:
220
        ephys_sess_dates = []
×
221
        n_delay = 0
×
222

223
    return trials, task_protocol, ephys_sess_dates, n_delay
×
224

225

226
def get_training_status(trials, task_protocol, ephys_sess_dates, n_delay):
1✔
227
    """
228
    Compute training status of a subject from three consecutive training datasets
229

230
    :param trials: dict containing trials objects from three consecutive training sessions
231
    :type trials: Bunch
232
    :param task_protocol: task protocol used for the three training session, can be 'training',
233
    'biased' or 'ephys'
234
    :type task_protocol: list of strings
235
    :param ephys_sess_dates: dates of sessions conducted on ephys rig
236
    :type ephys_sess_dates: list of strings
237
    :param n_delay: number of sessions on ephys rig with delay before start > 15 min
238
    :type n_delay: int
239
    :returns:
240
        - status - training status of subject
241
        - info - Bunch containing performance metrics that decide training status e.g performance
242
                 on easy trials, number of trials, psychometric fit parameters, reaction time
243
    """
244

245
    info = Bunch()
1✔
246
    trials_all = concatenate_trials(trials)
1✔
247

248
    # Case when all sessions are trainingChoiceWorld
249
    if np.all(np.array(task_protocol) == 'training'):
1✔
250
        signed_contrast = get_signed_contrast(trials_all)
1✔
251
        (info.perf_easy, info.n_trials,
1✔
252
         info.psych, info.rt) = compute_training_info(trials, trials_all)
253
        if not np.any(signed_contrast == 0):
1✔
254
            status = 'in training'
1✔
255
        else:
256
            if criterion_1b(info.psych, info.n_trials, info.perf_easy, info.rt):
1✔
257
                status = 'trained 1b'
1✔
258
            elif criterion_1a(info.psych, info.n_trials, info.perf_easy):
1✔
259
                status = 'trained 1a'
1✔
260
            else:
261
                status = 'in training'
1✔
262

263
        return status, info
1✔
264

265
    # Case when there are < 3 biasedChoiceWorld sessions after reaching trained_1b criterion
266
    if ~np.all(np.array(task_protocol) == 'training') and \
1✔
267
            np.any(np.array(task_protocol) == 'training'):
268
        status = 'trained 1b'
1✔
269
        (info.perf_easy, info.n_trials,
1✔
270
         info.psych, info.rt) = compute_training_info(trials, trials_all)
271

272
        return status, info
1✔
273

274
    # Case when there is biasedChoiceWorld or ephysChoiceWorld in last three sessions
275
    if not np.any(np.array(task_protocol) == 'training'):
1✔
276

277
        (info.perf_easy, info.n_trials,
1✔
278
         info.psych_20, info.psych_80,
279
         info.rt) = compute_bias_info(trials, trials_all)
280
        # We are still on training rig and so all sessions should be biased
281
        if len(ephys_sess_dates) == 0:
1✔
282
            assert np.all(np.array(task_protocol) == 'biased')
1✔
283
            if criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
1✔
284
                               info.rt):
285
                status = 'ready4ephysrig'
1✔
286
            else:
287
                status = 'trained 1b'
1✔
288

289
        elif len(ephys_sess_dates) < 3:
1✔
290
            assert all(date in trials for date in ephys_sess_dates)
1✔
291
            perf_ephys_easy = np.array([compute_performance_easy(trials[k]) for k in
1✔
292
                                        ephys_sess_dates])
293
            n_ephys_trials = np.array([compute_n_trials(trials[k]) for k in ephys_sess_dates])
1✔
294

295
            if criterion_delay(n_ephys_trials, perf_ephys_easy):
1✔
296
                status = 'ready4delay'
1✔
297
            else:
298
                status = 'ready4ephysrig'
×
299

300
        elif len(ephys_sess_dates) >= 3:
1✔
301
            if n_delay > 0 and \
1✔
302
                    criterion_ephys(info.psych_20, info.psych_80, info.n_trials, info.perf_easy,
303
                                    info.rt):
304
                status = 'ready4recording'
1✔
305
            elif criterion_delay(info.n_trials, info.perf_easy):
1✔
306
                status = 'ready4delay'
1✔
307
            else:
308
                status = 'ready4ephysrig'
×
309

310
        return status, info
1✔
311

312

313
def display_status(subj, sess_dates, status, perf_easy=None, n_trials=None, psych=None,
1✔
314
                   psych_20=None, psych_80=None, rt=None):
315
    """
316
    Display training status of subject to terminal
317

318
    :param subj: subject nickname
319
    :type subj: string
320
    :param sess_dates: training session dates used to determine training status
321
    :type sess_dates: list of strings
322
    :param status: training status of subject
323
    :type status: string
324
    :param perf_easy: performance on easy trials for each training sessions
325
    :type perf_easy: np.array
326
    :param n_trials: number of trials for each training sessions
327
    :type n_trials: np.array
328
    :param psych: parameters of psychometric curve fit to data from all training sessions
329
    :type psych: np.array - bias, threshold, lapse high, lapse low
330
    :param psych_20: parameters of psychometric curve fit to data in 20 (probability left) block
331
    from all training sessions
332
    :type psych_20: np.array - bias, threshold, lapse high, lapse low
333
    :param psych_80: parameters of psychometric curve fit to data in 80 (probability left) block
334
    from all training sessions
335
    :type psych_80: np.array - bias, threshold, lapse high, lapse low
336
    :param rt: median reaction time on zero contrast trials across all training sessions (if nan
337
    indicates no zero contrast stimuli in training sessions)
338
    """
339

340
    if perf_easy is None:
×
341
        print(f"\n{subj} : {status} \nSession dates=[{sess_dates[0]}, {sess_dates[1]}, "
×
342
              f"{sess_dates[2]}]")
343
    elif psych_20 is None:
×
344
        print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
×
345
              f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
346
              f"N trials={[nt for nt in n_trials]} "
347
              f"\nPsych fit over last 3 sessions: "
348
              f"bias={np.around(psych[0],2)}, thres={np.around(psych[1],2)}, "
349
              f"lapse_low={np.around(psych[2],2)}, lapse_high={np.around(psych[3],2)} "
350
              f"\nMedian reaction time at 0 contrast over last 3 sessions = "
351
              f"{np.around(rt,2)}")
352

353
    else:
354
        print(f"\n{subj} : {status} \nSession dates={[x for x in sess_dates]}, "
×
355
              f"Perf easy={[np.around(pe,2) for pe in perf_easy]}, "
356
              f"N trials={[nt for nt in n_trials]} "
357
              f"\nPsych fit over last 3 sessions (20): "
358
              f"bias={np.around(psych_20[0],2)}, thres={np.around(psych_20[1],2)}, "
359
              f"lapse_low={np.around(psych_20[2],2)}, lapse_high={np.around(psych_20[3],2)} "
360
              f"\nPsych fit over last 3 sessions (80): bias={np.around(psych_80[0],2)}, "
361
              f"thres={np.around(psych_80[1],2)}, lapse_low={np.around(psych_80[2],2)}, "
362
              f"lapse_high={np.around(psych_80[3],2)} "
363
              f"\nMedian reaction time at 0 contrast over last 3 sessions = "
364
              f"{np.around(rt, 2)}")
365

366

367
def concatenate_trials(trials):
1✔
368
    """
369
    Concatenate trials from different training sessions
370

371
    :param trials: dict containing trials objects from three consecutive training sessions,
372
    keys are session dates
373
    :type trials: Bunch
374
    :return: trials object with data concatenated over three training sessions
375
    :rtype: dict
376
    """
377
    trials_all = Bunch()
1✔
378
    for k in TRIALS_KEYS:
1✔
379
        trials_all[k] = np.concatenate(list(trials[kk][k] for kk in trials.keys()))
1✔
380

381
    return trials_all
1✔
382

383

384
def compute_training_info(trials, trials_all):
1✔
385
    """
386
    Compute all relevant performance metrics for when subject is on trainingChoiceWorld
387

388
    :param trials: dict containing trials objects from three consecutive training sessions,
389
    keys are session dates
390
    :type trials: Bunch
391
    :param trials_all: trials object with data concatenated over three training sessions
392
    :type trials_all: Bunch
393
    :returns:
394
        - perf_easy - performance of easy trials for each session
395
        - n_trials - number of trials in each session
396
        - psych - parameters for psychometric curve fit to all sessions
397
        - rt - median reaction time for zero contrast stimuli over all sessions
398
    """
399

400
    signed_contrast = get_signed_contrast(trials_all)
1✔
401
    perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()])
1✔
402
    n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()])
1✔
403
    psych = compute_psychometric(trials_all, signed_contrast=signed_contrast)
1✔
404
    rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast)
1✔
405

406
    return perf_easy, n_trials, psych, rt
1✔
407

408

409
def compute_bias_info(trials, trials_all):
1✔
410
    """
411
    Compute all relevant performance metrics for when subject is on biasedChoiceWorld
412

413
    :param trials: dict containing trials objects from three consecutive training sessions,
414
    keys are session dates
415
    :type trials: Bunch
416
    :param trials_all: trials object with data concatenated over three training sessions
417
    :type trials_all: Bunch
418
    :returns:
419
        - perf_easy - performance of easy trials for each session
420
        - n_trials - number of trials in each session
421
        - psych_20 - parameters for psychometric curve fit to trials in 20 block over all sessions
422
        - psych_80 - parameters for psychometric curve fit to trials in 80 block over all sessions
423
        - rt - median reaction time for zero contrast stimuli over all sessions
424
    """
425

426
    signed_contrast = get_signed_contrast(trials_all)
1✔
427
    perf_easy = np.array([compute_performance_easy(trials[k]) for k in trials.keys()])
1✔
428
    n_trials = np.array([compute_n_trials(trials[k]) for k in trials.keys()])
1✔
429
    psych_20 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.2)
1✔
430
    psych_80 = compute_psychometric(trials_all, signed_contrast=signed_contrast, block=0.8)
1✔
431
    rt = compute_median_reaction_time(trials_all, contrast=0, signed_contrast=signed_contrast)
1✔
432

433
    return perf_easy, n_trials, psych_20, psych_80, rt
1✔
434

435

436
def get_signed_contrast(trials):
1✔
437
    """
438
    Compute signed contrast from trials object
439

440
    :param trials: trials object that must contain contrastLeft and contrastRight keys
441
    :type trials: dict
442
    returns: array of signed contrasts in percent, where -ve values are on the left
443
    """
444
    # Replace NaNs with zeros, stack and take the difference
445
    contrast = np.nan_to_num(np.c_[trials['contrastLeft'], trials['contrastRight']])
1✔
446
    return np.diff(contrast).flatten() * 100
1✔
447

448

449
def compute_performance_easy(trials):
1✔
450
    """
451
    Compute performance on easy trials (stimulus >= 50 %) from trials object
452

453
    :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
454
    keys
455
    :type trials: dict
456
    returns: float containing performance on easy contrast trials
457
    """
458
    signed_contrast = get_signed_contrast(trials)
1✔
459
    easy_trials = np.where(np.abs(signed_contrast) >= 50)[0]
1✔
460
    return np.sum(trials['feedbackType'][easy_trials] == 1) / easy_trials.shape[0]
1✔
461

462

463
def compute_performance(trials, signed_contrast=None, block=None, prob_right=False):
1✔
464
    """
465
    Compute performance on all trials at each contrast level from trials object
466

467
    :param trials: trials object that must contain contrastLeft, contrastRight and feedbackType
468
    keys
469
    :type trials: dict
470
    returns: float containing performance on easy contrast trials
471
    """
472
    if signed_contrast is None:
1✔
473
        signed_contrast = get_signed_contrast(trials)
1✔
474

475
    if block is None:
1✔
476
        block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
1✔
477
    else:
478
        block_idx = trials.probabilityLeft == block
1✔
479

480
    if not np.any(block_idx):
1✔
481
        return np.nan * np.zeros(3)
×
482

483
    contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
1✔
484

485
    if not prob_right:
1✔
486
        correct = trials.feedbackType == 1
×
487
        performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts)
×
488
    else:
489
        rightward = trials.choice == -1
1✔
490
        # Calculate the proportion rightward for each contrast type
491
        performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts)
1✔
492

493
    return performance, contrasts, n_contrasts
1✔
494

495

496
def compute_n_trials(trials):
1✔
497
    """
498
    Compute number of trials in trials object
499

500
    :param trials: trials object
501
    :type trials: dict
502
    returns: int containing number of trials in session
503
    """
504
    return trials['choice'].shape[0]
1✔
505

506

507
def compute_psychometric(trials, signed_contrast=None, block=None, plotting=False, compute_ci=False, alpha=0.32):
1✔
508
    """
509
    Compute psychometric fit parameters for trials object
510

511
    :param trials: trials object that must contain contrastLeft, contrastRight and probabilityLeft
512
    :type trials: dict
513
    :param signed_contrast: array of signed contrasts in percent, where -ve values are on the left
514
    :type signed_contrast: np.array
515
    :param block: biased block can be either 0.2 or 0.8
516
    :type block: float
517
    :return: array of psychometric fit parameters - bias, threshold, lapse high, lapse low
518
    """
519

520
    if signed_contrast is None:
1✔
521
        signed_contrast = get_signed_contrast(trials)
1✔
522

523
    if block is None:
1✔
524
        block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
1✔
525
    else:
526
        block_idx = trials.probabilityLeft == block
1✔
527

528
    if not np.any(block_idx):
1✔
529
        return np.nan * np.zeros(4)
1✔
530

531
    prob_choose_right, contrasts, n_contrasts = compute_performance(trials, signed_contrast=signed_contrast, block=block,
1✔
532
                                                                    prob_right=True)
533

534
    if plotting:
1✔
535
        psych, _ = psy.mle_fit_psycho(
×
536
            np.vstack([contrasts, n_contrasts, prob_choose_right]),
537
            P_model='erf_psycho_2gammas',
538
            parstart=np.array([0., 40., 0.1, 0.1]),
539
            parmin=np.array([-50., 10., 0., 0.]),
540
            parmax=np.array([50., 50., 0.2, 0.2]),
541
            nfits=10)
542
    else:
543

544
        psych, _ = psy.mle_fit_psycho(
1✔
545
            np.vstack([contrasts, n_contrasts, prob_choose_right]),
546
            P_model='erf_psycho_2gammas',
547
            parstart=np.array([np.mean(contrasts), 20., 0.05, 0.05]),
548
            parmin=np.array([np.min(contrasts), 0., 0., 0.]),
549
            parmax=np.array([np.max(contrasts), 100., 1, 1]))
550

551
    if compute_ci:
1✔
552
        import statsmodels.stats.proportion as smp # noqa
×
553
        # choice == -1 means contrast on right hand side
554
        n_right = np.vectorize(lambda x: np.sum(trials['choice'][(x == signed_contrast) & block_idx] == -1))(contrasts)
×
555
        ci = smp.proportion_confint(n_right, n_contrasts, alpha=alpha / 10, method='normal') - prob_choose_right
×
556

557
        return psych, ci
×
558
    else:
559
        return psych
1✔
560

561

562
def compute_median_reaction_time(trials, stim_on_type='stimOn_times', contrast=None, signed_contrast=None):
1✔
563
    """
564
    Compute median reaction time on zero contrast trials from trials object
565

566
    :param trials: trials object that must contain response_times and stimOn_times
567
    :type trials: dict
568
    :param stim_on_type: feedback from which to compute the reaction time. Default is stimOn_times
569
    i.e when stimulus is presented
570
    :type stim_on_type: string (must be a valid key in trials object)
571
    :param signed_contrast: array of signed contrasts in percent, where -ve values are on the left
572
    :type signed_contrast: np.array
573
    :return: float of median reaction time at zero contrast (returns nan if no zero contrast
574
    trials in trials object)
575
    """
576
    if signed_contrast is None:
1✔
577
        signed_contrast = get_signed_contrast(trials)
1✔
578

579
    if contrast is None:
1✔
580
        contrast_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
1✔
581
    else:
582
        contrast_idx = signed_contrast == contrast
1✔
583

584
    if np.any(contrast_idx):
1✔
585
        reaction_time = np.nanmedian((trials.response_times - trials[stim_on_type])
1✔
586
                                     [contrast_idx])
587
    else:
588
        reaction_time = np.nan
1✔
589

590
    return reaction_time
1✔
591

592

593
def compute_reaction_time(trials, stim_on_type='stimOn_times', stim_off_type='response_times', signed_contrast=None, block=None,
1✔
594
                          compute_ci=False, alpha=0.32):
595
    """
596
    Compute median reaction time for all contrasts
597
    :param trials: trials object that must contain response_times and stimOn_times
598
    :param stim_on_type:
599
    :param stim_off_type:
600
    :param signed_contrast:
601
    :param block:
602
    :return:
603
    """
604

605
    if signed_contrast is None:
×
606
        signed_contrast = get_signed_contrast(trials)
×
607

608
    if block is None:
×
609
        block_idx = np.full(trials.probabilityLeft.shape, True, dtype=bool)
×
610
    else:
611
        block_idx = trials.probabilityLeft == block
×
612

613
    contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
×
614
    reaction_time = np.vectorize(lambda x: np.nanmedian((trials[stim_off_type] - trials[stim_on_type])
×
615
                                                        [(x == signed_contrast) & block_idx]))(contrasts)
616
    if compute_ci:
×
617
        ci = np.full((contrasts.size, 2), np.nan)
×
618
        for i, x in enumerate(contrasts):
×
619
            data = (trials[stim_off_type] - trials[stim_on_type])[(x == signed_contrast) & block_idx]
×
620
            bt = bootstrap((data,), np.nanmedian, confidence_level=1 - alpha)
×
621
            ci[i, 0] = bt.confidence_interval.low
×
622
            ci[i, 1] = bt.confidence_interval.high
×
623

624
        return reaction_time, contrasts, n_contrasts, ci
×
625
    else:
626
        return reaction_time, contrasts, n_contrasts,
×
627

628

629
def criterion_1a(psych, n_trials, perf_easy):
1✔
630
    """
631
    Returns bool indicating whether criterion for trained_1a is met. All criteria documented here
632
    (https://figshare.com/articles/preprint/A_standardized_and_reproducible_method_to_measure_
633
    decision-making_in_mice_Appendix_2_IBL_protocol_for_mice_training/11634729)
634
    """
635

636
    criterion = (abs(psych[0]) < 16 and psych[1] < 19 and psych[2] < 0.2 and psych[3] < 0.2 and
1✔
637
                 np.all(n_trials > 200) and np.all(perf_easy > 0.8))
638
    return criterion
1✔
639

640

641
def criterion_1b(psych, n_trials, perf_easy, rt):
1✔
642
    """
643
    Returns bool indicating whether criterion for trained_1b is met.
644
    """
645
    criterion = (abs(psych[0]) < 10 and psych[1] < 20 and psych[2] < 0.1 and psych[3] < 0.1 and
1✔
646
                 np.all(n_trials > 400) and np.all(perf_easy > 0.9) and rt < 2)
647
    return criterion
1✔
648

649

650
def criterion_ephys(psych_20, psych_80, n_trials, perf_easy, rt):
1✔
651
    """
652
    Returns bool indicating whether criterion for ready4ephysrig or ready4recording is met.
653
    """
654
    criterion = (psych_20[2] < 0.1 and psych_20[3] < 0.1 and psych_80[2] < 0.1 and psych_80[3] and
1✔
655
                 psych_80[0] - psych_20[0] > 5 and np.all(n_trials > 400) and
656
                 np.all(perf_easy > 0.9) and rt < 2)
657
    return criterion
1✔
658

659

660
def criterion_delay(n_trials, perf_easy):
1✔
661
    """
662
    Returns bool indicating whether criterion for ready4delay is met.
663
    """
664
    criterion = np.any(n_trials > 400) and np.any(perf_easy > 0.9)
1✔
665
    return criterion
1✔
666

667

668
def plot_psychometric(trials, ax=None, title=None, plot_ci=False, ci_aplha=0.32, **kwargs):
1✔
669
    """
670
    Function to plot psychometric curve plots a la datajoint webpage
671
    :param trials:
672
    :return:
673
    """
674

675
    signed_contrast = get_signed_contrast(trials)
×
676
    contrasts_fit = np.arange(-100, 100)
×
677

678
    prob_right_50, contrasts_50, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.5, prob_right=True)
×
679
    out_50 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.5, plotting=True,
×
680
                                  compute_ci=plot_ci, alpha=ci_aplha)
681
    pars_50 = out_50[0] if plot_ci else out_50
×
682
    prob_right_fit_50 = psy.erf_psycho_2gammas(pars_50, contrasts_fit)
×
683

684
    prob_right_20, contrasts_20, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.2, prob_right=True)
×
685
    out_20 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.2, plotting=True,
×
686
                                  compute_ci=plot_ci, alpha=ci_aplha)
687
    pars_20 = out_20[0] if plot_ci else out_20
×
688
    prob_right_fit_20 = psy.erf_psycho_2gammas(pars_20, contrasts_fit)
×
689

690
    prob_right_80, contrasts_80, _ = compute_performance(trials, signed_contrast=signed_contrast, block=0.8, prob_right=True)
×
691
    out_80 = compute_psychometric(trials, signed_contrast=signed_contrast, block=0.8, plotting=True,
×
692
                                  compute_ci=plot_ci, alpha=ci_aplha)
693
    pars_80 = out_80[0] if plot_ci else out_80
×
694
    prob_right_fit_80 = psy.erf_psycho_2gammas(pars_80, contrasts_fit)
×
695

696
    cmap = sns.diverging_palette(20, 220, n=3, center="dark")
×
697

698
    if not ax:
×
699
        fig, ax = plt.subplots(**kwargs)
×
700
    else:
701
        fig = plt.gcf()
×
702

703
    fit_50 = ax.plot(contrasts_fit, prob_right_fit_50, color=cmap[1])
×
704
    data_50 = ax.scatter(contrasts_50, prob_right_50, color=cmap[1])
×
705
    fit_20 = ax.plot(contrasts_fit, prob_right_fit_20, color=cmap[0])
×
706
    data_20 = ax.scatter(contrasts_20, prob_right_20, color=cmap[0])
×
707
    fit_80 = ax.plot(contrasts_fit, prob_right_fit_80, color=cmap[2])
×
708
    data_80 = ax.scatter(contrasts_80, prob_right_80, color=cmap[2])
×
709

710
    if plot_ci:
×
711
        errbar_50 = np.c_[np.abs(out_50[1][0]), np.abs(out_50[1][1])].T
×
712
        errbar_20 = np.c_[np.abs(out_20[1][0]), np.abs(out_20[1][1])].T
×
713
        errbar_80 = np.c_[np.abs(out_80[1][0]), np.abs(out_80[1][1])].T
×
714

715
        ax.errorbar(contrasts_50, prob_right_50, yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
×
716
        ax.errorbar(contrasts_20, prob_right_20, yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
×
717
        ax.errorbar(contrasts_80, prob_right_80, yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
×
718

719
    ax.legend([fit_50[0], data_50, fit_20[0], data_20, fit_80[0], data_80],
×
720
              ['p_left=0.5 fit', 'p_left=0.5 data', 'p_left=0.2 fit', 'p_left=0.2 data', 'p_left=0.8 fit', 'p_left=0.8 data'],
721
              loc='upper left')
722
    ax.set_ylim(-0.05, 1.05)
×
723
    ax.set_ylabel('Probability choosing right')
×
724
    ax.set_xlabel('Contrasts')
×
725
    if title:
×
726
        ax.set_title(title)
×
727

728
    return fig, ax
×
729

730

731
def plot_reaction_time(trials, ax=None, title=None, plot_ci=False, ci_alpha=0.32, **kwargs):
1✔
732
    """
733
    Function to plot reaction time against contrast a la datajoint webpage (inverted for some reason??)
734
    :param trials:
735
    :return:
736
    """
737

738
    signed_contrast = get_signed_contrast(trials)
×
739
    out_50 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.5, compute_ci=plot_ci, alpha=ci_alpha)
×
740
    out_20 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.2, compute_ci=plot_ci, alpha=ci_alpha)
×
741
    out_80 = compute_reaction_time(trials, signed_contrast=signed_contrast, block=0.8, compute_ci=plot_ci, alpha=ci_alpha)
×
742

743
    cmap = sns.diverging_palette(20, 220, n=3, center="dark")
×
744

745
    if not ax:
×
746
        fig, ax = plt.subplots(**kwargs)
×
747
    else:
748
        fig = plt.gcf()
×
749

750
    data_50 = ax.plot(out_50[1], out_50[0], '-o', color=cmap[1])
×
751
    data_20 = ax.plot(out_20[1], out_20[0], '-o', color=cmap[0])
×
752
    data_80 = ax.plot(out_80[1], out_80[0], '-o', color=cmap[2])
×
753

754
    if plot_ci:
×
755
        errbar_50 = np.c_[out_50[0] - out_50[3][:, 0], out_50[3][:, 1] - out_50[0]].T
×
756
        errbar_20 = np.c_[out_20[0] - out_20[3][:, 0], out_20[3][:, 1] - out_20[0]].T
×
757
        errbar_80 = np.c_[out_80[0] - out_80[3][:, 0], out_80[3][:, 1] - out_80[0]].T
×
758

759
        ax.errorbar(out_50[1], out_50[0], yerr=errbar_50, ecolor=cmap[1], fmt='none', capsize=5, alpha=0.4)
×
760
        ax.errorbar(out_20[1], out_20[0], yerr=errbar_20, ecolor=cmap[0], fmt='none', capsize=5, alpha=0.4)
×
761
        ax.errorbar(out_80[1], out_80[0], yerr=errbar_80, ecolor=cmap[2], fmt='none', capsize=5, alpha=0.4)
×
762

763
    ax.legend([data_50[0], data_20[0], data_80[0]],
×
764
              ['p_left=0.5 data', 'p_left=0.2 data', 'p_left=0.8 data'],
765
              loc='upper left')
766
    ax.set_ylabel('Reaction time (s)')
×
767
    ax.set_xlabel('Contrasts')
×
768

769
    if title:
×
770
        ax.set_title(title)
×
771

772
    return fig, ax
×
773

774

775
def plot_reaction_time_over_trials(trials, stim_on_type='stimOn_times', ax=None, title=None, **kwargs):
1✔
776
    """
777
    Function to plot reaction time with trial number a la datajoint webpage
778

779
    :param trials:
780
    :param stim_on_type:
781
    :param ax:
782
    :param title:
783
    :param kwargs:
784
    :return:
785
    """
786

787
    reaction_time = pd.DataFrame()
×
788
    reaction_time['reaction_time'] = trials.response_times - trials[stim_on_type]
×
789
    reaction_time.index = reaction_time.index + 1
×
790
    reaction_time_rolled = reaction_time['reaction_time'].rolling(window=10).median()
×
791
    reaction_time_rolled = reaction_time_rolled.where((pd.notnull(reaction_time_rolled)), None)
×
792
    reaction_time = reaction_time.where((pd.notnull(reaction_time)), None)
×
793

794
    if not ax:
×
795
        fig, ax = plt.subplots(**kwargs)
×
796
    else:
797
        fig = plt.gcf()
×
798

799
    ax.scatter(np.arange(len(reaction_time.values)), reaction_time.values, s=16, color='darkgray')
×
800
    ax.plot(np.arange(len(reaction_time_rolled.values)), reaction_time_rolled.values, color='k', linewidth=2)
×
801
    ax.set_yscale('log')
×
802
    ax.set_ylim(0.1, 100)
×
803
    ax.yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
×
804
    ax.set_ylabel('Reaction time (s)')
×
805
    ax.set_xlabel('Trial number')
×
806
    if title:
×
807
        ax.set_title(title)
×
808

809
    return fig, ax
×
810

811

812
def query_criterion(subject, status, from_status=None, one=None, validate=True):
1✔
813
    """Get the session for which a given training criterion was met.
814

815
    Parameters
816
    ----------
817
    subject : str
818
        The subject name.
819
    status : str
820
        The training status to query for.
821
    from_status : str, optional
822
        Count number of sessions and days from reaching `from_status` to `status`.
823
    one : one.api.OneAlyx, optional
824
        An instance of ONE.
825
    validate : bool
826
        If true, check if status in TrainingStatus enumeration. Set to false for non-standard
827
        training pipelines.
828

829
    Returns
830
    -------
831
    str
832
        The eID of the first session where this training status was reached.
833
    int
834
        The number of sessions it took to reach `status` (optionally from reaching `from_status`).
835
    int
836
        The number of days it tool to reach `status` (optionally from reaching `from_status`).
837
    """
838
    if validate:
1✔
839
        status = status.lower().replace(' ', '_')
1✔
840
        try:
1✔
841
            status = TrainingStatus[status.upper().replace(' ', '_')].name.lower()
1✔
842
        except KeyError as ex:
×
843
            raise ValueError(
×
844
                f'Unknown status "{status}". For non-standard training protocols set validate=False'
845
            ) from ex
846
    one = one or ONE()
1✔
847
    subject_json = one.alyx.rest('subjects', 'read', id=subject)['json']
1✔
848
    if not (criteria := subject_json.get('trained_criteria')) or status not in criteria:
1✔
849
        return None, None, None
×
850
    to_date, eid = criteria[status]
1✔
851
    from_date, _ = criteria.get(from_status, (None, None))
1✔
852
    eids, det = one.search(subject=subject, date_range=[from_date, to_date], details=True)
1✔
853
    if len(eids) == 0:
×
854
        return eid, None, None
×
855
    delta_date = det[0]['date'] - det[-1]['date']
×
856
    return eid, len(eids), delta_date.days
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc