• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.24
/psyneulink/core/components/functions/nonstateful/distributionfunctions.py
1
#
2
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
3
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
4
#     http://www.apache.org/licenses/LICENSE-2.0
5
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
6
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7
# See the License for the specific language governing permissions and limitations under the License.
8
#
9
#
10
# ****************************************   DISTRIBUTION FUNCTIONS   **************************************************
11
"""
12

13
* `NormalDist`
14
* `UniformToNormalDist`
15
* `ExponentialDist`
16
* `UniformDist`
17
* `GammaDist`
18
* `WaldDist`
19

20
Overview
21
--------
22

23
Functions that return one or more samples from a distribution.
24

25
"""
26

27
import numpy as np
1✔
28
from beartype import beartype
1✔
29
from scipy.special import erfinv
1✔
30

31
from psyneulink._typing import Optional
1✔
32

33
from psyneulink.core import llvm as pnlvm
1✔
34
from psyneulink.core.components.functions.function import (
1✔
35
    DEFAULT_SEED, Function_Base, FunctionError,
36
    _random_state_getter, _seed_setter, _noise_setter
37
)
38
from psyneulink.core.globals.keywords import (
1✔
39
    ADDITIVE_PARAM, DIST_FUNCTION_TYPE, BETA, DIST_MEAN, DIST_SHAPE, DRIFT_DIFFUSION_ANALYTICAL_FUNCTION,
40
    EXPONENTIAL_DIST_FUNCTION, GAMMA_DIST_FUNCTION, HIGH, LOW, MULTIPLICATIVE_PARAM, NOISE, NORMAL_DIST_FUNCTION,
41
    SCALE, STANDARD_DEVIATION, THRESHOLD, UNIFORM_DIST_FUNCTION, WALD_DIST_FUNCTION, DEFAULT,
42
)
43
from psyneulink.core.globals.utilities import convert_all_elements_to_np_array, convert_to_np_array, ValidParamSpecType
1✔
44
from psyneulink.core.globals.preferences.basepreferenceset import ValidPrefSet
1✔
45

46
from psyneulink.core.globals.parameters import Parameter, check_user_specified
1✔
47

48
__all__ = [
1✔
49
    'DistributionFunction', 'DRIFT_RATE', 'DRIFT_RATE_VARIABILITY', 'DriftDiffusionAnalytical', 'ExponentialDist',
50
    'GammaDist', 'NON_DECISION_TIME', 'NormalDist', 'STARTING_VALUE', 'STARTING_VALUE_VARIABILITY',
51
    'THRESHOLD_VARIABILITY', 'UniformDist', 'UniformToNormalDist', 'WaldDist',
52
]
53

54

55
class DistributionFunction(Function_Base):
1✔
56
    componentType = DIST_FUNCTION_TYPE
1✔
57

58
    def as_mdf_model(self):
1✔
59
        model = super().as_mdf_model()
1✔
60
        self._set_mdf_arg(model, 'shape', self.defaults.variable.shape)
1✔
61
        return model
1✔
62

63

64
class NormalDist(DistributionFunction):
1✔
65
    """
66
    NormalDist(                      \
67
             mean=0.0,               \
68
             standard_deviation=1.0, \
69
             params=None,            \
70
             owner=None,             \
71
             prefs=None              \
72
             )
73

74
    .. _NormalDist:
75

76
    Return a random sample from a normal distribution using numpy.random.normal;
77

78
    *Modulatory Parameters:*
79

80
    | *MULTIPLICATIVE_PARAM:* `standard_deviation <NormalDist.standard_deviation>`
81
    | *ADDITIVE_PARAM:* `mean <NormalDist.mean>`
82
    |
83

84
    Arguments
85
    ---------
86

87
    mean : float : default 0.0
88
        The mean or center of the normal distribution
89

90
    standard_deviation : float : default 1.0
91
        Standard deviation of the normal distribution. Must be > 0.0
92

93
    params : Dict[param keyword: param value] : default None
94
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
95
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
96
        arguments of the constructor.
97

98
    owner : Component
99
        `component <Component>` to which to assign the Function.
100

101
    name : str : default see `name <Function.name>`
102
        specifies the name of the Function.
103

104
    prefs : PreferenceSet or specification dict : default Function.classPreferences
105
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
106

107
    Attributes
108
    ----------
109

110
    mean : float : default 0.0
111
        The mean or center of the normal distribution.
112

113
    random_state : numpy.RandomState
114
        private pseudorandom number generator
115

116
    standard_deviation : float : default 1.0
117
        Standard deviation of the normal distribution; if it is 0.0, returns `mean <NormalDist.mean>`.
118

119
    params : Dict[param keyword: param value] : default None
120
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
121
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
122
        arguments of the constructor.
123

124
    owner : Component
125
        `component <Component>` to which to assign the Function.
126

127
    name : str : default see `name <Function.name>`
128
        specifies the name of the Function.
129

130
    prefs : PreferenceSet or specification dict : default Function.classPreferences
131
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
132

133
    """
134

135
    componentName = NORMAL_DIST_FUNCTION
1✔
136
    _model_spec_generic_type_name = 'onnx::RandomNormal'
1✔
137

138
    class Parameters(DistributionFunction.Parameters):
1✔
139
        """
140
            Attributes
141
            ----------
142

143
                mean
144
                    see `mean <NormalDist.mean>`
145

146
                    :default value: 0.0
147
                    :type: ``float``
148

149
                standard_deviation
150
                    see `standard_deviation <NormalDist.standard_deviation>`
151

152
                    :default value: 1.0
153
                    :type: ``float``
154

155
                random_state
156
                    see `random_state <NormalDist.random_state>`
157

158
                    :default value: None
159
                    :type: ``numpy.random.RandomState``
160
        """
161
        mean = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
162
        standard_deviation = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM], mdf_name='scale')
1✔
163
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
164
        seed = Parameter(DEFAULT_SEED(), modulable=True, fallback_value=DEFAULT, setter=_seed_setter)
1✔
165

166
    @check_user_specified
1✔
167
    @beartype
1✔
168
    def __init__(self,
1✔
169
                 default_variable=None,
170
                 mean=None,
171
                 standard_deviation=None,
172
                 params=None,
173
                 owner=None,
174
                 seed=None,
175
                 prefs:  Optional[ValidPrefSet] = None):
176

177
        super().__init__(
1✔
178
            default_variable=default_variable,
179
            mean=mean,
180
            standard_deviation=standard_deviation,
181
            seed=seed,
182
            params=params,
183
            owner=owner,
184
            prefs=prefs,
185
        )
186

187
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
188
        super()._validate_params(request_set=request_set, target_set=target_set, context=context)
1✔
189

190
        if STANDARD_DEVIATION in target_set and target_set[STANDARD_DEVIATION] is not None:
1!
191
            if target_set[STANDARD_DEVIATION] < 0.0:
1✔
192
                raise FunctionError("The standard_deviation parameter ({}) of {} must be greater than zero.".
193
                                    format(target_set[STANDARD_DEVIATION], self.name))
194

195
    def _function(self,
1✔
196
                 variable=None,
197
                 context=None,
198
                 params=None,
199
                 ):
200
        mean = self._get_current_parameter_value(DIST_MEAN, context)
1✔
201
        standard_deviation = self._get_current_parameter_value(STANDARD_DEVIATION, context)
1✔
202
        random_state = self._get_current_parameter_value("random_state", context)
1✔
203

204
        result = random_state.normal(mean, standard_deviation)
1✔
205

206
        return self.convert_output_type(result)
1✔
207

208
    def _gen_llvm_function_body(self, ctx, builder, params, state, _, arg_out, *, tags:frozenset):
1✔
209
        random_state = ctx.get_random_state_ptr(builder, self, state, params)
1✔
210
        mean_ptr = ctx.get_param_or_state_ptr(builder, self, DIST_MEAN, param_struct_ptr=params)
1✔
211
        std_dev_ptr = ctx.get_param_or_state_ptr(builder, self, STANDARD_DEVIATION, param_struct_ptr=params)
1✔
212
        ret_val_ptr = builder.alloca(ctx.float_ty)
1✔
213
        norm_rand_f = ctx.get_normal_dist_function_by_state(random_state)
1✔
214
        builder.call(norm_rand_f, [random_state, ret_val_ptr])
1✔
215

216
        ret_val = builder.load(ret_val_ptr)
1✔
217
        mean = builder.load(mean_ptr)
1✔
218
        std_dev = builder.load(std_dev_ptr)
1✔
219

220
        ret_val = builder.fmul(ret_val, std_dev)
1✔
221
        ret_val = builder.fadd(ret_val, mean)
1✔
222

223
        builder.store(ret_val, arg_out)
1✔
224
        return builder
1✔
225

226

227
class UniformToNormalDist(DistributionFunction):
1✔
228
    """
229
    UniformToNormalDist(             \
230
             mean=0.0,               \
231
             standard_deviation=1.0, \
232
             params=None,            \
233
             owner=None,             \
234
             prefs=None              \
235
             )
236

237
    .. _UniformToNormalDist:
238

239
    Return a random sample from a normal distribution using first np.random.rand(1) to generate a sample from a uniform
240
    distribution, and then converting that sample to a sample from a normal distribution with the following equation:
241

242
    .. math::
243

244
        normal\\_sample = \\sqrt{2} \\cdot standard\\_dev \\cdot scipy.special.erfinv(2 \\cdot uniform\\_sample - 1)  + mean
245

246
    The uniform --> normal conversion allows for a more direct comparison with MATLAB scripts.
247

248
    .. note::
249

250
        This function requires `SciPy <https://pypi.python.org/pypi/scipy>`_.
251

252
    (https://github.com/jonasrauber/randn-matlab-python)
253

254
    *Modulatory Parameters:*
255

256
    | *MULTIPLICATIVE_PARAM:* `standard_deviation <UniformToNormalDist.standard_deviation>`
257
    | *ADDITIVE_PARAM:* `mean <UniformToNormalDist.mean>`
258
    |
259

260
    Arguments
261
    ---------
262

263
    mean : float : default 0.0
264
        The mean or center of the normal distribution
265

266
    standard_deviation : float : default 1.0
267
        Standard deviation of the normal distribution
268

269
    params : Dict[param keyword: param value] : default None
270
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
271
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
272
        arguments of the constructor.
273

274
    owner : Component
275
        `component <Component>` to which to assign the Function.
276

277
    name : str : default see `name <Function.name>`
278
        specifies the name of the Function.
279

280
    prefs : PreferenceSet or specification dict : default Function.classPreferences
281
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
282

283
    Attributes
284
    ----------
285

286
    mean : float : default 0.0
287
        The mean or center of the normal distribution
288

289
    standard_deviation : float : default 1.0
290
        Standard deviation of the normal distribution
291

292
    params : Dict[param keyword: param value] : default None
293
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
294
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
295
        arguments of the constructor.
296

297
    random_state : numpy.RandomState
298
      private pseudorandom number generator
299

300
    owner : Component
301
        `component <Component>` to which to assign the Function.
302

303
    name : str : default see `name <Function.name>`
304
        specifies the name of the Function.
305

306
    prefs : PreferenceSet or specification dict : default Function.classPreferences
307
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
308

309
    """
310

311
    componentName = NORMAL_DIST_FUNCTION
1✔
312

313
    class Parameters(DistributionFunction.Parameters):
1✔
314
        """
315
            Attributes
316
            ----------
317

318
                variable
319
                    see `variable <UniformToNormalDist.variable>`
320

321
                    :default value: numpy.array([0])
322
                    :type: ``numpy.ndarray``
323
                    :read only: True
324

325
                random_state
326
                    see `random_state <UniformToNormalDist.random_state>`
327

328
                    :default value: None
329
                    :type: ``numpy.random.RandomState``
330

331
                mean
332
                    see `mean <UniformToNormalDist.mean>`
333

334
                    :default value: 0.0
335
                    :type: ``float``
336

337
                standard_deviation
338
                    see `standard_deviation <UniformToNormalDist.standard_deviation>`
339

340
                    :default value: 1.0
341
                    :type: ``float``
342
        """
343
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
344
        seed = Parameter(DEFAULT_SEED(), modulable=True, fallback_value=DEFAULT, setter=_seed_setter)
1✔
345
        variable = Parameter(np.array([0]), read_only=True, pnl_internal=True, constructor_argument='default_variable')
1✔
346
        mean = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
347
        standard_deviation = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
348

349
    @check_user_specified
1✔
350
    @beartype
1✔
351
    def __init__(self,
1✔
352
                 default_variable=None,
353
                 mean=None,
354
                 standard_deviation=None,
355
                 params=None,
356
                 owner=None,
357
                 seed=None,
358
                 prefs:  Optional[ValidPrefSet] = None):
359

360
        super().__init__(
1✔
361
            default_variable=default_variable,
362
            mean=mean,
363
            standard_deviation=standard_deviation,
364
            seed=seed,
365
            params=params,
366
            owner=owner,
367
            prefs=prefs,
368
        )
369

370
    def _function(self,
1✔
371
                 variable=None,
372
                 context=None,
373
                 params=None,
374
                 ):
375

376
        mean = self._get_current_parameter_value(DIST_MEAN, context)
1✔
377
        standard_deviation = self._get_current_parameter_value(STANDARD_DEVIATION, context)
1✔
378
        random_state = self.parameters.random_state._get(context)
1✔
379

380
        sample = random_state.rand(1)[0]
1✔
381
        result = ((np.sqrt(2) * erfinv(2 * sample - 1)) * standard_deviation) + mean
1✔
382

383
        return self.convert_output_type(result)
1✔
384

385

386
class ExponentialDist(DistributionFunction):
1✔
387
    """
388
    ExponentialDist(                \
389
             beta=1.0,              \
390
             params=None,           \
391
             owner=None,            \
392
             prefs=None             \
393
             )
394

395
    .. _ExponentialDist:
396

397
    Return a random sample from a exponential distribution using numpy.random.exponential
398

399
    *Modulatory Parameters:*
400

401
    | *MULTIPLICATIVE_PARAM:* `beta <ExponentialDist.beta>`
402
    |
403

404
    Arguments
405
    ---------
406

407
    beta : float : default 1.0
408
        The scale parameter of the exponential distribution
409

410
    params : Dict[param keyword: param value] : default None
411
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
412
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
413
        arguments of the constructor.
414

415
    owner : Component
416
        `component <Component>` to which to assign the Function.
417

418
    name : str : default see `name <Function.name>`
419
        specifies the name of the Function.
420

421
    prefs : PreferenceSet or specification dict : default Function.classPreferences
422
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
423

424
    Attributes
425
    ----------
426

427
    beta : float : default 1.0
428
        The scale parameter of the exponential distribution
429

430
    random_state : numpy.RandomState
431
        private pseudorandom number generator
432

433
    params : Dict[param keyword: param value] : default None
434
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
435
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
436
        arguments of the constructor.
437

438
    owner : Component
439
        `component <Component>` to which to assign the Function.
440

441
    name : str : default see `name <Function.name>`
442
        specifies the name of the Function.
443

444
    prefs : PreferenceSet or specification dict : default Function.classPreferences
445
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
446

447
    """
448
    componentName = EXPONENTIAL_DIST_FUNCTION
1✔
449

450
    class Parameters(DistributionFunction.Parameters):
1✔
451
        """
452
            Attributes
453
            ----------
454

455
                beta
456
                    see `beta <ExponentialDist.beta>`
457

458
                    :default value: 1.0
459
                    :type: ``float``
460

461
                random_state
462
                    see `random_state <ExponentialDist.random_state>`
463

464
                    :default value: None
465
                    :type: ``numpy.random.RandomState``
466
        """
467
        beta = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
468
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
469
        seed = Parameter(DEFAULT_SEED(), modulable=True, fallback_value=DEFAULT, setter=_seed_setter)
1✔
470

471
    @check_user_specified
1✔
472
    @beartype
1✔
473
    def __init__(self,
1✔
474
                 default_variable=None,
475
                 beta=None,
476
                 seed=None,
477
                 params=None,
478
                 owner=None,
479
                 prefs:  Optional[ValidPrefSet] = None):
480

481
        super().__init__(
1✔
482
            default_variable=default_variable,
483
            beta=beta,
484
            seed=seed,
485
            params=params,
486
            owner=owner,
487
            prefs=prefs,
488
        )
489

490
    def _function(self,
1✔
491
                 variable=None,
492
                 context=None,
493
                 params=None,
494
                 ):
495
        random_state = self._get_current_parameter_value('random_state', context)
1✔
496
        beta = self._get_current_parameter_value(BETA, context)
1✔
497

498
        result = random_state.exponential(beta)
1✔
499

500
        return self.convert_output_type(result)
1✔
501

502

503
class UniformDist(DistributionFunction):
1✔
504
    """
505
    UniformDist(                      \
506
             low=0.0,             \
507
             high=1.0,             \
508
             params=None,           \
509
             owner=None,            \
510
             prefs=None             \
511
             )
512

513
    .. _UniformDist:
514

515
    Return a random sample from a uniform distribution using numpy.random.uniform
516

517
    Arguments
518
    ---------
519

520
    low : float : default 0.0
521
        Lower bound of the uniform distribution
522

523
    high : float : default 1.0
524
        Upper bound of the uniform distribution
525

526
    params : Dict[param keyword: param value] : default None
527
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
528
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
529
        arguments of the constructor.
530

531
    owner : Component
532
        `component <Component>` to which to assign the Function.
533

534
    name : str : default see `name <Function.name>`
535
        specifies the name of the Function.
536

537
    prefs : PreferenceSet or specification dict : default Function.classPreferences
538
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
539

540
    Attributes
541
    ----------
542

543
    low : float : default 0.0
544
        Lower bound of the uniform distribution
545

546
    high : float : default 1.0
547
        Upper bound of the uniform distribution
548

549
    random_state : numpy.RandomState
550
        private pseudorandom number generator
551

552
    params : Dict[param keyword: param value] : default None
553
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
554
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
555
        arguments of the constructor.
556

557
    owner : Component
558
        `component <Component>` to which to assign the Function.
559

560
    name : str : default see `name <Function.name>`
561
        specifies the name of the Function.
562

563
    prefs : PreferenceSet or specification dict : default Function.classPreferences
564
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
565

566
    """
567
    componentName = UNIFORM_DIST_FUNCTION
1✔
568
    _model_spec_generic_type_name = 'onnx::RandomUniform'
1✔
569

570
    class Parameters(DistributionFunction.Parameters):
1✔
571
        """
572
            Attributes
573
            ----------
574

575
                high
576
                    see `high <UniformDist.high>`
577

578
                    :default value: 1.0
579
                    :type: ``float``
580

581
                low
582
                    see `low <UniformDist.low>`
583

584
                    :default value: 0.0
585
                    :type: ``float``
586

587
                random_state
588
                    see `random_state <UniformDist.random_state>`
589

590
                    :default value: None
591
                    :type: ``numpy.random.RandomState``
592
        """
593
        low = Parameter(0.0, modulable=True)
1✔
594
        high = Parameter(1.0, modulable=True)
1✔
595
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
596
        seed = Parameter(DEFAULT_SEED(), modulable=True, fallback_value=DEFAULT, setter=_seed_setter)
1✔
597

598
    @check_user_specified
1✔
599
    @beartype
1✔
600
    def __init__(self,
1✔
601
                 default_variable=None,
602
                 low=None,
603
                 high=None,
604
                 seed=None,
605
                 params=None,
606
                 owner=None,
607
                 prefs:  Optional[ValidPrefSet] = None):
608

609
        super().__init__(
1✔
610
            default_variable=default_variable,
611
            low=low,
612
            high=high,
613
            seed=seed,
614
            params=params,
615
            owner=owner,
616
            prefs=prefs,
617
        )
618

619
    def _function(self,
1✔
620
                 variable=None,
621
                 context=None,
622
                 params=None,
623
                 ):
624

625
        random_state = self._get_current_parameter_value('random_state', context)
1✔
626
        low = self._get_current_parameter_value(LOW, context)
1✔
627
        high = self._get_current_parameter_value(HIGH, context)
1✔
628
        result = random_state.uniform(low, high)
1✔
629

630
        return self.convert_output_type(result)
1✔
631

632
    def _gen_llvm_function_body(self, ctx, builder, params, state, _, arg_out, *, tags:frozenset):
1✔
633
        random_state = ctx.get_random_state_ptr(builder, self, state, params)
1✔
634
        low_ptr = ctx.get_param_or_state_ptr(builder, self, LOW, param_struct_ptr=params)
1✔
635
        high_ptr = ctx.get_param_or_state_ptr(builder, self, HIGH, param_struct_ptr=params)
1✔
636
        ret_val_ptr = builder.alloca(ctx.float_ty)
1✔
637
        norm_rand_f = ctx.get_uniform_dist_function_by_state(random_state)
1✔
638
        builder.call(norm_rand_f, [random_state, ret_val_ptr])
1✔
639

640
        ret_val = builder.load(ret_val_ptr)
1✔
641
        high = pnlvm.helpers.load_extract_scalar_array_one(builder, high_ptr)
1✔
642
        low = pnlvm.helpers.load_extract_scalar_array_one(builder, low_ptr)
1✔
643
        scale = builder.fsub(high, low)
1✔
644

645
        ret_val = builder.fmul(ret_val, scale)
1✔
646
        ret_val = builder.fadd(ret_val, low)
1✔
647

648
        while isinstance(arg_out.type.pointee, pnlvm.ir.ArrayType):
1✔
649
            assert len(arg_out.type.pointee) == 1
1✔
650
            arg_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
651
        builder.store(ret_val, arg_out)
1✔
652
        return builder
1✔
653

654

655
class GammaDist(DistributionFunction):
1✔
656
    """
657
    GammaDist(\
658
             scale=1.0,\
659
             dist_shape=1.0,\
660
             params=None,\
661
             owner=None,\
662
             prefs=None\
663
             )
664

665
    .. _GammaDist:
666

667
    Return a random sample from a gamma distribution using numpy.random.gamma
668

669
    *Modulatory Parameters:*
670

671
    | *MULTIPLICATIVE_PARAM:* `scale <GammaDist.scale>`
672
    | *ADDITIVE_PARAM:* `dist_shape <GammaDist.dist_shape>`
673
    |
674

675
    Arguments
676
    ---------
677

678
    scale : float : default 1.0
679
        The scale of the gamma distribution. Should be greater than zero.
680

681
    dist_shape : float : default 1.0
682
        The shape of the gamma distribution. Should be greater than zero.
683

684
    params : Dict[param keyword: param value] : default None
685
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
686
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
687
        arguments of the constructor.
688

689
    owner : Component
690
        `component <Component>` to which to assign the Function.
691

692
    name : str : default see `name <Function.name>`
693
        specifies the name of the Function.
694

695
    prefs : PreferenceSet or specification dict : default Function.classPreferences
696
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
697

698
    Attributes
699
    ----------
700

701
    scale : float : default 1.0
702
        The scale of the gamma distribution. Should be greater than zero.
703

704
    dist_shape : float : default 1.0
705
        The shape of the gamma distribution. Should be greater than zero.
706

707
    random_state : numpy.RandomState
708
        private pseudorandom number generator
709

710
    params : Dict[param keyword: param value] : default None
711
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
712
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
713
        arguments of the constructor.
714

715
    owner : Component
716
        `component <Component>` to which to assign the Function.
717

718
    name : str : default see `name <Function.name>`
719
        specifies the name of the Function.
720

721
    prefs : PreferenceSet or specification dict : default Function.classPreferences
722
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
723

724
    """
725

726
    componentName = GAMMA_DIST_FUNCTION
1✔
727

728
    class Parameters(DistributionFunction.Parameters):
1✔
729
        """
730
            Attributes
731
            ----------
732

733
                dist_shape
734
                    see `dist_shape <GammaDist.dist_shape>`
735

736
                    :default value: 1.0
737
                    :type: ``float``
738

739
                random_state
740
                    see `random_state <GammaDist.random_state>`
741

742
                    :default value: None
743
                    :type: ``numpy.random.RandomState``
744

745
                scale
746
                    see `scale <GammaDist.scale>`
747

748
                    :default value: 1.0
749
                    :type: ``float``
750
        """
751
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
752
        seed = Parameter(DEFAULT_SEED(), modulable=True, fallback_value=DEFAULT, setter=_seed_setter)
1✔
753
        scale = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
754
        dist_shape = Parameter(1.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
755

756
    @check_user_specified
1✔
757
    @beartype
1✔
758
    def __init__(self,
1✔
759
                 default_variable=None,
760
                 scale=None,
761
                 dist_shape=None,
762
                 seed=None,
763
                 params=None,
764
                 owner=None,
765
                 prefs:  Optional[ValidPrefSet] = None):
766

767
        super().__init__(
1✔
768
            default_variable=default_variable,
769
            scale=scale,
770
            dist_shape=dist_shape,
771
            seed=seed,
772
            params=params,
773
            owner=owner,
774
            prefs=prefs,
775
        )
776

777
    def _function(self,
1✔
778
                 variable=None,
779
                 context=None,
780
                 params=None,
781
                 ):
782

783
        random_state = self._get_current_parameter_value('random_state', context)
1✔
784
        scale = self._get_current_parameter_value(SCALE, context)
1✔
785
        dist_shape = self._get_current_parameter_value(DIST_SHAPE, context)
1✔
786

787
        result = random_state.gamma(dist_shape, scale)
1✔
788

789
        return self.convert_output_type(result)
1✔
790

791

792
class WaldDist(DistributionFunction):
1✔
793
    """
794
     WaldDist(             \
795
              scale=1.0,\
796
              mean=1.0,\
797
              params=None,\
798
              owner=None,\
799
              prefs=None\
800
              )
801

802
     .. _WaldDist:
803

804
     Return a random sample from a Wald distribution using numpy.random.wald
805

806
    *Modulatory Parameters:*
807

808
    | *MULTIPLICATIVE_PARAM:* `scale <WaldDist.scale>`
809
    | *ADDITIVE_PARAM:* `mean <WaldDist.mean>`
810
    |
811

812
     Arguments
813
     ---------
814

815
     scale : float : default 1.0
816
         Scale parameter of the Wald distribution. Should be greater than zero.
817

818
     mean : float : default 1.0
819
         Mean of the Wald distribution. Should be greater than or equal to zero.
820

821
     params : Dict[param keyword: param value] : default None
822
         a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
823
         function.  Values specified for parameters in the dictionary override any assigned to those parameters in
824
         arguments of the constructor.
825

826
     owner : Component
827
         `component <Component>` to which to assign the Function.
828

829
     prefs : PreferenceSet or specification dict : default Function.classPreferences
830
         the `PreferenceSet` for the Function. If it is not specified, a default is assigned using `classPreferences`
831
         defined in __init__.py (see `Preferences` for details).
832

833

834
     Attributes
835
     ----------
836

837
      random_state : numpy.RandomState
838
          private pseudorandom number generator
839

840
     scale : float : default 1.0
841
         Scale parameter of the Wald distribution. Should be greater than zero.
842

843
     mean : float : default 1.0
844
         Mean of the Wald distribution. Should be greater than or equal to zero.
845

846
     params : Dict[param keyword: param value] : default None
847
         a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
848
         function.  Values specified for parameters in the dictionary override any assigned to those parameters in
849
         arguments of the constructor.
850

851
     owner : Component
852
         `component <Component>` to which to assign the Function.
853

854
     prefs : PreferenceSet or specification dict : default Function.classPreferences
855
         the `PreferenceSet` for the Function. If it is not specified, a default is assigned using `classPreferences`
856
         defined in __init__.py (see `Preferences` for details).
857

858

859
     """
860

861
    componentName = WALD_DIST_FUNCTION
1✔
862

863
    class Parameters(DistributionFunction.Parameters):
1✔
864
        """
865
            Attributes
866
            ----------
867

868
                random_state
869
                    see `random_state <WaldDist.random_state>`
870

871
                    :default value: None
872
                    :type: ``numpy.random.RandomState``
873

874
                mean
875
                    see `mean <WaldDist.mean>`
876

877
                    :default value: 1.0
878
                    :type: ``float``
879

880
                scale
881
                    see `scale <WaldDist.scale>`
882

883
                    :default value: 1.0
884
                    :type: ``float``
885
        """
886
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
887
        seed = Parameter(DEFAULT_SEED(), modulable=True, fallback_value=DEFAULT, setter=_seed_setter)
1✔
888
        scale = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
889
        mean = Parameter(1.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
890

891
    @check_user_specified
1✔
892
    @beartype
1✔
893
    def __init__(self,
1✔
894
                 default_variable=None,
895
                 scale=None,
896
                 mean=None,
897
                 seed=None,
898
                 params=None,
899
                 owner=None,
900
                 prefs:  Optional[ValidPrefSet] = None):
901

902
        super().__init__(
1✔
903
            default_variable=default_variable,
904
            scale=scale,
905
            seed=seed,
906
            mean=mean,
907
            params=params,
908
            owner=owner,
909
            prefs=prefs,
910
        )
911

912
    def _function(self,
1✔
913
                 variable=None,
914
                 context=None,
915
                 params=None,
916
                 ):
917

918
        random_state = self._get_current_parameter_value('random_state', context)
1✔
919
        scale = self._get_current_parameter_value(SCALE, context)
1✔
920
        mean = self._get_current_parameter_value(DIST_MEAN, context)
1✔
921

922
        result = random_state.wald(mean, scale)
1✔
923

924
        return self.convert_output_type(result)
1✔
925

926

927
# Note:  For any of these that correspond to args, value must match the name of the corresponding arg in __init__()
928
DRIFT_RATE = 'drift_rate'
1✔
929
DRIFT_RATE_VARIABILITY = 'DDM_DriftRateVariability'
1✔
930
THRESHOLD_VARIABILITY = 'DDM_ThresholdRateVariability'
1✔
931
STARTING_VALUE = 'starting_value'
1✔
932
STARTING_VALUE_VARIABILITY = "DDM_StartingPointVariability"
1✔
933
NON_DECISION_TIME = 'non_decision_time'
1✔
934

935

936
def _DriftDiffusionAnalytical_bias_getter(owning_component=None, context=None):
1✔
937
    starting_value = owning_component.parameters.starting_value._get(context)
1✔
938
    threshold = owning_component.parameters.threshold._get(context)
1✔
939
    try:
1✔
940
        return np.asarray((starting_value + threshold) / (2 * threshold))
1✔
UNCOV
941
    except TypeError:
×
UNCOV
942
        return None
×
943

944

945
# QUESTION: IF VARIABLE IS AN ARRAY, DOES IT RETURN AN ARRAY FOR EACH RETURN VALUE (RT, ER, ETC.)
946
class DriftDiffusionAnalytical(DistributionFunction):  # -------------------------------------------------------------------------------
1✔
947
    """
948
    DriftDiffusionAnalytical(   \
949
        default_variable=None,  \
950
        drift_rate=1.0,         \
951
        threshold=1.0,          \
952
        starting_value=0.0,     \
953
        non_decision_time=0.2   \
954
        noise=0.5,              \
955
        params=None,            \
956
        owner=None,             \
957
        prefs=None              \
958
        )
959

960
    .. _DriftDiffusionAnalytical:
961

962
    Return terminal value of decision variable, mean accuracy, and mean response time computed analytically for the
963
    drift diffusion process as described in `Bogacz et al (2006) <https://www.ncbi.nlm.nih.gov/pubmed/17014301>`_.
964

965
    *Modulatory Parameters:*
966

967
    | *MULTIPLICATIVE_PARAM:* `drift_rate <DriftDiffusionAnalytical.drift_rate>`
968
    | *ADDITIVE_PARAM:* `starting_value <DriftDiffusionAnalytical.starting_value>`
969
    |
970

971
    Arguments
972
    ---------
973

974
    default_variable : number, list or array : default class_defaults.variable
975
        specifies a template for decision variable(s);  if it is list or array, a separate solution is computed
976
        independently for each element.
977

978
    drift_rate : float, list or 1d array : default 1.0
979
        specifies the drift_rate of the drift diffusion process.  If it is a list or array,
980
        it must be the same length as `default_variable <Component_Variable>`.
981

982
    threshold : float, list or 1d array : default 1.0
983
        specifies the threshold (boundary) of the drift diffusion process.  If it is a list or array,
984
        it must be the same length as `default_variable <Component_Variable>`.
985

986
    starting_value : float, list or 1d array : default 1.0
987
        specifies the initial value of the decision variable for the drift diffusion process.  If it is a list or
988
        array, it must be the same length as `default_variable <Component_Variable>`.
989

990
    noise : float, list or 1d array : default 0.0
991
        specifies the noise term (corresponding to the diffusion component) of the drift diffusion process.
992
        If it is a float, it must be a number from 0 to 1.  If it is a list or array, it must be the same length as
993
        `default_variable <Component_Variable>` and all elements must be floats from 0 to 1.
994

995
    non_decision_time : float, list or 1d array : default 0.2
996
        specifies the non-decision time for solution. If it is a float, it must be a number from 0 to 1.  If it is a
997
        list or array, it must be the same length as  `default_variable <Component_Variable>` and all
998
        elements must be floats from 0 to 1.
999

1000
    params : Dict[param keyword: param value] : default None
1001
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1002
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1003
        arguments of the constructor.
1004

1005
    owner : Component
1006
        `component <Component>` to which to assign the Function.
1007

1008
    name : str : default see `name <Function.name>`
1009
        specifies the name of the Function.
1010

1011
    prefs : PreferenceSet or specification dict : default Function.classPreferences
1012
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
1013

1014
    shenhav_et_al_compat_mode: bool : default False
1015
        whether Shenhav et al. compatibility mode is set. See shenhav_et_al_compat_mode property.
1016

1017

1018
    Attributes
1019
    ----------
1020

1021
    variable : number or 1d array
1022
        holds initial value assigned to :keyword:`default_variable` argument;
1023
        ignored by `function <BogaczEtal.function>`.
1024

1025
    drift_rate : float or 1d array
1026
        determines the drift component of the drift diffusion process.
1027

1028
    threshold : float or 1d array
1029
        determines the threshold (boundary) of the drift diffusion process (i.e., at which the integration
1030
        process is assumed to terminate).
1031

1032
    starting_value : float or 1d array
1033
        determines the initial value of the decision variable for the drift diffusion process.
1034

1035
    noise : float or 1d array
1036
        determines the diffusion component of the drift diffusion process (used to specify the variance of a
1037
        Gaussian random process).
1038

1039
    non_decision_time : float or 1d array
1040
        determines the assumed non-decision time to determine the response time returned by the solution.
1041

1042
    bias : float or 1d array
1043
        normalized starting point:
1044
        (`starting_value <DriftDiffusionAnalytical.starting_value>` + `threshold <DriftDiffusionAnalytical.threshold>`) /
1045
        (2 * `threshold <DriftDiffusionAnalytical.threshold>`)
1046

1047
    owner : Component
1048
        `component <Component>` to which the Function has been assigned.
1049

1050
    name : str
1051
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
1052
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
1053

1054
    prefs : PreferenceSet or specification dict : Function.classPreferences
1055
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
1056
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
1057
        for details).
1058

1059
    """
1060

1061
    componentName = DRIFT_DIFFUSION_ANALYTICAL_FUNCTION
1✔
1062

1063
    class Parameters(DistributionFunction.Parameters):
1✔
1064
        """
1065
            Attributes
1066
            ----------
1067

1068
                bias
1069
                    see `bias <DriftDiffusionAnalytical.bias>`
1070

1071
                    :default value: 0.5
1072
                    :type: ``float``
1073
                    :read only: True
1074

1075
                drift_rate
1076
                    see `drift_rate <DriftDiffusionAnalytical.drift_rate>`
1077

1078
                    :default value: 1.0
1079
                    :type: ``float``
1080

1081
                enable_output_type_conversion
1082
                    see `enable_output_type_conversion <DriftDiffusionAnalytical.enable_output_type_conversion>`
1083

1084
                    :default value: False
1085
                    :type: ``bool``
1086
                    :read only: True
1087

1088
                noise
1089
                    see `noise <DriftDiffusionAnalytical.noise>`
1090

1091
                    :default value: 0.5
1092
                    :type: ``float``
1093

1094
                starting_value
1095
                    see `starting_value <DriftDiffusionAnalytical.starting_value>`
1096

1097
                    :default value: 0.0
1098
                    :type: ``float``
1099

1100
                non_decision_time
1101
                    see `non_decision_time <DriftDiffusionAnalytical.non_decision_time>`
1102

1103
                    :default value: 0.2
1104
                    :type: ``float``
1105

1106
                threshold
1107
                    see `threshold <DriftDiffusionAnalytical.threshold>`
1108

1109
                    :default value: 1.0
1110
                    :type: ``float``
1111
        """
1112
        drift_rate = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
1113
        starting_value = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
1114
        threshold = Parameter(1.0, modulable=True)
1✔
1115
        noise = Parameter(0.5, modulable=True, setter=_noise_setter)
1✔
1116
        non_decision_time = Parameter(.200, modulable=True)
1✔
1117
        bias = Parameter(
1✔
1118
            0.5,
1119
            read_only=True,
1120
            getter=_DriftDiffusionAnalytical_bias_getter,
1121
            dependencies=['starting_value', 'threshold']
1122
        )
1123
        # this is read only because conversion is disabled for this function
1124
        # this occurs in other places as well
1125
        enable_output_type_conversion = Parameter(
1✔
1126
            False,
1127
            stateful=False,
1128
            loggable=False,
1129
            pnl_internal=True,
1130
            read_only=True
1131
        )
1132

1133
    @check_user_specified
1✔
1134
    @beartype
1✔
1135
    def __init__(self,
1✔
1136
                 default_variable=None,
1137
                 drift_rate: Optional[ValidParamSpecType] = None,
1138
                 starting_value: Optional[ValidParamSpecType] = None,
1139
                 threshold: Optional[ValidParamSpecType] = None,
1140
                 noise: Optional[ValidParamSpecType] = None,
1141
                 non_decision_time: Optional[ValidParamSpecType] = None,
1142
                 params=None,
1143
                 owner=None,
1144
                 prefs:  Optional[ValidPrefSet] = None,
1145
                 shenhav_et_al_compat_mode=False):
1146

1147
        self._shenhav_et_al_compat_mode = shenhav_et_al_compat_mode
1✔
1148

1149
        super().__init__(
1✔
1150
            default_variable=default_variable,
1151
            drift_rate=drift_rate,
1152
            starting_value=starting_value,
1153
            threshold=threshold,
1154
            noise=noise,
1155
            non_decision_time=non_decision_time,
1156
            params=params,
1157
            owner=owner,
1158
            prefs=prefs,
1159
        )
1160

1161
    @property
1✔
1162
    def shenhav_et_al_compat_mode(self):
1✔
1163
        """
1164
        Get whether the function is set to Shenhav et al. compatibility mode. This mode allows
1165
        the analytic computations of mean error rate and reaction time to match exactly the
1166
        computations made in the MATLAB DDM code (Matlab/ddmSimFRG.m). These compatibility changes
1167
        should only effect edges cases that involve the following cases:
1168

1169
            - Floating point overflows and underflows are ignored when computing mean RT and mean ER
1170
            - Exponential expressions used in cacluating mean RT and mean ER are bounded by 1e-12 to 1e12.
1171
            - Decision time is not permitted to be negative and will be set to 0 in these cases. Thus RT
1172
              will be RT = non-decision-time in these cases.
1173

1174
        Returns
1175
        -------
1176
        Shenhav et al. compatible mode setting : (bool)
1177

1178
        """
1179
        return self._shenhav_et_al_compat_mode
1✔
1180

1181
    @shenhav_et_al_compat_mode.setter
1✔
1182
    def shenhav_et_al_compat_mode(self, value):
1✔
1183
        """
1184
        Set whether the function is set to Shenhav et al. compatibility mode. This mode allows
1185
        the analytic computations of mean error rate and reaction time to match exactly the
1186
        computations made in the MATLAB DDM code (Matlab/ddmSimFRG.m). These compatibility chages
1187
        should only effect edges cases that involve the following cases:
1188

1189
            - Floating point overflows and underflows are ignored when computing mean RT and mean ER
1190
            - Exponential expressions used in cacluating mean RT and mean ER are bounded by 1e-12 to 1e12.
1191
            - Decision time is not permitted to be negative and will be set to 0 in these cases. Thus RT
1192
              will be RT = non-decision-time in these cases.
1193

1194
        Arguments
1195
        ---------
1196

1197
        value : bool
1198
            Set True to turn on Shenhav et al. compatibility mode, False for off.
1199
        """
1200
        self._shenhav_et_al_compat_mode = value
×
1201

1202
    def _function(self,
1✔
1203
                 variable=None,
1204
                 context=None,
1205
                 params=None,
1206
                 ):
1207
        """
1208
        Return: terminal value of decision variable (equal to threshold), mean accuracy (error rate; ER) and mean
1209
        response time (RT)
1210

1211
        Arguments
1212
        ---------
1213

1214
        variable : 2d array
1215
            ignored.
1216

1217
        params : Dict[param keyword: param value] : default None
1218
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1219
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1220
            arguments of the constructor.
1221

1222

1223
        Returns
1224
        -------
1225
        Decision variable, mean ER, mean RT : (float, float, float)
1226

1227
        """
1228

1229
        attentional_drift_rate = self._get_current_parameter_value(DRIFT_RATE, context).item()
1✔
1230
        stimulus_drift_rate = variable.item()
1✔
1231
        drift_rate = attentional_drift_rate * stimulus_drift_rate
1✔
1232
        threshold = self._get_current_parameter_value(THRESHOLD, context)
1✔
1233
        starting_value = self._get_current_parameter_value(STARTING_VALUE, context).item()
1✔
1234
        noise = self._get_current_parameter_value(NOISE, context).item()
1✔
1235
        non_decision_time = self._get_current_parameter_value(NON_DECISION_TIME, context).item()
1✔
1236

1237
        # drift_rate = float(self.drift_rate) * float(variable)
1238
        # threshold = float(self.threshold)
1239
        # starting_value = float(self.starting_value)
1240
        # noise = float(self.noise)
1241
        # non_decision_time = float(self.non_decision_time)
1242

1243
        bias = (starting_value + threshold) / (2 * threshold)
1✔
1244

1245
        # Prevents div by 0 issue below:
1246
        if bias <= 0:
1!
1247
            bias = 1e-8
×
1248
        if bias >= 1:
1✔
1249
            bias = 1 - 1e-8
1✔
1250

1251
        # drift_rate close to or at 0 (avoid float comparison)
1252
        if np.abs(drift_rate) < 1e-8:
1✔
1253
            # back to absolute bias in order to apply limit
1254
            bias_abs = bias * 2 * threshold - threshold
1✔
1255
            # use expression for limit a->0 from Srivastava et al. 2016
1256
            rt = non_decision_time + (threshold ** 2 - bias_abs ** 2) / (noise ** 2)
1✔
1257
            er = (threshold - bias_abs) / (2 * threshold)
1✔
1258
        else:
1259
            drift_rate_normed = np.abs(drift_rate)
1✔
1260
            ztilde = threshold / drift_rate_normed
1✔
1261
            atilde = (drift_rate_normed / noise) ** 2
1✔
1262

1263
            is_neg_drift = drift_rate < 0
1✔
1264
            bias_adj = (is_neg_drift == 1) * (1 - bias) + (is_neg_drift == 0) * bias
1✔
1265
            y0tilde = ((noise ** 2) / 2) * np.log(bias_adj / (1 - bias_adj))
1✔
1266
            if np.abs(y0tilde) > threshold:
1✔
1267
                # First difference between Shenhav et al. DDM code and PNL's.
1268
                if self.shenhav_et_al_compat_mode:
1✔
1269
                    y0tilde = -1 * (y0tilde < 0) * threshold + (y0tilde >=0 ) * threshold
1✔
1270
                else:
1271
                    y0tilde = -1 * (is_neg_drift == 1) * threshold + (is_neg_drift == 0) * threshold
1✔
1272

1273
            x0tilde = y0tilde / drift_rate_normed
1✔
1274

1275
            # Whether we should ignore or raise floating point over and underflow exceptions.
1276
            # Shenhav et al. MATLAB code ignores them.
1277
            ignore_or_raise = "raise"
1✔
1278
            if self.shenhav_et_al_compat_mode:
1✔
1279
                ignore_or_raise = "ignore"
1✔
1280

1281
            with np.errstate(over=ignore_or_raise, under=ignore_or_raise):
1✔
1282
                try:
1✔
1283
                    # Lets precompute these common sub-expressions
1284
                    exp_neg2_x0tilde_atilde = np.exp(-2 * x0tilde * atilde)
1✔
1285
                    exp_2_ztilde_atilde = np.exp(2 * ztilde * atilde)
1✔
1286
                    exp_neg2_ztilde_atilde = np.exp(-2 * ztilde * atilde)
1✔
1287

1288
                    if self.shenhav_et_al_compat_mode:
1✔
1289
                        exp_neg2_x0tilde_atilde = np.nanmax(
1✔
1290
                            convert_to_np_array([1e-12, exp_neg2_x0tilde_atilde])
1291
                        )
1292
                        exp_2_ztilde_atilde = np.nanmin(
1✔
1293
                            convert_to_np_array([1e12, exp_2_ztilde_atilde])
1294
                        )
1295
                        exp_neg2_ztilde_atilde = np.nanmax(
1✔
1296
                            convert_to_np_array([1e-12, exp_neg2_ztilde_atilde])
1297
                        )
1298

1299
                    rt = ztilde * np.tanh(ztilde * atilde) + \
1✔
1300
                         ((2 * ztilde * (1 - exp_neg2_x0tilde_atilde)) / (
1301
                                 exp_2_ztilde_atilde - exp_neg2_ztilde_atilde) - x0tilde)
1302
                    er = 1 / (1 + exp_2_ztilde_atilde) - \
1✔
1303
                         ((1 - exp_neg2_x0tilde_atilde) / (exp_2_ztilde_atilde - exp_neg2_ztilde_atilde))
1304

1305
                    # Fail safe to prevent negative mean RT's. Shenhav et al. do this.
1306
                    if self.shenhav_et_al_compat_mode:
1✔
1307
                        if rt < 0:
1✔
1308
                            rt = 0
1✔
1309

1310
                    rt = rt + non_decision_time
1✔
1311

1312
                except FloatingPointError:
1✔
1313
                    # Per Mike Shvartsman:
1314
                    # If ±2*ztilde*atilde (~ 2*z*a/(c^2) gets very large, the diffusion vanishes relative to drift
1315
                    # and the problem is near-deterministic. Without diffusion, error rate goes to 0 or 1
1316
                    # depending on the sign of the drift, and so decision time goes to a point mass on z/a – x0, and
1317
                    # generates a "RuntimeWarning: overflow encountered in exp"
1318
                    er = 0
1✔
1319
                    rt = ztilde / atilde - x0tilde + non_decision_time
1✔
1320

1321
            # This last line makes it report back in terms of a fixed reference point
1322
            #    (i.e., closer to 1 always means higher p(upper boundary))
1323
            # If you comment this out it will report errors in the reference frame of the drift rate
1324
            #    (i.e., reports p(upper) if drift is positive, and p(lower if drift is negative)
1325
            er = (is_neg_drift == 1) * (1 - er) + (is_neg_drift == 0) * (er)
1✔
1326

1327
        # Compute moments (mean, variance, skew) of condiational response time distributions
1328
        moments = DriftDiffusionAnalytical._compute_conditional_rt_moments(drift_rate, noise, threshold, bias, non_decision_time)
1✔
1329

1330
        return convert_all_elements_to_np_array([
1✔
1331
            rt, er,
1332
            moments['mean_rt_plus'], moments['var_rt_plus'], moments['skew_rt_plus'],
1333
            moments['mean_rt_minus'], moments['var_rt_minus'], moments['skew_rt_minus']
1334
        ])
1335

1336
    @staticmethod
1✔
1337
    def _compute_conditional_rt_moments(drift_rate, noise, threshold, starting_value, non_decision_time):
1✔
1338
        """
1339
        This is a helper function for computing the conditional decison time moments for the DDM.
1340
        It is based completely off of Matlab\\DDMFunctions\\ddm_metrics_cond_Mat.m.
1341

1342
        :param drift_rate: The drift rate of the DDM
1343
        :param noise: The diffusion rate.
1344
        :param threshold: The symmetric threshold of the DDM
1345
        :param starting_value: The initial condition.
1346
        :param non_decision_time: The non decision time.
1347
        :return: A dictionary containing the following key value pairs:
1348
         mean_rt_plus: The mean RT of positive responses.
1349
         mean_rt_minus: The mean RT of negative responses.
1350
         var_rt_plus: The variance of RT of positive responses.
1351
         var_rt_minus: The variance of RT of negative responses.
1352
         skew_rt_plus: The skew of RT of positive responses.
1353
         skew_rt_minus: The skew of RT of negative responses.
1354
        """
1355

1356
        #  transform starting point to be centered at 0
1357
        starting_value = (starting_value - 0.5) * 2.0 * threshold
1✔
1358

1359
        if abs(drift_rate) < 0.01:
1✔
1360
            drift_rate = 0.01
1✔
1361

1362
        X = drift_rate * starting_value / noise**2
1✔
1363
        Z = drift_rate * threshold / noise**2
1✔
1364

1365
        X = max(-100, min(100, X))
1✔
1366

1367
        Z = max(-100, min(100, Z))
1✔
1368

1369
        if abs(Z) < 0.0001:
1!
1370
            Z = 0.0001
×
1371

1372
        def coth(x):
1✔
1373
            return 1 / np.tanh(x)
1✔
1374

1375
        def csch(x):
1✔
1376
            return 1 / np.sinh(x)
1✔
1377

1378
        moments = {}
1✔
1379

1380
        # Lets ignore any divide by zeros we get or NaN errors. This will allow the NaN's to propogate.
1381
        with np.errstate(divide='ignore', invalid='ignore'):
1✔
1382
            moments["mean_rt_plus"] = noise**2 / (drift_rate**2) * (2 * Z * coth(2 * Z) - (X + Z) * coth(X + Z))
1✔
1383

1384
            moments["mean_rt_minus"] = noise**2 / (drift_rate**2) * (2 * Z * coth(2 * Z) - (-X + Z) * coth(-X + Z))
1✔
1385

1386
            moments["var_rt_plus"] = noise**4 / (drift_rate**4) * \
1✔
1387
                              (((2 * Z)**2 * csch(2 * Z)**2 -
1388
                                (Z + X)**2 * csch(Z + X)**2) +
1389
                               ((2 * Z) * coth(2 * Z) -
1390
                                (Z + X) * coth(Z + X)))
1391

1392
            moments["var_rt_minus"] = noise**4 / (drift_rate**4) * \
1✔
1393
                              (((2 * Z)**2 * csch(2 * Z)**2 -
1394
                                (Z - X)**2 * csch(Z - X)**2) +
1395
                               ((2 * Z) * coth(2 * Z) -
1396
                                (Z - X) * coth(Z - X)))
1397

1398
            moments["skew_rt_plus"] = noise**6 / (drift_rate**6) * \
1✔
1399
                               (3 * ((2 * Z)**2 * csch(2 * Z)**2 -
1400
                                     (Z + X)**2 * csch(Z + X)**2) +
1401
                                2 * ((2 * Z)**3 * coth(2 * Z) * csch(2 * Z)**2 -
1402
                                     (Z + X)**3 * coth(Z + X) * csch(Z + X)**2) +
1403
                                3 * ((2 * Z) * coth(2 * Z) -
1404
                                     (Z + X) * coth(Z + X)))
1405

1406
            moments["skew_rt_minus"] = noise**6 / (drift_rate**6) * \
1✔
1407
                               (3 * ((2 * Z)**2 * csch(2 * Z)**2 -
1408
                                     (Z - X)**2 * csch(Z - X)**2) +
1409
                                2 * ((2 * Z)**3 * coth(2 * Z) * csch(2 * Z)**2 -
1410
                                     (Z - X)**3 * coth(Z - X) * csch(Z - X)**2) +
1411
                                3 * ((2 * Z) * coth(2 * Z) -
1412
                                     (Z - X) * coth(Z - X)))
1413

1414
            # divide third central moment by var_rt**1.5 to get skewness
1415
            moments['skew_rt_plus'] /= moments['var_rt_plus']**1.5
1✔
1416
            moments['skew_rt_minus'] /= moments['var_rt_minus']**1.5
1✔
1417

1418
            # Add the non-decision time to the mean RTs
1419
            moments['mean_rt_plus'] += non_decision_time
1✔
1420
            moments['mean_rt_minus'] += non_decision_time
1✔
1421

1422

1423
        return moments
1✔
1424

1425
    def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
1✔
1426

1427
        def load_scalar_param(name):
1✔
1428
            param_ptr = ctx.get_param_or_state_ptr(builder, self, name, param_struct_ptr=params)
1✔
1429
            return pnlvm.helpers.load_extract_scalar_array_one(builder, param_ptr)
1✔
1430

1431
        attentional_drift_rate = load_scalar_param(DRIFT_RATE)
1✔
1432
        threshold = load_scalar_param(THRESHOLD)
1✔
1433
        starting_value = load_scalar_param(STARTING_VALUE)
1✔
1434
        noise = load_scalar_param(NOISE)
1✔
1435
        non_decision_time = load_scalar_param(NON_DECISION_TIME)
1✔
1436

1437
        noise_sqr = builder.fmul(noise, noise)
1✔
1438

1439
        # Arguments used in mechanisms are 2D
1440
        arg_in = pnlvm.helpers.unwrap_2d_array(builder, arg_in)
1✔
1441

1442
        stimulus_drift_rate = pnlvm.helpers.load_extract_scalar_array_one(builder, arg_in)
1✔
1443
        drift_rate = builder.fmul(attentional_drift_rate, stimulus_drift_rate)
1✔
1444

1445
        threshold_2 = builder.fmul(threshold, threshold.type(2))
1✔
1446
        bias = builder.fadd(starting_value, threshold)
1✔
1447
        bias = builder.fdiv(bias, threshold_2)
1✔
1448

1449
        bias = pnlvm.helpers.fclamp(builder, bias, 1e-8, 1 - 1e-8)
1✔
1450

1451
        def _get_arg_out_ptr(idx):
1✔
1452
            ptr = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(idx)])
1✔
1453
            if isinstance(ptr.type.pointee, pnlvm.ir.ArrayType):
1✔
1454
                assert len(ptr.type.pointee) == 1
1✔
1455
                ptr = builder.gep(ptr, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
1456
            return ptr
1✔
1457

1458
        rt_ptr = _get_arg_out_ptr(0)
1✔
1459
        er_ptr = _get_arg_out_ptr(1)
1✔
1460

1461
        abs_f = ctx.get_builtin("fabs", [bias.type])
1✔
1462
        abs_drift_rate = builder.call(abs_f, [drift_rate])
1✔
1463
        small_drift_rate = builder.fcmp_ordered("<", abs_drift_rate,
1✔
1464
                                                abs_drift_rate.type(1e-8))
1465

1466
        with builder.if_else(small_drift_rate) as (then, otherwise):
1✔
1467
            with then:
1✔
1468
                bias_abs = builder.fmul(bias, bias.type(2))
1✔
1469
                bias_abs = builder.fmul(bias_abs, threshold)
1✔
1470
                bias_abs = builder.fsub(bias_abs, threshold)
1✔
1471

1472
                bias_abs_sqr = builder.fmul(bias_abs, bias_abs)
1✔
1473
                threshold_sqr = builder.fmul(threshold, threshold)
1✔
1474
                rt = builder.fsub(threshold_sqr, bias_abs_sqr)
1✔
1475
                rt = builder.fdiv(rt, noise_sqr)
1✔
1476
                rt = builder.fadd(non_decision_time, rt)
1✔
1477
                builder.store(rt, rt_ptr)
1✔
1478

1479
                er = builder.fsub(threshold, bias_abs)
1✔
1480
                er = builder.fdiv(er, threshold_2)
1✔
1481
                builder.store(er, er_ptr)
1✔
1482
            with otherwise:
1✔
1483
                drift_rate_normed = builder.call(abs_f, [drift_rate])
1✔
1484
                ztilde = builder.fdiv(threshold, drift_rate_normed)
1✔
1485
                atilde = builder.fdiv(drift_rate_normed, noise)
1✔
1486
                atilde = builder.fmul(atilde, atilde)
1✔
1487

1488
                is_neg_drift = builder.fcmp_ordered("<", drift_rate,
1✔
1489
                                                    drift_rate.type(0))
1490
                bias_rev = builder.fsub(bias.type(1), bias)
1✔
1491
                bias_adj = builder.select(is_neg_drift, bias_rev, bias)
1✔
1492

1493
                noise_tmp = builder.fdiv(noise_sqr, noise_sqr.type(2))
1✔
1494

1495
                log_f = ctx.get_builtin("log", [bias_adj.type])
1✔
1496
                bias_tmp = builder.fsub(bias_adj.type(1), bias_adj)
1✔
1497
                bias_tmp = builder.fdiv(bias_adj, bias_tmp)
1✔
1498
                bias_log = builder.call(log_f, [bias_tmp])
1✔
1499
                y0tilde = builder.fmul(noise_tmp, bias_log)
1✔
1500

1501
                assert not self.shenhav_et_al_compat_mode
1✔
1502
                threshold_neg = pnlvm.helpers.fneg(builder, threshold)
1✔
1503
                new_y0tilde = builder.select(is_neg_drift, threshold_neg,
1✔
1504
                                                           threshold)
1505
                abs_y0tilde = builder.call(abs_f, [y0tilde])
1✔
1506
                abs_y0tilde_above_threshold = \
1✔
1507
                    builder.fcmp_ordered(">", abs_y0tilde, threshold)
1508
                y0tilde = builder.select(abs_y0tilde_above_threshold,
1✔
1509
                                         new_y0tilde, y0tilde)
1510

1511
                x0tilde = builder.fdiv(y0tilde, drift_rate_normed)
1✔
1512

1513
                exp_f = ctx.get_builtin("exp", [bias_adj.type])
1✔
1514
                # Pre-compute the same values as Python above
1515
                neg2_x0tilde_atilde = builder.fmul(x0tilde.type(-2), x0tilde)
1✔
1516
                neg2_x0tilde_atilde = builder.fmul(neg2_x0tilde_atilde, atilde)
1✔
1517
                exp_neg2_x0tilde_atilde = builder.call(exp_f, [neg2_x0tilde_atilde])
1✔
1518

1519
                n2_ztilde_atilde = builder.fmul(ztilde.type(2), ztilde)
1✔
1520
                n2_ztilde_atilde = builder.fmul(n2_ztilde_atilde, atilde)
1✔
1521
                exp_2_ztilde_atilde = builder.call(exp_f, [n2_ztilde_atilde])
1✔
1522

1523
                neg2_ztilde_atilde = builder.fmul(ztilde.type(-2), ztilde)
1✔
1524
                neg2_ztilde_atilde = builder.fmul(neg2_ztilde_atilde, atilde)
1✔
1525
                exp_neg2_ztilde_atilde = builder.call(exp_f, [neg2_ztilde_atilde])
1✔
1526
                # The final computation er
1527
                er_tmp1 = builder.fadd(exp_2_ztilde_atilde.type(1),
1✔
1528
                                       exp_2_ztilde_atilde)
1529
                er_tmp1 = builder.fdiv(er_tmp1.type(1), er_tmp1)
1✔
1530
                er_tmp2 = builder.fsub(exp_neg2_x0tilde_atilde.type(1),
1✔
1531
                                       exp_neg2_x0tilde_atilde)
1532
                er_tmp3 = builder.fsub(exp_2_ztilde_atilde,
1✔
1533
                                       exp_neg2_ztilde_atilde)
1534
                er_tmp = builder.fdiv(er_tmp2, er_tmp3)
1✔
1535
                er = builder.fsub(er_tmp1, er_tmp)
1✔
1536
                comp_er = builder.fsub(er.type(1), er)
1✔
1537
                er = builder.select(is_neg_drift, comp_er, er)
1✔
1538
                builder.store(er, er_ptr)
1✔
1539

1540
                # The final computation rt
1541
                rt_tmp0 = builder.fmul(ztilde, atilde)
1✔
1542
                rt_tmp0 = pnlvm.helpers.tanh(ctx, builder, rt_tmp0)
1✔
1543
                rt_tmp0 = builder.fmul(ztilde, rt_tmp0)
1✔
1544

1545
                rt_tmp1a = builder.fmul(ztilde.type(2), ztilde)
1✔
1546
                rt_tmp1b = builder.fsub(exp_neg2_x0tilde_atilde.type(1),
1✔
1547
                                       exp_neg2_x0tilde_atilde)
1548
                rt_tmp1 = builder.fmul(rt_tmp1a, rt_tmp1b)
1✔
1549

1550
                rt_tmp2 = builder.fsub(exp_2_ztilde_atilde,
1✔
1551
                                       exp_neg2_ztilde_atilde)
1552

1553
                rt = builder.fdiv(rt_tmp1, rt_tmp2)
1✔
1554
                rt = builder.fsub(rt, x0tilde)
1✔
1555
                rt = builder.fadd(rt_tmp0, rt)
1✔
1556
                rt = builder.fadd(rt, non_decision_time)
1✔
1557
                builder.store(rt, rt_ptr)
1✔
1558

1559
        # Calculate moments
1560
        mean_rt_plus_ptr = _get_arg_out_ptr(2)
1✔
1561
        var_rt_plus_ptr = _get_arg_out_ptr(3)
1✔
1562
        skew_rt_plus_ptr = _get_arg_out_ptr(4)
1✔
1563
        mean_rt_minus_ptr = _get_arg_out_ptr(5)
1✔
1564
        var_rt_minus_ptr = _get_arg_out_ptr(6)
1✔
1565
        skew_rt_minus_ptr = _get_arg_out_ptr(7)
1✔
1566

1567
        # Transform starting point to be centered at 0
1568
        starting_value = bias
1✔
1569
        starting_value = builder.fsub(starting_value, starting_value.type(0.5))
1✔
1570
        starting_value = builder.fmul(starting_value, starting_value.type(2))
1✔
1571
        starting_value = builder.fmul(starting_value, threshold)
1✔
1572

1573
        drift_rate_limit = abs_drift_rate.type(0.01)
1✔
1574
        small_drift = builder.fcmp_ordered("<", abs_drift_rate, drift_rate_limit)
1✔
1575
        drift_rate = builder.select(small_drift, drift_rate_limit, drift_rate)
1✔
1576

1577
        X = builder.fmul(drift_rate, starting_value)
1✔
1578
        X = builder.fdiv(X, noise_sqr)
1✔
1579
        X = pnlvm.helpers.fclamp(builder, X, X.type(-100), X.type(100))
1✔
1580

1581
        Z = builder.fmul(drift_rate, threshold)
1✔
1582
        Z = builder.fdiv(Z, noise_sqr)
1✔
1583
        Z = pnlvm.helpers.fclamp(builder, Z, Z.type(-100), Z.type(100))
1✔
1584

1585
        abs_Z = builder.call(abs_f, [Z])
1✔
1586
        tiny_Z = builder.fcmp_ordered("<", abs_Z, Z.type(0.0001))
1✔
1587
        Z = builder.select(tiny_Z, Z.type(0.0001), Z)
1✔
1588

1589
        # Mean helpers
1590
        drift_rate_sqr = builder.fmul(drift_rate, drift_rate)
1✔
1591
        Z2 = builder.fmul(Z, Z.type(2))
1✔
1592
        coth_Z2 = pnlvm.helpers.coth(ctx, builder, Z2)
1✔
1593
        Z2_coth_Z2 = builder.fmul(Z2, coth_Z2)
1✔
1594
        ZpX = builder.fadd(Z, X)
1✔
1595
        coth_ZpX = pnlvm.helpers.coth(ctx, builder, ZpX)
1✔
1596
        ZpX_coth_ZpX = builder.fmul(ZpX, coth_ZpX)
1✔
1597
        ZmX = builder.fsub(Z, X)
1✔
1598
        coth_ZmX = pnlvm.helpers.coth(ctx, builder, ZmX)
1✔
1599
        ZmX_coth_ZmX = builder.fmul(ZmX, coth_ZmX)
1✔
1600

1601
        # Mean plus
1602
        mrtp_tmp = builder.fsub(Z2_coth_Z2, ZpX_coth_ZpX)
1✔
1603
        m_rt_p = builder.fdiv(noise_sqr, drift_rate_sqr)
1✔
1604
        m_rt_p = builder.fmul(m_rt_p, mrtp_tmp)
1✔
1605
        m_rt_p = builder.fadd(m_rt_p, non_decision_time)
1✔
1606
        builder.store(m_rt_p, mean_rt_plus_ptr)
1✔
1607

1608
        # Mean minus
1609
        mrtm_tmp = builder.fsub(Z2_coth_Z2, ZmX_coth_ZmX)
1✔
1610
        m_rt_m = builder.fdiv(noise_sqr, drift_rate_sqr)
1✔
1611
        m_rt_m = builder.fmul(m_rt_m, mrtm_tmp)
1✔
1612
        m_rt_m = builder.fadd(m_rt_m, non_decision_time)
1✔
1613
        builder.store(m_rt_m, mean_rt_minus_ptr)
1✔
1614

1615
        # Variance helpers
1616
        noise_q = builder.fmul(noise_sqr, noise_sqr)
1✔
1617
        drift_rate_q = builder.fmul(drift_rate_sqr, drift_rate_sqr)
1✔
1618
        noise_q_drift_q = builder.fdiv(noise_q, drift_rate_q)
1✔
1619

1620
        Z2_sqr = builder.fmul(Z2, Z2)
1✔
1621
        csch_Z2 = pnlvm.helpers.csch(ctx, builder, Z2)
1✔
1622
        csch_Z2_sqr = builder.fmul(csch_Z2, csch_Z2)
1✔
1623
        Z2_sqr_csch_Z2_sqr = builder.fmul(Z2_sqr, csch_Z2_sqr)
1✔
1624

1625
        ZpX_sqr = builder.fmul(ZpX, ZpX)
1✔
1626
        csch_ZpX = pnlvm.helpers.csch(ctx, builder, ZpX)
1✔
1627
        csch_ZpX_sqr = builder.fmul(csch_ZpX, csch_ZpX)
1✔
1628
        ZpX_sqr_csch_ZpX_sqr = builder.fmul(ZpX_sqr, csch_ZpX_sqr)
1✔
1629

1630
        ZmX_sqr = builder.fmul(ZmX, ZmX)
1✔
1631
        csch_ZmX = pnlvm.helpers.csch(ctx, builder, ZmX)
1✔
1632
        csch_ZmX_sqr = builder.fmul(csch_ZmX, csch_ZmX)
1✔
1633
        ZmX_sqr_csch_ZmX_sqr = builder.fmul(ZmX_sqr, csch_ZmX_sqr)
1✔
1634

1635
        # Variance plus
1636
        v_rt_pA = builder.fsub(Z2_sqr_csch_Z2_sqr, ZpX_sqr_csch_ZpX_sqr)
1✔
1637
        v_rt_pB = builder.fsub(Z2_coth_Z2, ZpX_coth_ZpX)
1✔
1638
        v_rt_p = builder.fadd(v_rt_pA, v_rt_pB)
1✔
1639
        v_rt_p = builder.fmul(noise_q_drift_q, v_rt_p)
1✔
1640
        builder.store(v_rt_p, var_rt_plus_ptr)
1✔
1641

1642
        pow_f = ctx.get_builtin("pow", [v_rt_p.type, v_rt_p.type])
1✔
1643
        v_rt_p_1_5 = builder.call(pow_f, [v_rt_p, v_rt_p.type(1.5)])
1✔
1644

1645
        # Variance minus
1646
        v_rt_mA = builder.fsub(Z2_sqr_csch_Z2_sqr, ZmX_sqr_csch_ZmX_sqr)
1✔
1647
        v_rt_mB = builder.fsub(Z2_coth_Z2, ZmX_coth_ZmX)
1✔
1648
        v_rt_m = builder.fadd(v_rt_mA, v_rt_mB)
1✔
1649
        v_rt_m = builder.fmul(noise_q_drift_q, v_rt_m)
1✔
1650
        builder.store(v_rt_m, var_rt_minus_ptr)
1✔
1651

1652
        pow_f = ctx.get_builtin("pow", [v_rt_m.type, v_rt_m.type])
1✔
1653
        v_rt_m_1_5 = builder.call(pow_f, [v_rt_m, v_rt_m.type(1.5)])
1✔
1654

1655
        # Skew helpers
1656
        noise_6 = builder.fmul(noise_q, noise_sqr)
1✔
1657
        drift_rate_6 = builder.fmul(drift_rate_q, drift_rate_sqr)
1✔
1658

1659
        srt_tmp0 = builder.fdiv(noise_6, drift_rate_6)
1✔
1660

1661
        Z2_cub_coth_Z2_csch_Z2_sqr = builder.fmul(Z2_coth_Z2, Z2_sqr_csch_Z2_sqr)
1✔
1662
        ZpX_cub_coth_ZpX_csch_Z2_sqr = builder.fmul(ZpX_coth_ZpX, ZpX_sqr_csch_ZpX_sqr)
1✔
1663
        ZmX_cub_coth_ZmX_csch_Z2_sqr = builder.fmul(ZmX_coth_ZmX, ZmX_sqr_csch_ZmX_sqr)
1✔
1664

1665
        # Skew plus
1666
        s_rt_p_tmpA = builder.fsub(Z2_sqr_csch_Z2_sqr, ZpX_sqr_csch_ZpX_sqr)
1✔
1667
        s_rt_p_tmpA = builder.fmul(s_rt_p_tmpA, s_rt_p_tmpA.type(3))
1✔
1668

1669
        s_rt_p_tmpB = builder.fsub(Z2_cub_coth_Z2_csch_Z2_sqr,
1✔
1670
                                   ZpX_cub_coth_ZpX_csch_Z2_sqr)
1671
        s_rt_p_tmpB = builder.fadd(s_rt_p_tmpB, s_rt_p_tmpB)
1✔
1672

1673
        s_rt_p_tmpC = builder.fsub(Z2_coth_Z2, ZpX_coth_ZpX)
1✔
1674
        s_rt_p_tmpC = builder.fmul(s_rt_p_tmpC, s_rt_p_tmpC.type(3))
1✔
1675

1676
        s_rt_p = builder.fadd(s_rt_p_tmpA, s_rt_p_tmpB)
1✔
1677
        s_rt_p = builder.fadd(s_rt_p, s_rt_p_tmpC)
1✔
1678

1679
        s_rt_p = builder.fmul(srt_tmp0, s_rt_p)
1✔
1680
        s_rt_p = builder.fdiv(s_rt_p, v_rt_p_1_5)
1✔
1681
        builder.store(s_rt_p, skew_rt_plus_ptr)
1✔
1682

1683
        # Skew minus
1684
        s_rt_m_tmpA = builder.fsub(Z2_sqr_csch_Z2_sqr, ZmX_sqr_csch_ZmX_sqr)
1✔
1685
        s_rt_m_tmpA = builder.fmul(s_rt_m_tmpA, s_rt_m_tmpA.type(3))
1✔
1686

1687
        s_rt_m_tmpB = builder.fsub(Z2_cub_coth_Z2_csch_Z2_sqr,
1✔
1688
                                   ZmX_cub_coth_ZmX_csch_Z2_sqr)
1689
        s_rt_m_tmpB = builder.fadd(s_rt_m_tmpB, s_rt_m_tmpB)
1✔
1690

1691
        s_rt_m_tmpC = builder.fsub(Z2_coth_Z2, ZmX_coth_ZmX)
1✔
1692
        s_rt_m_tmpC = builder.fmul(s_rt_m_tmpC, s_rt_m_tmpC.type(3))
1✔
1693

1694
        s_rt_m = builder.fadd(s_rt_m_tmpA, s_rt_m_tmpB)
1✔
1695
        s_rt_m = builder.fadd(s_rt_m, s_rt_m_tmpC)
1✔
1696

1697
        s_rt_m = builder.fmul(srt_tmp0, s_rt_m)
1✔
1698
        s_rt_m = builder.fdiv(s_rt_m, v_rt_m_1_5)
1✔
1699
        builder.store(s_rt_m, skew_rt_minus_ptr)
1✔
1700

1701
        return builder
1✔
1702

1703
    def derivative(self, output=None, input=None, context=None):
1✔
1704
        """
1705
        derivative(output, input)
1706

1707
        Calculate the derivative of :math:`\\frac{1}{reward rate}` with respect to the threshold (**output** arg)
1708
        and drift_rate (**input** arg).  Reward rate (:math:`RR`) is assumed to be:
1709

1710
            :math:`RR = delay_{ITI} + \\frac{Z}{A} + ED`;
1711

1712
        the derivative of :math:`\\frac{1}{RR}` with respect to the `threshold <DriftDiffusionAnalytical.threshold>` is:
1713

1714
            :math:`\\frac{1}{A} - \\frac{E}{A} - 2\\frac{A}{c^2}ED`;
1715

1716
        and the derivative of 1/RR with respect to the `drift_rate <DriftDiffusionAnalytical.drift_rate>` is:
1717

1718
            :math:`-\\frac{Z}{A^2} + \\frac{Z}{A^2}E - \\frac{2Z}{c^2}ED`
1719

1720
        where:
1721

1722
            *A* = `drift_rate <DriftDiffusionAnalytical.drift_rate>`,
1723

1724
            *Z* = `threshold <DriftDiffusionAnalytical.threshold>`,
1725

1726
            *c* = `noise <DriftDiffusionAnalytical.noise>`,
1727

1728
            *E* = :math:`e^{-2\\frac{ZA}{c^2}}`,
1729

1730
            *D* = :math:`delay_{ITI} + delay_{penalty} - \\frac{Z}{A}`,
1731

1732
            :math:`delay_{ITI}` is the intertrial interval and :math:`delay_{penalty}` is a penalty delay.
1733

1734

1735
        Returns
1736
        -------
1737

1738
        derivatives :  List[float, float)
1739
            of :math:`\\frac{1}{RR}` with respect to `threshold <DriftDiffusionAnalytical.threshold>` and `drift_rate
1740
            <DriftDiffusionAnalytical.drift_rate>`.
1741

1742
        """
1743
        Z = output or self._get_current_parameter_value(THRESHOLD, context)
×
1744
        A = input or self._get_current_parameter_value(DRIFT_RATE, context)
×
1745
        c = self._get_current_parameter_value(NOISE, context)
×
1746
        c_sq = c ** 2
×
1747
        E = np.exp(-2 * Z * A / c_sq)
×
1748
        D_iti = 0
×
1749
        D_pen = 0
×
1750
        D = D_iti + D_pen
×
1751
        # RR =  1/(D_iti + Z/A + (E*D))
1752

1753
        dRR_dZ = 1 / A + E / A + (2 * A / c_sq) * E * D
×
1754
        dRR_dA = -Z / A ** 2 + (Z / A ** 2) * E - (2 * Z / c_sq) * E * D
×
1755

1756
        return [dRR_dZ, dRR_dA]
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc