• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 12279337432

11 Dec 2024 03:15PM UTC coverage: 47.031% (+0.2%) from 46.79%
12279337432

Pull #751

github

d4edef
web-flow
Merge eea51f2f7 into 2f9d65d86
Pull Request #751: Fiber trajectory GUI

0 of 114 new or added lines in 1 file covered. (0.0%)

1076 existing lines in 22 files now uncovered.

4246 of 9028 relevant lines covered (47.03%)

0.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.37
/iblrig/base_choice_world.py
1
"""Extends the base_tasks modules by providing task logic around the Choice World protocol."""
2

3
import abc
2✔
4
import logging
2✔
5
import math
2✔
6
import random
2✔
7
import subprocess
2✔
8
import time
2✔
9
from pathlib import Path
2✔
10
from string import ascii_letters
2✔
11
from typing import Annotated, Any
2✔
12

13
import numpy as np
2✔
14
import pandas as pd
2✔
15
from annotated_types import Interval, IsNan
2✔
16
from pydantic import NonNegativeFloat, NonNegativeInt
2✔
17

18
import iblrig.base_tasks
2✔
19
from iblrig import choiceworld, misc
2✔
20
from iblrig.hardware import SOFTCODE
2✔
21
from iblrig.pydantic_definitions import TrialDataModel
2✔
22
from iblutil.io import jsonable
2✔
23
from iblutil.util import Bunch
2✔
24
from pybpodapi.com.messaging.trial import Trial
2✔
25
from pybpodapi.protocol import StateMachine
2✔
26

27
log = logging.getLogger(__name__)
2✔
28

29
NTRIALS_INIT = 2000
2✔
30
NBLOCKS_INIT = 100
2✔
31

32
# TODO: task parameters should be verified through a pydantic model
33
#
34
# Probability = Annotated[float, Field(ge=0.0, le=1.0)]
35
#
36
# class ChoiceWorldParams(BaseModel):
37
#     AUTOMATIC_CALIBRATION: bool = True
38
#     ADAPTIVE_REWARD: bool = False
39
#     BONSAI_EDITOR: bool = False
40
#     CALIBRATION_VALUE: float = 0.067
41
#     CONTRAST_SET: list[Probability] = Field([1.0, 0.25, 0.125, 0.0625, 0.0], min_length=1)
42
#     CONTRAST_SET_PROBABILITY_TYPE: Literal['uniform', 'skew_zero'] = 'uniform'
43
#     GO_TONE_AMPLITUDE: float = 0.0272
44
#     GO_TONE_DURATION: float = 0.11
45
#     GO_TONE_IDX: int = Field(2, ge=0)
46
#     GO_TONE_FREQUENCY: float = Field(5000, gt=0)
47
#     FEEDBACK_CORRECT_DELAY_SECS: float = 1
48
#     FEEDBACK_ERROR_DELAY_SECS: float = 2
49
#     FEEDBACK_NOGO_DELAY_SECS: float = 2
50
#     INTERACTIVE_DELAY: float = 0.0
51
#     ITI_DELAY_SECS: float = 0.5
52
#     NTRIALS: int = Field(2000, gt=0)
53
#     PROBABILITY_LEFT: Probability = 0.5
54
#     QUIESCENCE_THRESHOLDS: list[float] = Field(default=[-2, 2], min_length=2, max_length=2)
55
#     QUIESCENT_PERIOD: float = 0.2
56
#     RECORD_AMBIENT_SENSOR_DATA: bool = True
57
#     RECORD_SOUND: bool = True
58
#     RESPONSE_WINDOW: float = 60
59
#     REWARD_AMOUNT_UL: float = 1.5
60
#     REWARD_TYPE: str = 'Water 10% Sucrose'
61
#     STIM_ANGLE: float = 0.0
62
#     STIM_FREQ: float = 0.1
63
#     STIM_GAIN: float = 4.0  # wheel to stimulus relationship (degrees visual angle per mm of wheel displacement)
64
#     STIM_POSITIONS: list[float] = [-35, 35]
65
#     STIM_SIGMA: float = 7.0
66
#     STIM_TRANSLATION_Z: Literal[7, 8] = 7  # 7 for ephys, 8 otherwise. -p:Stim.TranslationZ-{STIM_TRANSLATION_Z} bonsai param
67
#     STIM_REVERSE: bool = False
68
#     SYNC_SQUARE_X: float = 1.33
69
#     SYNC_SQUARE_Y: float = -1.03
70
#     USE_AUTOMATIC_STOPPING_CRITERIONS: bool = True
71
#     VISUAL_STIMULUS: str = 'GaborIBLTask / Gabor2D.bonsai'  # null / passiveChoiceWorld_passive.bonsai
72
#     WHITE_NOISE_AMPLITUDE: float = 0.05
73
#     WHITE_NOISE_DURATION: float = 0.5
74
#     WHITE_NOISE_IDX: int = 3
75

76

77
class ChoiceWorldTrialData(TrialDataModel):
2✔
78
    """Pydantic Model for Trial Data."""
79

80
    contrast: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
81
    stim_probability_left: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
82
    position: float
2✔
83
    quiescent_period: NonNegativeFloat
2✔
84
    reward_amount: NonNegativeFloat
2✔
85
    reward_valve_time: NonNegativeFloat
2✔
86
    stim_angle: Annotated[float, Interval(ge=-180.0, le=180.0)]
2✔
87
    stim_freq: NonNegativeFloat
2✔
88
    stim_gain: float
2✔
89
    stim_phase: Annotated[float, Interval(ge=0.0, le=2 * math.pi)]
2✔
90
    stim_reverse: bool
2✔
91
    stim_sigma: float
2✔
92
    trial_num: NonNegativeInt
2✔
93
    pause_duration: NonNegativeFloat = 0.0
2✔
94

95
    # The following variables are only used in ActiveChoiceWorld
96
    # We keep them here with fixed default values for sake of compatibility
97
    #
98
    # TODO: Yes, this should probably be done differently.
99
    response_side: Annotated[int, Interval(ge=0, le=0)] = 0
2✔
100
    response_time: IsNan[float] = np.nan
2✔
101
    trial_correct: Annotated[bool, Interval(ge=0, le=0)] = False
2✔
102

103

104
class ChoiceWorldSession(
2✔
105
    iblrig.base_tasks.BonsaiRecordingMixin,
106
    iblrig.base_tasks.BonsaiVisualStimulusMixin,
107
    iblrig.base_tasks.BpodMixin,
108
    iblrig.base_tasks.Frame2TTLMixin,
109
    iblrig.base_tasks.RotaryEncoderMixin,
110
    iblrig.base_tasks.SoundMixin,
111
    iblrig.base_tasks.ValveMixin,
112
    iblrig.base_tasks.NetworkSession,
113
):
114
    # task_params = ChoiceWorldParams()
115
    base_parameters_file = Path(__file__).parent.joinpath('base_choice_world_params.yaml')
2✔
116
    TrialDataModel = ChoiceWorldTrialData
2✔
117

118
    def __init__(self, *args, delay_secs=0, **kwargs):
2✔
119
        super().__init__(**kwargs)
2✔
120
        self.task_params['SESSION_DELAY_START'] = delay_secs
2✔
121
        # init behaviour data
122
        self.movement_left = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[0]]
2✔
123
        self.movement_right = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[1]]
2✔
124
        # init counter variables
125
        self.trial_num = -1
2✔
126
        self.block_num = -1
2✔
127
        self.block_trial_num = -1
2✔
128
        # init the tables, there are 2 of them: a trials table and a ambient sensor data table
129
        self.trials_table = self.TrialDataModel.preallocate_dataframe(NTRIALS_INIT)
2✔
130
        self.ambient_sensor_table = pd.DataFrame(
2✔
131
            {
132
                'Temperature_C': np.zeros(NTRIALS_INIT) * np.nan,
133
                'AirPressure_mb': np.zeros(NTRIALS_INIT) * np.nan,
134
                'RelativeHumidity': np.zeros(NTRIALS_INIT) * np.nan,
135
            }
136
        )
137

138
    @staticmethod
2✔
139
    def extra_parser():
2✔
140
        """:return: argparse.parser()"""
141
        parser = super(ChoiceWorldSession, ChoiceWorldSession).extra_parser()
2✔
142
        parser.add_argument(
2✔
143
            '--delay_secs',
144
            dest='delay_secs',
145
            default=0,
146
            type=int,
147
            required=False,
148
            help='initial delay before starting the first trial (default: 0s)',
149
        )
150
        parser.add_argument(
2✔
151
            '--remote',
152
            dest='remote_rigs',
153
            type=str,
154
            required=False,
155
            action='append',
156
            nargs='+',
157
            help='specify one of the remote rigs to interact with over the network',
158
        )
159
        return parser
2✔
160

161
    def start_hardware(self):
2✔
162
        """
163
        In this step we explicitly run the start methods of the various mixins.
164
        The super class start method is overloaded because we need to start the different hardware pieces in order
165
        """
166
        if not self.is_mock:
2✔
UNCOV
167
            self.start_mixin_frame2ttl()
×
UNCOV
168
            self.start_mixin_bpod()
×
UNCOV
169
            self.start_mixin_valve()
×
UNCOV
170
            self.start_mixin_sound()
×
UNCOV
171
            self.start_mixin_rotary_encoder()
×
UNCOV
172
            self.start_mixin_bonsai_cameras()
×
UNCOV
173
            self.start_mixin_bonsai_microphone()
×
UNCOV
174
            self.start_mixin_bonsai_visual_stimulus()
×
UNCOV
175
            self.bpod.register_softcodes(self.softcode_dictionary())
×
176

177
    def _run(self):
2✔
178
        """Run the task with the actual state machine."""
179
        time_last_trial_end = time.time()
2✔
180
        for i in range(self.task_params.NTRIALS):  # Main loop
2✔
181
            # t_overhead = time.time()
182
            self.next_trial()
2✔
183
            log.info(f'Starting trial: {i}')
2✔
184
            # =============================================================================
185
            #     Start state machine definition
186
            # =============================================================================
187
            sma = self.get_state_machine_trial(i)
2✔
188
            log.debug('Sending state machine to bpod')
2✔
189
            # Send state machine description to Bpod device
190
            self.bpod.send_state_machine(sma)
2✔
191
            # t_overhead = time.time() - t_overhead
192
            # The ITI_DELAY_SECS defines the grey screen period within the state machine, where the
193
            # Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next
194
            dead_time = self.task_params.get('DEAD_TIME', 0.5)
2✔
195
            dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end)
2✔
196
            # wait to achieve the desired ITI duration
197
            if dt > 0:
2✔
UNCOV
198
                time.sleep(dt)
×
199
            # Run state machine
200
            log.debug('running state machine')
2✔
201
            self.bpod.run_state_machine(sma)  # Locks until state machine 'exit' is reached
2✔
202
            time_last_trial_end = time.time()
2✔
203
            # handle pause event
204
            flag_pause = self.paths.SESSION_FOLDER.joinpath('.pause')
2✔
205
            flag_stop = self.paths.SESSION_FOLDER.joinpath('.stop')
2✔
206
            if flag_pause.exists() and i < (self.task_params.NTRIALS - 1):
2✔
207
                log.info(f'Pausing session inbetween trials {i} and {i + 1}')
2✔
208
                while flag_pause.exists() and not flag_stop.exists():
2✔
209
                    time.sleep(1)
2✔
210
                self.trials_table.at[self.trial_num, 'pause_duration'] = time.time() - time_last_trial_end
2✔
211
                if not flag_stop.exists():
2✔
212
                    log.info('Resuming session')
2✔
213

214
            # save trial and update log
215
            self.trial_completed(self.bpod.session.current_trial.export())
2✔
216
            self.ambient_sensor_table.loc[i] = self.bpod.get_ambient_sensor_reading()
2✔
217
            self.show_trial_log()
2✔
218

219
            # handle stop event
220
            if flag_stop.exists():
2✔
221
                log.info('Stopping session after trial %d', i)
2✔
222
                flag_stop.unlink()
2✔
223
                break
2✔
224

225
    def mock(self, file_jsonable_fixture=None):
2✔
226
        """
227
        Instantiate a state machine and Bpod object to simulate a task's run.
228

229
        This is useful to test or display the state machine flow.
230
        """
231
        super().mock()
2✔
232

233
        if file_jsonable_fixture is not None:
2✔
234
            task_data = jsonable.read(file_jsonable_fixture)
2✔
235
            # pop-out the bpod data from the table
236
            bpod_data = []
2✔
237
            for td in task_data:
2✔
238
                bpod_data.append(td.pop('behavior_data'))
2✔
239

240
            class MockTrial(Trial):
2✔
241
                def export(self):
2✔
242
                    return np.random.choice(bpod_data)
2✔
243
        else:
244

245
            class MockTrial(Trial):
2✔
246
                def export(self):
2✔
UNCOV
247
                    return {}
×
248

249
        self.bpod.session.trials = [MockTrial()]
2✔
250
        self.bpod.send_state_machine = lambda k: None
2✔
251
        self.bpod.run_state_machine = lambda k: time.sleep(1.2)
2✔
252

253
        daction = ('dummy', 'action')
2✔
254
        self.sound = Bunch({'GO_TONE': daction, 'WHITE_NOISE': daction})
2✔
255

256
        self.bpod.actions.update(
2✔
257
            {
258
                'play_tone': daction,
259
                'play_noise': daction,
260
                'stop_sound': daction,
261
                'rotary_encoder_reset': daction,
262
                'bonsai_hide_stim': daction,
263
                'bonsai_show_stim': daction,
264
                'bonsai_closed_loop': daction,
265
                'bonsai_freeze_stim': daction,
266
                'bonsai_show_center': daction,
267
            }
268
        )
269

270
    def get_graphviz_task(self, output_file=None, view=True):
2✔
271
        """
272
        Get the state machine's states diagram in Digraph format.
273

274
        :param output_file:
275
        :return:
276
        """
277
        import graphviz
2✔
278

279
        self.next_trial()
2✔
280
        sma = self.get_state_machine_trial(0)
2✔
281
        if sma is None:
2✔
282
            return
2✔
283
        states_indices = {i: k for i, k in enumerate(sma.state_names)}
2✔
284
        states_indices.update({(i + 10000): k for i, k in enumerate(sma.undeclared)})
2✔
285
        states_letters = {k: ascii_letters[i] for i, k in enumerate(sma.state_names)}
2✔
286
        dot = graphviz.Digraph(comment='The Great IBL Task')
2✔
287
        edges = []
2✔
288

289
        for i in range(len(sma.state_names)):
2✔
290
            letter = states_letters[sma.state_names[i]]
2✔
291
            dot.node(letter, sma.state_names[i])
2✔
292
            if ~np.isnan(sma.state_timer_matrix[i]):
2✔
293
                out_state = states_indices[sma.state_timer_matrix[i]]
2✔
294
                edges.append(f'{letter}{states_letters[out_state]}')
2✔
295
            for input in sma.input_matrix[i]:
2✔
296
                if input[0] == 0:
2✔
297
                    edges.append(f'{letter}{states_letters[states_indices[input[1]]]}')
2✔
298
        dot.edges(edges)
2✔
299
        if output_file is not None:
2✔
300
            try:
2✔
301
                dot.render(output_file, view=view)
2✔
302
            except graphviz.exceptions.ExecutableNotFound:
2✔
303
                log.info('Graphviz system executable not found, cannot render the graph')
2✔
304
        return dot
2✔
305

306
    def _instantiate_state_machine(self, *args, **kwargs):
2✔
307
        return StateMachine(self.bpod)
2✔
308

309
    def get_state_machine_trial(self, i):
2✔
310
        # we define the trial number here for subclasses that may need it
311
        sma = self._instantiate_state_machine(trial_number=i)
2✔
312

313
        if i == 0:  # First trial exception start camera
2✔
314
            session_delay_start = self.task_params.get('SESSION_DELAY_START', 0)
2✔
315
            log.info('First trial initializing, will move to next trial only if:')
2✔
316
            log.info('1. camera is detected')
2✔
317
            log.info(f'2. {session_delay_start} sec have elapsed')
2✔
318
            sma.add_state(
2✔
319
                state_name='trial_start',
320
                state_timer=0,
321
                state_change_conditions={'Port1In': 'delay_initiation'},
322
                output_actions=[('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
323
            )  # start camera
324
            sma.add_state(
2✔
325
                state_name='delay_initiation',
326
                state_timer=session_delay_start,
327
                output_actions=[],
328
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
329
            )
330
        else:
331
            sma.add_state(
2✔
332
                state_name='trial_start',
333
                state_timer=0,  # ~100µs hardware irreducible delay
334
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
335
                output_actions=[self.bpod.actions.stop_sound, ('BNC1', 255)],
336
            )  # stop all sounds
337

338
        # Reset the rotary encoder by sending the following opcodes via the modules serial interface
339
        # - 'Z' (ASCII 90): Set current rotary encoder position to zero
340
        # - 'E' (ASCII 69): Enable all position thresholds (that may have been disabled by a threshold-crossing)
341
        # cf. https://sanworks.github.io/Bpod_Wiki/serial-interfaces/rotary-encoder-module-serial-interface/
342
        sma.add_state(
2✔
343
            state_name='reset_rotary_encoder',
344
            state_timer=0,
345
            output_actions=[self.bpod.actions.rotary_encoder_reset],
346
            state_change_conditions={'Tup': 'quiescent_period'},
347
        )
348

349
        # Quiescent Period. If the wheel is moved past one of the thresholds: Reset the rotary encoder and start over.
350
        # Continue with the stimulation once the quiescent period has passed without triggering movement thresholds.
351
        sma.add_state(
2✔
352
            state_name='quiescent_period',
353
            state_timer=self.quiescent_period,
354
            output_actions=[],
355
            state_change_conditions={
356
                'Tup': 'stim_on',
357
                self.movement_left: 'reset_rotary_encoder',
358
                self.movement_right: 'reset_rotary_encoder',
359
            },
360
        )
361

362
        # Show the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
363
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
364
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
365
        sma.add_state(
2✔
366
            state_name='stim_on',
367
            state_timer=0.1,
368
            output_actions=[self.bpod.actions.bonsai_show_stim],
369
            state_change_conditions={'BNC1High': 'interactive_delay', 'BNC1Low': 'interactive_delay', 'Tup': 'interactive_delay'},
370
        )
371

372
        # Defined delay between visual and auditory cue
373
        sma.add_state(
2✔
374
            state_name='interactive_delay',
375
            state_timer=self.task_params.INTERACTIVE_DELAY,
376
            output_actions=[],
377
            state_change_conditions={'Tup': 'play_tone'},
378
        )
379

380
        # Play tone. Move to next state if sound is detected. Use the state-timer as a backup to prevent a stall.
381
        sma.add_state(
2✔
382
            state_name='play_tone',
383
            state_timer=0.1,
384
            output_actions=[self.bpod.actions.play_tone],
385
            state_change_conditions={'Tup': 'reset2_rotary_encoder', 'BNC2High': 'reset2_rotary_encoder'},
386
        )
387

388
        # Reset rotary encoder (see above). Move on after brief delay (to avoid a race conditions in the bonsai flow).
389
        sma.add_state(
2✔
390
            state_name='reset2_rotary_encoder',
391
            state_timer=0.05,
392
            output_actions=[self.bpod.actions.rotary_encoder_reset],
393
            state_change_conditions={'Tup': 'closed_loop'},
394
        )
395

396
        # Start the closed loop state in which the animal controls the position of the visual stimulus by means of the
397
        # rotary encoder. The three possible outcomes are:
398
        # 1) wheel has NOT been moved past a threshold: continue with no-go condition
399
        # 2) wheel has been moved in WRONG direction: continue with error condition
400
        # 3) wheel has been moved in CORRECT direction: continue with reward condition
401

402
        sma.add_state(
2✔
403
            state_name='closed_loop',
404
            state_timer=self.task_params.RESPONSE_WINDOW,
405
            output_actions=[self.bpod.actions.bonsai_closed_loop],
406
            state_change_conditions={'Tup': 'no_go', self.event_error: 'freeze_error', self.event_reward: 'freeze_reward'},
407
        )
408

409
        # No-go: hide the visual stimulus and play white noise. Go to exit_state after FEEDBACK_NOGO_DELAY_SECS.
410
        sma.add_state(
2✔
411
            state_name='no_go',
412
            state_timer=self.feedback_nogo_delay,
413
            output_actions=[self.bpod.actions.bonsai_hide_stim, self.bpod.actions.play_noise],
414
            state_change_conditions={'Tup': 'exit_state'},
415
        )
416

417
        # Error: Freeze the stimulus and play white noise.
418
        # Continue to hide_stim/exit_state once FEEDBACK_ERROR_DELAY_SECS have passed.
419
        sma.add_state(
2✔
420
            state_name='freeze_error',
421
            state_timer=0,
422
            output_actions=[self.bpod.actions.bonsai_freeze_stim],
423
            state_change_conditions={'Tup': 'error'},
424
        )
425
        sma.add_state(
2✔
426
            state_name='error',
427
            state_timer=self.feedback_error_delay,
428
            output_actions=[self.bpod.actions.play_noise],
429
            state_change_conditions={'Tup': 'hide_stim'},
430
        )
431

432
        # Reward: open the valve for a defined duration (and set BNC1 to high), freeze stimulus in center of screen.
433
        # Continue to hide_stim/exit_state once FEEDBACK_CORRECT_DELAY_SECS have passed.
434
        sma.add_state(
2✔
435
            state_name='freeze_reward',
436
            state_timer=0,
437
            output_actions=[self.bpod.actions.bonsai_show_center],
438
            state_change_conditions={'Tup': 'reward'},
439
        )
440
        sma.add_state(
2✔
441
            state_name='reward',
442
            state_timer=self.reward_time,
443
            output_actions=[('Valve1', 255), ('BNC1', 255)],
444
            state_change_conditions={'Tup': 'correct'},
445
        )
446
        sma.add_state(
2✔
447
            state_name='correct',
448
            state_timer=self.feedback_correct_delay - self.reward_time,
449
            output_actions=[],
450
            state_change_conditions={'Tup': 'hide_stim'},
451
        )
452

453
        # Hide the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
454
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
455
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
456
        sma.add_state(
2✔
457
            state_name='hide_stim',
458
            state_timer=0.1,
459
            output_actions=[self.bpod.actions.bonsai_hide_stim],
460
            state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'},
461
        )
462

463
        # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event.
464
        sma.add_state(
2✔
465
            state_name='exit_state',
466
            state_timer=self.task_params.ITI_DELAY_SECS,
467
            output_actions=[('BNC1', 255)],
468
            state_change_conditions={'Tup': 'exit'},
469
        )
470

471
        return sma
2✔
472

473
    @abc.abstractmethod
2✔
474
    def next_trial(self):
2✔
475
        pass
×
476

477
    @property
2✔
478
    def default_reward_amount(self):
2✔
479
        return self.task_params.REWARD_AMOUNT_UL
2✔
480

481
    def draw_next_trial_info(self, pleft=0.5, **kwargs):
2✔
482
        """Draw next trial variables.
483

484
        calls :meth:`send_trial_info_to_bonsai`.
485
        This is called by the `next_trial` method before updating the Bpod state machine.
486
        """
487
        assert len(self.task_params.STIM_POSITIONS) == 2, 'Only two positions are supported'
2✔
488
        contrast = misc.draw_contrast(self.task_params.CONTRAST_SET, self.task_params.CONTRAST_SET_PROBABILITY_TYPE)
2✔
489
        position = int(np.random.choice(self.task_params.STIM_POSITIONS, p=[pleft, 1 - pleft]))
2✔
490
        quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.truncated_exponential(
2✔
491
            scale=0.35, min_value=0.2, max_value=0.5
492
        )
493
        self.trials_table.at[self.trial_num, 'quiescent_period'] = quiescent_period
2✔
494
        self.trials_table.at[self.trial_num, 'contrast'] = contrast
2✔
495
        self.trials_table.at[self.trial_num, 'stim_phase'] = random.uniform(0, 2 * math.pi)
2✔
496
        self.trials_table.at[self.trial_num, 'stim_sigma'] = self.task_params.STIM_SIGMA
2✔
497
        self.trials_table.at[self.trial_num, 'stim_angle'] = self.task_params.STIM_ANGLE
2✔
498
        self.trials_table.at[self.trial_num, 'stim_gain'] = self.stimulus_gain
2✔
499
        self.trials_table.at[self.trial_num, 'stim_freq'] = self.task_params.STIM_FREQ
2✔
500
        self.trials_table.at[self.trial_num, 'stim_reverse'] = self.task_params.STIM_REVERSE
2✔
501
        self.trials_table.at[self.trial_num, 'trial_num'] = self.trial_num
2✔
502
        self.trials_table.at[self.trial_num, 'position'] = position
2✔
503
        self.trials_table.at[self.trial_num, 'reward_amount'] = self.default_reward_amount
2✔
504
        self.trials_table.at[self.trial_num, 'stim_probability_left'] = pleft
2✔
505

506
        # use the kwargs dict to override computed values
507
        for key, value in kwargs.items():
2✔
508
            if key == 'index':
2✔
UNCOV
509
                pass
×
510
            self.trials_table.at[self.trial_num, key] = value
2✔
511

512
        self.send_trial_info_to_bonsai()
2✔
513

514
    def trial_completed(self, bpod_data: dict[str, Any]) -> None:
2✔
515
        # if the reward state has not been triggered, null the reward
516
        if np.isnan(bpod_data['States timestamps']['reward'][0][0]):
2✔
517
            self.trials_table.at[self.trial_num, 'reward_amount'] = 0
2✔
518
        self.trials_table.at[self.trial_num, 'reward_valve_time'] = self.reward_time
2✔
519
        # update cumulative reward value
520
        self.session_info.TOTAL_WATER_DELIVERED += self.trials_table.at[self.trial_num, 'reward_amount']
2✔
521
        self.session_info.NTRIALS += 1
2✔
522
        # SAVE TRIAL DATA
523
        self.save_trial_data_to_json(bpod_data)
2✔
524
        # this is a flag for the online plots. If online plots were in pyqt5, there is a file watcher functionality
525
        Path(self.paths['DATA_FILE_PATH']).parent.joinpath('new_trial.flag').touch()
2✔
526
        self.paths.SESSION_FOLDER.joinpath('transfer_me.flag').touch()
2✔
527
        self.check_sync_pulses(bpod_data=bpod_data)
2✔
528

529
    def check_sync_pulses(self, bpod_data):
2✔
530
        # todo move this in the post trial when we have a task flow
531
        if not self.bpod.is_connected:
2✔
532
            return
2✔
UNCOV
533
        events = bpod_data['Events timestamps']
×
UNCOV
534
        if not misc.get_port_events(events, name='BNC1'):
×
UNCOV
535
            log.warning("NO FRAME2TTL PULSES RECEIVED ON BPOD'S TTL INPUT 1")
×
UNCOV
536
        if not misc.get_port_events(events, name='BNC2'):
×
UNCOV
537
            log.warning("NO SOUND SYNC PULSES RECEIVED ON BPOD'S TTL INPUT 2")
×
UNCOV
538
        if not misc.get_port_events(events, name='Port1'):
×
UNCOV
539
            log.warning("NO CAMERA SYNC PULSES RECEIVED ON BPOD'S BEHAVIOR PORT 1")
×
540

541
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
542
        """
543
        Log the details of the current trial.
544

545
        This method retrieves information about the current trial from the
546
        trials table and logs it. It can also incorporate additional information
547
        provided through the `extra_info` parameter.
548

549
        Parameters
550
        ----------
551
        extra_info : dict[str, Any], optional
552
            A dictionary containing additional information to include in the
553
            log.
554

555
        log_level : int, optional
556
            The logging level to use when logging the trial information.
557
            Default is logging.INFO.
558

559
        Notes
560
        -----
561
        When overloading, make sure to call the super class and pass additional
562
        log items by means of the extra_info parameter. See the implementation
563
        of :py:meth:`~iblrig.base_choice_world.ActiveChoiceWorldSession.show_trial_log` in
564
        :mod:`~iblrig.base_choice_world.ActiveChoiceWorldSession` for reference.
565
        """
566
        # construct base info dict
567
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
568
        info_dict = {
2✔
569
            'Stim. Position': trial_info.position,
570
            'Stim. Contrast': trial_info.contrast,
571
            'Stim. Phase': f'{trial_info.stim_phase:.2f}',
572
            'Stim. p Left': trial_info.stim_probability_left,
573
            'Water delivered': f'{self.session_info.TOTAL_WATER_DELIVERED:.1f} µl',
574
            'Time from Start': self.time_elapsed,
575
            'Temperature': f'{self.ambient_sensor_table.loc[self.trial_num, "Temperature_C"]:.1f} °C',
576
            'Air Pressure': f'{self.ambient_sensor_table.loc[self.trial_num, "AirPressure_mb"]:.1f} mb',
577
            'Rel. Humidity': f'{self.ambient_sensor_table.loc[self.trial_num, "RelativeHumidity"]:.1f} %',
578
        }
579

580
        # update info dict with extra_info dict
581
        if isinstance(extra_info, dict):
2✔
582
            info_dict.update(extra_info)
2✔
583

584
        # log info dict
585
        log.log(log_level, f'Outcome of Trial #{trial_info.trial_num}:')
2✔
586
        max_key_length = max(len(key) for key in info_dict)
2✔
587
        for key, value in info_dict.items():
2✔
588
            spaces = (max_key_length - len(key)) * ' '
2✔
589
            log.log(log_level, f'- {key}: {spaces}{str(value)}')
2✔
590

591
    @property
2✔
592
    def iti_reward(self):
2✔
593
        """
594
        Returns the ITI time that needs to be set in order to achieve the desired ITI,
595
        by subtracting the time it takes to give a reward from the desired ITI.
596
        """
UNCOV
597
        return self.task_params.ITI_CORRECT - self.calibration.get('REWARD_VALVE_TIME', None)
×
598

599
    """
2✔
600
    Those are the properties that are used in the state machine code
601
    """
602

603
    @property
2✔
604
    def reward_time(self):
2✔
605
        return self.compute_reward_time(amount_ul=self.trials_table.at[self.trial_num, 'reward_amount'])
2✔
606

607
    @property
2✔
608
    def quiescent_period(self):
2✔
609
        return self.trials_table.at[self.trial_num, 'quiescent_period']
2✔
610

611
    @property
2✔
612
    def feedback_correct_delay(self):
2✔
613
        return self.task_params['FEEDBACK_CORRECT_DELAY_SECS']
2✔
614

615
    @property
2✔
616
    def feedback_error_delay(self):
2✔
617
        return self.task_params['FEEDBACK_ERROR_DELAY_SECS']
2✔
618

619
    @property
2✔
620
    def feedback_nogo_delay(self):
2✔
621
        return self.task_params['FEEDBACK_NOGO_DELAY_SECS']
2✔
622

623
    @property
2✔
624
    def position(self):
2✔
625
        return self.trials_table.at[self.trial_num, 'position']
2✔
626

627
    @property
2✔
628
    def event_error(self):
2✔
629
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(-1 if self.task_params.STIM_REVERSE else 1) * self.position]
2✔
630

631
    @property
2✔
632
    def event_reward(self):
2✔
633
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(1 if self.task_params.STIM_REVERSE else -1) * self.position]
2✔
634

635

636
class HabituationChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
637
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
638

639
    delay_to_stim_center: NonNegativeFloat
2✔
640

641

642
class HabituationChoiceWorldSession(ChoiceWorldSession):
2✔
643
    protocol_name = '_iblrig_tasks_habituationChoiceWorld'
2✔
644
    TrialDataModel = HabituationChoiceWorldTrialData
2✔
645

646
    def next_trial(self):
2✔
647
        self.trial_num += 1
2✔
648
        self.draw_next_trial_info()
2✔
649

650
    def draw_next_trial_info(self, *args, **kwargs):
2✔
651
        # update trial table fields specific to habituation choice world
652
        self.trials_table.at[self.trial_num, 'delay_to_stim_center'] = np.random.normal(self.task_params.DELAY_TO_STIM_CENTER, 2)
2✔
653
        super().draw_next_trial_info(*args, **kwargs)
2✔
654

655
    def get_state_machine_trial(self, i):
2✔
656
        sma = StateMachine(self.bpod)
2✔
657

658
        if i == 0:  # First trial exception start camera
2✔
659
            log.info('Waiting for camera pulses...')
2✔
660
            sma.add_state(
2✔
661
                state_name='iti',
662
                state_timer=3600,
663
                state_change_conditions={'Port1In': 'stim_on'},
664
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
665
            )  # start camera
666
        else:
667
            # NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on.
668
            # During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end;
669
            # the offset of this state is trial start!
UNCOV
670
            sma.add_state(
×
671
                state_name='iti',
672
                state_timer=1,  # Stim off for 1 sec
673
                state_change_conditions={'Tup': 'stim_on'},
674
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)],
675
            )
676
        # This stim_on state is considered the actual trial start
677
        sma.add_state(
2✔
678
            state_name='stim_on',
679
            state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'],
680
            state_change_conditions={'Tup': 'stim_center'},
681
            output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone],
682
        )
683

684
        sma.add_state(
2✔
685
            state_name='stim_center',
686
            state_timer=0.5,
687
            state_change_conditions={'Tup': 'reward'},
688
            output_actions=[self.bpod.actions.bonsai_show_center],
689
        )
690

691
        sma.add_state(
2✔
692
            state_name='reward',
693
            state_timer=self.reward_time,  # the length of time to leave reward valve open, i.e. reward size
694
            state_change_conditions={'Tup': 'post_reward'},
695
            output_actions=[('Valve1', 255), ('BNC1', 255)],
696
        )
697
        # This state defines the period after reward where Bpod TTL is LOW.
698
        # NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit.
699
        # The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds.
700
        sma.add_state(
2✔
701
            state_name='post_reward',
702
            state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time,
703
            state_change_conditions={'Tup': 'exit'},
704
            output_actions=[],
705
        )
706
        return sma
2✔
707

708

709
class ActiveChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
710
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
711

712
    response_side: Annotated[int, Interval(ge=-1, le=1)]
2✔
713
    response_time: NonNegativeFloat
2✔
714
    trial_correct: bool
2✔
715

716

717
class ActiveChoiceWorldSession(ChoiceWorldSession):
2✔
718
    """
719
    The ActiveChoiceWorldSession is a base class for protocols where the mouse is actively making decisions
720
    by turning the wheel. It has the following characteristics
721

722
    -   it is trial based
723
    -   it is decision based
724
    -   left and right simulus are equiprobable: there is no biased block
725
    -   a trial can either be correct / error / no_go depending on the side of the stimulus and the response
726
    -   it has a quantifiable performance by computing the proportion of correct trials of passive stimulations protocols or
727
        habituation protocols.
728

729
    The TrainingChoiceWorld, BiasedChoiceWorld are all subclasses of this class
730
    """
731

732
    TrialDataModel = ActiveChoiceWorldTrialData
2✔
733

734
    def __init__(self, **kwargs):
2✔
735
        super().__init__(**kwargs)
2✔
736
        self.trials_table['stim_probability_left'] = np.zeros(NTRIALS_INIT, dtype=np.float64)
2✔
737

738
    def _run(self):
2✔
739
        # starts online plotting
740
        if self.interactive:
2✔
UNCOV
741
            subprocess.Popen(
×
742
                ['view_session', str(self.paths['DATA_FILE_PATH']), str(self.paths['SETTINGS_FILE_PATH'])],
743
                stdout=subprocess.DEVNULL,
744
                stderr=subprocess.STDOUT,
745
            )
746
        super()._run()
2✔
747

748
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
749
        # construct info dict
750
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
751
        info_dict = {
2✔
752
            'Response Time': f'{trial_info.response_time:.2f} s',
753
            'Trial Correct': trial_info.trial_correct,
754
            'N Trials Correct': self.session_info.NTRIALS_CORRECT,
755
            'N Trials Error': self.trial_num - self.session_info.NTRIALS_CORRECT,
756
        }
757

758
        # update info dict with extra_info dict
759
        if isinstance(extra_info, dict):
2✔
760
            info_dict.update(extra_info)
2✔
761

762
        # call parent method
763
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
764

765
    def trial_completed(self, bpod_data):
2✔
766
        """
767
        The purpose of this method is to
768

769
        - update the trials table with information about the behaviour coming from the bpod
770
          Constraints on the state machine data:
771
        - mandatory states: ['correct', 'error', 'no_go', 'reward']
772
        - optional states : ['omit_correct', 'omit_error', 'omit_no_go']
773

774
        :param bpod_data:
775
        :return:
776
        """
777
        # get the response time from the behaviour data
778
        response_time = bpod_data['States timestamps']['closed_loop'][0][1] - bpod_data['States timestamps']['stim_on'][0][0]
2✔
779
        self.trials_table.at[self.trial_num, 'response_time'] = response_time
2✔
780
        # get the trial outcome
781
        state_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go']
2✔
782
        raw_outcome = {sn: ~np.isnan(bpod_data['States timestamps'].get(sn, [[np.nan]])[0][0]) for sn in state_names}
2✔
783
        try:
2✔
784
            outcome = next(k for k in raw_outcome if raw_outcome[k])
2✔
785
            # Update response buffer -1 for left, 0 for nogo, and 1 for rightward
786
            position = self.trials_table.at[self.trial_num, 'position']
2✔
787
            self.trials_table.at[self.trial_num, 'trial_correct'] = 'correct' in outcome
2✔
788
            if 'correct' in outcome:
2✔
789
                self.session_info.NTRIALS_CORRECT += 1
2✔
790
                self.trials_table.at[self.trial_num, 'response_side'] = -np.sign(position)
2✔
791
            elif 'error' in outcome:
2✔
792
                self.trials_table.at[self.trial_num, 'response_side'] = np.sign(position)
2✔
793
            elif 'no_go' in outcome:
2✔
794
                self.trials_table.at[self.trial_num, 'response_side'] = 0
2✔
795
            super().trial_completed(bpod_data)
2✔
796
            # here we throw potential errors after having written the trial to disk
797
            assert np.sum(list(raw_outcome.values())) == 1
2✔
798
            assert position != 0, 'the position value should be either 35 or -35'
2✔
UNCOV
799
        except StopIteration as e:
×
UNCOV
800
            log.error(f'No outcome detected for trial {self.trial_num}.')
×
UNCOV
801
            log.error(f'raw_outcome: {raw_outcome}')
×
UNCOV
802
            log.error('State names: ' + ', '.join(bpod_data['States timestamps'].keys()))
×
UNCOV
803
            raise e
×
UNCOV
804
        except AssertionError as e:
×
UNCOV
805
            log.error(f'Assertion Error in trial {self.trial_num}.')
×
UNCOV
806
            log.error(f'raw_outcome: {raw_outcome}')
×
UNCOV
807
            log.error('State names: ' + ', '.join(bpod_data['States timestamps'].keys()))
×
UNCOV
808
            raise e
×
809

810

811
class BiasedChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
812
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
813

814
    block_num: NonNegativeInt = 0
2✔
815
    block_trial_num: NonNegativeInt = 0
2✔
816

817

818
class BiasedChoiceWorldSession(ActiveChoiceWorldSession):
2✔
819
    """
820
    Biased choice world session is the instantiation of ActiveChoiceWorld where the notion of biased
821
    blocks is introduced.
822
    """
823

824
    base_parameters_file = Path(__file__).parent.joinpath('base_biased_choice_world_params.yaml')
2✔
825
    protocol_name = '_iblrig_tasks_biasedChoiceWorld'
2✔
826
    TrialDataModel = BiasedChoiceWorldTrialData
2✔
827

828
    def __init__(self, **kwargs):
2✔
829
        super().__init__(**kwargs)
2✔
830
        self.blocks_table = pd.DataFrame(
2✔
831
            {'probability_left': np.zeros(NBLOCKS_INIT) * np.nan, 'block_length': np.zeros(NBLOCKS_INIT, dtype=np.int16) * -1}
832
        )
833

834
    def new_block(self):
2✔
835
        """
836
        if block_init_5050
837
            First block has 50/50 probability of leftward stim
838
            is 90 trials long
839
        """
840
        self.block_num += 1  # the block number is zero based
2✔
841
        self.block_trial_num = 0
2✔
842

843
        # handles the block length logic
844
        if self.task_params.BLOCK_INIT_5050 and self.block_num == 0:
2✔
845
            block_len = 90
2✔
846
        else:
847
            block_len = int(
2✔
848
                misc.truncated_exponential(
849
                    scale=self.task_params.BLOCK_LEN_FACTOR,
850
                    min_value=self.task_params.BLOCK_LEN_MIN,
851
                    max_value=self.task_params.BLOCK_LEN_MAX,
852
                )
853
            )
854
        if self.block_num == 0:
2✔
855
            pleft = 0.5 if self.task_params.BLOCK_INIT_5050 else np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
856
        elif self.block_num == 1 and self.task_params.BLOCK_INIT_5050:
2✔
857
            pleft = np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
858
        else:
859
            # this switches the probability of leftward stim for the next block
860
            pleft = round(abs(1 - self.blocks_table.loc[self.block_num - 1, 'probability_left']), 1)
2✔
861
        self.blocks_table.at[self.block_num, 'block_length'] = block_len
2✔
862
        self.blocks_table.at[self.block_num, 'probability_left'] = pleft
2✔
863

864
    def next_trial(self):
2✔
865
        self.trial_num += 1
2✔
866
        # if necessary update the block number
867
        self.block_trial_num += 1
2✔
868
        if self.block_num < 0 or self.block_trial_num > (self.blocks_table.loc[self.block_num, 'block_length'] - 1):
2✔
869
            self.new_block()
2✔
870
        # get and store probability left
871
        pleft = self.blocks_table.loc[self.block_num, 'probability_left']
2✔
872
        # update trial table fields specific to biased choice world task
873
        self.trials_table.at[self.trial_num, 'block_num'] = self.block_num
2✔
874
        self.trials_table.at[self.trial_num, 'block_trial_num'] = self.block_trial_num
2✔
875
        # save and send trial info to bonsai
876
        self.draw_next_trial_info(pleft=pleft)
2✔
877

878
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
879
        # construct info dict
880
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
881
        info_dict = {
2✔
882
            'Block Number': trial_info.block_num,
883
            'Block Length': self.blocks_table.loc[self.block_num, 'block_length'],
884
            'N Trials in Block': trial_info.block_trial_num,
885
        }
886

887
        # update info dict with extra_info dict
888
        if isinstance(extra_info, dict):
2✔
UNCOV
889
            info_dict.update(extra_info)
×
890

891
        # call parent method
892
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
893

894

895
class TrainingChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
896
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ActiveChoiceWorldTrialData`."""
897

898
    training_phase: NonNegativeInt
2✔
899
    debias_trial: bool
2✔
900
    signed_contrast: float | None = None
2✔
901

902

903
class TrainingChoiceWorldSession(ActiveChoiceWorldSession):
2✔
904
    """
905
    The TrainingChoiceWorldSession corresponds to the first training protocol of the choice world task.
906
    This protocol has a complicated adaptation of the number of contrasts (embodied by the training_phase
907
    property) and the reward amount, embodied by the adaptive_reward property.
908
    """
909

910
    protocol_name = '_iblrig_tasks_trainingChoiceWorld'
2✔
911
    TrialDataModel = TrainingChoiceWorldTrialData
2✔
912

913
    def __init__(self, training_phase=-1, adaptive_reward=-1.0, adaptive_gain=None, **kwargs):
2✔
914
        super().__init__(**kwargs)
2✔
915
        inferred_training_phase, inferred_adaptive_reward, inferred_adaptive_gain = self.get_subject_training_info()
2✔
916
        if training_phase == -1:
2✔
917
            log.critical(f'Got training phase: {inferred_training_phase}')
2✔
918
            self.training_phase = inferred_training_phase
2✔
919
        else:
920
            log.critical(f'Training phase manually set to: {training_phase}')
2✔
921
            self.training_phase = training_phase
2✔
922
        if adaptive_reward == -1:
2✔
923
            log.critical(f'Got Adaptive reward {inferred_adaptive_reward} uL')
2✔
924
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = inferred_adaptive_reward
2✔
925
        else:
926
            log.critical(f'Adaptive reward manually set to {adaptive_reward} uL')
2✔
927
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = adaptive_reward
2✔
928
        if adaptive_gain is None:
2✔
929
            log.critical(f'Got Adaptive gain {inferred_adaptive_gain} degrees/mm')
2✔
930
            self.session_info['ADAPTIVE_GAIN_VALUE'] = inferred_adaptive_gain
2✔
931
        else:
932
            log.critical(f'Adaptive gain manually set to {adaptive_gain} degrees/mm')
2✔
933
            self.session_info['ADAPTIVE_GAIN_VALUE'] = adaptive_gain
2✔
934
        self.var = {'training_phase_trial_counts': np.zeros(6), 'last_10_responses_sides': np.zeros(10)}
2✔
935

936
    @property
2✔
937
    def default_reward_amount(self):
2✔
938
        return self.session_info.get('ADAPTIVE_REWARD_AMOUNT_UL', self.task_params.REWARD_AMOUNT_UL)
2✔
939

940
    @property
2✔
941
    def stimulus_gain(self) -> float:
2✔
942
        return self.session_info.get('ADAPTIVE_GAIN_VALUE')
2✔
943

944
    def get_subject_training_info(self):
2✔
945
        """
946
        Get the previous session's according to this session parameters and deduce the
947
        training level, adaptive reward amount and adaptive gain value.
948

949
        Returns
950
        -------
951
        training_info: dict
952
            Dictionary with keys: training_phase, adaptive_reward, adaptive_gain
953
        """
954
        training_info, _ = choiceworld.get_subject_training_info(
2✔
955
            subject_name=self.session_info.SUBJECT_NAME,
956
            task_name=self.protocol_name,
957
            stim_gain=self.task_params.AG_INIT_VALUE,
958
            stim_gain_on_error=self.task_params.STIM_GAIN,
959
            default_reward=self.task_params.REWARD_AMOUNT_UL,
960
            local_path=self.iblrig_settings['iblrig_local_data_path'],
961
            remote_path=self.iblrig_settings['iblrig_remote_data_path'],
962
            lab=self.iblrig_settings['ALYX_LAB'],
963
            iblrig_settings=self.iblrig_settings,
964
        )
965
        return training_info['training_phase'], training_info['adaptive_reward'], training_info['adaptive_gain']
2✔
966

967
    def compute_performance(self):
2✔
968
        """Aggregate the trials table to compute the performance of the mouse on each contrast."""
969
        self.trials_table['signed_contrast'] = self.trials_table.contrast * self.trials_table.position
2✔
970
        performance = self.trials_table.groupby(['signed_contrast']).agg(
2✔
971
            last_50_perf=pd.NamedAgg(column='trial_correct', aggfunc=lambda x: np.sum(x[np.maximum(-50, -x.size) :]) / 50),
972
            ntrials=pd.NamedAgg(column='trial_correct', aggfunc='count'),
973
        )
974
        return performance
2✔
975

976
    def check_training_phase(self):
2✔
977
        """Check if the mouse is ready to move to the next training phase."""
978
        move_on = False
2✔
979
        if self.training_phase == 0:  # each of the -1, -.5, .5, 1 contrast should be above 80% perf to switch
2✔
980
            performance = self.compute_performance()
2✔
981
            passing = performance[np.abs(performance.index) >= 0.5]['last_50_perf']
2✔
982
            if np.all(passing > 0.8) and passing.size == 4:
2✔
983
                move_on = True
2✔
984
        elif self.training_phase == 1:  # each of the -.25, .25 should be above 80% perf to switch
2✔
985
            performance = self.compute_performance()
2✔
986
            passing = performance[np.abs(performance.index) == 0.25]['last_50_perf']
2✔
987
            if np.all(passing > 0.8) and passing.size == 2:
2✔
UNCOV
988
                move_on = True
×
UNCOV
989
        elif 5 > self.training_phase >= 2:  # for the next phases, always switch after 200 trials
×
UNCOV
990
            if self.var['training_phase_trial_counts'][self.training_phase] >= 200:
×
UNCOV
991
                move_on = True
×
992
        if move_on:
2✔
993
            self.training_phase = np.minimum(5, self.training_phase + 1)
2✔
994
            log.warning(f'Moving on to training phase {self.training_phase}, {self.trial_num}')
2✔
995

996
    def next_trial(self):
2✔
997
        # update counters
998
        self.trial_num += 1
2✔
999
        self.var['training_phase_trial_counts'][self.training_phase] += 1
2✔
1000
        # check if the subject graduates to a new training phase
1001
        self.check_training_phase()
2✔
1002
        # draw the next trial
1003
        signed_contrast = choiceworld.draw_training_contrast(self.training_phase)
2✔
1004
        position = self.task_params.STIM_POSITIONS[int(np.sign(signed_contrast) == 1)]
2✔
1005
        contrast = np.abs(signed_contrast)
2✔
1006
        # debiasing: if the previous trial was incorrect and easy repeat the trial
1007
        if self.task_params.DEBIAS and self.trial_num >= 1 and self.training_phase < 5:
2✔
1008
            last_contrast = self.trials_table.loc[self.trial_num - 1, 'contrast']
2✔
1009
            do_debias_trial = (self.trials_table.loc[self.trial_num - 1, 'trial_correct'] != 1) and last_contrast >= 0.5
2✔
1010
            self.trials_table.at[self.trial_num, 'debias_trial'] = do_debias_trial
2✔
1011
            if do_debias_trial:
2✔
1012
                iresponse = np.logical_and(
2✔
1013
                    ~self.trials_table['response_side'].isna(), self.trials_table['response_side'] != 0
1014
                )  # trials that had a response
1015
                # takes the average of right responses over last 10 response trials
1016
                average_right = np.mean(self.trials_table['response_side'][iresponse[-np.maximum(10, iresponse.size) :]] == 1)
2✔
1017
                # the next probability of next stimulus being on the left is a draw from a normal distribution
1018
                # centered on average right with sigma 0.5. If it is less than 0.5 the next stimulus will be on the left
1019
                position = self.task_params.STIM_POSITIONS[int(np.random.normal(average_right, 0.5) >= 0.5)]
2✔
1020
                # contrast is the last contrast
1021
                contrast = last_contrast
2✔
1022
        else:
1023
            self.trials_table.at[self.trial_num, 'debias_trial'] = False
2✔
1024
        # save and send trial info to bonsai
1025
        self.draw_next_trial_info(pleft=self.task_params.PROBABILITY_LEFT, position=position, contrast=contrast)
2✔
1026
        self.trials_table.at[self.trial_num, 'training_phase'] = self.training_phase
2✔
1027

1028
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
1029
        # construct info dict
1030
        info_dict = {
2✔
1031
            'Contrast Set': np.unique(np.abs(choiceworld.contrasts_set(self.training_phase))),
1032
            'Training Phase': self.training_phase,
1033
            'Debias Trial': self.trials_table.at[self.trial_num, 'debias_trial'],
1034
        }
1035

1036
        # update info dict with extra_info dict
1037
        if isinstance(extra_info, dict):
2✔
UNCOV
1038
            info_dict.update(extra_info)
×
1039

1040
        # call parent method
1041
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc