• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 10992769815

23 Sep 2024 10:50AM UTC coverage: 47.799% (+1.0%) from 46.79%
10992769815

Pull #716

github

60cd00
web-flow
Merge 73b6a53cb into a946a6ff9
Pull Request #716: 8.24.1

22 of 49 new or added lines in 9 files covered. (44.9%)

1015 existing lines in 22 files now uncovered.

4191 of 8768 relevant lines covered (47.8%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.16
/iblrig/base_choice_world.py
1
"""Extends the base_tasks modules by providing task logic around the Choice World protocol."""
2

3
import abc
2✔
4
import logging
2✔
5
import math
2✔
6
import random
2✔
7
import subprocess
2✔
8
import time
2✔
9
from pathlib import Path
2✔
10
from string import ascii_letters
2✔
11
from typing import Annotated, Any
2✔
12

13
import numpy as np
2✔
14
import pandas as pd
2✔
15
from annotated_types import Interval, IsNan
2✔
16
from pydantic import NonNegativeFloat, NonNegativeInt
2✔
17

18
import iblrig.base_tasks
2✔
19
import iblrig.graphic
2✔
20
from iblrig import choiceworld, misc
2✔
21
from iblrig.hardware import SOFTCODE
2✔
22
from iblrig.pydantic_definitions import TrialDataModel
2✔
23
from iblutil.io import jsonable
2✔
24
from iblutil.util import Bunch
2✔
25
from pybpodapi.com.messaging.trial import Trial
2✔
26
from pybpodapi.protocol import StateMachine
2✔
27

28
log = logging.getLogger(__name__)
2✔
29

30
NTRIALS_INIT = 2000
2✔
31
NBLOCKS_INIT = 100
2✔
32

33
# TODO: task parameters should be verified through a pydantic model
34
#
35
# Probability = Annotated[float, Field(ge=0.0, le=1.0)]
36
#
37
# class ChoiceWorldParams(BaseModel):
38
#     AUTOMATIC_CALIBRATION: bool = True
39
#     ADAPTIVE_REWARD: bool = False
40
#     BONSAI_EDITOR: bool = False
41
#     CALIBRATION_VALUE: float = 0.067
42
#     CONTRAST_SET: list[Probability] = Field([1.0, 0.25, 0.125, 0.0625, 0.0], min_length=1)
43
#     CONTRAST_SET_PROBABILITY_TYPE: Literal['uniform', 'skew_zero'] = 'uniform'
44
#     GO_TONE_AMPLITUDE: float = 0.0272
45
#     GO_TONE_DURATION: float = 0.11
46
#     GO_TONE_IDX: int = Field(2, ge=0)
47
#     GO_TONE_FREQUENCY: float = Field(5000, gt=0)
48
#     FEEDBACK_CORRECT_DELAY_SECS: float = 1
49
#     FEEDBACK_ERROR_DELAY_SECS: float = 2
50
#     FEEDBACK_NOGO_DELAY_SECS: float = 2
51
#     INTERACTIVE_DELAY: float = 0.0
52
#     ITI_DELAY_SECS: float = 0.5
53
#     NTRIALS: int = Field(2000, gt=0)
54
#     PROBABILITY_LEFT: Probability = 0.5
55
#     QUIESCENCE_THRESHOLDS: list[float] = Field(default=[-2, 2], min_length=2, max_length=2)
56
#     QUIESCENT_PERIOD: float = 0.2
57
#     RECORD_AMBIENT_SENSOR_DATA: bool = True
58
#     RECORD_SOUND: bool = True
59
#     RESPONSE_WINDOW: float = 60
60
#     REWARD_AMOUNT_UL: float = 1.5
61
#     REWARD_TYPE: str = 'Water 10% Sucrose'
62
#     STIM_ANGLE: float = 0.0
63
#     STIM_FREQ: float = 0.1
64
#     STIM_GAIN: float = 4.0  # wheel to stimulus relationship (degrees visual angle per mm of wheel displacement)
65
#     STIM_POSITIONS: list[float] = [-35, 35]
66
#     STIM_SIGMA: float = 7.0
67
#     STIM_TRANSLATION_Z: Literal[7, 8] = 7  # 7 for ephys, 8 otherwise. -p:Stim.TranslationZ-{STIM_TRANSLATION_Z} bonsai param
68
#     STIM_REVERSE: bool = False
69
#     SYNC_SQUARE_X: float = 1.33
70
#     SYNC_SQUARE_Y: float = -1.03
71
#     USE_AUTOMATIC_STOPPING_CRITERIONS: bool = True
72
#     VISUAL_STIMULUS: str = 'GaborIBLTask / Gabor2D.bonsai'  # null / passiveChoiceWorld_passive.bonsai
73
#     WHITE_NOISE_AMPLITUDE: float = 0.05
74
#     WHITE_NOISE_DURATION: float = 0.5
75
#     WHITE_NOISE_IDX: int = 3
76

77

78
class ChoiceWorldTrialData(TrialDataModel):
2✔
79
    """Pydantic Model for Trial Data."""
80

81
    contrast: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
82
    stim_probability_left: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
83
    position: float
2✔
84
    quiescent_period: NonNegativeFloat
2✔
85
    reward_amount: NonNegativeFloat
2✔
86
    reward_valve_time: NonNegativeFloat
2✔
87
    stim_angle: Annotated[float, Interval(ge=-180.0, le=180.0)]
2✔
88
    stim_freq: NonNegativeFloat
2✔
89
    stim_gain: float
2✔
90
    stim_phase: Annotated[float, Interval(ge=0.0, le=2 * math.pi)]
2✔
91
    stim_reverse: bool
2✔
92
    stim_sigma: float
2✔
93
    trial_num: NonNegativeInt
2✔
94
    pause_duration: NonNegativeFloat = 0.0
2✔
95

96
    # The following variables are only used in ActiveChoiceWorld
97
    # We keep them here with fixed default values for sake of compatibility
98
    #
99
    # TODO: Yes, this should probably be done differently.
100
    response_side: Annotated[int, Interval(ge=0, le=0)] = 0
2✔
101
    response_time: IsNan[float] = np.nan
2✔
102
    trial_correct: Annotated[bool, Interval(ge=0, le=0)] = False
2✔
103

104

105
class ChoiceWorldSession(
2✔
106
    iblrig.base_tasks.BonsaiRecordingMixin,
107
    iblrig.base_tasks.BonsaiVisualStimulusMixin,
108
    iblrig.base_tasks.BpodMixin,
109
    iblrig.base_tasks.Frame2TTLMixin,
110
    iblrig.base_tasks.RotaryEncoderMixin,
111
    iblrig.base_tasks.SoundMixin,
112
    iblrig.base_tasks.ValveMixin,
113
    iblrig.base_tasks.NetworkSession,
114
):
115
    # task_params = ChoiceWorldParams()
116
    base_parameters_file = Path(__file__).parent.joinpath('base_choice_world_params.yaml')
2✔
117
    TrialDataModel = ChoiceWorldTrialData
2✔
118

119
    def __init__(self, *args, delay_secs=0, **kwargs):
2✔
120
        super().__init__(**kwargs)
2✔
121
        self.task_params['SESSION_DELAY_START'] = delay_secs
2✔
122
        # init behaviour data
123
        self.movement_left = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[0]]
2✔
124
        self.movement_right = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[1]]
2✔
125
        # init counter variables
126
        self.trial_num = -1
2✔
127
        self.block_num = -1
2✔
128
        self.block_trial_num = -1
2✔
129
        # init the tables, there are 2 of them: a trials table and a ambient sensor data table
130
        self.trials_table = self.TrialDataModel.preallocate_dataframe(NTRIALS_INIT)
2✔
131
        self.ambient_sensor_table = pd.DataFrame(
2✔
132
            {
133
                'Temperature_C': np.zeros(NTRIALS_INIT) * np.NaN,
134
                'AirPressure_mb': np.zeros(NTRIALS_INIT) * np.NaN,
135
                'RelativeHumidity': np.zeros(NTRIALS_INIT) * np.NaN,
136
            }
137
        )
138

139
    @staticmethod
2✔
140
    def extra_parser():
2✔
141
        """:return: argparse.parser()"""
142
        parser = super(ChoiceWorldSession, ChoiceWorldSession).extra_parser()
2✔
143
        parser.add_argument(
2✔
144
            '--delay_secs',
145
            dest='delay_secs',
146
            default=0,
147
            type=int,
148
            required=False,
149
            help='initial delay before starting the first trial (default: 0s)',
150
        )
151
        parser.add_argument(
2✔
152
            '--remote',
153
            dest='remote_rigs',
154
            type=str,
155
            required=False,
156
            action='append',
157
            nargs='+',
158
            help='specify one of the remote rigs to interact with over the network',
159
        )
160
        return parser
2✔
161

162
    def start_hardware(self):
2✔
163
        """
164
        In this step we explicitly run the start methods of the various mixins.
165
        The super class start method is overloaded because we need to start the different hardware pieces in order
166
        """
167
        if not self.is_mock:
2✔
UNCOV
168
            self.start_mixin_frame2ttl()
×
UNCOV
169
            self.start_mixin_bpod()
×
UNCOV
170
            self.start_mixin_valve()
×
UNCOV
171
            self.start_mixin_sound()
×
UNCOV
172
            self.start_mixin_rotary_encoder()
×
UNCOV
173
            self.start_mixin_bonsai_cameras()
×
UNCOV
174
            self.start_mixin_bonsai_microphone()
×
UNCOV
175
            self.start_mixin_bonsai_visual_stimulus()
×
UNCOV
176
            self.bpod.register_softcodes(self.softcode_dictionary())
×
177

178
    def _run(self):
2✔
179
        """Run the task with the actual state machine."""
180
        time_last_trial_end = time.time()
2✔
181
        for i in range(self.task_params.NTRIALS):  # Main loop
2✔
182
            # t_overhead = time.time()
183
            self.next_trial()
2✔
184
            log.info(f'Starting trial: {i}')
2✔
185
            # =============================================================================
186
            #     Start state machine definition
187
            # =============================================================================
188
            sma = self.get_state_machine_trial(i)
2✔
189
            log.debug('Sending state machine to bpod')
2✔
190
            # Send state machine description to Bpod device
191
            self.bpod.send_state_machine(sma)
2✔
192
            # t_overhead = time.time() - t_overhead
193
            # The ITI_DELAY_SECS defines the grey screen period within the state machine, where the
194
            # Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next
195
            dead_time = self.task_params.get('DEAD_TIME', 0.5)
2✔
196
            dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end)
2✔
197
            # wait to achieve the desired ITI duration
198
            if dt > 0:
2✔
UNCOV
199
                time.sleep(dt)
×
200
            # Run state machine
201
            log.debug('running state machine')
2✔
202
            self.bpod.run_state_machine(sma)  # Locks until state machine 'exit' is reached
2✔
203
            time_last_trial_end = time.time()
2✔
204
            # handle pause event
205
            flag_pause = self.paths.SESSION_FOLDER.joinpath('.pause')
2✔
206
            flag_stop = self.paths.SESSION_FOLDER.joinpath('.stop')
2✔
207
            if flag_pause.exists() and i < (self.task_params.NTRIALS - 1):
2✔
208
                log.info(f'Pausing session inbetween trials {i} and {i + 1}')
2✔
209
                while flag_pause.exists() and not flag_stop.exists():
2✔
210
                    time.sleep(1)
2✔
211
                self.trials_table.at[self.trial_num, 'pause_duration'] = time.time() - time_last_trial_end
2✔
212
                if not flag_stop.exists():
2✔
213
                    log.info('Resuming session')
2✔
214

215
            # save trial and update log
216
            self.trial_completed(self.bpod.session.current_trial.export())
2✔
217
            self.ambient_sensor_table.loc[i] = self.bpod.get_ambient_sensor_reading()
2✔
218
            self.show_trial_log()
2✔
219

220
            # handle stop event
221
            if flag_stop.exists():
2✔
222
                log.info('Stopping session after trial %d', i)
2✔
223
                flag_stop.unlink()
2✔
224
                break
2✔
225

226
    def mock(self, file_jsonable_fixture=None):
2✔
227
        """
228
        Instantiate a state machine and Bpod object to simulate a task's run.
229

230
        This is useful to test or display the state machine flow.
231
        """
232
        super().mock()
2✔
233

234
        if file_jsonable_fixture is not None:
2✔
235
            task_data = jsonable.read(file_jsonable_fixture)
2✔
236
            # pop-out the bpod data from the table
237
            bpod_data = []
2✔
238
            for td in task_data:
2✔
239
                bpod_data.append(td.pop('behavior_data'))
2✔
240

241
            class MockTrial(Trial):
2✔
242
                def export(self):
2✔
243
                    return np.random.choice(bpod_data)
2✔
244
        else:
245

246
            class MockTrial(Trial):
2✔
247
                def export(self):
2✔
UNCOV
248
                    return {}
×
249

250
        self.bpod.session.trials = [MockTrial()]
2✔
251
        self.bpod.send_state_machine = lambda k: None
2✔
252
        self.bpod.run_state_machine = lambda k: time.sleep(1.2)
2✔
253

254
        daction = ('dummy', 'action')
2✔
255
        self.sound = Bunch({'GO_TONE': daction, 'WHITE_NOISE': daction})
2✔
256

257
        self.bpod.actions.update(
2✔
258
            {
259
                'play_tone': daction,
260
                'play_noise': daction,
261
                'stop_sound': daction,
262
                'rotary_encoder_reset': daction,
263
                'bonsai_hide_stim': daction,
264
                'bonsai_show_stim': daction,
265
                'bonsai_closed_loop': daction,
266
                'bonsai_freeze_stim': daction,
267
                'bonsai_show_center': daction,
268
            }
269
        )
270

271
    def get_graphviz_task(self, output_file=None, view=True):
2✔
272
        """
273
        Get the state machine's states diagram in Digraph format.
274

275
        :param output_file:
276
        :return:
277
        """
278
        import graphviz
2✔
279

280
        self.next_trial()
2✔
281
        sma = self.get_state_machine_trial(0)
2✔
282
        if sma is None:
2✔
283
            return
2✔
284
        states_indices = {i: k for i, k in enumerate(sma.state_names)}
2✔
285
        states_indices.update({(i + 10000): k for i, k in enumerate(sma.undeclared)})
2✔
286
        states_letters = {k: ascii_letters[i] for i, k in enumerate(sma.state_names)}
2✔
287
        dot = graphviz.Digraph(comment='The Great IBL Task')
2✔
288
        edges = []
2✔
289

290
        for i in range(len(sma.state_names)):
2✔
291
            letter = states_letters[sma.state_names[i]]
2✔
292
            dot.node(letter, sma.state_names[i])
2✔
293
            if ~np.isnan(sma.state_timer_matrix[i]):
2✔
294
                out_state = states_indices[sma.state_timer_matrix[i]]
2✔
295
                edges.append(f'{letter}{states_letters[out_state]}')
2✔
296
            for input in sma.input_matrix[i]:
2✔
297
                if input[0] == 0:
2✔
298
                    edges.append(f'{letter}{states_letters[states_indices[input[1]]]}')
2✔
299
        dot.edges(edges)
2✔
300
        if output_file is not None:
2✔
301
            try:
2✔
302
                dot.render(output_file, view=view)
2✔
303
            except graphviz.exceptions.ExecutableNotFound:
2✔
304
                log.info('Graphviz system executable not found, cannot render the graph')
2✔
305
        return dot
2✔
306

307
    def _instantiate_state_machine(self, *args, **kwargs):
2✔
308
        return StateMachine(self.bpod)
2✔
309

310
    def get_state_machine_trial(self, i):
2✔
311
        # we define the trial number here for subclasses that may need it
312
        sma = self._instantiate_state_machine(trial_number=i)
2✔
313

314
        if i == 0:  # First trial exception start camera
2✔
315
            session_delay_start = self.task_params.get('SESSION_DELAY_START', 0)
2✔
316
            log.info('First trial initializing, will move to next trial only if:')
2✔
317
            log.info('1. camera is detected')
2✔
318
            log.info(f'2. {session_delay_start} sec have elapsed')
2✔
319
            sma.add_state(
2✔
320
                state_name='trial_start',
321
                state_timer=0,
322
                state_change_conditions={'Port1In': 'delay_initiation'},
323
                output_actions=[('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
324
            )  # start camera
325
            sma.add_state(
2✔
326
                state_name='delay_initiation',
327
                state_timer=session_delay_start,
328
                output_actions=[],
329
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
330
            )
331
        else:
332
            sma.add_state(
2✔
333
                state_name='trial_start',
334
                state_timer=0,  # ~100µs hardware irreducible delay
335
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
336
                output_actions=[self.bpod.actions.stop_sound, ('BNC1', 255)],
337
            )  # stop all sounds
338

339
        # Reset the rotary encoder by sending the following opcodes via the modules serial interface
340
        # - 'Z' (ASCII 90): Set current rotary encoder position to zero
341
        # - 'E' (ASCII 69): Enable all position thresholds (that may have been disabled by a threshold-crossing)
342
        # cf. https://sanworks.github.io/Bpod_Wiki/serial-interfaces/rotary-encoder-module-serial-interface/
343
        sma.add_state(
2✔
344
            state_name='reset_rotary_encoder',
345
            state_timer=0,
346
            output_actions=[self.bpod.actions.rotary_encoder_reset],
347
            state_change_conditions={'Tup': 'quiescent_period'},
348
        )
349

350
        # Quiescent Period. If the wheel is moved past one of the thresholds: Reset the rotary encoder and start over.
351
        # Continue with the stimulation once the quiescent period has passed without triggering movement thresholds.
352
        sma.add_state(
2✔
353
            state_name='quiescent_period',
354
            state_timer=self.quiescent_period,
355
            output_actions=[],
356
            state_change_conditions={
357
                'Tup': 'stim_on',
358
                self.movement_left: 'reset_rotary_encoder',
359
                self.movement_right: 'reset_rotary_encoder',
360
            },
361
        )
362

363
        # Show the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
364
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
365
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
366
        sma.add_state(
2✔
367
            state_name='stim_on',
368
            state_timer=0.1,
369
            output_actions=[self.bpod.actions.bonsai_show_stim],
370
            state_change_conditions={'BNC1High': 'interactive_delay', 'BNC1Low': 'interactive_delay', 'Tup': 'interactive_delay'},
371
        )
372

373
        # Defined delay between visual and auditory cue
374
        sma.add_state(
2✔
375
            state_name='interactive_delay',
376
            state_timer=self.task_params.INTERACTIVE_DELAY,
377
            output_actions=[],
378
            state_change_conditions={'Tup': 'play_tone'},
379
        )
380

381
        # Play tone. Move to next state if sound is detected. Use the state-timer as a backup to prevent a stall.
382
        sma.add_state(
2✔
383
            state_name='play_tone',
384
            state_timer=0.1,
385
            output_actions=[self.bpod.actions.play_tone],
386
            state_change_conditions={'Tup': 'reset2_rotary_encoder', 'BNC2High': 'reset2_rotary_encoder'},
387
        )
388

389
        # Reset rotary encoder (see above). Move on after brief delay (to avoid a race conditions in the bonsai flow).
390
        sma.add_state(
2✔
391
            state_name='reset2_rotary_encoder',
392
            state_timer=0.05,
393
            output_actions=[self.bpod.actions.rotary_encoder_reset],
394
            state_change_conditions={'Tup': 'closed_loop'},
395
        )
396

397
        # Start the closed loop state in which the animal controls the position of the visual stimulus by means of the
398
        # rotary encoder. The three possible outcomes are:
399
        # 1) wheel has NOT been moved past a threshold: continue with no-go condition
400
        # 2) wheel has been moved in WRONG direction: continue with error condition
401
        # 3) wheel has been moved in CORRECT direction: continue with reward condition
402

403
        sma.add_state(
2✔
404
            state_name='closed_loop',
405
            state_timer=self.task_params.RESPONSE_WINDOW,
406
            output_actions=[self.bpod.actions.bonsai_closed_loop],
407
            state_change_conditions={'Tup': 'no_go', self.event_error: 'freeze_error', self.event_reward: 'freeze_reward'},
408
        )
409

410
        # No-go: hide the visual stimulus and play white noise. Go to exit_state after FEEDBACK_NOGO_DELAY_SECS.
411
        sma.add_state(
2✔
412
            state_name='no_go',
413
            state_timer=self.task_params.FEEDBACK_NOGO_DELAY_SECS,
414
            output_actions=[self.bpod.actions.bonsai_hide_stim, self.bpod.actions.play_noise],
415
            state_change_conditions={'Tup': 'exit_state'},
416
        )
417

418
        # Error: Freeze the stimulus and play white noise.
419
        # Continue to hide_stim/exit_state once FEEDBACK_ERROR_DELAY_SECS have passed.
420
        sma.add_state(
2✔
421
            state_name='freeze_error',
422
            state_timer=0,
423
            output_actions=[self.bpod.actions.bonsai_freeze_stim],
424
            state_change_conditions={'Tup': 'error'},
425
        )
426
        sma.add_state(
2✔
427
            state_name='error',
428
            state_timer=self.task_params.FEEDBACK_ERROR_DELAY_SECS,
429
            output_actions=[self.bpod.actions.play_noise],
430
            state_change_conditions={'Tup': 'hide_stim'},
431
        )
432

433
        # Reward: open the valve for a defined duration (and set BNC1 to high), freeze stimulus in center of screen.
434
        # Continue to hide_stim/exit_state once FEEDBACK_CORRECT_DELAY_SECS have passed.
435
        sma.add_state(
2✔
436
            state_name='freeze_reward',
437
            state_timer=0,
438
            output_actions=[self.bpod.actions.bonsai_show_center],
439
            state_change_conditions={'Tup': 'reward'},
440
        )
441
        sma.add_state(
2✔
442
            state_name='reward',
443
            state_timer=self.reward_time,
444
            output_actions=[('Valve1', 255), ('BNC1', 255)],
445
            state_change_conditions={'Tup': 'correct'},
446
        )
447
        sma.add_state(
2✔
448
            state_name='correct',
449
            state_timer=self.task_params.FEEDBACK_CORRECT_DELAY_SECS - self.reward_time,
450
            output_actions=[],
451
            state_change_conditions={'Tup': 'hide_stim'},
452
        )
453

454
        # Hide the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
455
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
456
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
457
        sma.add_state(
2✔
458
            state_name='hide_stim',
459
            state_timer=0.1,
460
            output_actions=[self.bpod.actions.bonsai_hide_stim],
461
            state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'},
462
        )
463

464
        # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event.
465
        sma.add_state(
2✔
466
            state_name='exit_state',
467
            state_timer=self.task_params.ITI_DELAY_SECS,
468
            output_actions=[('BNC1', 255)],
469
            state_change_conditions={'Tup': 'exit'},
470
        )
471

472
        return sma
2✔
473

474
    @abc.abstractmethod
2✔
475
    def next_trial(self):
2✔
476
        pass
×
477

478
    @property
2✔
479
    def default_reward_amount(self):
2✔
480
        return self.task_params.REWARD_AMOUNT_UL
2✔
481

482
    def draw_next_trial_info(self, pleft=0.5, **kwargs):
2✔
483
        """Draw next trial variables.
484

485
        calls :meth:`send_trial_info_to_bonsai`.
486
        This is called by the `next_trial` method before updating the Bpod state machine.
487
        """
488
        assert len(self.task_params.STIM_POSITIONS) == 2, 'Only two positions are supported'
2✔
489
        contrast = misc.draw_contrast(self.task_params.CONTRAST_SET, self.task_params.CONTRAST_SET_PROBABILITY_TYPE)
2✔
490
        position = int(np.random.choice(self.task_params.STIM_POSITIONS, p=[pleft, 1 - pleft]))
2✔
491
        quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.truncated_exponential(
2✔
492
            scale=0.35, min_value=0.2, max_value=0.5
493
        )
494
        stim_gain = (
2✔
495
            self.session_info.ADAPTIVE_GAIN_VALUE if self.task_params.get('ADAPTIVE_GAIN', False) else self.task_params.STIM_GAIN
496
        )
497
        self.trials_table.at[self.trial_num, 'quiescent_period'] = quiescent_period
2✔
498
        self.trials_table.at[self.trial_num, 'contrast'] = contrast
2✔
499
        self.trials_table.at[self.trial_num, 'stim_phase'] = random.uniform(0, 2 * math.pi)
2✔
500
        self.trials_table.at[self.trial_num, 'stim_sigma'] = self.task_params.STIM_SIGMA
2✔
501
        self.trials_table.at[self.trial_num, 'stim_angle'] = self.task_params.STIM_ANGLE
2✔
502
        self.trials_table.at[self.trial_num, 'stim_gain'] = stim_gain
2✔
503
        self.trials_table.at[self.trial_num, 'stim_freq'] = self.task_params.STIM_FREQ
2✔
504
        self.trials_table.at[self.trial_num, 'stim_reverse'] = self.task_params.STIM_REVERSE
2✔
505
        self.trials_table.at[self.trial_num, 'trial_num'] = self.trial_num
2✔
506
        self.trials_table.at[self.trial_num, 'position'] = position
2✔
507
        self.trials_table.at[self.trial_num, 'reward_amount'] = self.default_reward_amount
2✔
508
        self.trials_table.at[self.trial_num, 'stim_probability_left'] = pleft
2✔
509

510
        # use the kwargs dict to override computed values
511
        for key, value in kwargs.items():
2✔
512
            if key == 'index':
2✔
UNCOV
513
                pass
×
514
            self.trials_table.at[self.trial_num, key] = value
2✔
515

516
        self.send_trial_info_to_bonsai()
2✔
517

518
    def trial_completed(self, bpod_data: dict[str, Any]) -> None:
2✔
519
        # if the reward state has not been triggered, null the reward
520
        if np.isnan(bpod_data['States timestamps']['reward'][0][0]):
2✔
521
            self.trials_table.at[self.trial_num, 'reward_amount'] = 0
2✔
522
        self.trials_table.at[self.trial_num, 'reward_valve_time'] = self.reward_time
2✔
523
        # update cumulative reward value
524
        self.session_info.TOTAL_WATER_DELIVERED += self.trials_table.at[self.trial_num, 'reward_amount']
2✔
525
        self.session_info.NTRIALS += 1
2✔
526
        # SAVE TRIAL DATA
527
        self.save_trial_data_to_json(bpod_data)
2✔
528
        # this is a flag for the online plots. If online plots were in pyqt5, there is a file watcher functionality
529
        Path(self.paths['DATA_FILE_PATH']).parent.joinpath('new_trial.flag').touch()
2✔
530
        self.paths.SESSION_FOLDER.joinpath('transfer_me.flag').touch()
2✔
531
        self.check_sync_pulses(bpod_data=bpod_data)
2✔
532

533
    def check_sync_pulses(self, bpod_data):
2✔
534
        # todo move this in the post trial when we have a task flow
535
        if not self.bpod.is_connected:
2✔
536
            return
2✔
UNCOV
537
        events = bpod_data['Events timestamps']
×
UNCOV
538
        if not misc.get_port_events(events, name='BNC1'):
×
UNCOV
539
            log.warning("NO FRAME2TTL PULSES RECEIVED ON BPOD'S TTL INPUT 1")
×
UNCOV
540
        if not misc.get_port_events(events, name='BNC2'):
×
UNCOV
541
            log.warning("NO SOUND SYNC PULSES RECEIVED ON BPOD'S TTL INPUT 2")
×
UNCOV
542
        if not misc.get_port_events(events, name='Port1'):
×
UNCOV
543
            log.warning("NO CAMERA SYNC PULSES RECEIVED ON BPOD'S BEHAVIOR PORT 1")
×
544

545
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
546
        """
547
        Log the details of the current trial.
548

549
        This method retrieves information about the current trial from the
550
        trials table and logs it. It can also incorporate additional information
551
        provided through the `extra_info` parameter.
552

553
        Parameters
554
        ----------
555
        extra_info : dict[str, Any], optional
556
            A dictionary containing additional information to include in the
557
            log.
558

559
        log_level : int, optional
560
            The logging level to use when logging the trial information.
561
            Default is logging.INFO.
562

563
        Notes
564
        -----
565
        When overloading, make sure to call the super class and pass additional
566
        log items by means of the extra_info parameter. See the implementation
567
        of :py:meth:`~iblrig.base_choice_world.ActiveChoiceWorldSession.show_trial_log` in
568
        :mod:`~iblrig.base_choice_world.ActiveChoiceWorldSession` for reference.
569
        """
570
        # construct base info dict
571
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
572
        info_dict = {
2✔
573
            'Stim. Position': trial_info.position,
574
            'Stim. Contrast': trial_info.contrast,
575
            'Stim. Phase': f'{trial_info.stim_phase:.2f}',
576
            'Stim. p Left': trial_info.stim_probability_left,
577
            'Water delivered': f'{self.session_info.TOTAL_WATER_DELIVERED:.1f} µl',
578
            'Time from Start': self.time_elapsed,
579
            'Temperature': f'{self.ambient_sensor_table.loc[self.trial_num, "Temperature_C"]:.1f} °C',
580
            'Air Pressure': f'{self.ambient_sensor_table.loc[self.trial_num, "AirPressure_mb"]:.1f} mb',
581
            'Rel. Humidity': f'{self.ambient_sensor_table.loc[self.trial_num, "RelativeHumidity"]:.1f} %',
582
        }
583

584
        # update info dict with extra_info dict
585
        if isinstance(extra_info, dict):
2✔
586
            info_dict.update(extra_info)
2✔
587

588
        # log info dict
589
        log.log(log_level, f'Outcome of Trial #{trial_info.trial_num}:')
2✔
590
        max_key_length = max(len(key) for key in info_dict)
2✔
591
        for key, value in info_dict.items():
2✔
592
            spaces = (max_key_length - len(key)) * ' '
2✔
593
            log.log(log_level, f'- {key}: {spaces}{str(value)}')
2✔
594

595
    @property
2✔
596
    def iti_reward(self):
2✔
597
        """
598
        Returns the ITI time that needs to be set in order to achieve the desired ITI,
599
        by subtracting the time it takes to give a reward from the desired ITI.
600
        """
UNCOV
601
        return self.task_params.ITI_CORRECT - self.calibration.get('REWARD_VALVE_TIME', None)
×
602

603
    """
2✔
604
    Those are the properties that are used in the state machine code
605
    """
606

607
    @property
2✔
608
    def reward_time(self):
2✔
609
        return self.compute_reward_time(amount_ul=self.trials_table.at[self.trial_num, 'reward_amount'])
2✔
610

611
    @property
2✔
612
    def quiescent_period(self):
2✔
613
        return self.trials_table.at[self.trial_num, 'quiescent_period']
2✔
614

615
    @property
2✔
616
    def position(self):
2✔
617
        return self.trials_table.at[self.trial_num, 'position']
2✔
618

619
    @property
2✔
620
    def event_error(self):
2✔
621
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(-1 if self.task_params.STIM_REVERSE else 1) * self.position]
2✔
622

623
    @property
2✔
624
    def event_reward(self):
2✔
625
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(1 if self.task_params.STIM_REVERSE else -1) * self.position]
2✔
626

627

628
class HabituationChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
629
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
630

631
    delay_to_stim_center: NonNegativeFloat
2✔
632

633

634
class HabituationChoiceWorldSession(ChoiceWorldSession):
2✔
635
    protocol_name = '_iblrig_tasks_habituationChoiceWorld'
2✔
636
    TrialDataModel = HabituationChoiceWorldTrialData
2✔
637

638
    def next_trial(self):
2✔
639
        self.trial_num += 1
2✔
640
        self.draw_next_trial_info()
2✔
641

642
    def draw_next_trial_info(self, *args, **kwargs):
2✔
643
        # update trial table fields specific to habituation choice world
644
        self.trials_table.at[self.trial_num, 'delay_to_stim_center'] = np.random.normal(self.task_params.DELAY_TO_STIM_CENTER, 2)
2✔
645
        super().draw_next_trial_info(*args, **kwargs)
2✔
646

647
    def get_state_machine_trial(self, i):
2✔
648
        sma = StateMachine(self.bpod)
2✔
649

650
        if i == 0:  # First trial exception start camera
2✔
651
            log.info('Waiting for camera pulses...')
2✔
652
            sma.add_state(
2✔
653
                state_name='iti',
654
                state_timer=3600,
655
                state_change_conditions={'Port1In': 'stim_on'},
656
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
657
            )  # start camera
658
        else:
659
            # NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on.
660
            # During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end;
661
            # the offset of this state is trial start!
UNCOV
662
            sma.add_state(
×
663
                state_name='iti',
664
                state_timer=1,  # Stim off for 1 sec
665
                state_change_conditions={'Tup': 'stim_on'},
666
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)],
667
            )
668
        # This stim_on state is considered the actual trial start
669
        sma.add_state(
2✔
670
            state_name='stim_on',
671
            state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'],
672
            state_change_conditions={'Tup': 'stim_center'},
673
            output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone],
674
        )
675

676
        sma.add_state(
2✔
677
            state_name='stim_center',
678
            state_timer=0.5,
679
            state_change_conditions={'Tup': 'reward'},
680
            output_actions=[self.bpod.actions.bonsai_show_center],
681
        )
682

683
        sma.add_state(
2✔
684
            state_name='reward',
685
            state_timer=self.reward_time,  # the length of time to leave reward valve open, i.e. reward size
686
            state_change_conditions={'Tup': 'post_reward'},
687
            output_actions=[('Valve1', 255), ('BNC1', 255)],
688
        )
689
        # This state defines the period after reward where Bpod TTL is LOW.
690
        # NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit.
691
        # The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds.
692
        sma.add_state(
2✔
693
            state_name='post_reward',
694
            state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time,
695
            state_change_conditions={'Tup': 'exit'},
696
            output_actions=[],
697
        )
698
        return sma
2✔
699

700

701
class ActiveChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
702
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
703

704
    response_side: Annotated[int, Interval(ge=-1, le=1)]
2✔
705
    response_time: NonNegativeFloat
2✔
706
    trial_correct: bool
2✔
707

708

709
class ActiveChoiceWorldSession(ChoiceWorldSession):
2✔
710
    """
711
    The ActiveChoiceWorldSession is a base class for protocols where the mouse is actively making decisions
712
    by turning the wheel. It has the following characteristics
713

714
    -   it is trial based
715
    -   it is decision based
716
    -   left and right simulus are equiprobable: there is no biased block
717
    -   a trial can either be correct / error / no_go depending on the side of the stimulus and the response
718
    -   it has a quantifiable performance by computing the proportion of correct trials of passive stimulations protocols or
719
        habituation protocols.
720

721
    The TrainingChoiceWorld, BiasedChoiceWorld are all subclasses of this class
722
    """
723

724
    TrialDataModel = ActiveChoiceWorldTrialData
2✔
725

726
    def __init__(self, **kwargs):
2✔
727
        super().__init__(**kwargs)
2✔
728
        self.trials_table['stim_probability_left'] = np.zeros(NTRIALS_INIT, dtype=np.float64)
2✔
729

730
    def _run(self):
2✔
731
        # starts online plotting
732
        if self.interactive:
2✔
UNCOV
733
            subprocess.Popen(
×
734
                ['view_session', str(self.paths['DATA_FILE_PATH']), str(self.paths['SETTINGS_FILE_PATH'])],
735
                stdout=subprocess.DEVNULL,
736
                stderr=subprocess.STDOUT,
737
            )
738
        super()._run()
2✔
739

740
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
741
        # construct info dict
742
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
743
        info_dict = {
2✔
744
            'Response Time': f'{trial_info.response_time:.2f} s',
745
            'Trial Correct': trial_info.trial_correct,
746
            'N Trials Correct': self.session_info.NTRIALS_CORRECT,
747
            'N Trials Error': self.trial_num - self.session_info.NTRIALS_CORRECT,
748
        }
749

750
        # update info dict with extra_info dict
751
        if isinstance(extra_info, dict):
2✔
752
            info_dict.update(extra_info)
2✔
753

754
        # call parent method
755
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
756

757
    def trial_completed(self, bpod_data):
2✔
758
        """
759
        The purpose of this method is to
760

761
        - update the trials table with information about the behaviour coming from the bpod
762
          Constraints on the state machine data:
763
        - mandatory states: ['correct', 'error', 'no_go', 'reward']
764
        - optional states : ['omit_correct', 'omit_error', 'omit_no_go']
765

766
        :param bpod_data:
767
        :return:
768
        """
769
        # get the response time from the behaviour data
770
        response_time = bpod_data['States timestamps']['closed_loop'][0][1] - bpod_data['States timestamps']['stim_on'][0][0]
2✔
771
        self.trials_table.at[self.trial_num, 'response_time'] = response_time
2✔
772
        # get the trial outcome
773
        state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go']
2✔
774
        raw_outcome = {sn: ~np.isnan(bpod_data['States timestamps'].get(sn, [[np.NaN]])[0][0]) for sn in state_names}
2✔
775
        try:
2✔
776
            outcome = next(k for k in raw_outcome if raw_outcome[k])
2✔
777
            # Update response buffer -1 for left, 0 for nogo, and 1 for rightward
778
            position = self.trials_table.at[self.trial_num, 'position']
2✔
779
            self.trials_table.at[self.trial_num, 'trial_correct'] = 'correct' in outcome
2✔
780
            if 'correct' in outcome:
2✔
781
                self.session_info.NTRIALS_CORRECT += 1
2✔
782
                self.trials_table.at[self.trial_num, 'response_side'] = -np.sign(position)
2✔
783
            elif 'error' in outcome:
2✔
784
                self.trials_table.at[self.trial_num, 'response_side'] = np.sign(position)
2✔
785
            elif 'no_go' in outcome:
2✔
786
                self.trials_table.at[self.trial_num, 'response_side'] = 0
2✔
787
            super().trial_completed(bpod_data)
2✔
788
            # here we throw potential errors after having written the trial to disk
789
            assert np.sum(list(raw_outcome.values())) == 1
2✔
790
            assert position != 0, 'the position value should be either 35 or -35'
2✔
UNCOV
791
        except StopIteration as e:
×
UNCOV
792
            log.error(f'No outcome detected for trial {self.trial_num}.')
×
UNCOV
793
            log.error(f'raw_outcome: {raw_outcome}')
×
UNCOV
794
            log.error('State names: ' + ', '.join(bpod_data['States timestamps'].keys()))
×
UNCOV
795
            raise e
×
UNCOV
796
        except AssertionError as e:
×
UNCOV
797
            log.error(f'Assertion Error in trial {self.trial_num}.')
×
UNCOV
798
            log.error(f'raw_outcome: {raw_outcome}')
×
UNCOV
799
            log.error('State names: ' + ', '.join(bpod_data['States timestamps'].keys()))
×
UNCOV
800
            raise e
×
801

802

803
class BiasedChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
804
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
805

806
    block_num: NonNegativeInt = 0
2✔
807
    block_trial_num: NonNegativeInt = 0
2✔
808

809

810
class BiasedChoiceWorldSession(ActiveChoiceWorldSession):
2✔
811
    """
812
    Biased choice world session is the instantiation of ActiveChoiceWorld where the notion of biased
813
    blocks is introduced.
814
    """
815

816
    base_parameters_file = Path(__file__).parent.joinpath('base_biased_choice_world_params.yaml')
2✔
817
    protocol_name = '_iblrig_tasks_biasedChoiceWorld'
2✔
818
    TrialDataModel = BiasedChoiceWorldTrialData
2✔
819

820
    def __init__(self, **kwargs):
2✔
821
        super().__init__(**kwargs)
2✔
822
        self.blocks_table = pd.DataFrame(
2✔
823
            {'probability_left': np.zeros(NBLOCKS_INIT) * np.NaN, 'block_length': np.zeros(NBLOCKS_INIT, dtype=np.int16) * -1}
824
        )
825

826
    def new_block(self):
2✔
827
        """
828
        if block_init_5050
829
            First block has 50/50 probability of leftward stim
830
            is 90 trials long
831
        """
832
        self.block_num += 1  # the block number is zero based
2✔
833
        self.block_trial_num = 0
2✔
834

835
        # handles the block length logic
836
        if self.task_params.BLOCK_INIT_5050 and self.block_num == 0:
2✔
837
            block_len = 90
2✔
838
        else:
839
            block_len = int(
2✔
840
                misc.truncated_exponential(
841
                    scale=self.task_params.BLOCK_LEN_FACTOR,
842
                    min_value=self.task_params.BLOCK_LEN_MIN,
843
                    max_value=self.task_params.BLOCK_LEN_MAX,
844
                )
845
            )
846
        if self.block_num == 0:
2✔
847
            pleft = 0.5 if self.task_params.BLOCK_INIT_5050 else np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
848
        elif self.block_num == 1 and self.task_params.BLOCK_INIT_5050:
2✔
849
            pleft = np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
850
        else:
851
            # this switches the probability of leftward stim for the next block
852
            pleft = round(abs(1 - self.blocks_table.loc[self.block_num - 1, 'probability_left']), 1)
2✔
853
        self.blocks_table.at[self.block_num, 'block_length'] = block_len
2✔
854
        self.blocks_table.at[self.block_num, 'probability_left'] = pleft
2✔
855

856
    def next_trial(self):
2✔
857
        self.trial_num += 1
2✔
858
        # if necessary update the block number
859
        self.block_trial_num += 1
2✔
860
        if self.block_num < 0 or self.block_trial_num > (self.blocks_table.loc[self.block_num, 'block_length'] - 1):
2✔
861
            self.new_block()
2✔
862
        # get and store probability left
863
        pleft = self.blocks_table.loc[self.block_num, 'probability_left']
2✔
864
        # update trial table fields specific to biased choice world task
865
        self.trials_table.at[self.trial_num, 'block_num'] = self.block_num
2✔
866
        self.trials_table.at[self.trial_num, 'block_trial_num'] = self.block_trial_num
2✔
867
        # save and send trial info to bonsai
868
        self.draw_next_trial_info(pleft=pleft)
2✔
869

870
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
871
        # construct info dict
872
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
873
        info_dict = {
2✔
874
            'Block Number': trial_info.block_num,
875
            'Block Length': self.blocks_table.loc[self.block_num, 'block_length'],
876
            'N Trials in Block': trial_info.block_trial_num,
877
        }
878

879
        # update info dict with extra_info dict
880
        if isinstance(extra_info, dict):
2✔
UNCOV
881
            info_dict.update(extra_info)
×
882

883
        # call parent method
884
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
885

886

887
class TrainingChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
888
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ActiveChoiceWorldTrialData`."""
889

890
    training_phase: NonNegativeInt
2✔
891
    debias_trial: bool
2✔
892

893

894
class TrainingChoiceWorldSession(ActiveChoiceWorldSession):
2✔
895
    """
896
    The TrainingChoiceWorldSession corresponds to the first training protocol of the choice world task.
897
    This protocol has a complicated adaptation of the number of contrasts (embodied by the training_phase
898
    property) and the reward amount, embodied by the adaptive_reward property.
899
    """
900

901
    protocol_name = '_iblrig_tasks_trainingChoiceWorld'
2✔
902
    TrialDataModel = TrainingChoiceWorldTrialData
2✔
903

904
    def __init__(self, training_phase=-1, adaptive_reward=-1.0, adaptive_gain=None, **kwargs):
2✔
905
        super().__init__(**kwargs)
2✔
906
        inferred_training_phase, inferred_adaptive_reward, inferred_adaptive_gain = self.get_subject_training_info()
2✔
907
        if training_phase == -1:
2✔
908
            log.critical(f'Got training phase: {inferred_training_phase}')
2✔
909
            self.training_phase = inferred_training_phase
2✔
910
        else:
911
            log.critical(f'Training phase manually set to: {training_phase}')
2✔
912
            self.training_phase = training_phase
2✔
913
        if adaptive_reward == -1:
2✔
914
            log.critical(f'Got Adaptive reward {inferred_adaptive_reward} uL')
2✔
915
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = inferred_adaptive_reward
2✔
916
        else:
917
            log.critical(f'Adaptive reward manually set to {adaptive_reward} uL')
2✔
918
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = adaptive_reward
2✔
919
        if adaptive_gain is None:
2✔
920
            log.critical(f'Got Adaptive gain {inferred_adaptive_gain} degrees/mm')
2✔
921
            self.session_info['ADAPTIVE_GAIN_VALUE'] = inferred_adaptive_gain
2✔
922
        else:
923
            log.critical(f'Adaptive gain manually set to {adaptive_gain} degrees/mm')
2✔
924
            self.session_info['ADAPTIVE_GAIN_VALUE'] = adaptive_gain
2✔
925
        self.var = {'training_phase_trial_counts': np.zeros(6), 'last_10_responses_sides': np.zeros(10)}
2✔
926

927
    @property
2✔
928
    def default_reward_amount(self):
2✔
929
        return self.session_info.get('ADAPTIVE_REWARD_AMOUNT_UL', self.task_params.REWARD_AMOUNT_UL)
2✔
930

931
    def get_subject_training_info(self):
2✔
932
        """
933
        Get the previous session's according to this session parameters and deduce the
934
        training level, adaptive reward amount and adaptive gain value.
935

936
        Returns
937
        -------
938
        training_info: dict
939
            Dictionary with keys: training_phase, adaptive_reward, adaptive_gain
940
        """
941
        training_info, _ = choiceworld.get_subject_training_info(
2✔
942
            subject_name=self.session_info.SUBJECT_NAME,
943
            task_name=self.protocol_name,
944
            stim_gain=self.task_params.AG_INIT_VALUE,
945
            stim_gain_on_error=self.task_params.STIM_GAIN,
946
            default_reward=self.task_params.REWARD_AMOUNT_UL,
947
            local_path=self.iblrig_settings['iblrig_local_data_path'],
948
            remote_path=self.iblrig_settings['iblrig_remote_data_path'],
949
            lab=self.iblrig_settings['ALYX_LAB'],
950
            iblrig_settings=self.iblrig_settings,
951
        )
952
        return training_info['training_phase'], training_info['adaptive_reward'], training_info['adaptive_gain']
2✔
953

954
    def compute_performance(self):
2✔
955
        """Aggregate the trials table to compute the performance of the mouse on each contrast."""
956
        self.trials_table['signed_contrast'] = self.trials_table.contrast * self.trials_table.position
2✔
957
        performance = self.trials_table.groupby(['signed_contrast']).agg(
2✔
958
            last_50_perf=pd.NamedAgg(column='trial_correct', aggfunc=lambda x: np.sum(x[np.maximum(-50, -x.size) :]) / 50),
959
            ntrials=pd.NamedAgg(column='trial_correct', aggfunc='count'),
960
        )
961
        return performance
2✔
962

963
    def check_training_phase(self):
2✔
964
        """Check if the mouse is ready to move to the next training phase."""
965
        move_on = False
2✔
966
        if self.training_phase == 0:  # each of the -1, -.5, .5, 1 contrast should be above 80% perf to switch
2✔
967
            performance = self.compute_performance()
2✔
968
            passing = performance[np.abs(performance.index) >= 0.5]['last_50_perf']
2✔
969
            if np.all(passing > 0.8) and passing.size == 4:
2✔
970
                move_on = True
2✔
971
        elif self.training_phase == 1:  # each of the -.25, .25 should be above 80% perf to switch
2✔
972
            performance = self.compute_performance()
2✔
973
            passing = performance[np.abs(performance.index) == 0.25]['last_50_perf']
2✔
974
            if np.all(passing > 0.8) and passing.size == 2:
2✔
UNCOV
975
                move_on = True
×
UNCOV
976
        elif 5 > self.training_phase >= 2:  # for the next phases, always switch after 200 trials
×
UNCOV
977
            if self.var['training_phase_trial_counts'][self.training_phase] >= 200:
×
UNCOV
978
                move_on = True
×
979
        if move_on:
2✔
980
            self.training_phase = np.minimum(5, self.training_phase + 1)
2✔
981
            log.warning(f'Moving on to training phase {self.training_phase}, {self.trial_num}')
2✔
982

983
    def next_trial(self):
2✔
984
        # update counters
985
        self.trial_num += 1
2✔
986
        self.var['training_phase_trial_counts'][self.training_phase] += 1
2✔
987
        # check if the subject graduates to a new training phase
988
        self.check_training_phase()
2✔
989
        # draw the next trial
990
        signed_contrast = choiceworld.draw_training_contrast(self.training_phase)
2✔
991
        position = self.task_params.STIM_POSITIONS[int(np.sign(signed_contrast) == 1)]
2✔
992
        contrast = np.abs(signed_contrast)
2✔
993
        # debiasing: if the previous trial was incorrect and easy repeat the trial
994
        if self.task_params.DEBIAS and self.trial_num >= 1 and self.training_phase < 5:
2✔
995
            last_contrast = self.trials_table.loc[self.trial_num - 1, 'contrast']
2✔
996
            do_debias_trial = (self.trials_table.loc[self.trial_num - 1, 'trial_correct'] != 1) and last_contrast >= 0.5
2✔
997
            self.trials_table.at[self.trial_num, 'debias_trial'] = do_debias_trial
2✔
998
            if do_debias_trial:
2✔
999
                iresponse = self.trials_table['response_side'] != 0  # trials that had a response
2✔
1000
                # takes the average of right responses over last 10 response trials
1001
                average_right = np.mean(self.trials_table['response_side'][iresponse[-np.maximum(10, iresponse.size) :]] == 1)
2✔
1002
                # the next probability of next stimulus being on the left is a draw from a normal distribution
1003
                # centered on average right with sigma 0.5. If it is less than 0.5 the next stimulus will be on the left
1004
                position = self.task_params.STIM_POSITIONS[int(np.random.normal(average_right, 0.5) >= 0.5)]
2✔
1005
                # contrast is the last contrast
1006
                contrast = last_contrast
2✔
1007
        else:
1008
            self.trials_table.at[self.trial_num, 'debias_trial'] = False
2✔
1009
        # save and send trial info to bonsai
1010
        self.draw_next_trial_info(pleft=self.task_params.PROBABILITY_LEFT, position=position, contrast=contrast)
2✔
1011
        self.trials_table.at[self.trial_num, 'training_phase'] = self.training_phase
2✔
1012

1013
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
1014
        # construct info dict
1015
        info_dict = {
2✔
1016
            'Contrast Set': np.unique(np.abs(choiceworld.contrasts_set(self.training_phase))),
1017
            'Training Phase': self.training_phase,
1018
        }
1019

1020
        # update info dict with extra_info dict
1021
        if isinstance(extra_info, dict):
2✔
UNCOV
1022
            info_dict.update(extra_info)
×
1023

1024
        # call parent method
1025
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc