• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 9032957364

10 May 2024 01:25PM UTC coverage: 48.538% (+1.7%) from 46.79%
9032957364

Pull #643

github

74d2ec
web-flow
Merge aebf2c9af into ec2d8e4fe
Pull Request #643: 8.19.0

377 of 1074 new or added lines in 38 files covered. (35.1%)

977 existing lines in 19 files now uncovered.

3253 of 6702 relevant lines covered (48.54%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.17
/iblrig/base_choice_world.py
1
"""Extends the base_tasks modules by providing task logic around the Choice World protocol."""
2

3
import abc
2✔
4
import json
2✔
5
import logging
2✔
6
import math
2✔
7
import random
2✔
8
import subprocess
2✔
9
import time
2✔
10
import traceback
2✔
11
from pathlib import Path
2✔
12
from string import ascii_letters
2✔
13
from typing import Annotated, Literal
2✔
14

15
import numpy as np
2✔
16
import pandas as pd
2✔
17
from pydantic import BaseModel, Field
2✔
18

19
import iblrig.base_tasks
2✔
20
import iblrig.graphic
2✔
21
from iblrig import choiceworld, misc
2✔
22
from iblrig.hardware import SOFTCODE
2✔
23
from iblutil.io import jsonable
2✔
24
from iblutil.util import Bunch
2✔
25
from pybpodapi.com.messaging.trial import Trial
2✔
26
from pybpodapi.protocol import StateMachine
2✔
27

28
log = logging.getLogger(__name__)
2✔
29

30
NTRIALS_INIT = 2000
2✔
31
NBLOCKS_INIT = 100
2✔
32

33
Probability = Annotated[float, Field(ge=0.0, le=1.0)]
2✔
34

35

36
class ChoiceWorldParams(BaseModel):
2✔
37
    AUTOMATIC_CALIBRATION: bool = True
2✔
38
    ADAPTIVE_REWARD: bool = False
2✔
39
    BONSAI_EDITOR: bool = False
2✔
40
    CALIBRATION_VALUE: float = 0.067
2✔
41
    CONTRAST_SET: list[Probability] = Field([1.0, 0.25, 0.125, 0.0625, 0.0], min_items=1)
2✔
42
    CONTRAST_SET_PROBABILITY_TYPE: Literal['uniform', 'skew_zero'] = 'uniform'
2✔
43
    GO_TONE_AMPLITUDE: float = 0.0272
2✔
44
    GO_TONE_DURATION: float = 0.11
2✔
45
    GO_TONE_IDX: int = Field(2, ge=0)
2✔
46
    GO_TONE_FREQUENCY: float = Field(5000, gt=0)
2✔
47
    FEEDBACK_CORRECT_DELAY_SECS: float = 1
2✔
48
    FEEDBACK_ERROR_DELAY_SECS: float = 2
2✔
49
    FEEDBACK_NOGO_DELAY_SECS: float = 2
2✔
50
    INTERACTIVE_DELAY: float = 0.0
2✔
51
    ITI_DELAY_SECS: float = 0.5
2✔
52
    NTRIALS: int = Field(2000, gt=0)
2✔
53
    PROBABILITY_LEFT: Probability = 0.5
2✔
54
    QUIESCENCE_THRESHOLDS: list[float] = Field(default=[-2, 2], min_length=2, max_length=2)
2✔
55
    QUIESCENT_PERIOD: float = 0.2
2✔
56
    RECORD_AMBIENT_SENSOR_DATA: bool = True
2✔
57
    RECORD_SOUND: bool = True
2✔
58
    RESPONSE_WINDOW: float = 60
2✔
59
    REWARD_AMOUNT_UL: float = 1.5
2✔
60
    REWARD_TYPE: str = 'Water 10% Sucrose'
2✔
61
    STIM_ANGLE: float = 0.0
2✔
62
    STIM_FREQ: float = 0.1
2✔
63
    STIM_GAIN: float = 4.0  # wheel to stimulus relationship (degrees visual angle per mm of wheel displacement)
2✔
64
    STIM_POSITIONS: list[float] = [-35, 35]
2✔
65
    STIM_SIGMA: float = 7.0
2✔
66
    STIM_TRANSLATION_Z: Literal[7, 8] = 7  # 7 for ephys, 8 otherwise. -p:Stim.TranslationZ-{STIM_TRANSLATION_Z} bonsai parameter
2✔
67
    SYNC_SQUARE_X: float = 1.33
2✔
68
    SYNC_SQUARE_Y: float = -1.03
2✔
69
    USE_AUTOMATIC_STOPPING_CRITERIONS: bool = True
2✔
70
    VISUAL_STIMULUS: str = 'GaborIBLTask / Gabor2D.bonsai'  # null / passiveChoiceWorld_passive.bonsai
2✔
71
    WHITE_NOISE_AMPLITUDE: float = 0.05
2✔
72
    WHITE_NOISE_DURATION: float = 0.5
2✔
73
    WHITE_NOISE_IDX: int = 3
2✔
74

75

76
class ChoiceWorldSession(
2✔
77
    iblrig.base_tasks.BonsaiRecordingMixin,
78
    iblrig.base_tasks.BonsaiVisualStimulusMixin,
79
    iblrig.base_tasks.BpodMixin,
80
    iblrig.base_tasks.Frame2TTLMixin,
81
    iblrig.base_tasks.RotaryEncoderMixin,
82
    iblrig.base_tasks.SoundMixin,
83
    iblrig.base_tasks.ValveMixin,
84
):
85
    # task_params = ChoiceWorldParams()
86
    base_parameters_file = Path(__file__).parent.joinpath('base_choice_world_params.yaml')
2✔
87

88
    def __init__(self, *args, delay_secs=0, **kwargs):
2✔
89
        super().__init__(**kwargs)
2✔
90
        self.task_params['SESSION_DELAY_START'] = delay_secs
2✔
91
        # init behaviour data
92
        self.movement_left = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[0]]
2✔
93
        self.movement_right = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[1]]
2✔
94
        # init counter variables
95
        self.trial_num = -1
2✔
96
        self.block_num = -1
2✔
97
        self.block_trial_num = -1
2✔
98
        # init the tables, there are 2 of them: a trials table and a ambient sensor data table
99
        self.trials_table = pd.DataFrame(
2✔
100
            {
101
                'contrast': np.zeros(NTRIALS_INIT) * np.NaN,
102
                'position': np.zeros(NTRIALS_INIT) * np.NaN,
103
                'quiescent_period': np.zeros(NTRIALS_INIT) * np.NaN,
104
                'response_side': np.zeros(NTRIALS_INIT, dtype=np.int8),
105
                'response_time': np.zeros(NTRIALS_INIT) * np.NaN,
106
                'reward_amount': np.zeros(NTRIALS_INIT) * np.NaN,
107
                'reward_valve_time': np.zeros(NTRIALS_INIT) * np.NaN,
108
                'stim_angle': np.zeros(NTRIALS_INIT) * np.NaN,
109
                'stim_freq': np.zeros(NTRIALS_INIT) * np.NaN,
110
                'stim_gain': np.zeros(NTRIALS_INIT) * np.NaN,
111
                'stim_phase': np.zeros(NTRIALS_INIT) * np.NaN,
112
                'stim_reverse': np.zeros(NTRIALS_INIT, dtype=bool),
113
                'stim_sigma': np.zeros(NTRIALS_INIT) * np.NaN,
114
                'trial_correct': np.zeros(NTRIALS_INIT, dtype=bool),
115
                'trial_num': np.zeros(NTRIALS_INIT, dtype=np.int16),
116
            }
117
        )
118

119
        self.ambient_sensor_table = pd.DataFrame(
2✔
120
            {
121
                'Temperature_C': np.zeros(NTRIALS_INIT) * np.NaN,
122
                'AirPressure_mb': np.zeros(NTRIALS_INIT) * np.NaN,
123
                'RelativeHumidity': np.zeros(NTRIALS_INIT) * np.NaN,
124
            }
125
        )
126

127
    @staticmethod
2✔
128
    def extra_parser():
2✔
129
        """:return: argparse.parser()"""
130
        parser = super(ChoiceWorldSession, ChoiceWorldSession).extra_parser()
2✔
131
        parser.add_argument(
2✔
132
            '--delay_secs',
133
            dest='delay_secs',
134
            default=0,
135
            type=int,
136
            required=False,
137
            help='initial delay before starting the first trial (default: 0s)',
138
        )
139
        return parser
2✔
140

141
    def start_hardware(self):
2✔
142
        """
143
        In this step we explicitly run the start methods of the various mixins.
144
        The super class start method is overloaded because we need to start the different hardware pieces in order
145
        """
146
        if not self.is_mock:
2✔
147
            self.start_mixin_frame2ttl()
×
148
            self.start_mixin_bpod()
×
149
            self.start_mixin_valve()
×
150
            self.start_mixin_sound()
×
151
            self.start_mixin_rotary_encoder()
×
152
            self.start_mixin_bonsai_cameras()
×
UNCOV
153
            self.start_mixin_bonsai_microphone()
×
UNCOV
154
            self.start_mixin_bonsai_visual_stimulus()
×
UNCOV
155
            self.bpod.register_softcodes(self.softcode_dictionary())
×
156

157
    def _run(self):
2✔
158
        """
159
        This is the method that runs the task with the actual state machine
160
        :return:
161
        """
162
        time_last_trial_end = time.time()
2✔
163
        for i in range(self.task_params.NTRIALS):  # Main loop
2✔
164
            # t_overhead = time.time()
165
            self.next_trial()
2✔
166
            log.info(f'Starting trial: {i}')
2✔
167
            # =============================================================================
168
            #     Start state machine definition
169
            # =============================================================================
170
            sma = self.get_state_machine_trial(i)
2✔
171
            log.debug('Sending state machine to bpod')
2✔
172
            # Send state machine description to Bpod device
173
            self.bpod.send_state_machine(sma)
2✔
174
            # t_overhead = time.time() - t_overhead
175
            # The ITI_DELAY_SECS defines the grey screen period within the state machine, where the
176
            # Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next
177
            dead_time = self.task_params.get('DEAD_TIME', 0.5)
2✔
178
            dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end)
2✔
179
            # wait to achieve the desired ITI duration
180
            if dt > 0:
2✔
UNCOV
181
                time.sleep(dt)
×
182
            # Run state machine
183
            log.debug('running state machine')
2✔
184
            self.bpod.run_state_machine(sma)  # Locks until state machine 'exit' is reached
2✔
185
            time_last_trial_end = time.time()
2✔
186
            self.trial_completed(self.bpod.session.current_trial.export())
2✔
187
            self.ambient_sensor_table.loc[i] = self.bpod.get_ambient_sensor_reading()
2✔
188
            self.show_trial_log()
2✔
189

190
            # handle pause and stop events
191
            flag_pause = self.paths.SESSION_FOLDER.joinpath('.pause')
2✔
192
            flag_stop = self.paths.SESSION_FOLDER.joinpath('.stop')
2✔
193
            if flag_pause.exists() and i < (self.task_params.NTRIALS - 1):
2✔
NEW
194
                log.info(f'Pausing session inbetween trials {i} and {i + 1}')
×
NEW
195
                while flag_pause.exists() and not flag_stop.exists():
×
NEW
196
                    time.sleep(1)
×
NEW
197
                if not flag_stop.exists():
×
NEW
198
                    log.info('Resuming session')
×
199
            if flag_stop.exists():
2✔
NEW
200
                log.info('Stopping session after trial {i}')
×
NEW
201
                flag_stop.unlink()
×
UNCOV
202
                break
×
203

204
    def mock(self, file_jsonable_fixture=None):
2✔
205
        """
206
        This methods serves to instantiate a state machine and bpod object to simulate a taks run.
207
        This is useful to test or display the state machine flow
208
        """
209
        super().mock()
2✔
210

211
        if file_jsonable_fixture is not None:
2✔
212
            task_data = jsonable.read(file_jsonable_fixture)
2✔
213
            # pop-out the bpod data from the table
214
            bpod_data = []
2✔
215
            for td in task_data:
2✔
216
                bpod_data.append(td.pop('behavior_data'))
2✔
217

218
            class MockTrial(Trial):
2✔
219
                def export(self):
2✔
220
                    return np.random.choice(bpod_data)
2✔
221
        else:
222

223
            class MockTrial(Trial):
2✔
224
                def export(self):
2✔
UNCOV
225
                    return {}
×
226

227
        self.bpod.session.trials = [MockTrial()]
2✔
228
        self.bpod.send_state_machine = lambda k: None
2✔
229
        self.bpod.run_state_machine = lambda k: time.sleep(1.2)
2✔
230

231
        daction = ('dummy', 'action')
2✔
232
        self.sound = Bunch({'GO_TONE': daction, 'WHITE_NOISE': daction})
2✔
233

234
        self.bpod.actions.update(
2✔
235
            {
236
                'play_tone': daction,
237
                'play_noise': daction,
238
                'stop_sound': daction,
239
                'rotary_encoder_reset': daction,
240
                'bonsai_hide_stim': daction,
241
                'bonsai_show_stim': daction,
242
                'bonsai_closed_loop': daction,
243
                'bonsai_freeze_stim': daction,
244
                'bonsai_show_center': daction,
245
            }
246
        )
247

248
    def get_graphviz_task(self, output_file=None, view=True):
2✔
249
        """
250
        For a given task, outputs the state machine states diagram in Digraph format
251
        :param output_file:
252
        :return:
253
        """
254
        import graphviz
2✔
255

256
        self.next_trial()
2✔
257
        sma = self.get_state_machine_trial(0)
2✔
258
        if sma is None:
2✔
259
            return
2✔
260
        states_indices = {i: k for i, k in enumerate(sma.state_names)}
2✔
261
        states_indices.update({(i + 10000): k for i, k in enumerate(sma.undeclared)})
2✔
262
        states_letters = {k: ascii_letters[i] for i, k in enumerate(sma.state_names)}
2✔
263
        dot = graphviz.Digraph(comment='The Great IBL Task')
2✔
264
        edges = []
2✔
265

266
        for i in range(len(sma.state_names)):
2✔
267
            letter = states_letters[sma.state_names[i]]
2✔
268
            dot.node(letter, sma.state_names[i])
2✔
269
            if ~np.isnan(sma.state_timer_matrix[i]):
2✔
270
                out_state = states_indices[sma.state_timer_matrix[i]]
2✔
271
                edges.append(f'{letter}{states_letters[out_state]}')
2✔
272
            for input in sma.input_matrix[i]:
2✔
273
                if input[0] == 0:
2✔
274
                    edges.append(f'{letter}{states_letters[states_indices[input[1]]]}')
2✔
275
        dot.edges(edges)
2✔
276
        if output_file is not None:
2✔
277
            try:
2✔
278
                dot.render(output_file, view=view)
2✔
279
            except graphviz.exceptions.ExecutableNotFound:
2✔
280
                log.info('Graphviz system executable not found, cannot render the graph')
2✔
281
        return dot
2✔
282

283
    def _instantiate_state_machine(self, *args, **kwargs):
2✔
284
        return StateMachine(self.bpod)
2✔
285

286
    def get_state_machine_trial(self, i):
2✔
287
        # we define the trial number here for subclasses that may need it
288
        sma = self._instantiate_state_machine(trial_number=i)
2✔
289
        if i == 0:  # First trial exception start camera
2✔
290
            session_delay_start = self.task_params.get('SESSION_DELAY_START', 0)
2✔
291
            log.info('First trial initializing, will move to next trial only if:')
2✔
292
            log.info('1. camera is detected')
2✔
293
            log.info(f'2. {session_delay_start} sec have elapsed')
2✔
294
            sma.add_state(
2✔
295
                state_name='trial_start',
296
                state_timer=0,
297
                state_change_conditions={'Port1In': 'delay_initiation'},
298
                output_actions=[('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
299
            )  # start camera
300
            sma.add_state(
2✔
301
                state_name='delay_initiation',
302
                state_timer=session_delay_start,
303
                output_actions=[],
304
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
305
            )
306
        else:
307
            sma.add_state(
2✔
308
                state_name='trial_start',
309
                state_timer=0,  # ~100µs hardware irreducible delay
310
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
311
                output_actions=[self.bpod.actions.stop_sound, ('BNC1', 255)],
312
            )  # stop all sounds
313

314
        sma.add_state(
2✔
315
            state_name='reset_rotary_encoder',
316
            state_timer=0,
317
            output_actions=[self.bpod.actions.rotary_encoder_reset],
318
            state_change_conditions={'Tup': 'quiescent_period'},
319
        )
320

321
        sma.add_state(  # '>back' | '>reset_timer'
2✔
322
            state_name='quiescent_period',
323
            state_timer=self.quiescent_period,
324
            output_actions=[],
325
            state_change_conditions={
326
                'Tup': 'stim_on',
327
                self.movement_left: 'reset_rotary_encoder',
328
                self.movement_right: 'reset_rotary_encoder',
329
            },
330
        )
331
        # show stimulus, move on to next state if a frame2ttl is detected, with a time-out of 0.1s
332
        sma.add_state(
2✔
333
            state_name='stim_on',
334
            state_timer=0.1,
335
            output_actions=[self.bpod.actions.bonsai_show_stim],
336
            state_change_conditions={'Tup': 'interactive_delay', 'BNC1High': 'interactive_delay', 'BNC1Low': 'interactive_delay'},
337
        )
338
        # this is a feature that can eventually add a delay between visual and auditory cue
339
        sma.add_state(
2✔
340
            state_name='interactive_delay',
341
            state_timer=self.task_params.INTERACTIVE_DELAY,
342
            output_actions=[],
343
            state_change_conditions={'Tup': 'play_tone'},
344
        )
345
        # play tone, move on to next state if sound is detected, with a time-out of 0.1s
346
        sma.add_state(
2✔
347
            state_name='play_tone',
348
            state_timer=0.1,
349
            output_actions=[self.bpod.actions.play_tone],
350
            state_change_conditions={'Tup': 'reset2_rotary_encoder', 'BNC2High': 'reset2_rotary_encoder'},
351
        )
352

353
        sma.add_state(
2✔
354
            state_name='reset2_rotary_encoder',
355
            state_timer=0.05,  # the delay here is to avoid race conditions in the bonsai flow
356
            output_actions=[self.bpod.actions.rotary_encoder_reset],
357
            state_change_conditions={'Tup': 'closed_loop'},
358
        )
359

360
        sma.add_state(
2✔
361
            state_name='closed_loop',
362
            state_timer=self.task_params.RESPONSE_WINDOW,
363
            output_actions=[self.bpod.actions.bonsai_closed_loop],
364
            state_change_conditions={'Tup': 'no_go', self.event_error: 'freeze_error', self.event_reward: 'freeze_reward'},
365
        )
366

367
        sma.add_state(
2✔
368
            state_name='no_go',
369
            state_timer=self.task_params.FEEDBACK_NOGO_DELAY_SECS,
370
            output_actions=[self.bpod.actions.bonsai_hide_stim, self.bpod.actions.play_noise],
371
            state_change_conditions={'Tup': 'exit_state'},
372
        )
373

374
        sma.add_state(
2✔
375
            state_name='freeze_error',
376
            state_timer=0,
377
            output_actions=[self.bpod.actions.bonsai_freeze_stim],
378
            state_change_conditions={'Tup': 'error'},
379
        )
380

381
        sma.add_state(
2✔
382
            state_name='error',
383
            state_timer=self.task_params.FEEDBACK_ERROR_DELAY_SECS,
384
            output_actions=[self.bpod.actions.play_noise],
385
            state_change_conditions={'Tup': 'hide_stim'},
386
        )
387

388
        sma.add_state(
2✔
389
            state_name='freeze_reward',
390
            state_timer=0,
391
            output_actions=[self.bpod.actions.bonsai_freeze_stim],
392
            state_change_conditions={'Tup': 'reward'},
393
        )
394

395
        sma.add_state(
2✔
396
            state_name='reward',
397
            state_timer=self.reward_time,
398
            output_actions=[('Valve1', 255), ('BNC1', 255)],
399
            state_change_conditions={'Tup': 'correct'},
400
        )
401

402
        sma.add_state(
2✔
403
            state_name='correct',
404
            state_timer=self.task_params.FEEDBACK_CORRECT_DELAY_SECS - self.reward_time,
405
            output_actions=[],
406
            state_change_conditions={'Tup': 'hide_stim'},
407
        )
408

409
        sma.add_state(
2✔
410
            state_name='hide_stim',
411
            state_timer=0.1,
412
            output_actions=[self.bpod.actions.bonsai_hide_stim],
413
            state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'},
414
        )
415

416
        sma.add_state(
2✔
417
            state_name='exit_state',
418
            state_timer=self.task_params.ITI_DELAY_SECS,
419
            output_actions=[('BNC1', 255)],
420
            state_change_conditions={'Tup': 'exit'},
421
        )
422
        return sma
2✔
423

424
    @abc.abstractmethod
2✔
425
    def next_trial(self):
2✔
UNCOV
426
        pass
×
427

428
    @property
2✔
429
    def default_reward_amount(self):
2✔
430
        return self.task_params.REWARD_AMOUNT_UL
2✔
431

432
        """Draw next trial variables.
433

434
        calls :meth:`send_trial_info_to_bonsai`.
435
        This is called by the `next_trial` method before updating the Bpod state machine. This also
436
        """
437

438
    def draw_next_trial_info(self, pleft=0.5, contrast=None, position=None, reward_amount=None):
2✔
439
        if contrast is None:
2✔
440
            contrast = misc.draw_contrast(self.task_params.CONTRAST_SET, self.task_params.CONTRAST_SET_PROBABILITY_TYPE)
2✔
441
        assert len(self.task_params.STIM_POSITIONS) == 2, 'Only two positions are supported'
2✔
442
        position = position or int(np.random.choice(self.task_params.STIM_POSITIONS, p=[pleft, 1 - pleft]))
2✔
443
        quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.truncated_exponential(
2✔
444
            scale=0.35, min_value=0.2, max_value=0.5
445
        )
446
        reward_amount = self.default_reward_amount if reward_amount is None else reward_amount
2✔
447
        self.trials_table.at[self.trial_num, 'quiescent_period'] = quiescent_period
2✔
448
        self.trials_table.at[self.trial_num, 'contrast'] = contrast
2✔
449
        self.trials_table.at[self.trial_num, 'stim_phase'] = random.uniform(0, 2 * math.pi)
2✔
450
        self.trials_table.at[self.trial_num, 'stim_sigma'] = self.task_params.STIM_SIGMA
2✔
451
        self.trials_table.at[self.trial_num, 'stim_angle'] = self.task_params.STIM_ANGLE
2✔
452
        self.trials_table.at[self.trial_num, 'stim_gain'] = self.task_params.STIM_GAIN
2✔
453
        self.trials_table.at[self.trial_num, 'stim_freq'] = self.task_params.STIM_FREQ
2✔
454
        self.trials_table.at[self.trial_num, 'trial_num'] = self.trial_num
2✔
455
        self.trials_table.at[self.trial_num, 'position'] = position
2✔
456
        self.trials_table.at[self.trial_num, 'reward_amount'] = reward_amount
2✔
457
        self.trials_table.at[self.trial_num, 'stim_probability_left'] = pleft
2✔
458
        self.send_trial_info_to_bonsai()
2✔
459

460
    def trial_completed(self, bpod_data):
2✔
461
        # if the reward state has not been triggered, null the reward
462
        if np.isnan(bpod_data['States timestamps']['reward'][0][0]):
2✔
463
            self.trials_table.at[self.trial_num, 'reward_amount'] = 0
2✔
464
        self.trials_table.at[self.trial_num, 'reward_valve_time'] = self.reward_time
2✔
465
        # update cumulative reward value
466
        self.session_info.TOTAL_WATER_DELIVERED += self.trials_table.at[self.trial_num, 'reward_amount']
2✔
467
        self.session_info.NTRIALS += 1
2✔
468
        # SAVE TRIAL DATA
469
        save_dict = self.trials_table.iloc[self.trial_num].to_dict()
2✔
470
        save_dict['behavior_data'] = bpod_data
2✔
471
        # Dump and save
472
        with open(self.paths['DATA_FILE_PATH'], 'a') as fp:
2✔
473
            fp.write(json.dumps(save_dict) + '\n')
2✔
474
        # this is a flag for the online plots. If online plots were in pyqt5, there is a file watcher functionality
475
        Path(self.paths['DATA_FILE_PATH']).parent.joinpath('new_trial.flag').touch()
2✔
476
        self.paths.SESSION_FOLDER.joinpath('transfer_me.flag').touch()
2✔
477
        self.check_sync_pulses(bpod_data=bpod_data)
2✔
478

479
    def check_sync_pulses(self, bpod_data):
2✔
480
        # todo move this in the post trial when we have a task flow
481
        if not self.bpod.is_connected:
2✔
482
            return
2✔
UNCOV
483
        events = bpod_data['Events timestamps']
×
484
        if not misc.get_port_events(events, name='BNC1'):
×
485
            log.warning("NO FRAME2TTL PULSES RECEIVED ON BPOD'S TTL INPUT 1")
×
UNCOV
486
        if not misc.get_port_events(events, name='BNC2'):
×
UNCOV
487
            log.warning("NO SOUND SYNC PULSES RECEIVED ON BPOD'S TTL INPUT 2")
×
UNCOV
488
        if not misc.get_port_events(events, name='Port1'):
×
UNCOV
489
            log.warning("NO CAMERA SYNC PULSES RECEIVED ON BPOD'S BEHAVIOR PORT 1")
×
490

491
    def show_trial_log(self, extra_info=''):
2✔
492
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
493
        level = logging.INFO
2✔
494
        log.log(level=level, msg=f'Outcome of Trial #{trial_info.trial_num}:')
2✔
495
        log.log(level=level, msg=f'- Stim. Position:  {trial_info.position}')
2✔
496
        log.log(level=level, msg=f'- Stim. Contrast:  {trial_info.contrast}')
2✔
497
        log.log(level=level, msg=f'- Stim. Phase:     {trial_info.stim_phase}')
2✔
498
        log.log(level=level, msg=f'- Stim. p Left:    {trial_info.stim_probability_left}')
2✔
499
        log.log(level=level, msg=f'- Water delivered: {self.session_info.TOTAL_WATER_DELIVERED:.1f} µl')
2✔
500
        log.log(level=level, msg=f'- Time from Start: {self.time_elapsed}')
2✔
501
        log.log(level=level, msg=f'- Temperature:     {self.ambient_sensor_table.loc[self.trial_num, "Temperature_C"]:.1f} °C')
2✔
502
        log.log(level=level, msg=f'- Air Pressure:    {self.ambient_sensor_table.loc[self.trial_num, "AirPressure_mb"]:.1f} mb')
2✔
503
        log.log(
2✔
504
            level=level, msg=f'- Rel. Humidity:   {self.ambient_sensor_table.loc[self.trial_num, "RelativeHumidity"]:.1f} %\n'
505
        )
506

507
    @property
2✔
508
    def iti_reward(self):
2✔
509
        """
510
        Returns the ITI time that needs to be set in order to achieve the desired ITI,
511
        by subtracting the time it takes to give a reward from the desired ITI.
512
        """
513
        return self.task_params.ITI_CORRECT - self.calibration.get('REWARD_VALVE_TIME', None)
×
514

515
    """
2✔
516
    Those are the properties that are used in the state machine code
517
    """
518

519
    @property
2✔
520
    def reward_time(self):
2✔
521
        return self.compute_reward_time(amount_ul=self.trials_table.at[self.trial_num, 'reward_amount'])
2✔
522

523
    @property
2✔
524
    def quiescent_period(self):
2✔
525
        return self.trials_table.at[self.trial_num, 'quiescent_period']
2✔
526

527
    @property
2✔
528
    def position(self):
2✔
529
        return self.trials_table.at[self.trial_num, 'position']
2✔
530

531
    @property
2✔
532
    def event_error(self):
2✔
533
        return self.device_rotary_encoder.THRESHOLD_EVENTS[self.position]
2✔
534

535
    @property
2✔
536
    def event_reward(self):
2✔
537
        return self.device_rotary_encoder.THRESHOLD_EVENTS[-self.position]
2✔
538

539

540
class HabituationChoiceWorldSession(ChoiceWorldSession):
2✔
541
    protocol_name = '_iblrig_tasks_habituationChoiceWorld'
2✔
542

543
    def __init__(self, **kwargs):
2✔
544
        super().__init__(**kwargs)
2✔
545
        self.trials_table['delay_to_stim_center'] = np.zeros(NTRIALS_INIT) * np.NaN
2✔
546

547
    def next_trial(self):
2✔
548
        self.trial_num += 1
2✔
549
        self.draw_next_trial_info()
2✔
550

551
    def draw_next_trial_info(self, *args, **kwargs):
2✔
552
        # update trial table fields specific to habituation choice world
553
        self.trials_table.at[self.trial_num, 'delay_to_stim_center'] = np.random.normal(self.task_params.DELAY_TO_STIM_CENTER, 2)
2✔
554
        super().draw_next_trial_info(*args, **kwargs)
2✔
555

556
    def get_state_machine_trial(self, i):
2✔
557
        sma = StateMachine(self.bpod)
2✔
558

559
        if i == 0:  # First trial exception start camera
2✔
560
            log.info('Waiting for camera pulses...')
2✔
561
            sma.add_state(
2✔
562
                state_name='iti',
563
                state_timer=3600,
564
                state_change_conditions={'Port1In': 'stim_on'},
565
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
566
            )  # start camera
567
        else:
568
            # NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on.
569
            # During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end;
570
            # the offset of this state is trial start!
UNCOV
571
            sma.add_state(
×
572
                state_name='iti',
573
                state_timer=1,  # Stim off for 1 sec
574
                state_change_conditions={'Tup': 'stim_on'},
575
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)],
576
            )
577
        # This stim_on state is considered the actual trial start
578
        sma.add_state(
2✔
579
            state_name='stim_on',
580
            state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'],
581
            state_change_conditions={'Tup': 'stim_center'},
582
            output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone],
583
        )
584

585
        sma.add_state(
2✔
586
            state_name='stim_center',
587
            state_timer=0.5,
588
            state_change_conditions={'Tup': 'reward'},
589
            output_actions=[self.bpod.actions.bonsai_show_center],
590
        )
591

592
        sma.add_state(
2✔
593
            state_name='reward',
594
            state_timer=self.reward_time,  # the length of time to leave reward valve open, i.e. reward size
595
            state_change_conditions={'Tup': 'post_reward'},
596
            output_actions=[('Valve1', 255), ('BNC1', 255)],
597
        )
598
        # This state defines the period after reward where Bpod TTL is LOW.
599
        # NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit.
600
        # The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds.
601
        sma.add_state(
2✔
602
            state_name='post_reward',
603
            state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time,
604
            state_change_conditions={'Tup': 'exit'},
605
            output_actions=[],
606
        )
607
        return sma
2✔
608

609

610
class ActiveChoiceWorldSession(ChoiceWorldSession):
2✔
611
    """
612
    The ActiveChoiceWorldSession is a base class for protocols where the mouse is actively making decisions
613
    by turning the wheel. It has the following characteristics
614
    -   it is trial based
615
    -   it is decision based
616
    -   left and right simulus are equiprobable: there is no biased block
617
    -   a trial can either be correct / error / no_go depending on the side of the stimulus and the response
618
    -   it has a quantifiable performance by computing the proportion of correct trials of passive stimulations protocols or
619
        habituation protocols.
620

621
    The TrainingChoiceWorld, BiasedChoiceWorld are all subclasses of this class
622
    """
623

624
    def __init__(self, **kwargs):
2✔
625
        super().__init__(**kwargs)
2✔
626
        self.trials_table['stim_probability_left'] = np.zeros(NTRIALS_INIT, dtype=np.float32)
2✔
627

628
    def _run(self):
2✔
629
        # starts online plotting
630
        if self.interactive:
2✔
UNCOV
631
            subprocess.Popen(
×
632
                ['view_session', str(self.paths['DATA_FILE_PATH']), str(self.paths['SETTINGS_FILE_PATH'])],
633
                stdout=subprocess.DEVNULL,
634
                stderr=subprocess.STDOUT,
635
            )
636
        super()._run()
2✔
637

638
    def show_trial_log(self, extra_info=''):
2✔
639
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
640
        extra_info = f"""
2✔
641
RESPONSE TIME:        {trial_info.response_time}
642
{extra_info}
643

644
TRIAL CORRECT:        {trial_info.trial_correct}
645
NTRIALS CORRECT:      {self.session_info.NTRIALS_CORRECT}
646
NTRIALS ERROR:        {self.trial_num - self.session_info.NTRIALS_CORRECT}
647
        """
648
        super().show_trial_log(extra_info=extra_info)
2✔
649

650
    def trial_completed(self, bpod_data):
2✔
651
        """
652
        The purpose of this method is to
653
        -   update the trials table with information about the behaviour coming from the bpod
654
        Constraints on the state machine data:
655
        - mandatory states: ['correct', 'error', 'no_go', 'reward']
656
        - optional states : ['omit_correct', 'omit_error', 'omit_no_go']
657
        :param bpod_data:
658
        :return:
659
        """
660
        # get the response time from the behaviour data
661
        response_time = bpod_data['States timestamps']['closed_loop'][0][1] - bpod_data['States timestamps']['stim_on'][0][0]
2✔
662
        self.trials_table.at[self.trial_num, 'response_time'] = response_time
2✔
663
        # get the trial outcome
664
        state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go']
2✔
665
        raw_outcome = {sn: ~np.isnan(bpod_data['States timestamps'].get(sn, [[np.NaN]])[0][0]) for sn in state_names}
2✔
666
        outcome = next(k for k in raw_outcome if raw_outcome[k])
2✔
667
        # Update response buffer -1 for left, 0 for nogo, and 1 for rightward
668
        position = self.trials_table.at[self.trial_num, 'position']
2✔
669
        if 'correct' in outcome:
2✔
670
            self.trials_table.at[self.trial_num, 'trial_correct'] = True
2✔
671
            self.session_info.NTRIALS_CORRECT += 1
2✔
672
            self.trials_table.at[self.trial_num, 'response_side'] = -np.sign(position)
2✔
673
        elif 'error' in outcome:
2✔
674
            self.trials_table.at[self.trial_num, 'response_side'] = np.sign(position)
2✔
675
        elif 'no_go' in outcome:
2✔
676
            self.trials_table.at[self.trial_num, 'response_side'] = 0
2✔
677
        super().trial_completed(bpod_data)
2✔
678
        # here we throw potential errors after having written the trial to disk
679
        assert np.sum(list(raw_outcome.values())) == 1
2✔
680
        assert position != 0, 'the position value should be either 35 or -35'
2✔
681

682

683
class BiasedChoiceWorldSession(ActiveChoiceWorldSession):
2✔
684
    """
685
    Biased choice world session is the instantiation of ActiveChoiceWorld where the notion of biased
686
    blocks is introduced.
687
    """
688

689
    base_parameters_file = Path(__file__).parent.joinpath('base_biased_choice_world_params.yaml')
2✔
690
    protocol_name = '_iblrig_tasks_biasedChoiceWorld'
2✔
691

692
    def __init__(self, **kwargs):
2✔
693
        super().__init__(**kwargs)
2✔
694
        self.blocks_table = pd.DataFrame(
2✔
695
            {'probability_left': np.zeros(NBLOCKS_INIT) * np.NaN, 'block_length': np.zeros(NBLOCKS_INIT, dtype=np.int16) * -1}
696
        )
697
        self.trials_table['block_num'] = np.zeros(NTRIALS_INIT, dtype=np.int16)
2✔
698
        self.trials_table['block_trial_num'] = np.zeros(NTRIALS_INIT, dtype=np.int16)
2✔
699

700
    def new_block(self):
2✔
701
        """
702
        if block_init_5050
703
            First block has 50/50 probability of leftward stim
704
            is 90 trials long
705
        """
706
        self.block_num += 1  # the block number is zero based
2✔
707
        self.block_trial_num = 0
2✔
708

709
        # handles the block length logic
710
        if self.task_params.BLOCK_INIT_5050 and self.block_num == 0:
2✔
711
            block_len = 90
2✔
712
        else:
713
            block_len = int(
2✔
714
                misc.truncated_exponential(
715
                    scale=self.task_params.BLOCK_LEN_FACTOR,
716
                    min_value=self.task_params.BLOCK_LEN_MIN,
717
                    max_value=self.task_params.BLOCK_LEN_MAX,
718
                )
719
            )
720
        if self.block_num == 0:
2✔
721
            pleft = 0.5 if self.task_params.BLOCK_INIT_5050 else np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
722
        elif self.block_num == 1 and self.task_params.BLOCK_INIT_5050:
2✔
723
            pleft = np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
724
        else:
725
            # this switches the probability of leftward stim for the next block
726
            pleft = round(abs(1 - self.blocks_table.loc[self.block_num - 1, 'probability_left']), 1)
2✔
727
        self.blocks_table.at[self.block_num, 'block_length'] = block_len
2✔
728
        self.blocks_table.at[self.block_num, 'probability_left'] = pleft
2✔
729

730
    def next_trial(self):
2✔
731
        self.trial_num += 1
2✔
732
        # if necessary update the block number
733
        self.block_trial_num += 1
2✔
734
        if self.block_num < 0 or self.block_trial_num > (self.blocks_table.loc[self.block_num, 'block_length'] - 1):
2✔
735
            self.new_block()
2✔
736
        # get and store probability left
737
        pleft = self.blocks_table.loc[self.block_num, 'probability_left']
2✔
738
        # update trial table fields specific to biased choice world task
739
        self.trials_table.at[self.trial_num, 'block_num'] = self.block_num
2✔
740
        self.trials_table.at[self.trial_num, 'block_trial_num'] = self.block_trial_num
2✔
741
        # save and send trial info to bonsai
742
        self.draw_next_trial_info(pleft=pleft)
2✔
743

744
    def show_trial_log(self):
2✔
745
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
746
        extra_info = f"""
2✔
747
BLOCK NUMBER:         {trial_info.block_num}
748
BLOCK LENGTH:         {self.blocks_table.loc[self.block_num, 'block_length']}
749
TRIALS IN BLOCK:      {trial_info.block_trial_num}
750
        """
751
        super().show_trial_log(extra_info=extra_info)
2✔
752

753

754
class TrainingChoiceWorldSession(ActiveChoiceWorldSession):
2✔
755
    """
756
    The TrainingChoiceWorldSession corresponds to the first training protocol of the choice world task.
757
    This protocol has a complicated adaptation of the number of contrasts (embodied by the training_phase
758
    property) and the reward amount, embodied by the adaptive_reward property.
759
    """
760

761
    protocol_name = '_iblrig_tasks_trainingChoiceWorld'
2✔
762

763
    def __init__(self, training_phase=-1, adaptive_reward=-1.0, adaptive_gain=None, **kwargs):
2✔
764
        super().__init__(**kwargs)
2✔
765
        inferred_training_phase, inferred_adaptive_reward, inferred_adaptive_gain = self.get_subject_training_info()
2✔
766
        if training_phase == -1:
2✔
767
            log.critical(f'Got training phase: {inferred_training_phase}')
2✔
768
            self.training_phase = inferred_training_phase
2✔
769
        else:
770
            log.critical(f'Training phase manually set to: {training_phase}')
2✔
771
            self.training_phase = training_phase
2✔
772
        if adaptive_reward == -1:
2✔
773
            log.critical(f'Got Adaptive reward {inferred_adaptive_reward} uL')
2✔
774
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = inferred_adaptive_reward
2✔
775
        else:
776
            log.critical(f'Adaptive reward manually set to {adaptive_reward} uL')
2✔
777
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = adaptive_reward
2✔
778
        if adaptive_gain is None:
2✔
779
            log.critical(f'Got Adaptive gain {inferred_adaptive_gain} degrees/mm')
2✔
780
            self.session_info['ADAPTIVE_GAIN_VALUE'] = inferred_adaptive_gain
2✔
781
        else:
782
            log.critical(f'Adaptive gain manually set to {adaptive_gain} degrees/mm')
2✔
783
            self.session_info['ADAPTIVE_GAIN_VALUE'] = adaptive_gain
2✔
784
        self.var = {'training_phase_trial_counts': np.zeros(6), 'last_10_responses_sides': np.zeros(10)}
2✔
785
        self.trials_table['training_phase'] = np.zeros(NTRIALS_INIT, dtype=np.int8)
2✔
786
        self.trials_table['debias_trial'] = np.zeros(NTRIALS_INIT, dtype=bool)
2✔
787

788
    @property
2✔
789
    def default_reward_amount(self):
2✔
790
        return self.session_info.get('ADAPTIVE_REWARD_AMOUNT_UL', self.task_params.REWARD_AMOUNT_UL)
2✔
791

792
    def get_subject_training_info(self):
2✔
793
        """
794
        Get the previous session's according to this session parameters and deduce the
795
        training level, adaptive reward amount and adaptive gain value
796
        :return:
797
        """
798
        try:
2✔
799
            tinfo, _ = choiceworld.get_subject_training_info(
2✔
800
                subject_name=self.session_info.SUBJECT_NAME,
801
                default_reward=self.task_params.REWARD_AMOUNT_UL,
802
                stim_gain=self.task_params.STIM_GAIN,
803
                local_path=self.iblrig_settings['iblrig_local_data_path'],
804
                remote_path=self.iblrig_settings['iblrig_remote_data_path'],
805
                lab=self.iblrig_settings['ALYX_LAB'],
806
                task_name=self.protocol_name,
807
                iblrig_settings=self.iblrig_settings,
808
            )
809
        except Exception:
2✔
810
            log.critical('Failed to get training information from previous subjects: %s', traceback.format_exc())
2✔
811
            tinfo = dict(
2✔
812
                training_phase=iblrig.choiceworld.DEFAULT_TRAINING_PHASE,
813
                adaptive_reward=iblrig.choiceworld.DEFAULT_REWARD_VOLUME,
814
                adaptive_gain=self.task_params.AG_INIT_VALUE,
815
            )
816
            log.critical(
2✔
817
                f"The mouse will train on level {tinfo['training_phase']}, "
818
                f"with reward {tinfo['adaptive_reward']} uL and gain {tinfo['adaptive_gain']}"
819
            )
820
        return tinfo['training_phase'], tinfo['adaptive_reward'], tinfo['adaptive_gain']
2✔
821

822
    def compute_performance(self):
2✔
823
        """
824
        Aggregates the trials table to compute the performance of the mouse on each contrast
825
        :return: None
826
        """
827
        self.trials_table['signed_contrast'] = self.trials_table['contrast'] * np.sign(self.trials_table['position'])
2✔
828
        performance = self.trials_table.groupby(['signed_contrast']).agg(
2✔
829
            last_50_perf=pd.NamedAgg(column='trial_correct', aggfunc=lambda x: np.sum(x[np.maximum(-50, -x.size) :]) / 50),
830
            ntrials=pd.NamedAgg(column='trial_correct', aggfunc='count'),
831
        )
832
        return performance
2✔
833

834
    def check_training_phase(self):
2✔
835
        """
836
        Checks if the mouse is ready to move to the next training phase
837
        :return: None
838
        """
839
        move_on = False
2✔
840
        if self.training_phase == 0:  # each of the -1, -.5, .5, 1 contrast should be above 80% perf to switch
2✔
841
            performance = self.compute_performance()
2✔
842
            passing = performance[np.abs(performance.index) >= 0.5]['last_50_perf']
2✔
843
            if np.all(passing > 0.8) and passing.size == 4:
2✔
844
                move_on = True
2✔
845
        elif self.training_phase == 1:  # each of the -.25, .25 should be above 80% perf to switch
2✔
846
            performance = self.compute_performance()
2✔
847
            passing = performance[np.abs(performance.index) == 0.25]['last_50_perf']
2✔
848
            if np.all(passing > 0.8) and passing.size == 2:
2✔
849
                move_on = True
2✔
850
        elif 5 > self.training_phase >= 2:  # for the next phases, always switch after 200 trials
2✔
851
            if self.var['training_phase_trial_counts'][self.training_phase] >= 200:
2✔
852
                move_on = True
2✔
853
        if move_on:
2✔
854
            self.training_phase = np.minimum(5, self.training_phase + 1)
2✔
855
            log.warning(f'Moving on to training phase {self.training_phase}, {self.trial_num}')
2✔
856

857
    def next_trial(self):
2✔
858
        # update counters
859
        self.trial_num += 1
2✔
860
        self.var['training_phase_trial_counts'][self.training_phase] += 1
2✔
861
        # check if the subject graduates to a new training phase
862
        self.check_training_phase()
2✔
863
        # draw the next trial
864
        signed_contrast = choiceworld.draw_training_contrast(self.training_phase)
2✔
865
        position = self.task_params.STIM_POSITIONS[int(np.sign(signed_contrast) == 1)]
2✔
866
        contrast = np.abs(signed_contrast)
2✔
867
        # debiasing: if the previous trial was incorrect and easy repeat the trial
868
        if self.task_params.DEBIAS and self.trial_num >= 1 and self.training_phase < 5:
2✔
869
            last_contrast = self.trials_table.loc[self.trial_num - 1, 'contrast']
2✔
870
            do_debias_trial = (self.trials_table.loc[self.trial_num - 1, 'trial_correct'] != 1) and last_contrast >= 0.5
2✔
871
            self.trials_table.at[self.trial_num, 'debias_trial'] = do_debias_trial
2✔
872
            if do_debias_trial:
2✔
873
                iresponse = self.trials_table['response_side'] != 0  # trials that had a response
2✔
874
                # takes the average of right responses over last 10 response trials
875
                average_right = np.mean(self.trials_table['response_side'][iresponse[-np.maximum(10, iresponse.size) :]] == 1)
2✔
876
                # the next probability of next stimulus being on the left is a draw from a normal distribution
877
                # centered on average right with sigma 0.5. If it is less than 0.5 the next stimulus will be on the left
878
                position = self.task_params.STIM_POSITIONS[int(np.random.normal(average_right, 0.5) >= 0.5)]
2✔
879
                # contrast is the last contrast
880
                contrast = last_contrast
2✔
881
        # save and send trial info to bonsai
882
        self.draw_next_trial_info(pleft=self.task_params.PROBABILITY_LEFT, position=position, contrast=contrast)
2✔
883
        self.trials_table.at[self.trial_num, 'training_phase'] = self.training_phase
2✔
884

885
    def show_trial_log(self):
2✔
886
        extra_info = f"""
2✔
887
CONTRAST SET:         {np.unique(np.abs(choiceworld.contrasts_set(self.training_phase)))}
888
SUBJECT TRAINING PHASE (0-5):         {self.training_phase}
889
            """
890
        super().show_trial_log(extra_info=extra_info)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc