• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 15883534687

25 Jun 2025 05:53PM UTC coverage: 50.257% (+3.5%) from 46.79%
15883534687

Pull #783

github

d0d43a
web-flow
Merge 6b5b19c5a into f353f90a7
Pull Request #783: UDP ephys

93 of 102 new or added lines in 1 file covered. (91.18%)

1073 existing lines in 20 files now uncovered.

4892 of 9734 relevant lines covered (50.26%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.51
/iblrig/base_choice_world.py
1
"""Extends the base_tasks modules by providing task logic around the Choice World protocol."""
2

3
import abc
2✔
4
import enum
2✔
5
import logging
2✔
6
import math
2✔
7
import random
2✔
8
import subprocess
2✔
9
import time
2✔
10
from pathlib import Path
2✔
11
from re import split as re_split
2✔
12
from string import ascii_letters
2✔
13
from typing import Annotated, Any, final
2✔
14

15
import numpy as np
2✔
16
import pandas as pd
2✔
17
from annotated_types import Interval, IsNan
2✔
18
from pydantic import NonNegativeFloat, NonNegativeInt
2✔
19

20
import iblrig.base_tasks
2✔
21
from iblrig import choiceworld, misc
2✔
22
from iblrig.hardware import DTYPE_AMBIENT_SENSOR_BIN, SOFTCODE
2✔
23
from iblrig.pydantic_definitions import TrialDataModel
2✔
24
from iblutil.io import binary, jsonable
2✔
25
from iblutil.util import Bunch
2✔
26
from pybpodapi.com.messaging.trial import Trial
2✔
27
from pybpodapi.protocol import StateMachine
2✔
28

29
log = logging.getLogger(__name__)
2✔
30

31
NTRIALS_INIT = 2000
2✔
32
NBLOCKS_INIT = 100
2✔
33

34

35
# TODO: task parameters should be verified through a pydantic model
36
#
37
# Probability = Annotated[float, Field(ge=0.0, le=1.0)]
38
#
39
# class ChoiceWorldParams(BaseModel):
40
#     AUTOMATIC_CALIBRATION: bool = True
41
#     ADAPTIVE_REWARD: bool = False
42
#     BONSAI_EDITOR: bool = False
43
#     CALIBRATION_VALUE: float = 0.067
44
#     CONTRAST_SET: list[Probability] = Field([1.0, 0.25, 0.125, 0.0625, 0.0], min_length=1)
45
#     CONTRAST_SET_PROBABILITY_TYPE: Literal['uniform', 'skew_zero'] = 'uniform'
46
#     GO_TONE_AMPLITUDE: float = 0.0272
47
#     GO_TONE_DURATION: float = 0.11
48
#     GO_TONE_IDX: int = Field(2, ge=0)
49
#     GO_TONE_FREQUENCY: float = Field(5000, gt=0)
50
#     FEEDBACK_CORRECT_DELAY_SECS: float = 1
51
#     FEEDBACK_ERROR_DELAY_SECS: float = 2
52
#     FEEDBACK_NOGO_DELAY_SECS: float = 2
53
#     INTERACTIVE_DELAY: float = 0.0
54
#     ITI_DELAY_SECS: float = 0.5
55
#     NTRIALS: int = Field(2000, gt=0)
56
#     PROBABILITY_LEFT: Probability = 0.5
57
#     QUIESCENCE_THRESHOLDS: list[float] = Field(default=[-2, 2], min_length=2, max_length=2)
58
#     QUIESCENT_PERIOD: float = 0.2
59
#     RECORD_AMBIENT_SENSOR_DATA: bool = True
60
#     RECORD_SOUND: bool = True
61
#     RESPONSE_WINDOW: float = 60
62
#     REWARD_AMOUNT_UL: float = 1.5
63
#     REWARD_TYPE: str = 'Water 10% Sucrose'
64
#     STIM_ANGLE: float = 0.0
65
#     STIM_FREQ: float = 0.1
66
#     STIM_GAIN: float = 4.0  # wheel to stimulus relationship (degrees visual angle per mm of wheel displacement)
67
#     STIM_POSITIONS: list[float] = [-35, 35]
68
#     STIM_SIGMA: float = 7.0
69
#     STIM_TRANSLATION_Z: Literal[7, 8] = 7  # 7 for ephys, 8 otherwise. -p:Stim.TranslationZ-{STIM_TRANSLATION_Z} bonsai param
70
#     STIM_REVERSE: bool = False
71
#     SYNC_SQUARE_X: float = 1.33
72
#     SYNC_SQUARE_Y: float = -1.03
73
#     USE_AUTOMATIC_STOPPING_CRITERIONS: bool = True
74
#     VISUAL_STIMULUS: str = 'GaborIBLTask / Gabor2D.bonsai'  # null / passiveChoiceWorld_passive.bonsai
75
#     WHITE_NOISE_AMPLITUDE: float = 0.05
76
#     WHITE_NOISE_DURATION: float = 0.5
77
#     WHITE_NOISE_IDX: int = 3
78

79

80
class ChoiceWorldTrialData(TrialDataModel):
2✔
81
    """Pydantic Model for Trial Data."""
82

83
    contrast: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
84
    stim_probability_left: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
85
    position: float
2✔
86
    quiescent_period: NonNegativeFloat
2✔
87
    reward_amount: NonNegativeFloat
2✔
88
    reward_valve_time: NonNegativeFloat
2✔
89
    stim_angle: Annotated[float, Interval(ge=-180.0, le=180.0)]
2✔
90
    stim_freq: NonNegativeFloat
2✔
91
    stim_gain: float
2✔
92
    stim_phase: Annotated[float, Interval(ge=0.0, le=2 * math.pi)]
2✔
93
    stim_reverse: bool
2✔
94
    stim_sigma: float
2✔
95
    trial_num: NonNegativeInt
2✔
96
    pause_duration: NonNegativeFloat = 0.0
2✔
97

98
    # The following variables are only used in ActiveChoiceWorld
99
    # We keep them here with fixed default values for sake of compatibility
100
    #
101
    # TODO: Yes, this should probably be done differently.
102
    response_side: Annotated[int, Interval(ge=0, le=0)] = 0
2✔
103
    response_time: IsNan[float] = np.nan
2✔
104
    trial_correct: Annotated[bool, Interval(ge=0, le=0)] = False
2✔
105

106

107
class ChoiceWorldSession(
2✔
108
    iblrig.base_tasks.BonsaiRecordingMixin,
109
    iblrig.base_tasks.BonsaiVisualStimulusMixin,
110
    iblrig.base_tasks.BpodMixin,
111
    iblrig.base_tasks.Frame2TTLMixin,
112
    iblrig.base_tasks.RotaryEncoderMixin,
113
    iblrig.base_tasks.SoundMixin,
114
    iblrig.base_tasks.ValveMixin,
115
    iblrig.base_tasks.NetworkSession,
116
):
117
    # task_params = ChoiceWorldParams()
118
    base_parameters_file = Path(__file__).parent.joinpath('base_choice_world_params.yaml')
2✔
119
    TrialDataModel = ChoiceWorldTrialData
2✔
120

121
    def __init__(self, *args, delay_mins: float = 0, **kwargs):
2✔
122
        super().__init__(**kwargs)
2✔
123

124
        # session delay is handled in seconds internally
125
        self.task_params['SESSION_DELAY_START'] = delay_mins * 60.0
2✔
126

127
        # init behaviour data
128
        self.movement_left = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[0]]
2✔
129
        self.movement_right = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[1]]
2✔
130

131
        # init counter variables
132
        self.trial_num = -1
2✔
133
        self.block_num = -1
2✔
134
        self.block_trial_num = -1
2✔
135

136
        # init the tables, there are 2 of them: a trials table and a ambient sensor data table
137
        self.trials_table = self.TrialDataModel.preallocate_dataframe(NTRIALS_INIT)
2✔
138
        self.ambient_sensor_table = pd.DataFrame(
2✔
139
            np.nan, index=range(NTRIALS_INIT), columns=['Temperature_C', 'AirPressure_mb', 'RelativeHumidity'], dtype=np.float32
140
        )
141
        self.ambient_sensor_table.rename_axis('Trial', inplace=True)
2✔
142

143
    @staticmethod
2✔
144
    def extra_parser():
2✔
145
        """:return: argparse.parser()"""
146
        parser = super(ChoiceWorldSession, ChoiceWorldSession).extra_parser()
2✔
147
        parser.add_argument(
2✔
148
            '--delay_mins',
149
            dest='delay_mins',
150
            default=0,
151
            type=float,
152
            required=False,
153
            help='initial delay before starting the first trial (default: 0 min)',
154
        )
155
        return parser
2✔
156

157
    def start_hardware(self):
2✔
158
        """
159
        In this step we explicitly run the start methods of the various mixins.
160
        The super class start method is overloaded because we need to start the different hardware pieces in order
161
        """
162
        if not self.is_mock:
2✔
UNCOV
163
            self.start_mixin_frame2ttl()
×
UNCOV
164
            self.start_mixin_bpod()
×
UNCOV
165
            self.start_mixin_valve()
×
UNCOV
166
            self.start_mixin_sound()
×
UNCOV
167
            self.start_mixin_rotary_encoder()
×
UNCOV
168
            self.start_mixin_bonsai_cameras()
×
UNCOV
169
            self.start_mixin_bonsai_microphone()
×
UNCOV
170
            self.start_mixin_bonsai_visual_stimulus()
×
UNCOV
171
            self.bpod.register_softcodes(self.softcode_dictionary())
×
172

173
    @final
2✔
174
    def _wait_for_camera_and_initial_delay(self) -> None:
2✔
175
        """Wait for the camera to start recording and manage the initial delay.
176

177
        This method implements a temporary state machine to coordinate the process of waiting for the camera recording
178
        to commence and to handle any specified initial delay. It should be called just prior to the start of the task.
179
        The states defined here were previously part of the task's main state machine (see `get_state_machine_trial()`).
180
        """
181
        initial_delay = self.task_params.get('SESSION_DELAY_START', 0)
2✔
182

183
        # temporary IntEnum for storing softcodes
184
        # SOFTCODE.TRIGGER_CAMERA is being reused; we add three more unique values
185
        class TemporarySoftcodes(enum.IntEnum):
2✔
186
            START_CAMERA_RECORDING = SOFTCODE.TRIGGER_CAMERA.value
2✔
187
            WAIT_FOR_CAMERA_TRIGGER = enum.auto()
2✔
188
            CAMERA_TRIGGER_RECEIVED = enum.auto()
2✔
189
            STARTING_INITIAL_DELAY = enum.auto()
2✔
190

191
        # store the original softcode handler
192
        original_softcode_handler = self.bpod.softcode_handler_function
2✔
193

194
        # define temporary softcode handler
195
        def temporary_softcode_handler(softcode: int):
2✔
UNCOV
196
            match softcode:
×
UNCOV
197
                case TemporarySoftcodes.START_CAMERA_RECORDING:
×
UNCOV
198
                    original_softcode_handler(softcode)  # pass to original handler
×
UNCOV
199
                case TemporarySoftcodes.WAIT_FOR_CAMERA_TRIGGER:
×
UNCOV
200
                    log.info('Waiting to receive first camera trigger ...')
×
UNCOV
201
                case TemporarySoftcodes.CAMERA_TRIGGER_RECEIVED:
×
UNCOV
202
                    log.info('Camera trigger received')
×
UNCOV
203
                case TemporarySoftcodes.STARTING_INITIAL_DELAY:
×
UNCOV
204
                    if initial_delay > 0:
×
UNCOV
205
                        log.info(f'Waiting for {initial_delay} s')
×
206
                    else:
UNCOV
207
                        log.info('No initial delay defined')
×
208

209
        # overwrite softcode handler
210
        self.bpod.softcode_handler_function = temporary_softcode_handler
2✔
211

212
        # define and run state machine
213
        sma = StateMachine(self.bpod)
2✔
214
        sma.add_state(
2✔
215
            state_name='start_camera_workflow',
216
            output_actions=[('SoftCode', TemporarySoftcodes.START_CAMERA_RECORDING)],
217
            state_change_conditions={'Tup': 'wait_for_camera_trigger'},
218
        )
219
        sma.add_state(
2✔
220
            state_name='wait_for_camera_trigger',
221
            output_actions=[('SoftCode', TemporarySoftcodes.WAIT_FOR_CAMERA_TRIGGER)],
222
            state_change_conditions={'Port1In': 'camera_trigger_received'},
223
        )
224
        sma.add_state(
2✔
225
            state_name='camera_trigger_received',
226
            output_actions=[('SoftCode', TemporarySoftcodes.CAMERA_TRIGGER_RECEIVED)],
227
            state_change_conditions={'Tup': 'delay_initiation'},
228
        )
229
        sma.add_state(
2✔
230
            state_name='delay_initiation',
231
            state_timer=self.task_params.get('SESSION_DELAY_START', 0),
232
            output_actions=[('SoftCode', TemporarySoftcodes.STARTING_INITIAL_DELAY)],
233
            state_change_conditions={'Tup': 'exit'},
234
        )
235
        self.bpod.send_state_machine(sma)
2✔
236
        self.bpod.run_state_machine(sma)  # blocking until state-machine is finished
2✔
237
        if initial_delay > 0:
2✔
UNCOV
238
            log.info('Initial delay has passed')
×
239

240
        # restore original softcode handler
241
        self.bpod.softcode_handler_function = original_softcode_handler
2✔
242

243
    def _run(self) -> None:
2✔
244
        """Execute the task using the defined state machine.
245

246
        This method orchestrates the execution of the task by running a state machine for a specified number of trials.
247
        """
248
        time_last_trial_end = time.time()
2✔
249
        for trial_number in range(self.task_params.NTRIALS):  # Main loop
2✔
250
            # obtain state machine definition
251
            self.next_trial()
2✔
252
            sma = self.get_state_machine_trial(trial_number)
2✔
253

254
            # Waiting for camera / initial delay will be handled just prior to the first trial
255
            # This is done here to allow for backward compatibility with unadapted tasks
256
            if trial_number == 0:
2✔
257
                # warn if state machine uses deprecated way of waiting for camera / initial delay
258
                if (5, SOFTCODE.TRIGGER_CAMERA) in sma.output_matrix[0] and sma.state_names[1] == 'delay_initiation':
2✔
UNCOV
259
                    log.warning('')
×
UNCOV
260
                    log.warning('**********************************************')
×
UNCOV
261
                    log.warning('ATTENTION: YOUR TASK DEFINITION NEEDS UPDATING')
×
UNCOV
262
                    log.warning('**********************************************')
×
UNCOV
263
                    log.warning('Camera and initial delay should not be handled')
×
UNCOV
264
                    log.warning('within the `get_state_machine_trial()` method.')
×
UNCOV
265
                    log.warning('For further details, please refer to section  ')
×
UNCOV
266
                    log.warning("'Deprecation Notes' in IBLRIG's documentation.")
×
UNCOV
267
                    log.warning('**********************************************')
×
UNCOV
268
                    log.warning('')
×
UNCOV
269
                    log.info('Waiting for 10s so you actually read this message ;-)')
×
UNCOV
270
                    time.sleep(10)
×
271
                else:
272
                    self._wait_for_camera_and_initial_delay()
2✔
273

274
            # send state machine description to Bpod device
275
            log.debug('Sending state machine to bpod')
2✔
276
            self.bpod.send_state_machine(sma)
2✔
277

278
            # handle ITI durations
279
            if trial_number > 0:
2✔
280
                # The ITI_DELAY_SECS defines the grey screen period within the state machine, where the
281
                # Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next
282
                dead_time = self.task_params.get('DEAD_TIME', 0.5)
2✔
283
                dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end)
2✔
284

285
                # wait to achieve the desired ITI duration
286
                if dt > 0:
2✔
UNCOV
287
                    log.debug(f'Waiting {dt} s to achieve an ITI duration of {self.task_params.ITI_DELAY_SECS} s')
×
UNCOV
288
                    time.sleep(dt)
×
289

290
            # run state machine
291
            log.info('-----------------------')
2✔
292
            log.info(f'Starting Trial #{trial_number}')
2✔
293
            log.debug('running state machine')
2✔
294
            self.bpod.run_state_machine(sma)  # Locks until state machine 'exit' is reached
2✔
295
            time_last_trial_end = time.time()
2✔
296

297
            # handle pause event
298
            flag_pause = self.paths.SESSION_FOLDER.joinpath('.pause')
2✔
299
            flag_stop = self.paths.SESSION_FOLDER.joinpath('.stop')
2✔
300
            if flag_pause.exists() and trial_number < (self.task_params.NTRIALS - 1):
2✔
301
                log.info(f'Pausing session inbetween trials {trial_number} and {trial_number + 1}')
2✔
302
                while flag_pause.exists() and not flag_stop.exists():
2✔
303
                    time.sleep(1)
2✔
304
                self.trials_table.at[self.trial_num, 'pause_duration'] = time.time() - time_last_trial_end
2✔
305
                if not flag_stop.exists():
2✔
306
                    log.info('Resuming session')
2✔
307

308
            # save trial and update log
309
            self.trial_completed(self.bpod.session.current_trial.export())
2✔
310
            self.show_trial_log()
2✔
311

312
            # handle stop event
313
            if flag_stop.exists():
2✔
314
                log.info('Stopping session after trial %d', trial_number)
2✔
315
                flag_stop.unlink()
2✔
316
                break
2✔
317

318
    def mock(self, file_jsonable_fixture=None):
2✔
319
        """
320
        Instantiate a state machine and Bpod object to simulate a task's run.
321

322
        This is useful to test or display the state machine flow.
323
        """
324
        super().mock()
2✔
325

326
        if file_jsonable_fixture is not None:
2✔
327
            task_data = jsonable.read(file_jsonable_fixture)
2✔
328
            # pop-out the bpod data from the table
329
            bpod_data = []
2✔
330
            for td in task_data:
2✔
331
                bpod_data.append(td.pop('behavior_data'))
2✔
332

333
            class MockTrial(Trial):
2✔
334
                def export(self):
2✔
335
                    return np.random.choice(bpod_data)
2✔
336
        else:
337

338
            class MockTrial(Trial):
2✔
339
                def export(self):
2✔
UNCOV
340
                    return {}
×
341

342
        self.bpod.session.trials = [MockTrial()]
2✔
343
        self.bpod.send_state_machine = lambda k: None
2✔
344
        self.bpod.run_state_machine = lambda k: time.sleep(1.2)
2✔
345

346
        daction = ('dummy', 'action')
2✔
347
        self.sound = Bunch({'GO_TONE': daction, 'WHITE_NOISE': daction})
2✔
348

349
        self.bpod.actions.update(
2✔
350
            {
351
                'play_tone': daction,
352
                'play_noise': daction,
353
                'stop_sound': daction,
354
                'rotary_encoder_reset': daction,
355
                'bonsai_hide_stim': daction,
356
                'bonsai_show_stim': daction,
357
                'bonsai_closed_loop': daction,
358
                'bonsai_freeze_stim': daction,
359
                'bonsai_show_center': daction,
360
                'bonsai_freeze_center': daction,
361
            }
362
        )
363

364
    def get_graphviz_task(self, output_file=None, view=True):
2✔
365
        """
366
        Get the state machine's states diagram in Digraph format.
367

368
        :param output_file:
369
        :return:
370
        """
371
        import graphviz
2✔
372

373
        self.next_trial()
2✔
374
        sma = self.get_state_machine_trial(0)
2✔
375
        if sma is None:
2✔
376
            return
2✔
377
        states_indices = {i: k for i, k in enumerate(sma.state_names)}
2✔
378
        states_indices.update({(i + 10000): k for i, k in enumerate(sma.undeclared)})
2✔
379
        states_letters = {k: ascii_letters[i] for i, k in enumerate(sma.state_names)}
2✔
380
        dot = graphviz.Digraph(comment='The Great IBL Task')
2✔
381
        edges = []
2✔
382

383
        for i in range(len(sma.state_names)):
2✔
384
            letter = states_letters[sma.state_names[i]]
2✔
385
            dot.node(letter, sma.state_names[i])
2✔
386
            if ~np.isnan(sma.state_timer_matrix[i]):
2✔
387
                out_state = states_indices[sma.state_timer_matrix[i]]
2✔
388
                edges.append(f'{letter}{states_letters[out_state]}')
2✔
389
            for inputs in sma.input_matrix[i]:
2✔
390
                if inputs[0] == 0:
2✔
391
                    edges.append(f'{letter}{states_letters[states_indices[inputs[1]]]}')
2✔
392
        dot.edges(edges)
2✔
393
        if output_file is not None:
2✔
394
            try:
2✔
395
                dot.render(output_file, view=view)
2✔
396
            except graphviz.exceptions.ExecutableNotFound:
2✔
397
                log.info('Graphviz system executable not found, cannot render the graph')
2✔
398
        return dot
2✔
399

400
    def _instantiate_state_machine(self, *args, **kwargs):
2✔
401
        return StateMachine(self.bpod)
2✔
402

403
    def get_state_machine_trial(self, i):
2✔
404
        # we define the trial number here for subclasses that may need it
405
        sma = self._instantiate_state_machine(trial_number=i)
2✔
406

407
        # Signal trial start and stop all sounds
408
        sma.add_state(
2✔
409
            state_name='trial_start',
410
            state_timer=0,  # ~100µs hardware irreducible delay
411
            state_change_conditions={'Tup': 'reset_rotary_encoder'},
412
            output_actions=[self.bpod.actions.stop_sound, ('BNC1', 255)],
413
        )
414

415
        # Reset the rotary encoder by sending the following opcodes via the modules serial interface
416
        # - 'Z' (ASCII 90): Set current rotary encoder position to zero
417
        # - 'E' (ASCII 69): Enable all position thresholds (that may have been disabled by a threshold-crossing)
418
        # cf. https://sanworks.github.io/Bpod_Wiki/serial-interfaces/rotary-encoder-module-serial-interface/
419
        sma.add_state(
2✔
420
            state_name='reset_rotary_encoder',
421
            state_timer=0,
422
            output_actions=[self.bpod.actions.rotary_encoder_reset],
423
            state_change_conditions={'Tup': 'quiescent_period'},
424
        )
425

426
        # Quiescent Period. If the wheel is moved past one of the thresholds: Reset the rotary encoder and start over.
427
        # Continue with the stimulation once the quiescent period has passed without triggering movement thresholds.
428
        sma.add_state(
2✔
429
            state_name='quiescent_period',
430
            state_timer=self.quiescent_period,
431
            output_actions=[],
432
            state_change_conditions={
433
                'Tup': 'stim_on',
434
                self.movement_left: 'reset_rotary_encoder',
435
                self.movement_right: 'reset_rotary_encoder',
436
            },
437
        )
438

439
        # Show the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
440
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
441
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
442
        sma.add_state(
2✔
443
            state_name='stim_on',
444
            state_timer=0.1,
445
            output_actions=[self.bpod.actions.bonsai_show_stim],
446
            state_change_conditions={'BNC1High': 'interactive_delay', 'BNC1Low': 'interactive_delay', 'Tup': 'interactive_delay'},
447
        )
448

449
        # Defined delay between visual and auditory cue
450
        sma.add_state(
2✔
451
            state_name='interactive_delay',
452
            state_timer=self.task_params.INTERACTIVE_DELAY,
453
            output_actions=[],
454
            state_change_conditions={'Tup': 'play_tone'},
455
        )
456

457
        # Play tone. Move to next state if sound is detected. Use the state-timer as a backup to prevent a stall.
458
        sma.add_state(
2✔
459
            state_name='play_tone',
460
            state_timer=0.1,
461
            output_actions=[self.bpod.actions.play_tone],
462
            state_change_conditions={'Tup': 'reset2_rotary_encoder', 'BNC2High': 'reset2_rotary_encoder'},
463
        )
464

465
        # Reset rotary encoder (see above). Move on after brief delay (to avoid a race conditions in the bonsai flow).
466
        sma.add_state(
2✔
467
            state_name='reset2_rotary_encoder',
468
            state_timer=0.05,
469
            output_actions=[self.bpod.actions.rotary_encoder_reset],
470
            state_change_conditions={'Tup': 'closed_loop'},
471
        )
472

473
        # Start the closed loop state in which the animal controls the position of the visual stimulus by means of the
474
        # rotary encoder. The three possible outcomes are:
475
        # 1) wheel has NOT been moved past a threshold: continue with no-go condition
476
        # 2) wheel has been moved in WRONG direction: continue with error condition
477
        # 3) wheel has been moved in CORRECT direction: continue with reward condition
478

479
        sma.add_state(
2✔
480
            state_name='closed_loop',
481
            state_timer=self.task_params.RESPONSE_WINDOW,
482
            output_actions=[self.bpod.actions.bonsai_closed_loop],
483
            state_change_conditions={'Tup': 'no_go', self.event_error: 'freeze_error', self.event_reward: 'freeze_reward'},
484
        )
485

486
        # No-go: hide the visual stimulus and play white noise. Go to exit_state after FEEDBACK_NOGO_DELAY_SECS.
487
        sma.add_state(
2✔
488
            state_name='no_go',
489
            state_timer=self.feedback_nogo_delay,
490
            output_actions=[self.bpod.actions.bonsai_hide_stim, self.bpod.actions.play_noise],
491
            state_change_conditions={'Tup': 'exit_state'},
492
        )
493

494
        # Error: Freeze the stimulus and play white noise.
495
        # Continue to hide_stim/exit_state once FEEDBACK_ERROR_DELAY_SECS have passed.
496
        sma.add_state(
2✔
497
            state_name='freeze_error',
498
            state_timer=0,
499
            output_actions=[self.bpod.actions.bonsai_freeze_stim],
500
            state_change_conditions={'Tup': 'error'},
501
        )
502
        sma.add_state(
2✔
503
            state_name='error',
504
            state_timer=self.feedback_error_delay,
505
            output_actions=[self.bpod.actions.play_noise],
506
            state_change_conditions={'Tup': 'hide_stim'},
507
        )
508

509
        # Reward: open the valve for a defined duration (and set BNC1 to high), freeze stimulus in center of screen.
510
        # Continue to hide_stim/exit_state once FEEDBACK_CORRECT_DELAY_SECS have passed.
511
        sma.add_state(
2✔
512
            state_name='freeze_reward',
513
            state_timer=0,
514
            output_actions=[self.bpod.actions.bonsai_freeze_center],
515
            state_change_conditions={'Tup': 'reward'},
516
        )
517
        sma.add_state(
2✔
518
            state_name='reward',
519
            state_timer=self.reward_time,
520
            output_actions=[('Valve1', 255), ('BNC1', 255)],
521
            state_change_conditions={'Tup': 'correct'},
522
        )
523
        sma.add_state(
2✔
524
            state_name='correct',
525
            state_timer=self.feedback_correct_delay - self.reward_time,
526
            output_actions=[],
527
            state_change_conditions={'Tup': 'hide_stim'},
528
        )
529

530
        # Hide the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
531
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
532
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
533
        sma.add_state(
2✔
534
            state_name='hide_stim',
535
            state_timer=0.1,
536
            output_actions=[self.bpod.actions.bonsai_hide_stim],
537
            state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'},
538
        )
539

540
        # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event.
541
        sma.add_state(
2✔
542
            state_name='exit_state',
543
            state_timer=self.task_params.ITI_DELAY_SECS,
544
            output_actions=[('BNC1', 255)],
545
            state_change_conditions={'Tup': 'exit'},
546
        )
547

548
        return sma
2✔
549

550
    @abc.abstractmethod
2✔
551
    def next_trial(self):
2✔
UNCOV
552
        pass
×
553

554
    @property
2✔
555
    def default_reward_amount(self):
2✔
556
        return self.task_params.REWARD_AMOUNT_UL
2✔
557

558
    def draw_next_trial_info(self, pleft=0.5, **kwargs):
2✔
559
        """Draw next trial variables.
560

561
        calls :meth:`send_trial_info_to_bonsai`.
562
        This is called by the `next_trial` method before updating the Bpod state machine.
563
        """
564
        assert len(self.task_params.STIM_POSITIONS) == 2, 'Only two positions are supported'
2✔
565
        contrast = misc.draw_contrast(self.task_params.CONTRAST_SET, self.task_params.CONTRAST_SET_PROBABILITY_TYPE)
2✔
566
        position = int(np.random.choice(self.task_params.STIM_POSITIONS, p=[pleft, 1 - pleft]))
2✔
567
        quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.truncated_exponential(
2✔
568
            scale=0.35, min_value=0.2, max_value=0.5
569
        )
570
        self.trials_table.at[self.trial_num, 'quiescent_period'] = quiescent_period
2✔
571
        self.trials_table.at[self.trial_num, 'contrast'] = contrast
2✔
572
        self.trials_table.at[self.trial_num, 'stim_phase'] = random.uniform(0, 2 * math.pi)
2✔
573
        self.trials_table.at[self.trial_num, 'stim_sigma'] = self.task_params.STIM_SIGMA
2✔
574
        self.trials_table.at[self.trial_num, 'stim_angle'] = self.task_params.STIM_ANGLE
2✔
575
        self.trials_table.at[self.trial_num, 'stim_gain'] = self.stimulus_gain
2✔
576
        self.trials_table.at[self.trial_num, 'stim_freq'] = self.task_params.STIM_FREQ
2✔
577
        self.trials_table.at[self.trial_num, 'stim_reverse'] = self.task_params.STIM_REVERSE
2✔
578
        self.trials_table.at[self.trial_num, 'trial_num'] = self.trial_num
2✔
579
        self.trials_table.at[self.trial_num, 'position'] = position
2✔
580
        self.trials_table.at[self.trial_num, 'reward_amount'] = self.default_reward_amount
2✔
581
        self.trials_table.at[self.trial_num, 'stim_probability_left'] = pleft
2✔
582

583
        # use the kwargs dict to override computed values
584
        for key, value in kwargs.items():
2✔
585
            if key == 'index':
2✔
UNCOV
586
                pass
×
587
            self.trials_table.at[self.trial_num, key] = value
2✔
588

589
        self.send_trial_info_to_bonsai()
2✔
590

591
    def trial_completed(self, bpod_data: dict[str, Any]) -> None:
2✔
592
        # if the reward state has not been triggered, null the reward
593
        if np.isnan(bpod_data['States timestamps']['reward'][0][0]):
2✔
594
            self.trials_table.at[self.trial_num, 'reward_amount'] = 0
2✔
595
        self.trials_table.at[self.trial_num, 'reward_valve_time'] = self.reward_time
2✔
596
        # update cumulative reward value
597
        self.session_info.TOTAL_WATER_DELIVERED += self.trials_table.at[self.trial_num, 'reward_amount']
2✔
598
        self.session_info.NTRIALS += 1
2✔
599
        # SAVE TRIAL DATA
600
        self.save_trial_data_to_json(bpod_data)
2✔
601

602
        # save ambient data
603
        if self.hardware_settings.device_bpod.USE_AMBIENT_MODULE:
2✔
604
            self.ambient_sensor_table.iloc[self.trial_num] = (sensor_reading := self.bpod.get_ambient_sensor_reading())
2✔
605
            with self.paths['AMBIENT_FILE_PATH'].open('ab') as f:
2✔
606
                binary.write_array(f, [self.trial_num, *sensor_reading], DTYPE_AMBIENT_SENSOR_BIN)
2✔
607

608
        # this is a flag for the online plots. If online plots were in pyqt5, there is a file watcher functionality
609
        Path(self.paths['DATA_FILE_PATH']).parent.joinpath('new_trial.flag').touch()
2✔
610
        self.paths.SESSION_FOLDER.joinpath('transfer_me.flag').touch()
2✔
611
        self.check_sync_pulses(bpod_data=bpod_data)
2✔
612

613
    def check_sync_pulses(self, bpod_data):
2✔
614
        # todo move this in the post trial when we have a task flow
615
        if not self.bpod.is_connected:
2✔
616
            return
2✔
UNCOV
617
        events = bpod_data['Events timestamps']
×
UNCOV
618
        if not misc.get_port_events(events, name='BNC1'):
×
UNCOV
619
            log.warning("NO FRAME2TTL PULSES RECEIVED ON BPOD'S TTL INPUT 1")
×
UNCOV
620
        if not misc.get_port_events(events, name='BNC2'):
×
UNCOV
621
            log.warning("NO SOUND SYNC PULSES RECEIVED ON BPOD'S TTL INPUT 2")
×
UNCOV
622
        if not misc.get_port_events(events, name='Port1'):
×
UNCOV
623
            log.warning("NO CAMERA SYNC PULSES RECEIVED ON BPOD'S BEHAVIOR PORT 1")
×
624

625
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
626
        """
627
        Log the details of the current trial.
628

629
        This method retrieves information about the current trial from the
630
        trials table and logs it. It can also incorporate additional information
631
        provided through the `extra_info` parameter.
632

633
        Parameters
634
        ----------
635
        extra_info : dict[str, Any], optional
636
            A dictionary containing additional information to include in the
637
            log.
638

639
        log_level : int, optional
640
            The logging level to use when logging the trial information.
641
            Default is logging.INFO.
642

643
        Notes
644
        -----
645
        When overloading, make sure to call the super class and pass additional
646
        log items by means of the extra_info parameter. See the implementation
647
        of :py:meth:`~iblrig.base_choice_world.ActiveChoiceWorldSession.show_trial_log` in
648
        :mod:`~iblrig.base_choice_world.ActiveChoiceWorldSession` for reference.
649
        """
650
        # construct base info dict
651
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
652
        info_dict = {
2✔
653
            'Stim. Position': trial_info.position,
654
            'Stim. Contrast': trial_info.contrast,
655
            'Stim. Phase': f'{trial_info.stim_phase:.2f}',
656
            'Stim. p Left': trial_info.stim_probability_left,
657
            'Water delivered': f'{self.session_info.TOTAL_WATER_DELIVERED:.1f} µl',
658
            'Time from Start': self.time_elapsed,
659
            'Temperature': f'{self.ambient_sensor_table.loc[self.trial_num, "Temperature_C"]:.1f} °C',
660
            'Air Pressure': f'{self.ambient_sensor_table.loc[self.trial_num, "AirPressure_mb"]:.1f} mb',
661
            'Rel. Humidity': f'{self.ambient_sensor_table.loc[self.trial_num, "RelativeHumidity"]:.1f} %',
662
        }
663

664
        # update info dict with extra_info dict
665
        if isinstance(extra_info, dict):
2✔
666
            info_dict.update(extra_info)
2✔
667

668
        # log info dict
669
        log.log(log_level, f'Outcome of Trial #{trial_info.trial_num}:')
2✔
670
        key_format = '- {}: '
2✔
671
        n_justify = max(len(key) for key in info_dict) + len(key_format.format(''))
2✔
672
        for key, value in info_dict.items():
2✔
673
            log.log(log_level, key_format.format(key).ljust(n_justify) + str(value))
2✔
674

675
    @property
2✔
676
    def iti_reward(self):
2✔
677
        """
678
        Returns the ITI time that needs to be set in order to achieve the desired ITI,
679
        by subtracting the time it takes to give a reward from the desired ITI.
680
        """
UNCOV
681
        return self.task_params.ITI_CORRECT - self.calibration.get('REWARD_VALVE_TIME', None)
×
682

683
    """
2✔
684
    Those are the properties that are used in the state machine code
685
    """
686

687
    @property
2✔
688
    def reward_time(self):
2✔
689
        return self.compute_reward_time(amount_ul=self.trials_table.at[self.trial_num, 'reward_amount'])
2✔
690

691
    @property
2✔
692
    def quiescent_period(self):
2✔
693
        return self.trials_table.at[self.trial_num, 'quiescent_period']
2✔
694

695
    @property
2✔
696
    def feedback_correct_delay(self):
2✔
697
        return self.task_params['FEEDBACK_CORRECT_DELAY_SECS']
2✔
698

699
    @property
2✔
700
    def feedback_error_delay(self):
2✔
701
        return self.task_params['FEEDBACK_ERROR_DELAY_SECS']
2✔
702

703
    @property
2✔
704
    def feedback_nogo_delay(self):
2✔
705
        return self.task_params['FEEDBACK_NOGO_DELAY_SECS']
2✔
706

707
    @property
2✔
708
    def position(self):
2✔
709
        return self.trials_table.at[self.trial_num, 'position']
2✔
710

711
    @property
2✔
712
    def event_error(self):
2✔
713
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(-1 if self.task_params.STIM_REVERSE else 1) * self.position]
2✔
714

715
    @property
2✔
716
    def event_reward(self):
2✔
717
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(1 if self.task_params.STIM_REVERSE else -1) * self.position]
2✔
718

719

720
class HabituationChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
721
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
722

723
    delay_to_stim_center: NonNegativeFloat
2✔
724

725

726
class HabituationChoiceWorldSession(ChoiceWorldSession):
2✔
727
    protocol_name = '_iblrig_tasks_habituationChoiceWorld'
2✔
728
    TrialDataModel = HabituationChoiceWorldTrialData
2✔
729

730
    def next_trial(self):
2✔
731
        self.trial_num += 1
2✔
732
        self.draw_next_trial_info()
2✔
733

734
    def draw_next_trial_info(self, *args, **kwargs):
2✔
735
        # update trial table fields specific to habituation choice world
736
        self.trials_table.at[self.trial_num, 'delay_to_stim_center'] = np.random.normal(self.task_params.DELAY_TO_STIM_CENTER, 2)
2✔
737
        super().draw_next_trial_info(*args, **kwargs)
2✔
738

739
    def get_state_machine_trial(self, i):
2✔
740
        sma = StateMachine(self.bpod)
2✔
741

742
        # NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on.
743
        # During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end;
744
        # the offset of this state is trial start!
745
        sma.add_state(
2✔
746
            state_name='iti',
747
            state_timer=1,  # Stim off for 1 sec
748
            state_change_conditions={'Tup': 'stim_on'},
749
            output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)],
750
        )
751

752
        # This stim_on state is considered the actual trial start
753
        sma.add_state(
2✔
754
            state_name='stim_on',
755
            state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'],
756
            state_change_conditions={'Tup': 'stim_center'},
757
            output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone],
758
        )
759

760
        sma.add_state(
2✔
761
            state_name='stim_center',
762
            state_timer=0.5,
763
            state_change_conditions={'Tup': 'reward'},
764
            output_actions=[self.bpod.actions.bonsai_show_center],
765
        )
766

767
        sma.add_state(
2✔
768
            state_name='reward',
769
            state_timer=self.reward_time,  # the length of time to leave reward valve open, i.e. reward size
770
            state_change_conditions={'Tup': 'post_reward'},
771
            output_actions=[('Valve1', 255), ('BNC1', 255)],
772
        )
773
        # This state defines the period after reward where Bpod TTL is LOW.
774
        # NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit.
775
        # The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds.
776
        sma.add_state(
2✔
777
            state_name='post_reward',
778
            state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time,
779
            state_change_conditions={'Tup': 'exit'},
780
            output_actions=[],
781
        )
782
        return sma
2✔
783

784

785
class ActiveChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
786
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
787

788
    response_side: Annotated[int, Interval(ge=-1, le=1)]
2✔
789
    response_time: NonNegativeFloat
2✔
790
    trial_correct: bool
2✔
791

792

793
class ActiveChoiceWorldSession(ChoiceWorldSession):
2✔
794
    """
795
    The ActiveChoiceWorldSession is a base class for protocols where the mouse is actively making decisions
796
    by turning the wheel. It has the following characteristics
797

798
    -   it is trial based
799
    -   it is decision based
800
    -   left and right simulus are equiprobable: there is no biased block
801
    -   a trial can either be correct / error / no_go depending on the side of the stimulus and the response
802
    -   it has a quantifiable performance by computing the proportion of correct trials of passive stimulations protocols or
803
        habituation protocols.
804

805
    The TrainingChoiceWorld, BiasedChoiceWorld are all subclasses of this class
806
    """
807

808
    TrialDataModel = ActiveChoiceWorldTrialData
2✔
809
    plot_subprocess: subprocess.Popen | None = None
2✔
810

811
    def __init__(self, **kwargs):
2✔
812
        super().__init__(**kwargs)
2✔
813
        self.trials_table['stim_probability_left'] = np.zeros(NTRIALS_INIT, dtype=np.float64)
2✔
814

815
    def _run(self):
2✔
816
        # starts online plotting
817
        if self.interactive:
2✔
UNCOV
818
            log.info('Starting subprocess: online plots')
×
UNCOV
819
            self.plot_subprocess = subprocess.Popen(
×
820
                ['view_session', str(self.paths['SESSION_RAW_DATA_FOLDER'])],
821
                stdout=subprocess.DEVNULL,
822
                stderr=subprocess.STDOUT,
823
            )
824
        super()._run()
2✔
825

826
    def __del__(self):
2✔
827
        if isinstance(self.plot_subprocess, subprocess.Popen) and self.plot_subprocess.poll() is None:
2✔
UNCOV
828
            log.info('Terminating subprocess: online plots')
×
UNCOV
829
            self.plot_subprocess.terminate()
×
UNCOV
830
            try:
×
UNCOV
831
                self.plot_subprocess.wait(timeout=5)
×
UNCOV
832
            except subprocess.TimeoutExpired:
×
UNCOV
833
                log.warning('Process did not terminate within 5 seconds - killing it.')
×
UNCOV
834
                self.plot_subprocess.kill()
×
835

836
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
837
        # construct info dict
838
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
839
        info_dict = {
2✔
840
            'Response Time': f'{trial_info.response_time:.2f} s',
841
            'Trial Correct': trial_info.trial_correct,
842
            'N Trials Correct': self.session_info.NTRIALS_CORRECT,
843
            'N Trials Error': self.trial_num - self.session_info.NTRIALS_CORRECT + 1,
844
        }
845

846
        # update info dict with extra_info dict
847
        if isinstance(extra_info, dict):
2✔
848
            info_dict.update(extra_info)
2✔
849

850
        # call parent method
851
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
852

853
    def trial_completed(self, bpod_data: dict) -> None:
2✔
854
        """
855
        Update the trials table with information about the behaviour coming from the bpod.
856

857
        Constraints on the state machine data:
858

859
        - mandatory states: ['correct', 'error', 'no_go', 'reward']
860
        - optional states : ['omit_correct', 'omit_error', 'omit_no_go']
861

862
        Parameters
863
        ----------
864
        bpod_data : dict
865
            The Bpod data as returned by pybpod
866

867
        Raises
868
        ------
869
        AssertionError
870
            If the position is zero or if the number of detected outcomes is not exactly one.
871
        """
872
        # Get the response time from the behaviour data.
873
        # It is defined as the time passing between the start of `stim_on` and the end of `closed_loop`.
874
        state_times = bpod_data['States timestamps']
2✔
875
        response_time = state_times['closed_loop'][0][1] - state_times['stim_on'][0][0]
2✔
876
        self.trials_table.at[self.trial_num, 'response_time'] = response_time
2✔
877

878
        try:
2✔
879
            # Get the stimulus position
880
            position = self.trials_table.at[self.trial_num, 'position']
2✔
881
            assert position != 0, 'the stimulus position should not be 0'
2✔
882

883
            # Get the trial's outcome, i.e., the states that have a matching name and a valid time-stamp
884
            # Assert that we have exactly one outcome
885
            outcome_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go']
2✔
886
            outcomes = [name for name, times in state_times.items() if name in outcome_names and ~np.isnan(times[0][0])]
2✔
887
            if (n_outcomes := len(outcomes)) != 1:
2✔
UNCOV
888
                trial_states = 'Trial states: ' + ', '.join(k for k, v in state_times.items() if ~np.isnan(v[0][0]))
×
UNCOV
889
                assert n_outcomes != 0, f'No outcome detected for trial {self.trial_num}.\n{trial_states}'
×
UNCOV
890
                assert n_outcomes == 1, f'{n_outcomes} outcomes detected for trial {self.trial_num}.\n{trial_states}'
×
891
            outcome = outcomes[0]
2✔
892

UNCOV
893
        except AssertionError as e:
×
894
            # write bpod_data to disk, log exception then raise
UNCOV
895
            self.save_trial_data_to_json(bpod_data, validate=False)
×
UNCOV
896
            for line in re_split(r'\n', e.args[0]):
×
UNCOV
897
                log.error(line)
×
UNCOV
898
            raise e
×
899

900
        # record the trial's outcome in the trials_table
901
        self.trials_table.at[self.trial_num, 'trial_correct'] = 'correct' in outcome
2✔
902
        if 'correct' in outcome:
2✔
903
            self.session_info.NTRIALS_CORRECT += 1
2✔
904
            self.trials_table.at[self.trial_num, 'response_side'] = -np.sign(position)
2✔
905
        elif 'error' in outcome:
2✔
906
            self.trials_table.at[self.trial_num, 'response_side'] = np.sign(position)
2✔
907
        elif 'no_go' in outcome:
2✔
908
            self.trials_table.at[self.trial_num, 'response_side'] = 0
2✔
909

910
        super().trial_completed(bpod_data)
2✔
911

912

913
class BiasedChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
914
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
915

916
    block_num: NonNegativeInt = 0
2✔
917
    block_trial_num: NonNegativeInt = 0
2✔
918

919

920
class BiasedChoiceWorldSession(ActiveChoiceWorldSession):
2✔
921
    """
922
    Biased choice world session is the instantiation of ActiveChoiceWorld where the notion of biased
923
    blocks is introduced.
924
    """
925

926
    base_parameters_file = Path(__file__).parent.joinpath('base_biased_choice_world_params.yaml')
2✔
927
    protocol_name = '_iblrig_tasks_biasedChoiceWorld'
2✔
928
    TrialDataModel = BiasedChoiceWorldTrialData
2✔
929

930
    def __init__(self, **kwargs):
2✔
931
        super().__init__(**kwargs)
2✔
932
        self.blocks_table = pd.DataFrame(
2✔
933
            {'probability_left': np.zeros(NBLOCKS_INIT) * np.nan, 'block_length': np.zeros(NBLOCKS_INIT, dtype=np.int16) * -1}
934
        )
935

936
    def new_block(self):
2✔
937
        """
938
        If block_init_5050
939
            First block has 50/50 probability of leftward stim
940
            is 90 trials long
941
        """
942
        self.block_num += 1  # the block number is zero based
2✔
943
        self.block_trial_num = 0
2✔
944

945
        # handles the block length logic
946
        if self.task_params.BLOCK_INIT_5050 and self.block_num == 0:
2✔
947
            block_len = 90
2✔
948
        else:
949
            block_len = int(
2✔
950
                misc.truncated_exponential(
951
                    scale=self.task_params.BLOCK_LEN_FACTOR,
952
                    min_value=self.task_params.BLOCK_LEN_MIN,
953
                    max_value=self.task_params.BLOCK_LEN_MAX,
954
                )
955
            )
956
        if self.block_num == 0:
2✔
957
            pleft = 0.5 if self.task_params.BLOCK_INIT_5050 else np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
958
        elif self.block_num == 1 and self.task_params.BLOCK_INIT_5050:
2✔
959
            pleft = np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
960
        else:
961
            # this switches the probability of leftward stim for the next block
962
            pleft = round(abs(1 - self.blocks_table.loc[self.block_num - 1, 'probability_left']), 1)
2✔
963
        self.blocks_table.at[self.block_num, 'block_length'] = block_len
2✔
964
        self.blocks_table.at[self.block_num, 'probability_left'] = pleft
2✔
965

966
    def next_trial(self):
2✔
967
        self.trial_num += 1
2✔
968
        # if necessary update the block number
969
        self.block_trial_num += 1
2✔
970
        if self.block_num < 0 or self.block_trial_num > (self.blocks_table.loc[self.block_num, 'block_length'] - 1):
2✔
971
            self.new_block()
2✔
972
        # get and store probability left
973
        pleft = self.blocks_table.loc[self.block_num, 'probability_left']
2✔
974
        # update trial table fields specific to biased choice world task
975
        self.trials_table.at[self.trial_num, 'block_num'] = self.block_num
2✔
976
        self.trials_table.at[self.trial_num, 'block_trial_num'] = self.block_trial_num
2✔
977
        # save and send trial info to bonsai
978
        self.draw_next_trial_info(pleft=pleft)
2✔
979

980
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
981
        # construct info dict
982
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
983
        info_dict = {
2✔
984
            'Block Number': trial_info.block_num,
985
            'Block Length': self.blocks_table.loc[self.block_num, 'block_length'],
986
            'N Trials in Block': trial_info.block_trial_num,
987
        }
988

989
        # update info dict with extra_info dict
990
        if isinstance(extra_info, dict):
2✔
UNCOV
991
            info_dict.update(extra_info)
×
992

993
        # call parent method
994
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
995

996

997
class TrainingChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
998
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ActiveChoiceWorldTrialData`."""
999

1000
    training_phase: NonNegativeInt
2✔
1001
    debias_trial: bool
2✔
1002

1003

1004
class TrainingChoiceWorldSession(ActiveChoiceWorldSession):
2✔
1005
    """
1006
    The TrainingChoiceWorldSession corresponds to the first training protocol of the choice world task.
1007
    This protocol has a complicated adaptation of the number of contrasts (embodied by the training_phase
1008
    property) and the reward amount, embodied by the adaptive_reward property.
1009
    """
1010

1011
    protocol_name = '_iblrig_tasks_trainingChoiceWorld'
2✔
1012
    TrialDataModel = TrainingChoiceWorldTrialData
2✔
1013

1014
    def __init__(self, training_phase=-1, adaptive_reward=-1.0, adaptive_gain=None, **kwargs):
2✔
1015
        super().__init__(**kwargs)
2✔
1016
        inferred_training_phase, inferred_adaptive_reward, inferred_adaptive_gain = self.get_subject_training_info()
2✔
1017
        if training_phase == -1:
2✔
1018
            log.critical(f'Got training phase: {inferred_training_phase}')
2✔
1019
            self.training_phase = inferred_training_phase
2✔
1020
        else:
1021
            log.critical(f'Training phase manually set to: {training_phase}')
2✔
1022
            self.training_phase = training_phase
2✔
1023
        if adaptive_reward == -1:
2✔
1024
            log.critical(f'Got Adaptive reward {inferred_adaptive_reward} uL')
2✔
1025
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = inferred_adaptive_reward
2✔
1026
        else:
1027
            log.critical(f'Adaptive reward manually set to {adaptive_reward} uL')
2✔
1028
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = adaptive_reward
2✔
1029
        if adaptive_gain is None:
2✔
1030
            log.critical(f'Got Adaptive gain {inferred_adaptive_gain} degrees/mm')
2✔
1031
            self.session_info['ADAPTIVE_GAIN_VALUE'] = inferred_adaptive_gain
2✔
1032
        else:
1033
            log.critical(f'Adaptive gain manually set to {adaptive_gain} degrees/mm')
2✔
1034
            self.session_info['ADAPTIVE_GAIN_VALUE'] = adaptive_gain
2✔
1035
        self.var = {'training_phase_trial_counts': np.zeros(6), 'last_10_responses_sides': np.zeros(10)}
2✔
1036

1037
    @property
2✔
1038
    def default_reward_amount(self):
2✔
1039
        return self.session_info.get('ADAPTIVE_REWARD_AMOUNT_UL', self.task_params.REWARD_AMOUNT_UL)
2✔
1040

1041
    @property
2✔
1042
    def stimulus_gain(self) -> float:
2✔
1043
        return self.session_info.get('ADAPTIVE_GAIN_VALUE')
2✔
1044

1045
    def get_subject_training_info(self):
2✔
1046
        """
1047
        Get the previous session's according to this session parameters and deduce the
1048
        training level, adaptive reward amount and adaptive gain value.
1049

1050
        Returns
1051
        -------
1052
        training_info: dict
1053
            Dictionary with keys: training_phase, adaptive_reward, adaptive_gain
1054
        """
1055
        training_info, _ = choiceworld.get_subject_training_info(
2✔
1056
            subject_name=self.session_info.SUBJECT_NAME,
1057
            task_name=self.protocol_name,
1058
            stim_gain=self.task_params.AG_INIT_VALUE,
1059
            stim_gain_on_error=self.task_params.STIM_GAIN,
1060
            default_reward=self.task_params.REWARD_AMOUNT_UL,
1061
            local_path=self.iblrig_settings['iblrig_local_data_path'],
1062
            remote_path=self.iblrig_settings['iblrig_remote_data_path'],
1063
            lab=self.iblrig_settings['ALYX_LAB'],
1064
            iblrig_settings=self.iblrig_settings,
1065
        )
1066
        return training_info['training_phase'], training_info['adaptive_reward'], training_info['adaptive_gain']
2✔
1067

1068
    def check_training_phase(self) -> bool:
2✔
1069
        """Check if the mouse is ready to move to the next training phase."""
1070
        move_on = False
2✔
1071
        if self.training_phase == 0:  # each of the -1, -.5, .5, 1 contrast should be above 80% perf to switch
2✔
1072
            performance = choiceworld.compute_performance(self.trials_table)
2✔
1073
            passing = performance[np.abs(performance.index) >= 0.5]['last_50_perf']
2✔
1074
            if np.all(passing > 0.8) and passing.size == 4:
2✔
1075
                move_on = True
2✔
1076
        elif self.training_phase == 1:  # each of the -.25, .25 should be above 80% perf to switch
2✔
1077
            performance = choiceworld.compute_performance(self.trials_table)
2✔
1078
            passing = performance[np.abs(performance.index) == 0.25]['last_50_perf']
2✔
1079
            if np.all(passing > 0.8) and passing.size == 2:
2✔
1080
                move_on = True
2✔
1081
        elif 5 > self.training_phase >= 2:  # for the next phases, always switch after 200 trials
2✔
1082
            if self.var['training_phase_trial_counts'][self.training_phase] >= 200:
2✔
1083
                move_on = True
2✔
1084
        if move_on:
2✔
1085
            self.training_phase = np.minimum(5, self.training_phase + 1)
2✔
1086
            log.warning(f'Moving on to training phase {self.training_phase}, {self.trial_num}')
2✔
1087
        return move_on
2✔
1088

1089
    def next_trial(self):
2✔
1090
        # update counters
1091
        self.trial_num += 1
2✔
1092
        self.var['training_phase_trial_counts'][self.training_phase] += 1
2✔
1093

1094
        # check if the subject graduates to a new training phase
1095
        self.check_training_phase()
2✔
1096

1097
        # draw the next trial
1098
        signed_contrast = choiceworld.draw_training_contrast(self.training_phase)
2✔
1099
        position = self.task_params.STIM_POSITIONS[int(np.sign(signed_contrast) == 1)]
2✔
1100
        contrast = np.abs(signed_contrast)
2✔
1101

1102
        # debiasing: if the previous trial was incorrect, not a no-go and easy
1103
        if self.task_params.DEBIAS and self.trial_num >= 1 and self.training_phase < 5:
2✔
1104
            last_contrast = self.trials_table.loc[self.trial_num - 1, 'contrast']
2✔
1105
            do_debias_trial = (
2✔
1106
                (self.trials_table.loc[self.trial_num - 1, 'trial_correct'] != 1)
1107
                and (self.trials_table.loc[self.trial_num - 1, 'response_side'] != 0)
1108
                and last_contrast >= 0.5
1109
            )
1110
            self.trials_table.at[self.trial_num, 'debias_trial'] = do_debias_trial
2✔
1111
            if do_debias_trial:
2✔
1112
                # indices of trials that had a response
1113
                iresponse = np.logical_and(self.trials_table['response_side'].notna(), self.trials_table['response_side'] != 0)
2✔
1114
                iresponse = iresponse.index[iresponse]
2✔
1115

1116
                # takes the average of right responses over last 10 response trials
1117
                average_right = (self.trials_table['response_side'][iresponse[-np.minimum(10, iresponse.size) :]] == 1).mean()
2✔
1118

1119
                # the probability of the next stimulus being on the left is a draw from a normal distribution centered
1120
                # on the average right with sigma 0.5 - if it is less than 0.5 the next stimulus will be on the left.
1121
                position = self.task_params.STIM_POSITIONS[int(np.random.normal(average_right, 0.5) >= 0.5)]
2✔
1122

1123
                # contrast is the last contrast
1124
                contrast = last_contrast
2✔
1125
        else:
1126
            self.trials_table.at[self.trial_num, 'debias_trial'] = False
2✔
1127

1128
        # save and send trial info to bonsai
1129
        self.draw_next_trial_info(pleft=self.task_params.PROBABILITY_LEFT, position=position, contrast=contrast)
2✔
1130
        self.trials_table.at[self.trial_num, 'training_phase'] = self.training_phase
2✔
1131

1132
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
1133
        # construct info dict
UNCOV
1134
        info_dict = {
×
1135
            'Contrast Set': np.unique(np.abs(choiceworld.contrasts_set(self.training_phase))),
1136
            'Training Phase': self.training_phase,
1137
            'Debias Trial': self.trials_table.at[self.trial_num, 'debias_trial'],
1138
        }
1139

1140
        # update info dict with extra_info dict
UNCOV
1141
        if isinstance(extra_info, dict):
×
UNCOV
1142
            info_dict.update(extra_info)
×
1143

1144
        # call parent method
UNCOV
1145
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc