• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 11407201950

18 Oct 2024 04:12PM UTC coverage: 47.898% (+1.1%) from 46.79%
11407201950

Pull #730

github

86ab26
web-flow
Merge 9801a3e94 into 0f4a57326
Pull Request #730: 8.24.4

47 of 68 new or added lines in 8 files covered. (69.12%)

1013 existing lines in 22 files now uncovered.

4170 of 8706 relevant lines covered (47.9%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.22
/iblrig/base_choice_world.py
1
"""Extends the base_tasks modules by providing task logic around the Choice World protocol."""
2

3
import abc
2✔
4
import logging
2✔
5
import math
2✔
6
import random
2✔
7
import subprocess
2✔
8
import time
2✔
9
from pathlib import Path
2✔
10
from string import ascii_letters
2✔
11
from typing import Annotated, Any
2✔
12

13
import numpy as np
2✔
14
import pandas as pd
2✔
15
from annotated_types import Interval, IsNan
2✔
16
from pydantic import NonNegativeFloat, NonNegativeInt
2✔
17

18
import iblrig.base_tasks
2✔
19
import iblrig.graphic
2✔
20
from iblrig import choiceworld, misc
2✔
21
from iblrig.hardware import SOFTCODE
2✔
22
from iblrig.pydantic_definitions import TrialDataModel
2✔
23
from iblutil.io import jsonable
2✔
24
from iblutil.util import Bunch
2✔
25
from pybpodapi.com.messaging.trial import Trial
2✔
26
from pybpodapi.protocol import StateMachine
2✔
27

28
log = logging.getLogger(__name__)
2✔
29

30
NTRIALS_INIT = 2000
2✔
31
NBLOCKS_INIT = 100
2✔
32

33
# TODO: task parameters should be verified through a pydantic model
34
#
35
# Probability = Annotated[float, Field(ge=0.0, le=1.0)]
36
#
37
# class ChoiceWorldParams(BaseModel):
38
#     AUTOMATIC_CALIBRATION: bool = True
39
#     ADAPTIVE_REWARD: bool = False
40
#     BONSAI_EDITOR: bool = False
41
#     CALIBRATION_VALUE: float = 0.067
42
#     CONTRAST_SET: list[Probability] = Field([1.0, 0.25, 0.125, 0.0625, 0.0], min_length=1)
43
#     CONTRAST_SET_PROBABILITY_TYPE: Literal['uniform', 'skew_zero'] = 'uniform'
44
#     GO_TONE_AMPLITUDE: float = 0.0272
45
#     GO_TONE_DURATION: float = 0.11
46
#     GO_TONE_IDX: int = Field(2, ge=0)
47
#     GO_TONE_FREQUENCY: float = Field(5000, gt=0)
48
#     FEEDBACK_CORRECT_DELAY_SECS: float = 1
49
#     FEEDBACK_ERROR_DELAY_SECS: float = 2
50
#     FEEDBACK_NOGO_DELAY_SECS: float = 2
51
#     INTERACTIVE_DELAY: float = 0.0
52
#     ITI_DELAY_SECS: float = 0.5
53
#     NTRIALS: int = Field(2000, gt=0)
54
#     PROBABILITY_LEFT: Probability = 0.5
55
#     QUIESCENCE_THRESHOLDS: list[float] = Field(default=[-2, 2], min_length=2, max_length=2)
56
#     QUIESCENT_PERIOD: float = 0.2
57
#     RECORD_AMBIENT_SENSOR_DATA: bool = True
58
#     RECORD_SOUND: bool = True
59
#     RESPONSE_WINDOW: float = 60
60
#     REWARD_AMOUNT_UL: float = 1.5
61
#     REWARD_TYPE: str = 'Water 10% Sucrose'
62
#     STIM_ANGLE: float = 0.0
63
#     STIM_FREQ: float = 0.1
64
#     STIM_GAIN: float = 4.0  # wheel to stimulus relationship (degrees visual angle per mm of wheel displacement)
65
#     STIM_POSITIONS: list[float] = [-35, 35]
66
#     STIM_SIGMA: float = 7.0
67
#     STIM_TRANSLATION_Z: Literal[7, 8] = 7  # 7 for ephys, 8 otherwise. -p:Stim.TranslationZ-{STIM_TRANSLATION_Z} bonsai param
68
#     STIM_REVERSE: bool = False
69
#     SYNC_SQUARE_X: float = 1.33
70
#     SYNC_SQUARE_Y: float = -1.03
71
#     USE_AUTOMATIC_STOPPING_CRITERIONS: bool = True
72
#     VISUAL_STIMULUS: str = 'GaborIBLTask / Gabor2D.bonsai'  # null / passiveChoiceWorld_passive.bonsai
73
#     WHITE_NOISE_AMPLITUDE: float = 0.05
74
#     WHITE_NOISE_DURATION: float = 0.5
75
#     WHITE_NOISE_IDX: int = 3
76

77

78
class ChoiceWorldTrialData(TrialDataModel):
2✔
79
    """Pydantic Model for Trial Data."""
80

81
    contrast: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
82
    stim_probability_left: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
83
    position: float
2✔
84
    quiescent_period: NonNegativeFloat
2✔
85
    reward_amount: NonNegativeFloat
2✔
86
    reward_valve_time: NonNegativeFloat
2✔
87
    stim_angle: Annotated[float, Interval(ge=-180.0, le=180.0)]
2✔
88
    stim_freq: NonNegativeFloat
2✔
89
    stim_gain: float
2✔
90
    stim_phase: Annotated[float, Interval(ge=0.0, le=2 * math.pi)]
2✔
91
    stim_reverse: bool
2✔
92
    stim_sigma: float
2✔
93
    trial_num: NonNegativeInt
2✔
94
    pause_duration: NonNegativeFloat = 0.0
2✔
95

96
    # The following variables are only used in ActiveChoiceWorld
97
    # We keep them here with fixed default values for sake of compatibility
98
    #
99
    # TODO: Yes, this should probably be done differently.
100
    response_side: Annotated[int, Interval(ge=0, le=0)] = 0
2✔
101
    response_time: IsNan[float] = np.nan
2✔
102
    trial_correct: Annotated[bool, Interval(ge=0, le=0)] = False
2✔
103

104

105
class ChoiceWorldSession(
2✔
106
    iblrig.base_tasks.BonsaiRecordingMixin,
107
    iblrig.base_tasks.BonsaiVisualStimulusMixin,
108
    iblrig.base_tasks.BpodMixin,
109
    iblrig.base_tasks.Frame2TTLMixin,
110
    iblrig.base_tasks.RotaryEncoderMixin,
111
    iblrig.base_tasks.SoundMixin,
112
    iblrig.base_tasks.ValveMixin,
113
    iblrig.base_tasks.NetworkSession,
114
):
115
    # task_params = ChoiceWorldParams()
116
    base_parameters_file = Path(__file__).parent.joinpath('base_choice_world_params.yaml')
2✔
117
    TrialDataModel = ChoiceWorldTrialData
2✔
118

119
    def __init__(self, *args, delay_secs=0, **kwargs):
2✔
120
        super().__init__(**kwargs)
2✔
121
        self.task_params['SESSION_DELAY_START'] = delay_secs
2✔
122
        # init behaviour data
123
        self.movement_left = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[0]]
2✔
124
        self.movement_right = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[1]]
2✔
125
        # init counter variables
126
        self.trial_num = -1
2✔
127
        self.block_num = -1
2✔
128
        self.block_trial_num = -1
2✔
129
        # init the tables, there are 2 of them: a trials table and a ambient sensor data table
130
        self.trials_table = self.TrialDataModel.preallocate_dataframe(NTRIALS_INIT)
2✔
131
        self.ambient_sensor_table = pd.DataFrame(
2✔
132
            {
133
                'Temperature_C': np.zeros(NTRIALS_INIT) * np.nan,
134
                'AirPressure_mb': np.zeros(NTRIALS_INIT) * np.nan,
135
                'RelativeHumidity': np.zeros(NTRIALS_INIT) * np.nan,
136
            }
137
        )
138

139
    @staticmethod
2✔
140
    def extra_parser():
2✔
141
        """:return: argparse.parser()"""
142
        parser = super(ChoiceWorldSession, ChoiceWorldSession).extra_parser()
2✔
143
        parser.add_argument(
2✔
144
            '--delay_secs',
145
            dest='delay_secs',
146
            default=0,
147
            type=int,
148
            required=False,
149
            help='initial delay before starting the first trial (default: 0s)',
150
        )
151
        parser.add_argument(
2✔
152
            '--remote',
153
            dest='remote_rigs',
154
            type=str,
155
            required=False,
156
            action='append',
157
            nargs='+',
158
            help='specify one of the remote rigs to interact with over the network',
159
        )
160
        return parser
2✔
161

162
    def start_hardware(self):
2✔
163
        """
164
        In this step we explicitly run the start methods of the various mixins.
165
        The super class start method is overloaded because we need to start the different hardware pieces in order
166
        """
167
        if not self.is_mock:
2✔
UNCOV
168
            self.start_mixin_frame2ttl()
×
UNCOV
169
            self.start_mixin_bpod()
×
UNCOV
170
            self.start_mixin_valve()
×
UNCOV
171
            self.start_mixin_sound()
×
UNCOV
172
            self.start_mixin_rotary_encoder()
×
UNCOV
173
            self.start_mixin_bonsai_cameras()
×
UNCOV
174
            self.start_mixin_bonsai_microphone()
×
UNCOV
175
            self.start_mixin_bonsai_visual_stimulus()
×
UNCOV
176
            self.bpod.register_softcodes(self.softcode_dictionary())
×
177

178
    def _run(self):
2✔
179
        """Run the task with the actual state machine."""
180
        time_last_trial_end = time.time()
2✔
181
        for i in range(self.task_params.NTRIALS):  # Main loop
2✔
182
            # t_overhead = time.time()
183
            self.next_trial()
2✔
184
            log.info(f'Starting trial: {i}')
2✔
185
            # =============================================================================
186
            #     Start state machine definition
187
            # =============================================================================
188
            sma = self.get_state_machine_trial(i)
2✔
189
            log.debug('Sending state machine to bpod')
2✔
190
            # Send state machine description to Bpod device
191
            self.bpod.send_state_machine(sma)
2✔
192
            # t_overhead = time.time() - t_overhead
193
            # The ITI_DELAY_SECS defines the grey screen period within the state machine, where the
194
            # Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next
195
            dead_time = self.task_params.get('DEAD_TIME', 0.5)
2✔
196
            dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end)
2✔
197
            # wait to achieve the desired ITI duration
198
            if dt > 0:
2✔
UNCOV
199
                time.sleep(dt)
×
200
            # Run state machine
201
            log.debug('running state machine')
2✔
202
            self.bpod.run_state_machine(sma)  # Locks until state machine 'exit' is reached
2✔
203
            time_last_trial_end = time.time()
2✔
204
            # handle pause event
205
            flag_pause = self.paths.SESSION_FOLDER.joinpath('.pause')
2✔
206
            flag_stop = self.paths.SESSION_FOLDER.joinpath('.stop')
2✔
207
            if flag_pause.exists() and i < (self.task_params.NTRIALS - 1):
2✔
208
                log.info(f'Pausing session inbetween trials {i} and {i + 1}')
2✔
209
                while flag_pause.exists() and not flag_stop.exists():
2✔
210
                    time.sleep(1)
2✔
211
                self.trials_table.at[self.trial_num, 'pause_duration'] = time.time() - time_last_trial_end
2✔
212
                if not flag_stop.exists():
2✔
213
                    log.info('Resuming session')
2✔
214

215
            # save trial and update log
216
            self.trial_completed(self.bpod.session.current_trial.export())
2✔
217
            self.ambient_sensor_table.loc[i] = self.bpod.get_ambient_sensor_reading()
2✔
218
            self.show_trial_log()
2✔
219

220
            # handle stop event
221
            if flag_stop.exists():
2✔
222
                log.info('Stopping session after trial %d', i)
2✔
223
                flag_stop.unlink()
2✔
224
                break
2✔
225

226
    def mock(self, file_jsonable_fixture=None):
2✔
227
        """
228
        Instantiate a state machine and Bpod object to simulate a task's run.
229

230
        This is useful to test or display the state machine flow.
231
        """
232
        super().mock()
2✔
233

234
        if file_jsonable_fixture is not None:
2✔
235
            task_data = jsonable.read(file_jsonable_fixture)
2✔
236
            # pop-out the bpod data from the table
237
            bpod_data = []
2✔
238
            for td in task_data:
2✔
239
                bpod_data.append(td.pop('behavior_data'))
2✔
240

241
            class MockTrial(Trial):
2✔
242
                def export(self):
2✔
243
                    return np.random.choice(bpod_data)
2✔
244
        else:
245

246
            class MockTrial(Trial):
2✔
247
                def export(self):
2✔
UNCOV
248
                    return {}
×
249

250
        self.bpod.session.trials = [MockTrial()]
2✔
251
        self.bpod.send_state_machine = lambda k: None
2✔
252
        self.bpod.run_state_machine = lambda k: time.sleep(1.2)
2✔
253

254
        daction = ('dummy', 'action')
2✔
255
        self.sound = Bunch({'GO_TONE': daction, 'WHITE_NOISE': daction})
2✔
256

257
        self.bpod.actions.update(
2✔
258
            {
259
                'play_tone': daction,
260
                'play_noise': daction,
261
                'stop_sound': daction,
262
                'rotary_encoder_reset': daction,
263
                'bonsai_hide_stim': daction,
264
                'bonsai_show_stim': daction,
265
                'bonsai_closed_loop': daction,
266
                'bonsai_freeze_stim': daction,
267
                'bonsai_show_center': daction,
268
            }
269
        )
270

271
    def get_graphviz_task(self, output_file=None, view=True):
2✔
272
        """
273
        Get the state machine's states diagram in Digraph format.
274

275
        :param output_file:
276
        :return:
277
        """
278
        import graphviz
2✔
279

280
        self.next_trial()
2✔
281
        sma = self.get_state_machine_trial(0)
2✔
282
        if sma is None:
2✔
283
            return
2✔
284
        states_indices = {i: k for i, k in enumerate(sma.state_names)}
2✔
285
        states_indices.update({(i + 10000): k for i, k in enumerate(sma.undeclared)})
2✔
286
        states_letters = {k: ascii_letters[i] for i, k in enumerate(sma.state_names)}
2✔
287
        dot = graphviz.Digraph(comment='The Great IBL Task')
2✔
288
        edges = []
2✔
289

290
        for i in range(len(sma.state_names)):
2✔
291
            letter = states_letters[sma.state_names[i]]
2✔
292
            dot.node(letter, sma.state_names[i])
2✔
293
            if ~np.isnan(sma.state_timer_matrix[i]):
2✔
294
                out_state = states_indices[sma.state_timer_matrix[i]]
2✔
295
                edges.append(f'{letter}{states_letters[out_state]}')
2✔
296
            for input in sma.input_matrix[i]:
2✔
297
                if input[0] == 0:
2✔
298
                    edges.append(f'{letter}{states_letters[states_indices[input[1]]]}')
2✔
299
        dot.edges(edges)
2✔
300
        if output_file is not None:
2✔
301
            try:
2✔
302
                dot.render(output_file, view=view)
2✔
303
            except graphviz.exceptions.ExecutableNotFound:
2✔
304
                log.info('Graphviz system executable not found, cannot render the graph')
2✔
305
        return dot
2✔
306

307
    def _instantiate_state_machine(self, *args, **kwargs):
2✔
308
        return StateMachine(self.bpod)
2✔
309

310
    def get_state_machine_trial(self, i):
2✔
311
        # we define the trial number here for subclasses that may need it
312
        sma = self._instantiate_state_machine(trial_number=i)
2✔
313

314
        if i == 0:  # First trial exception start camera
2✔
315
            session_delay_start = self.task_params.get('SESSION_DELAY_START', 0)
2✔
316
            log.info('First trial initializing, will move to next trial only if:')
2✔
317
            log.info('1. camera is detected')
2✔
318
            log.info(f'2. {session_delay_start} sec have elapsed')
2✔
319
            sma.add_state(
2✔
320
                state_name='trial_start',
321
                state_timer=0,
322
                state_change_conditions={'Port1In': 'delay_initiation'},
323
                output_actions=[('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
324
            )  # start camera
325
            sma.add_state(
2✔
326
                state_name='delay_initiation',
327
                state_timer=session_delay_start,
328
                output_actions=[],
329
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
330
            )
331
        else:
332
            sma.add_state(
2✔
333
                state_name='trial_start',
334
                state_timer=0,  # ~100µs hardware irreducible delay
335
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
336
                output_actions=[self.bpod.actions.stop_sound, ('BNC1', 255)],
337
            )  # stop all sounds
338

339
        # Reset the rotary encoder by sending the following opcodes via the modules serial interface
340
        # - 'Z' (ASCII 90): Set current rotary encoder position to zero
341
        # - 'E' (ASCII 69): Enable all position thresholds (that may have been disabled by a threshold-crossing)
342
        # cf. https://sanworks.github.io/Bpod_Wiki/serial-interfaces/rotary-encoder-module-serial-interface/
343
        sma.add_state(
2✔
344
            state_name='reset_rotary_encoder',
345
            state_timer=0,
346
            output_actions=[self.bpod.actions.rotary_encoder_reset],
347
            state_change_conditions={'Tup': 'quiescent_period'},
348
        )
349

350
        # Quiescent Period. If the wheel is moved past one of the thresholds: Reset the rotary encoder and start over.
351
        # Continue with the stimulation once the quiescent period has passed without triggering movement thresholds.
352
        sma.add_state(
2✔
353
            state_name='quiescent_period',
354
            state_timer=self.quiescent_period,
355
            output_actions=[],
356
            state_change_conditions={
357
                'Tup': 'stim_on',
358
                self.movement_left: 'reset_rotary_encoder',
359
                self.movement_right: 'reset_rotary_encoder',
360
            },
361
        )
362

363
        # Show the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
364
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
365
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
366
        sma.add_state(
2✔
367
            state_name='stim_on',
368
            state_timer=0.1,
369
            output_actions=[self.bpod.actions.bonsai_show_stim],
370
            state_change_conditions={'BNC1High': 'interactive_delay', 'BNC1Low': 'interactive_delay', 'Tup': 'interactive_delay'},
371
        )
372

373
        # Defined delay between visual and auditory cue
374
        sma.add_state(
2✔
375
            state_name='interactive_delay',
376
            state_timer=self.task_params.INTERACTIVE_DELAY,
377
            output_actions=[],
378
            state_change_conditions={'Tup': 'play_tone'},
379
        )
380

381
        # Play tone. Move to next state if sound is detected. Use the state-timer as a backup to prevent a stall.
382
        sma.add_state(
2✔
383
            state_name='play_tone',
384
            state_timer=0.1,
385
            output_actions=[self.bpod.actions.play_tone],
386
            state_change_conditions={'Tup': 'reset2_rotary_encoder', 'BNC2High': 'reset2_rotary_encoder'},
387
        )
388

389
        # Reset rotary encoder (see above). Move on after brief delay (to avoid a race conditions in the bonsai flow).
390
        sma.add_state(
2✔
391
            state_name='reset2_rotary_encoder',
392
            state_timer=0.05,
393
            output_actions=[self.bpod.actions.rotary_encoder_reset],
394
            state_change_conditions={'Tup': 'closed_loop'},
395
        )
396

397
        # Start the closed loop state in which the animal controls the position of the visual stimulus by means of the
398
        # rotary encoder. The three possible outcomes are:
399
        # 1) wheel has NOT been moved past a threshold: continue with no-go condition
400
        # 2) wheel has been moved in WRONG direction: continue with error condition
401
        # 3) wheel has been moved in CORRECT direction: continue with reward condition
402

403
        sma.add_state(
2✔
404
            state_name='closed_loop',
405
            state_timer=self.task_params.RESPONSE_WINDOW,
406
            output_actions=[self.bpod.actions.bonsai_closed_loop],
407
            state_change_conditions={'Tup': 'no_go', self.event_error: 'freeze_error', self.event_reward: 'freeze_reward'},
408
        )
409

410
        # No-go: hide the visual stimulus and play white noise. Go to exit_state after FEEDBACK_NOGO_DELAY_SECS.
411
        sma.add_state(
2✔
412
            state_name='no_go',
413
            state_timer=self.task_params.FEEDBACK_NOGO_DELAY_SECS,
414
            output_actions=[self.bpod.actions.bonsai_hide_stim, self.bpod.actions.play_noise],
415
            state_change_conditions={'Tup': 'exit_state'},
416
        )
417

418
        # Error: Freeze the stimulus and play white noise.
419
        # Continue to hide_stim/exit_state once FEEDBACK_ERROR_DELAY_SECS have passed.
420
        sma.add_state(
2✔
421
            state_name='freeze_error',
422
            state_timer=0,
423
            output_actions=[self.bpod.actions.bonsai_freeze_stim],
424
            state_change_conditions={'Tup': 'error'},
425
        )
426
        sma.add_state(
2✔
427
            state_name='error',
428
            state_timer=self.task_params.FEEDBACK_ERROR_DELAY_SECS,
429
            output_actions=[self.bpod.actions.play_noise],
430
            state_change_conditions={'Tup': 'hide_stim'},
431
        )
432

433
        # Reward: open the valve for a defined duration (and set BNC1 to high), freeze stimulus in center of screen.
434
        # Continue to hide_stim/exit_state once FEEDBACK_CORRECT_DELAY_SECS have passed.
435
        sma.add_state(
2✔
436
            state_name='freeze_reward',
437
            state_timer=0,
438
            output_actions=[self.bpod.actions.bonsai_show_center],
439
            state_change_conditions={'Tup': 'reward'},
440
        )
441
        sma.add_state(
2✔
442
            state_name='reward',
443
            state_timer=self.reward_time,
444
            output_actions=[('Valve1', 255), ('BNC1', 255)],
445
            state_change_conditions={'Tup': 'correct'},
446
        )
447
        sma.add_state(
2✔
448
            state_name='correct',
449
            state_timer=self.task_params.FEEDBACK_CORRECT_DELAY_SECS - self.reward_time,
450
            output_actions=[],
451
            state_change_conditions={'Tup': 'hide_stim'},
452
        )
453

454
        # Hide the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
455
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
456
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
457
        sma.add_state(
2✔
458
            state_name='hide_stim',
459
            state_timer=0.1,
460
            output_actions=[self.bpod.actions.bonsai_hide_stim],
461
            state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'},
462
        )
463

464
        # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event.
465
        sma.add_state(
2✔
466
            state_name='exit_state',
467
            state_timer=self.task_params.ITI_DELAY_SECS,
468
            output_actions=[('BNC1', 255)],
469
            state_change_conditions={'Tup': 'exit'},
470
        )
471

472
        return sma
2✔
473

474
    @abc.abstractmethod
2✔
475
    def next_trial(self):
2✔
476
        pass
×
477

478
    @property
2✔
479
    def default_reward_amount(self):
2✔
480
        return self.task_params.REWARD_AMOUNT_UL
2✔
481

482
    def draw_next_trial_info(self, pleft=0.5, **kwargs):
2✔
483
        """Draw next trial variables.
484

485
        calls :meth:`send_trial_info_to_bonsai`.
486
        This is called by the `next_trial` method before updating the Bpod state machine.
487
        """
488
        assert len(self.task_params.STIM_POSITIONS) == 2, 'Only two positions are supported'
2✔
489
        contrast = misc.draw_contrast(self.task_params.CONTRAST_SET, self.task_params.CONTRAST_SET_PROBABILITY_TYPE)
2✔
490
        position = int(np.random.choice(self.task_params.STIM_POSITIONS, p=[pleft, 1 - pleft]))
2✔
491
        quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.truncated_exponential(
2✔
492
            scale=0.35, min_value=0.2, max_value=0.5
493
        )
494
        self.trials_table.at[self.trial_num, 'quiescent_period'] = quiescent_period
2✔
495
        self.trials_table.at[self.trial_num, 'contrast'] = contrast
2✔
496
        self.trials_table.at[self.trial_num, 'stim_phase'] = random.uniform(0, 2 * math.pi)
2✔
497
        self.trials_table.at[self.trial_num, 'stim_sigma'] = self.task_params.STIM_SIGMA
2✔
498
        self.trials_table.at[self.trial_num, 'stim_angle'] = self.task_params.STIM_ANGLE
2✔
499
        self.trials_table.at[self.trial_num, 'stim_gain'] = self.stimulus_gain
2✔
500
        self.trials_table.at[self.trial_num, 'stim_freq'] = self.task_params.STIM_FREQ
2✔
501
        self.trials_table.at[self.trial_num, 'stim_reverse'] = self.task_params.STIM_REVERSE
2✔
502
        self.trials_table.at[self.trial_num, 'trial_num'] = self.trial_num
2✔
503
        self.trials_table.at[self.trial_num, 'position'] = position
2✔
504
        self.trials_table.at[self.trial_num, 'reward_amount'] = self.default_reward_amount
2✔
505
        self.trials_table.at[self.trial_num, 'stim_probability_left'] = pleft
2✔
506

507
        # use the kwargs dict to override computed values
508
        for key, value in kwargs.items():
2✔
509
            if key == 'index':
2✔
UNCOV
510
                pass
×
511
            self.trials_table.at[self.trial_num, key] = value
2✔
512

513
        self.send_trial_info_to_bonsai()
2✔
514

515
    def trial_completed(self, bpod_data: dict[str, Any]) -> None:
2✔
516
        # if the reward state has not been triggered, null the reward
517
        if np.isnan(bpod_data['States timestamps']['reward'][0][0]):
2✔
518
            self.trials_table.at[self.trial_num, 'reward_amount'] = 0
2✔
519
        self.trials_table.at[self.trial_num, 'reward_valve_time'] = self.reward_time
2✔
520
        # update cumulative reward value
521
        self.session_info.TOTAL_WATER_DELIVERED += self.trials_table.at[self.trial_num, 'reward_amount']
2✔
522
        self.session_info.NTRIALS += 1
2✔
523
        # SAVE TRIAL DATA
524
        self.save_trial_data_to_json(bpod_data)
2✔
525
        # this is a flag for the online plots. If online plots were in pyqt5, there is a file watcher functionality
526
        Path(self.paths['DATA_FILE_PATH']).parent.joinpath('new_trial.flag').touch()
2✔
527
        self.paths.SESSION_FOLDER.joinpath('transfer_me.flag').touch()
2✔
528
        self.check_sync_pulses(bpod_data=bpod_data)
2✔
529

530
    def check_sync_pulses(self, bpod_data):
2✔
531
        # todo move this in the post trial when we have a task flow
532
        if not self.bpod.is_connected:
2✔
533
            return
2✔
UNCOV
534
        events = bpod_data['Events timestamps']
×
UNCOV
535
        if not misc.get_port_events(events, name='BNC1'):
×
UNCOV
536
            log.warning("NO FRAME2TTL PULSES RECEIVED ON BPOD'S TTL INPUT 1")
×
UNCOV
537
        if not misc.get_port_events(events, name='BNC2'):
×
UNCOV
538
            log.warning("NO SOUND SYNC PULSES RECEIVED ON BPOD'S TTL INPUT 2")
×
UNCOV
539
        if not misc.get_port_events(events, name='Port1'):
×
UNCOV
540
            log.warning("NO CAMERA SYNC PULSES RECEIVED ON BPOD'S BEHAVIOR PORT 1")
×
541

542
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
543
        """
544
        Log the details of the current trial.
545

546
        This method retrieves information about the current trial from the
547
        trials table and logs it. It can also incorporate additional information
548
        provided through the `extra_info` parameter.
549

550
        Parameters
551
        ----------
552
        extra_info : dict[str, Any], optional
553
            A dictionary containing additional information to include in the
554
            log.
555

556
        log_level : int, optional
557
            The logging level to use when logging the trial information.
558
            Default is logging.INFO.
559

560
        Notes
561
        -----
562
        When overloading, make sure to call the super class and pass additional
563
        log items by means of the extra_info parameter. See the implementation
564
        of :py:meth:`~iblrig.base_choice_world.ActiveChoiceWorldSession.show_trial_log` in
565
        :mod:`~iblrig.base_choice_world.ActiveChoiceWorldSession` for reference.
566
        """
567
        # construct base info dict
568
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
569
        info_dict = {
2✔
570
            'Stim. Position': trial_info.position,
571
            'Stim. Contrast': trial_info.contrast,
572
            'Stim. Phase': f'{trial_info.stim_phase:.2f}',
573
            'Stim. p Left': trial_info.stim_probability_left,
574
            'Water delivered': f'{self.session_info.TOTAL_WATER_DELIVERED:.1f} µl',
575
            'Time from Start': self.time_elapsed,
576
            'Temperature': f'{self.ambient_sensor_table.loc[self.trial_num, "Temperature_C"]:.1f} °C',
577
            'Air Pressure': f'{self.ambient_sensor_table.loc[self.trial_num, "AirPressure_mb"]:.1f} mb',
578
            'Rel. Humidity': f'{self.ambient_sensor_table.loc[self.trial_num, "RelativeHumidity"]:.1f} %',
579
        }
580

581
        # update info dict with extra_info dict
582
        if isinstance(extra_info, dict):
2✔
583
            info_dict.update(extra_info)
2✔
584

585
        # log info dict
586
        log.log(log_level, f'Outcome of Trial #{trial_info.trial_num}:')
2✔
587
        max_key_length = max(len(key) for key in info_dict)
2✔
588
        for key, value in info_dict.items():
2✔
589
            spaces = (max_key_length - len(key)) * ' '
2✔
590
            log.log(log_level, f'- {key}: {spaces}{str(value)}')
2✔
591

592
    @property
2✔
593
    def iti_reward(self):
2✔
594
        """
595
        Returns the ITI time that needs to be set in order to achieve the desired ITI,
596
        by subtracting the time it takes to give a reward from the desired ITI.
597
        """
UNCOV
598
        return self.task_params.ITI_CORRECT - self.calibration.get('REWARD_VALVE_TIME', None)
×
599

600
    """
2✔
601
    Those are the properties that are used in the state machine code
602
    """
603

604
    @property
2✔
605
    def reward_time(self):
2✔
606
        return self.compute_reward_time(amount_ul=self.trials_table.at[self.trial_num, 'reward_amount'])
2✔
607

608
    @property
2✔
609
    def quiescent_period(self):
2✔
610
        return self.trials_table.at[self.trial_num, 'quiescent_period']
2✔
611

612
    @property
2✔
613
    def position(self):
2✔
614
        return self.trials_table.at[self.trial_num, 'position']
2✔
615

616
    @property
2✔
617
    def event_error(self):
2✔
618
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(-1 if self.task_params.STIM_REVERSE else 1) * self.position]
2✔
619

620
    @property
2✔
621
    def event_reward(self):
2✔
622
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(1 if self.task_params.STIM_REVERSE else -1) * self.position]
2✔
623

624

625
class HabituationChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
626
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
627

628
    delay_to_stim_center: NonNegativeFloat
2✔
629

630

631
class HabituationChoiceWorldSession(ChoiceWorldSession):
2✔
632
    protocol_name = '_iblrig_tasks_habituationChoiceWorld'
2✔
633
    TrialDataModel = HabituationChoiceWorldTrialData
2✔
634

635
    def next_trial(self):
2✔
636
        self.trial_num += 1
2✔
637
        self.draw_next_trial_info()
2✔
638

639
    def draw_next_trial_info(self, *args, **kwargs):
2✔
640
        # update trial table fields specific to habituation choice world
641
        self.trials_table.at[self.trial_num, 'delay_to_stim_center'] = np.random.normal(self.task_params.DELAY_TO_STIM_CENTER, 2)
2✔
642
        super().draw_next_trial_info(*args, **kwargs)
2✔
643

644
    def get_state_machine_trial(self, i):
2✔
645
        sma = StateMachine(self.bpod)
2✔
646

647
        if i == 0:  # First trial exception start camera
2✔
648
            log.info('Waiting for camera pulses...')
2✔
649
            sma.add_state(
2✔
650
                state_name='iti',
651
                state_timer=3600,
652
                state_change_conditions={'Port1In': 'stim_on'},
653
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
654
            )  # start camera
655
        else:
656
            # NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on.
657
            # During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end;
658
            # the offset of this state is trial start!
UNCOV
659
            sma.add_state(
×
660
                state_name='iti',
661
                state_timer=1,  # Stim off for 1 sec
662
                state_change_conditions={'Tup': 'stim_on'},
663
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)],
664
            )
665
        # This stim_on state is considered the actual trial start
666
        sma.add_state(
2✔
667
            state_name='stim_on',
668
            state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'],
669
            state_change_conditions={'Tup': 'stim_center'},
670
            output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone],
671
        )
672

673
        sma.add_state(
2✔
674
            state_name='stim_center',
675
            state_timer=0.5,
676
            state_change_conditions={'Tup': 'reward'},
677
            output_actions=[self.bpod.actions.bonsai_show_center],
678
        )
679

680
        sma.add_state(
2✔
681
            state_name='reward',
682
            state_timer=self.reward_time,  # the length of time to leave reward valve open, i.e. reward size
683
            state_change_conditions={'Tup': 'post_reward'},
684
            output_actions=[('Valve1', 255), ('BNC1', 255)],
685
        )
686
        # This state defines the period after reward where Bpod TTL is LOW.
687
        # NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit.
688
        # The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds.
689
        sma.add_state(
2✔
690
            state_name='post_reward',
691
            state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time,
692
            state_change_conditions={'Tup': 'exit'},
693
            output_actions=[],
694
        )
695
        return sma
2✔
696

697

698
class ActiveChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
699
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
700

701
    response_side: Annotated[int, Interval(ge=-1, le=1)]
2✔
702
    response_time: NonNegativeFloat
2✔
703
    trial_correct: bool
2✔
704

705

706
class ActiveChoiceWorldSession(ChoiceWorldSession):
2✔
707
    """
708
    The ActiveChoiceWorldSession is a base class for protocols where the mouse is actively making decisions
709
    by turning the wheel. It has the following characteristics
710

711
    -   it is trial based
712
    -   it is decision based
713
    -   left and right simulus are equiprobable: there is no biased block
714
    -   a trial can either be correct / error / no_go depending on the side of the stimulus and the response
715
    -   it has a quantifiable performance by computing the proportion of correct trials of passive stimulations protocols or
716
        habituation protocols.
717

718
    The TrainingChoiceWorld, BiasedChoiceWorld are all subclasses of this class
719
    """
720

721
    TrialDataModel = ActiveChoiceWorldTrialData
2✔
722

723
    def __init__(self, **kwargs):
2✔
724
        super().__init__(**kwargs)
2✔
725
        self.trials_table['stim_probability_left'] = np.zeros(NTRIALS_INIT, dtype=np.float64)
2✔
726

727
    def _run(self):
2✔
728
        # starts online plotting
729
        if self.interactive:
2✔
UNCOV
730
            subprocess.Popen(
×
731
                ['view_session', str(self.paths['DATA_FILE_PATH']), str(self.paths['SETTINGS_FILE_PATH'])],
732
                stdout=subprocess.DEVNULL,
733
                stderr=subprocess.STDOUT,
734
            )
735
        super()._run()
2✔
736

737
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
738
        # construct info dict
739
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
740
        info_dict = {
2✔
741
            'Response Time': f'{trial_info.response_time:.2f} s',
742
            'Trial Correct': trial_info.trial_correct,
743
            'N Trials Correct': self.session_info.NTRIALS_CORRECT,
744
            'N Trials Error': self.trial_num - self.session_info.NTRIALS_CORRECT,
745
        }
746

747
        # update info dict with extra_info dict
748
        if isinstance(extra_info, dict):
2✔
749
            info_dict.update(extra_info)
2✔
750

751
        # call parent method
752
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
753

754
    def trial_completed(self, bpod_data):
2✔
755
        """
756
        The purpose of this method is to
757

758
        - update the trials table with information about the behaviour coming from the bpod
759
          Constraints on the state machine data:
760
        - mandatory states: ['correct', 'error', 'no_go', 'reward']
761
        - optional states : ['omit_correct', 'omit_error', 'omit_no_go']
762

763
        :param bpod_data:
764
        :return:
765
        """
766
        # get the response time from the behaviour data
767
        response_time = bpod_data['States timestamps']['closed_loop'][0][1] - bpod_data['States timestamps']['stim_on'][0][0]
2✔
768
        self.trials_table.at[self.trial_num, 'response_time'] = response_time
2✔
769
        # get the trial outcome
770
        state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go']
2✔
771
        raw_outcome = {sn: ~np.isnan(bpod_data['States timestamps'].get(sn, [[np.nan]])[0][0]) for sn in state_names}
2✔
772
        try:
2✔
773
            outcome = next(k for k in raw_outcome if raw_outcome[k])
2✔
774
            # Update response buffer -1 for left, 0 for nogo, and 1 for rightward
775
            position = self.trials_table.at[self.trial_num, 'position']
2✔
776
            self.trials_table.at[self.trial_num, 'trial_correct'] = 'correct' in outcome
2✔
777
            if 'correct' in outcome:
2✔
778
                self.session_info.NTRIALS_CORRECT += 1
2✔
779
                self.trials_table.at[self.trial_num, 'response_side'] = -np.sign(position)
2✔
780
            elif 'error' in outcome:
2✔
781
                self.trials_table.at[self.trial_num, 'response_side'] = np.sign(position)
2✔
782
            elif 'no_go' in outcome:
2✔
783
                self.trials_table.at[self.trial_num, 'response_side'] = 0
2✔
784
            super().trial_completed(bpod_data)
2✔
785
            # here we throw potential errors after having written the trial to disk
786
            assert np.sum(list(raw_outcome.values())) == 1
2✔
787
            assert position != 0, 'the position value should be either 35 or -35'
2✔
UNCOV
788
        except StopIteration as e:
×
UNCOV
789
            log.error(f'No outcome detected for trial {self.trial_num}.')
×
UNCOV
790
            log.error(f'raw_outcome: {raw_outcome}')
×
UNCOV
791
            log.error('State names: ' + ', '.join(bpod_data['States timestamps'].keys()))
×
UNCOV
792
            raise e
×
UNCOV
793
        except AssertionError as e:
×
UNCOV
794
            log.error(f'Assertion Error in trial {self.trial_num}.')
×
UNCOV
795
            log.error(f'raw_outcome: {raw_outcome}')
×
UNCOV
796
            log.error('State names: ' + ', '.join(bpod_data['States timestamps'].keys()))
×
UNCOV
797
            raise e
×
798

799

800
class BiasedChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
801
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
802

803
    block_num: NonNegativeInt = 0
2✔
804
    block_trial_num: NonNegativeInt = 0
2✔
805

806

807
class BiasedChoiceWorldSession(ActiveChoiceWorldSession):
2✔
808
    """
809
    Biased choice world session is the instantiation of ActiveChoiceWorld where the notion of biased
810
    blocks is introduced.
811
    """
812

813
    base_parameters_file = Path(__file__).parent.joinpath('base_biased_choice_world_params.yaml')
2✔
814
    protocol_name = '_iblrig_tasks_biasedChoiceWorld'
2✔
815
    TrialDataModel = BiasedChoiceWorldTrialData
2✔
816

817
    def __init__(self, **kwargs):
2✔
818
        super().__init__(**kwargs)
2✔
819
        self.blocks_table = pd.DataFrame(
2✔
820
            {'probability_left': np.zeros(NBLOCKS_INIT) * np.nan, 'block_length': np.zeros(NBLOCKS_INIT, dtype=np.int16) * -1}
821
        )
822

823
    def new_block(self):
2✔
824
        """
825
        if block_init_5050
826
            First block has 50/50 probability of leftward stim
827
            is 90 trials long
828
        """
829
        self.block_num += 1  # the block number is zero based
2✔
830
        self.block_trial_num = 0
2✔
831

832
        # handles the block length logic
833
        if self.task_params.BLOCK_INIT_5050 and self.block_num == 0:
2✔
834
            block_len = 90
2✔
835
        else:
836
            block_len = int(
2✔
837
                misc.truncated_exponential(
838
                    scale=self.task_params.BLOCK_LEN_FACTOR,
839
                    min_value=self.task_params.BLOCK_LEN_MIN,
840
                    max_value=self.task_params.BLOCK_LEN_MAX,
841
                )
842
            )
843
        if self.block_num == 0:
2✔
844
            pleft = 0.5 if self.task_params.BLOCK_INIT_5050 else np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
845
        elif self.block_num == 1 and self.task_params.BLOCK_INIT_5050:
2✔
846
            pleft = np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
847
        else:
848
            # this switches the probability of leftward stim for the next block
849
            pleft = round(abs(1 - self.blocks_table.loc[self.block_num - 1, 'probability_left']), 1)
2✔
850
        self.blocks_table.at[self.block_num, 'block_length'] = block_len
2✔
851
        self.blocks_table.at[self.block_num, 'probability_left'] = pleft
2✔
852

853
    def next_trial(self):
2✔
854
        self.trial_num += 1
2✔
855
        # if necessary update the block number
856
        self.block_trial_num += 1
2✔
857
        if self.block_num < 0 or self.block_trial_num > (self.blocks_table.loc[self.block_num, 'block_length'] - 1):
2✔
858
            self.new_block()
2✔
859
        # get and store probability left
860
        pleft = self.blocks_table.loc[self.block_num, 'probability_left']
2✔
861
        # update trial table fields specific to biased choice world task
862
        self.trials_table.at[self.trial_num, 'block_num'] = self.block_num
2✔
863
        self.trials_table.at[self.trial_num, 'block_trial_num'] = self.block_trial_num
2✔
864
        # save and send trial info to bonsai
865
        self.draw_next_trial_info(pleft=pleft)
2✔
866

867
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
868
        # construct info dict
869
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
870
        info_dict = {
2✔
871
            'Block Number': trial_info.block_num,
872
            'Block Length': self.blocks_table.loc[self.block_num, 'block_length'],
873
            'N Trials in Block': trial_info.block_trial_num,
874
        }
875

876
        # update info dict with extra_info dict
877
        if isinstance(extra_info, dict):
2✔
UNCOV
878
            info_dict.update(extra_info)
×
879

880
        # call parent method
881
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
882

883

884
class TrainingChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
885
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ActiveChoiceWorldTrialData`."""
886

887
    training_phase: NonNegativeInt
2✔
888
    debias_trial: bool
2✔
889
    signed_contrast: float | None = None
2✔
890

891

892
class TrainingChoiceWorldSession(ActiveChoiceWorldSession):
2✔
893
    """
894
    The TrainingChoiceWorldSession corresponds to the first training protocol of the choice world task.
895
    This protocol has a complicated adaptation of the number of contrasts (embodied by the training_phase
896
    property) and the reward amount, embodied by the adaptive_reward property.
897
    """
898

899
    protocol_name = '_iblrig_tasks_trainingChoiceWorld'
2✔
900
    TrialDataModel = TrainingChoiceWorldTrialData
2✔
901

902
    def __init__(self, training_phase=-1, adaptive_reward=-1.0, adaptive_gain=None, **kwargs):
2✔
903
        super().__init__(**kwargs)
2✔
904
        inferred_training_phase, inferred_adaptive_reward, inferred_adaptive_gain = self.get_subject_training_info()
2✔
905
        if training_phase == -1:
2✔
906
            log.critical(f'Got training phase: {inferred_training_phase}')
2✔
907
            self.training_phase = inferred_training_phase
2✔
908
        else:
909
            log.critical(f'Training phase manually set to: {training_phase}')
2✔
910
            self.training_phase = training_phase
2✔
911
        if adaptive_reward == -1:
2✔
912
            log.critical(f'Got Adaptive reward {inferred_adaptive_reward} uL')
2✔
913
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = inferred_adaptive_reward
2✔
914
        else:
915
            log.critical(f'Adaptive reward manually set to {adaptive_reward} uL')
2✔
916
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = adaptive_reward
2✔
917
        if adaptive_gain is None:
2✔
918
            log.critical(f'Got Adaptive gain {inferred_adaptive_gain} degrees/mm')
2✔
919
            self.session_info['ADAPTIVE_GAIN_VALUE'] = inferred_adaptive_gain
2✔
920
        else:
921
            log.critical(f'Adaptive gain manually set to {adaptive_gain} degrees/mm')
2✔
922
            self.session_info['ADAPTIVE_GAIN_VALUE'] = adaptive_gain
2✔
923
        self.var = {'training_phase_trial_counts': np.zeros(6), 'last_10_responses_sides': np.zeros(10)}
2✔
924

925
    @property
2✔
926
    def default_reward_amount(self):
2✔
927
        return self.session_info.get('ADAPTIVE_REWARD_AMOUNT_UL', self.task_params.REWARD_AMOUNT_UL)
2✔
928

929
    @property
2✔
930
    def stimulus_gain(self) -> float:
2✔
931
        return self.session_info.get('ADAPTIVE_GAIN_VALUE')
2✔
932

933
    def get_subject_training_info(self):
2✔
934
        """
935
        Get the previous session's according to this session parameters and deduce the
936
        training level, adaptive reward amount and adaptive gain value.
937

938
        Returns
939
        -------
940
        training_info: dict
941
            Dictionary with keys: training_phase, adaptive_reward, adaptive_gain
942
        """
943
        training_info, _ = choiceworld.get_subject_training_info(
2✔
944
            subject_name=self.session_info.SUBJECT_NAME,
945
            task_name=self.protocol_name,
946
            stim_gain=self.task_params.AG_INIT_VALUE,
947
            stim_gain_on_error=self.task_params.STIM_GAIN,
948
            default_reward=self.task_params.REWARD_AMOUNT_UL,
949
            local_path=self.iblrig_settings['iblrig_local_data_path'],
950
            remote_path=self.iblrig_settings['iblrig_remote_data_path'],
951
            lab=self.iblrig_settings['ALYX_LAB'],
952
            iblrig_settings=self.iblrig_settings,
953
        )
954
        return training_info['training_phase'], training_info['adaptive_reward'], training_info['adaptive_gain']
2✔
955

956
    def compute_performance(self):
2✔
957
        """Aggregate the trials table to compute the performance of the mouse on each contrast."""
958
        self.trials_table['signed_contrast'] = self.trials_table.contrast * self.trials_table.position
2✔
959
        performance = self.trials_table.groupby(['signed_contrast']).agg(
2✔
960
            last_50_perf=pd.NamedAgg(column='trial_correct', aggfunc=lambda x: np.sum(x[np.maximum(-50, -x.size) :]) / 50),
961
            ntrials=pd.NamedAgg(column='trial_correct', aggfunc='count'),
962
        )
963
        return performance
2✔
964

965
    def check_training_phase(self):
2✔
966
        """Check if the mouse is ready to move to the next training phase."""
967
        move_on = False
2✔
968
        if self.training_phase == 0:  # each of the -1, -.5, .5, 1 contrast should be above 80% perf to switch
2✔
969
            performance = self.compute_performance()
2✔
970
            passing = performance[np.abs(performance.index) >= 0.5]['last_50_perf']
2✔
971
            if np.all(passing > 0.8) and passing.size == 4:
2✔
972
                move_on = True
2✔
973
        elif self.training_phase == 1:  # each of the -.25, .25 should be above 80% perf to switch
2✔
974
            performance = self.compute_performance()
2✔
975
            passing = performance[np.abs(performance.index) == 0.25]['last_50_perf']
2✔
976
            if np.all(passing > 0.8) and passing.size == 2:
2✔
UNCOV
977
                move_on = True
×
UNCOV
978
        elif 5 > self.training_phase >= 2:  # for the next phases, always switch after 200 trials
×
UNCOV
979
            if self.var['training_phase_trial_counts'][self.training_phase] >= 200:
×
UNCOV
980
                move_on = True
×
981
        if move_on:
2✔
982
            self.training_phase = np.minimum(5, self.training_phase + 1)
2✔
983
            log.warning(f'Moving on to training phase {self.training_phase}, {self.trial_num}')
2✔
984

985
    def next_trial(self):
2✔
986
        # update counters
987
        self.trial_num += 1
2✔
988
        self.var['training_phase_trial_counts'][self.training_phase] += 1
2✔
989
        # check if the subject graduates to a new training phase
990
        self.check_training_phase()
2✔
991
        # draw the next trial
992
        signed_contrast = choiceworld.draw_training_contrast(self.training_phase)
2✔
993
        position = self.task_params.STIM_POSITIONS[int(np.sign(signed_contrast) == 1)]
2✔
994
        contrast = np.abs(signed_contrast)
2✔
995
        # debiasing: if the previous trial was incorrect and easy repeat the trial
996
        if self.task_params.DEBIAS and self.trial_num >= 1 and self.training_phase < 5:
2✔
997
            last_contrast = self.trials_table.loc[self.trial_num - 1, 'contrast']
2✔
998
            do_debias_trial = (self.trials_table.loc[self.trial_num - 1, 'trial_correct'] != 1) and last_contrast >= 0.5
2✔
999
            self.trials_table.at[self.trial_num, 'debias_trial'] = do_debias_trial
2✔
1000
            if do_debias_trial:
2✔
1001
                iresponse = self.trials_table['response_side'] != 0  # trials that had a response
2✔
1002
                # takes the average of right responses over last 10 response trials
1003
                average_right = np.mean(self.trials_table['response_side'][iresponse[-np.maximum(10, iresponse.size) :]] == 1)
2✔
1004
                # the next probability of next stimulus being on the left is a draw from a normal distribution
1005
                # centered on average right with sigma 0.5. If it is less than 0.5 the next stimulus will be on the left
1006
                position = self.task_params.STIM_POSITIONS[int(np.random.normal(average_right, 0.5) >= 0.5)]
2✔
1007
                # contrast is the last contrast
1008
                contrast = last_contrast
2✔
1009
        else:
1010
            self.trials_table.at[self.trial_num, 'debias_trial'] = False
2✔
1011
        # save and send trial info to bonsai
1012
        self.draw_next_trial_info(pleft=self.task_params.PROBABILITY_LEFT, position=position, contrast=contrast)
2✔
1013
        self.trials_table.at[self.trial_num, 'training_phase'] = self.training_phase
2✔
1014

1015
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
1016
        # construct info dict
1017
        info_dict = {
2✔
1018
            'Contrast Set': np.unique(np.abs(choiceworld.contrasts_set(self.training_phase))),
1019
            'Training Phase': self.training_phase,
1020
        }
1021

1022
        # update info dict with extra_info dict
1023
        if isinstance(extra_info, dict):
2✔
UNCOV
1024
            info_dict.update(extra_info)
×
1025

1026
        # call parent method
1027
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc