• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / ibllib / 7961675356254463

pending completion
7961675356254463

Pull #557

continuous-integration/UCL

olivier
add test
Pull Request #557: Chained protocols

718 of 718 new or added lines in 27 files covered. (100.0%)

12554 of 18072 relevant lines covered (69.47%)

0.69 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.31
/ibllib/io/extractors/habituation_trials.py
1
import logging
1✔
2
import numpy as np
1✔
3

4
import ibllib.io.raw_data_loaders as raw
1✔
5
from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
1✔
6
from ibllib.io.extractors.biased_trials import ContrastLR
1✔
7
from ibllib.io.extractors.training_trials import (
1✔
8
    FeedbackTimes, StimOnTriggerTimes, Intervals, GoCueTimes
9
)
10

11
_logger = logging.getLogger(__name__)
1✔
12

13

14
class HabituationTrials(BaseBpodTrialsExtractor):
1✔
15
    var_names = ('feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', 'contrastRight',
1✔
16
                 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals',
17
                 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times',
18
                 'stimCenterTrigger_times', 'stimCenter_times')
19

20
    def __init__(self, *args, **kwargs):
1✔
21
        super().__init__(*args, **kwargs)
1✔
22
        exclude = ['itiIn_times', 'stimOffTrigger_times',
1✔
23
                   'stimCenter_times', 'stimCenterTrigger_times']
24
        self.save_names = tuple([f'_ibl_trials.{x}.npy' if x not in exclude else None
1✔
25
                                 for x in self.var_names])
26

27
    def _extract(self):
1✔
28
        # Extract all trials...
29

30
        # Get all stim_sync events detected
31
        ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials]
1✔
32

33
        # Report missing events
34
        n_missing = sum(len(pulses) != 3 for pulses in ttls)
1✔
35
        # Check if all stim syncs have failed to be detected
36
        if n_missing == len(ttls):
1✔
37
            _logger.error(f'{self.session_path}: Missing ALL BNC1 TTLs ({n_missing} trials)')
×
38
        elif n_missing > 0:  # Check if any stim_sync has failed be detected for every trial
1✔
39
            _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)')
×
40

41
        # Extract datasets common to trainingChoiceWorld
42
        training = [ContrastLR, FeedbackTimes, Intervals, GoCueTimes, StimOnTriggerTimes]
1✔
43
        out, _ = run_extractor_classes(training, session_path=self.session_path, save=False,
1✔
44
                                       bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection)
45

46
        # GoCueTriggerTimes is the same event as StimOnTriggerTimes
47
        out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy()
1✔
48

49
        # StimCenterTrigger times
50
        # Get the stim_on_state that triggers the onset of the stim
51
        stim_center_state = np.array([tr['behavior_data']['States timestamps']
1✔
52
                                      ['stim_center'][0] for tr in self.bpod_trials])
53
        out['stimCenterTrigger_times'] = stim_center_state[:, 0].T
1✔
54

55
        # StimCenter times
56
        stim_center_times = np.full(out['stimCenterTrigger_times'].shape, np.nan)
1✔
57
        for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])):
1✔
58
            """We expect there to be 3 pulses per trial; if this is the case, stim center will
59
            be the third pulse. If any pulses are missing, we can only be confident of the correct
60
            one if exactly one pulse occurs after the stim center trigger"""
61
            if len(sync) == 3 or (len(sync) > 0 and sum(pulse > last for pulse in sync) == 1):
1✔
62
                stim_center_times[i] = sync[-1]
1✔
63
        out['stimCenter_times'] = stim_center_times
1✔
64

65
        # StimOn times
66
        stimOn_times = np.full(out['stimOnTrigger_times'].shape, np.nan)
1✔
67
        for i, (sync, last) in enumerate(zip(ttls, out['stimCenterTrigger_times'])):
1✔
68
            """We expect there to be 3 pulses per trial; if this is the case, stim on will be the
69
            second pulse. If 1 pulse is missing, we can only be confident of the correct one if
70
            both pulses occur before the stim center trigger"""
71
            if len(sync) == 3 or (len(sync) == 2 and sum(pulse < last for pulse in sync) == 2):
1✔
72
                stimOn_times[i] = sync[1]
1✔
73
        out['stimOn_times'] = stimOn_times
1✔
74

75
        # RewardVolume
76
        trial_volume = [x['reward_amount'] for x in self.bpod_trials]
1✔
77
        out['rewardVolume'] = np.array(trial_volume).astype(np.float64)
1✔
78

79
        # StimOffTrigger times
80
        # StimOff occurs at trial start (ignore the first trial's state update)
81
        out['stimOffTrigger_times'] = np.array(
1✔
82
            [tr["behavior_data"]["States timestamps"]
83
             ["trial_start"][0][0] for tr in self.bpod_trials[1:]]
84
        )
85

86
        # StimOff times
87
        """
88
        There should be exactly three TTLs per trial.  stimOff_times should be the first TTL pulse.
89
        If 1 or more pulses are missing, we can not be confident of assigning the correct one.
90
        """
91
        trigg = out['stimOffTrigger_times']
1✔
92
        out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan
1✔
93
                                         for sync, off in zip(ttls[1:], trigg)])
94

95
        # FeedbackType is always positive
96
        out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8)
1✔
97

98
        # ItiIn times
99
        out['itiIn_times'] = np.array(
1✔
100
            [tr["behavior_data"]["States timestamps"]
101
             ["iti"][0][0] for tr in self.bpod_trials]
102
        )
103

104
        # NB: We lose the last trial because the stim off event occurs at trial_num + 1
105
        n_trials = out['stimOff_times'].size
1✔
106
        return [out[k][:n_trials] for k in self.var_names]
1✔
107

108

109
def extract_all(session_path, save=False, bpod_trials=False, settings=False, task_collection='raw_behavior_data', save_path=None):
1✔
110
    """Extract all datasets from habituationChoiceWorld
111
    Note: only the datasets from the HabituationTrials extractor will be saved to disc.
112

113
    :param session_path: The session path where the raw data are saved
114
    :param save: If True, the datasets that are considered standard are saved to the session path
115
    :param bpod_trials: The raw Bpod trial data
116
    :param settings: The raw Bpod sessions
117
    :returns: a dict of datasets and a corresponding list of file names
118
    """
119
    if not bpod_trials:
1✔
120
        bpod_trials = raw.load_data(session_path, task_collection=task_collection)
×
121
    if not settings:
1✔
122
        settings = raw.load_settings(session_path, task_collection=task_collection)
×
123

124
    # Standard datasets that may be saved as ALFs
125
    params = dict(session_path=session_path, bpod_trials=bpod_trials, settings=settings, task_collection=task_collection,
1✔
126
                  path_out=save_path)
127
    out, fil = run_extractor_classes(HabituationTrials, save=save, **params)
1✔
128
    return out, fil
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc