• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.16
/psyneulink/core/components/functions/nonstateful/transferfunctions.py
1
#
2
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
3
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
4
#     http://www.apache.org/licenses/LICENSE-2.0
5
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
6
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7
# See the License for the specific language governing permissions and limitations under the License.
8
#
9
#
10
# *******************************************  TRANSFER FUNCTIONS  *****************************************************
11
"""
12

13
**Deterministic**
14
    * `Identity`
15
    * `Linear`
16
    * `Exponential`
17
    * `Logistic`
18
    * `Tanh`
19
    * `ReLU`
20
    * `Gaussian`
21

22
**Probabilistic**
23
    * `GaussianDistort`
24
    * `BinomialDistort`
25
    * `Dropout`
26
    * `SoftMax`
27

28
**Other**
29
    * `Angle`
30
    * `TransferWithCosts`
31

32
Overview
33
--------
34

35
TransferFunctions transform their variable but maintain its shape.  There are two subclasses of TransferFunctions --
36
`Deterministic <Deterministic>` and `Probabilistic <Probabilistic> -- that have specialized attributes and/or methods.
37

38
.. _TransferFunction_StandardAttributes:
39

40
Standard Attributes
41
~~~~~~~~~~~~~~~~~~~
42

43
All TransferFunctions have a `range <TransferFunction.range>` attribute that specifies the lower and upper limits
44
of the function's result.  For some subclasses, this may be modified by other parameters.  In addition, all
45
TransferFunctions have a pair of modulable parameters as described below.
46

47
.. _TransferFunction_Modulable_Params:
48

49
* **multiplicative_param** and **additive_param**:
50
  each of these is assigned the name of one of the function's parameters and used by `ModulatoryProjections
51
  <ModulatoryProjection>` to modulate the output of the TransferFunction's function (see `Function_Modulatory_Params`).
52

53
.. _TransferFunction_Derivative:
54

55
Derivatives
56
~~~~~~~~~~~
57

58
Most TransferFunctions have a derivative method.  These take both an **input** and **output** argument.  In general,
59
the **input** is used to compute the derivative of the function at that value. If that is not provided, some
60
Functions can compute the derivative using the function's output, either directly (such as `Logistic.derivative`) or by
61
inferring the input from the **output** and then computing the derivative for that value (such as `ReLU.derivative`)
62

63

64
TranferFunction Class References
65
--------------------------------
66

67

68
"""
69

70
import numbers
1✔
71
import types
1✔
72
import warnings
1✔
73
from enum import Flag, auto
1✔
74
from math import e, pi, sqrt
1✔
75

76
import numpy as np
1✔
77
try:
1✔
78
    import torch
1✔
79
except ImportError:
×
80
    torch = None
×
81
from beartype import beartype
1✔
82

83
from psyneulink._typing import Callable, Mapping, Optional, Union
1✔
84

85
from psyneulink.core import llvm as pnlvm
1✔
86
from psyneulink.core.components.component import parameter_keywords
1✔
87
from psyneulink.core.components.functions.function import (
1✔
88
    DEFAULT_SEED, Function, Function_Base, FunctionError, _random_state_getter, _seed_setter, function_keywords,
89
    get_matrix, is_function_type,
90
)
91
from psyneulink.core.components.functions.nonstateful.transformfunctions import LinearCombination
1✔
92
from psyneulink.core.components.functions.nonstateful.selectionfunctions import OneHot, ARG_MAX, ARG_MAX_INDICATOR
1✔
93
from psyneulink.core.components.functions.stateful.integratorfunctions import SimpleIntegrator
1✔
94
from psyneulink.core.components.shellclasses import Projection
1✔
95
from psyneulink.core.globals.context import ContextFlags, handle_external_context
1✔
96
from psyneulink.core.globals.utilities import is_numeric_scalar
1✔
97
from psyneulink.core.globals.keywords import (
1✔
98
    ADAPTIVE, ADDITIVE_PARAM, ALL, ANGLE_FUNCTION, BIAS, BINOMIAL_DISTORT_FUNCTION,
99
    DETERMINISTIC_TRANSFER_FUNCTION_TYPE, DROPOUT_FUNCTION,
100
    EXPONENTIAL_FUNCTION, GAIN, GAUSSIAN_DISTORT_FUNCTION, GAUSSIAN_FUNCTION,
101
    IDENTITY_FUNCTION, INTERCEPT, LEAK, LINEAR_FUNCTION, LOGISTIC_FUNCTION,
102
    MAX_INDICATOR, MAX_VAL, MULTIPLICATIVE_PARAM, OFF, OFFSET, ON, OUTPUT_TYPE,
103
    PER_ITEM, PROB, PRODUCT, PROB_INDICATOR, PROBABILISTIC_TRANSFER_FUNCTION_TYPE,
104
    RATE, RELU_FUNCTION, SCALE, SLOPE, SOFTMAX_FUNCTION, STANDARD_DEVIATION, SUM,
105
    TANH_FUNCTION, TRANSFER_FUNCTION_TYPE, TRANSFER_WITH_COSTS_FUNCTION,
106
    VARIANCE, VARIABLE, X_0, PREFERENCE_SET_NAME, DEFAULT,
107
)
108
from psyneulink.core.globals.parameters import \
1✔
109
    FunctionParameter, Parameter, get_validator_by_function, check_user_specified, copy_parameter_value
110
from psyneulink.core.globals.preferences.basepreferenceset import \
1✔
111
    REPORT_OUTPUT_PREF, PreferenceEntry, PreferenceLevel, ValidPrefSet
112
from psyneulink.core.globals.utilities import (
1✔
113
    ValidParamSpecType, convert_all_elements_to_np_array, safe_len, is_matrix_keyword)
114

115
__all__ = ['Angle', 'BinomialDistort', 'Dropout', 'Exponential', 'Gaussian', 'GaussianDistort', 'Identity',
1✔
116
           'Linear', 'Logistic', 'ReLU', 'SoftMax', 'Tanh', 'TransferFunction', 'TransferWithCosts'
117
           ]
118

119
def _range_getter_using_scale_and_offset(owning_component=None, context=None):
1✔
120
    """Reassign range based on scale and offset applied to function's default_range
121
    """
122
    default_range = owning_component.default_range
1✔
123
    scale = owning_component.parameters.scale._get(context)
1✔
124
    offset = owning_component.parameters.offset._get(context)
1✔
125

126
    # Deal with lower bound = None:
127
    lower_bound = -np.inf if default_range[0] is None else default_range[0]
1✔
128
    output_for_fct_lower_bound = scale * lower_bound + offset
1✔
129

130
    # Deal with upper bound = None:
131
    upper_bound = np.inf if default_range[1] is None else default_range[1]
1✔
132
    output_for_fct_upper_bound = scale * upper_bound + offset
1✔
133

134
    # Need to do this since scale could be negative, reversing upper and lower range:
135
    if np.isscalar(scale) or np.isscalar(offset):
1!
NEW
136
        lower_bound = min(output_for_fct_lower_bound, output_for_fct_upper_bound)
×
NEW
137
        upper_bound = max(output_for_fct_lower_bound, output_for_fct_upper_bound)
×
138
    else:
139
        lower_bound = np.minimum(output_for_fct_lower_bound, output_for_fct_upper_bound)
1✔
140
        upper_bound = np.maximum(output_for_fct_lower_bound, output_for_fct_upper_bound)
1✔
141

142
    return (lower_bound, upper_bound)
1✔
143

144

145
class TransferFunction(Function_Base):
1✔
146
    """Function that transforms variable but maintains its shape.
147

148
    Abstract base class for TransferFunctions.
149

150
    In addition to the Parameters listed below, all TransferFunctions have a
151
    `multiplicative_param <Function_Modulatory_Params>` and an `additive_param <Function_Modulatory_Params>` --
152
    see `multiplicative and additive params <TransferFunction_Modulable_Params>` for additional information.
153

154
    Attributes
155
    ----------
156

157
    range : tuple(lower bound: float, uppper bound: float)
158
      read-only Parameter that  indicates the lower and upper limits of the function's result. The two items of the
159
      tuple indicate the lower and upper bounds of range, respectively, with `None` as the entry if there are no
160
      bounds.  Some subclasses of TransferFunction may have other Parameters that influence the range, which are
161
      described under the `range <TransferFunction.range>` attribute of the relevant subclasses.
162

163
    default_range : tuple(lower bound: float, uppper bound: float)
164
       class attribute that indicates the upper and lower limits of the Function's result.
165
    """
166

167
    componentType = TRANSFER_FUNCTION_TYPE
1✔
168

169
    class Parameters(Function_Base.Parameters):
1✔
170
        """
171
            Attributes
172
            ----------
173

174
                range
175
                    see `range <TransferFunction.range>`
176

177
                    :default value: (None, None)
178
                    :type:
179

180
        """
181
        range = Parameter((None,None), read_only=True)
1✔
182

183
    def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
1✔
184
        assert isinstance(arg_in.type.pointee, pnlvm.ir.ArrayType)
1✔
185
        assert arg_in.type == arg_out.type
1✔
186

187
        is_2d = isinstance(arg_in.type.pointee.element, pnlvm.ir.ArrayType)
1✔
188

189
        assert arg_in.type == arg_out.type
1✔
190
        with pnlvm.helpers.array_ptr_loop(builder, arg_in, "transfer_loop") as (b, idx):
1✔
191
            if is_2d:
1✔
192
                vi = b.gep(arg_in, [ctx.int32_ty(0), idx])
1✔
193
                vo = b.gep(arg_out, [ctx.int32_ty(0), idx])
1✔
194
                with pnlvm.helpers.array_ptr_loop(b, vi, "nested_transfer_loop") as args:
1✔
195
                    self._gen_llvm_transfer(ctx=ctx, vi=vi, vo=vo,
1✔
196
                                            params=params, state=state, *args, tags=tags)
197
            else:
198
                self._gen_llvm_transfer(b, idx, ctx=ctx, vi=arg_in, vo=arg_out,
1✔
199
                                        params=params, state=state, tags=tags)
200

201
        return builder
1✔
202

203
class DeterministicTransferFunction(TransferFunction):
1✔
204
    """Subclass of TransferFunction that computes a deterministic function.
205

206
    Abstract base class for TransferFunctions that take scale and offset as parameters.
207

208
    In addition to the `standard attributes <TransferFunction_StandardAttributes>` of a TransferFunction,
209
    all DeterministicTransferFunctions have a `scale <DeterministicTransferFunction.scale>` and `offset
210
    <DeterministicTransferFunction.offset>` Parameter, that are used to determine the `range <TransferFunction.range>`
211

212
    Attributes
213
    ----------
214

215
    default_range : tuple(lower bound: float, uppper bound: float)
216
       class attribute that indicates the upper and lower limits of the Function's result, before `scale
217
       <DeterministicTransferFunction.scale>` or `offset <DeterministicTransferFunction.offset>` are applied.
218

219
    range : tuple(lower bound: float, uppper bound: float)
220
      read-only Parameter that indicates the lower and upper limits of the Function's result, after the `scale
221
      <DeterministicTransferFunction.scale>` and `offset <DeterministicTransferFunction.offset>` Parameters
222
      have been applied to the Function's default_range:  :math:`default_range(lower, upper) * scale + offset`.
223

224
    scale : float
225
      determines the value by which the result of the function is multiplied, before `offset
226
      <TransferFunction.offset>` is added.
227

228
    offset : float
229
      determines the value added to the result of the function after `scale <TransferFunction.scale>` has been applied.
230

231
    """
232
    componentType = DETERMINISTIC_TRANSFER_FUNCTION_TYPE
1✔
233

234
    class Parameters(TransferFunction.Parameters):
1✔
235
        """
236
            Attributes
237
            ----------
238

239
                range
240
                    see `range <TransferFunction.range>`
241

242
                    :default value: None
243
                    :type:
244

245
                scale
246
                    see `scale <DeterministicTransferFunction.scale>`
247

248
                    :default value: 1.0
249
                    :type: float
250

251
                offset
252
                    see `offset <DeterministicTransferFunction.offset>`
253

254
                    :default value: 0.0
255
                    :type: float
256
        """
257
        range = Parameter((None, None),
1✔
258
                           getter=_range_getter_using_scale_and_offset,
259
                           read_only=True,
260
                           dependencies={'scale', 'offset'})
261
        scale = Parameter(1.0, modulable=True)
1✔
262
        offset = Parameter(0.0, modulable=True)
1✔
263

264

265
# **********************************************************************************************************************
266
#                                                 Identity
267
# **********************************************************************************************************************
268

269

270
class Identity(DeterministicTransferFunction):  #
1✔
271
    # ----------------------------------------------------------------------
272
    """
273
    Identity(                  \
274
             default_variable, \
275
             params=None,      \
276
             owner=None,       \
277
             name=None,        \
278
             prefs=None        \
279
            )
280

281
    .. _Identity:
282

283
    Returns variable.
284

285
    Arguments
286
    ---------
287

288
    variable : number or np.array : default class_defaults.variable
289
        specifies a template for the value to be returned.
290

291
    params : Dict[param keyword: param value] : default None
292
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
293
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
294
        arguments of the constructor.
295

296
    owner : Component
297
        `component <Component>` to which to assign the Function.
298

299
    name : str : default see `name <Function.name>`
300
        specifies the name of the Function.
301

302
    prefs : PreferenceSet or specification dict : default Function.classPreferences
303
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
304

305
    Attributes
306
    ----------
307

308
    variable : number or np.array
309
        contains value to be returned.
310

311
    owner : Component
312
        `component <Component>` to which the Function has been assigned.
313

314
    name : str
315
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
316
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
317

318
    prefs : PreferenceSet or specification dict : Function.classPreferences
319
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
320
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
321
        for details).
322
    """
323

324
    componentName = IDENTITY_FUNCTION
1✔
325

326
    classPreferences = {
1✔
327
        PREFERENCE_SET_NAME: 'IdentityClassPreferences',
328
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
329
    }
330
    default_range = (None, None)
1✔
331

332
    @check_user_specified
1✔
333
    @beartype
1✔
334
    def __init__(self,
1✔
335
                 default_variable=None,
336
                 params=None,
337
                 owner=None,
338
                 prefs:  Optional[ValidPrefSet] = None):
339
        super().__init__(default_variable=default_variable,
1✔
340
                         params=params,
341
                         owner=owner,
342
                         prefs=prefs,
343
                         )
344

345
        # self.functionOutputType = None
346

347
    def _function(
1✔
348
        self,
349
        variable=None,
350
        context=None,
351
        params=None,
352

353
    ):
354
        """
355
        Return: `variable <Identity.variable>`
356

357
        Arguments
358
        ---------
359

360
        variable : number or np.array : default class_defaults.variable
361
           a single value or array to be returned.
362

363
        params : Dict[param keyword: param value] : default None
364
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
365
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
366
            arguments of the constructor.
367

368

369
        Returns
370
        -------
371

372
        variable : number or np.array
373

374
        """
375
        # outputType = self.functionOutputType
376

377
        return variable
1✔
378

379
    def _gen_llvm_function_body(self, ctx, builder, _1, _2, arg_in, arg_out, *, tags:frozenset):
1✔
380
        val = builder.load(arg_in)
1✔
381
        builder.store(val, arg_out)
1✔
382
        return builder
1✔
383

384
    def _gen_pytorch_fct(self, device, context=None):
1✔
385
        return lambda x: x
×
386

387

388
# **********************************************************************************************************************
389
#                                                    Linear
390
# **********************************************************************************************************************
391

392
class Linear(DeterministicTransferFunction):  #
1✔
393
    # -------------------------------------------------------------------------------------
394
    """
395
    Linear(                \
396
         default_variable, \
397
         slope=1.0,        \
398
         intercept=0.0,    \
399
         scale=1.0,        \
400
         offset=0.0,       \
401
         params=None,      \
402
         owner=None,       \
403
         name=None,        \
404
         prefs=None        \
405
         )
406

407
    .. _Linear:
408

409
    `function <Linear._function>` returns linear transform of `variable <Linear.variable>`:
410

411
    .. math::
412
        scale * (slope * variable + intercept) + offset
413

414
    .. note::
415
       Whereas `scale <Linear.scale>` and `offset <Linear.offset>` have effects similar to `slope <Linear.slope>`
416
       and `intercept <Linear.intercept>`, they are applied after those Parameters have been applied to `variable
417
       <Linear.variable>`, and thus are not identical; rather, they can be thought of as "amplifying" and
418
       "displacing" the Linear function, respectively.
419

420
    .. note::
421
       The default values for `slope <Linear.slope>`, `intercept <Linear.intercept>`, `scale
422
       <DeterministicTransferFunction.scale>`, and `offset <DeterministicTransferFunction.offset>`
423
       implement the *IDENTITY_FUNCTION*.  This may cause the Linear function to be replaced with the
424
       `Identity` Function during some circumstances (e.g., `compilation <Composition_Compilation>`).
425

426
    `derivative <Exponential.derivative>` returns the derivative of the Linear Function:
427

428
    .. math::
429
        scale*slope
430

431
    Arguments
432
    ---------
433

434
    default_variable : number or array : default class_defaults.variable
435
        specifies a template for the value to be transformed.
436

437
    slope : float : default 1.0
438
        specifies a value by which to multiply `variable <Linear.variable>`.
439

440
    intercept : float : default 0.0
441
        specifies a value to add to each element of `variable <Linear.variable>` after applying `slope <Linear.slope>`.
442

443
    scale : float : default 1.0
444
      specifies the value by which the result of the function is multiplied, before `offset <Linear.offset>` is added.
445

446
    offset : float : default 0.0
447
      specifies the value added to the result of the function after `scale <Linear.scale>` has been applied.
448

449
    params : Dict[param keyword: param value] : default None
450
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
451
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
452
        arguments of the constructor.
453

454
    owner : Component
455
        `component <Component>` to which to assign the Function.
456

457
    name : str : default see `name <Function.name>`
458
        specifies the name of the Function.
459

460
    prefs : PreferenceSet or specification dict : default Function.classPreferences
461
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
462

463
    Attributes
464
    ----------
465

466
    variable : number or array
467
        contains value to be transformed.
468

469
    slope : float
470
        value by which each element of `variable <Linear.variable>` is multiplied before applying the
471
        `intercept <Linear.intercept>` (if it is specified).
472

473
    intercept : float
474
        value added to each element of `variable <Linear.variable>` after applying the `slope <Linear.slope>`
475
        (if it is specified).
476

477
    range : (None, None)
478
        modified by `scale <Linear.scale> and/or `offset <Linear.offset>` if they are specified.
479

480
    scale : float
481
      determines the value by which the result of the function is multiplied, before `offset <Linear.offset>` is added.
482

483
    offset : float
484
      determines the value added to the result of the function after `scale <Linear.scale>` has been applied.
485

486
    owner : Component
487
        `component <Component>` to which the Function has been assigned.
488

489
    name : str
490
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
491
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
492

493
    prefs : PreferenceSet or specification dict : Function.classPreferences
494
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
495
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
496
        for details).
497
    """
498

499
    componentName = LINEAR_FUNCTION
1✔
500

501
    classPreferences = {
1✔
502
        PREFERENCE_SET_NAME: 'LinearClassPreferences',
503
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
504
    }
505

506
    _model_spec_class_name_is_generic = True
1✔
507

508
    default_range = (None, None)
1✔
509

510
    class Parameters(DeterministicTransferFunction.Parameters):
1✔
511
        """
512
            Attributes
513
            ----------
514

515
                intercept
516
                    see `intercept <Linear.intercept>`
517

518
                    :default value: 0.0
519
                    :type: ``float``
520

521
                slope
522
                    see `slope <Linear.slope>`
523

524
                    :default value: 1.0
525
                    :type: ``float``
526
        """
527
        slope = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
528
        intercept = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
529

530
    @check_user_specified
1✔
531
    @beartype
1✔
532
    def __init__(self,
1✔
533
                 default_variable=None,
534
                 slope: Optional[ValidParamSpecType] = None,
535
                 intercept: Optional[ValidParamSpecType] = None,
536
                 scale: Optional[ValidParamSpecType] = None,
537
                 offset: Optional[ValidParamSpecType] = None,
538
                 params=None,
539
                 owner=None,
540
                 prefs:  Optional[ValidPrefSet] = None):
541

542
        super().__init__(
1✔
543
            default_variable=default_variable,
544
            slope=slope,
545
            intercept=intercept,
546
            scale=scale,
547
            offset=offset,
548
            params=params,
549
            owner=owner,
550
            prefs=prefs
551
        )
552

553
    def _function(self,
1✔
554
                 variable=None,
555
                 context=None,
556
                 params=None,
557
                 ):
558
        """
559

560
        Arguments
561
        ---------
562

563
        variable : number or array : default class_defaults.variable
564
           a single value or array to be transformed.
565

566
        params : Dict[param keyword: param value] : default None
567
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
568
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
569
            arguments of the constructor.
570

571
        Returns
572
        -------
573

574
        linear transformation of variable : number or array
575

576
        """
577
        slope = self._get_current_parameter_value(SLOPE, context)
1✔
578
        intercept = self._get_current_parameter_value(INTERCEPT, context)
1✔
579
        scale = self._get_current_parameter_value(SCALE, context)
1✔
580
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
581

582
        try:
1✔
583
            # By default, result should be returned as np.array with same dimensionality as input
584
            result = scale * (variable * slope + intercept) + offset
1✔
585
        except TypeError:
1✔
586
            if hasattr(variable, "dtype"):
1!
587
                # If variable is an array with mixed sizes or types, try item-by-item operation
588
                if variable.dtype == object:
1!
589
                    result = np.zeros_like(variable)
1✔
590
                    for i, item in enumerate(variable):
1!
591
                        try:
1✔
592
                            result[i] = scale * (variable[i] * slope + intercept) + offset
1✔
593
                        except TypeError:
1✔
594
                            owner_str = f" of '{self.owner.name}'" if self.owner else ""
1✔
595
                            if variable[i] is None:
1!
596
                                err_msg = (f"Item {i} of {VARIABLE} passed to {self.name}{owner_str} is 'None'; "
1✔
597
                                           f"may be due to missing afferent projection to input_ports[{i}]")
598
                            else:
599
                                err_msg = (f"Unrecognized type for item {i} of {VARIABLE} (variable[i]) "
×
600
                                           f"passed to {self.name}{owner_str}.")
601
                            raise FunctionError(err_msg)
602
                else:
603
                    owner_str = f"'{self.owner.name}'" if self.owner else ""
×
604
                    raise FunctionError(f"Unrecognized type for {VARIABLE} ({variable}) "
605
                                        f"passed to {self.name}{owner_str}.")
606
            # KAM 6/28/18: If the variable does not have a "dtype" attr but made it to this line, then it must be of a
607
            # type that even np does not recognize -- typically a custom OutputPort variable with items of different
608
            # shapes (e.g. variable = [[0.0], [0.0], array([[0.0, 0.0]])] )
609
            elif isinstance(variable, list):
×
610
                result = []
×
611
                for variable_item in variable:
×
612
                    result.append(np.multiply(np.multiply(variable_item, slope) + intercept) + offset)
×
613
            else:
614
                raise FunctionError("Unrecognized type for {} of {} ({})".format(VARIABLE, self.name, variable))
615

616
        return self.convert_output_type(result)
1✔
617

618
    @handle_external_context()
1✔
619
    def derivative(self, input=None, output=None, context=None):
1✔
620
        """
621
        derivative(input)
622

623
        Derivative of `function <Linear._function>` at **input**.
624

625
        Arguments
626
        ---------
627

628
        input : number
629
            value of the input to the Linear transform at which derivative is to be taken.
630

631
        Returns
632
        -------
633

634
        Slope of function :  number or array
635

636
        """
637
        return self._get_current_parameter_value(SLOPE, context) * self._get_current_parameter_value(SCALE, context)
1✔
638

639
    def _is_identity(self, context=None, defaults=False):
1✔
640
        if defaults:
1✔
641
            slope = self.defaults.slope
1✔
642
            intercept = self.defaults.intercept
1✔
643
            scale = self.defaults.scale
1✔
644
            offset = self.defaults.offset
1✔
645
        else:
646
            slope = self.parameters.slope._get(context)
1✔
647
            intercept = self.parameters.intercept._get(context)
1✔
648
            scale = self.parameters.scale._get(context)
1✔
649
            offset = self.parameters.offset._get(context)
1✔
650

651
        return slope == 1 and intercept == 0 and scale == 1 and offset == 0
1✔
652

653
    def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags:frozenset):
1✔
654
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
655
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
656
        slope_ptr = ctx.get_param_or_state_ptr(builder, self, SLOPE, param_struct_ptr=params)
1✔
657
        intercept_ptr = ctx.get_param_or_state_ptr(builder, self, INTERCEPT, param_struct_ptr=params)
1✔
658
        scale_ptr = ctx.get_param_or_state_ptr(builder, self, SCALE, param_struct_ptr=params)
1✔
659
        offset_ptr = ctx.get_param_or_state_ptr(builder, self, OFFSET, param_struct_ptr=params)
1✔
660

661
        slope = pnlvm.helpers.load_extract_scalar_array_one(builder, slope_ptr)
1✔
662
        intercept = pnlvm.helpers.load_extract_scalar_array_one(builder, intercept_ptr)
1✔
663
        scale = pnlvm.helpers.load_extract_scalar_array_one(builder, scale_ptr)
1✔
664
        offset = pnlvm.helpers.load_extract_scalar_array_one(builder, offset_ptr)
1✔
665

666

667
        if "derivative" in tags:
1✔
668
            # f'(x) = m * scale
669
            val = slope
1✔
670
            val = builder.fmul(val,scale)
1✔
671
        else:
672
            # f(x) = scale * (mx + b) + offset
673
            val = builder.load(ptri)
1✔
674
            val = builder.fmul(val, slope)
1✔
675
            val = builder.fadd(val, intercept)
1✔
676
            val = builder.fmul(val, scale)
1✔
677
            val = builder.fadd(val, offset)
1✔
678

679
        builder.store(val, ptro)
1✔
680

681
    def _gen_pytorch_fct(self, device, context=None):
1✔
682
        slope = self._get_pytorch_fct_param_value('slope', device, context)
1✔
683
        intercept = self._get_pytorch_fct_param_value('intercept', device, context)
1✔
684
        scale = self._get_pytorch_fct_param_value('scale', device, context)
1✔
685
        offset = self._get_pytorch_fct_param_value('offset', device, context)
1✔
686
        return lambda x: scale * (x * slope + intercept) + offset
1✔
687

688

689
# **********************************************************************************************************************
690
#                                                    Exponential
691
# **********************************************************************************************************************
692

693
class Exponential(DeterministicTransferFunction):  # -------------------------------------------------------------------
1✔
694
    """
695
    Exponential(           \
696
         default_variable, \
697
         rate=1.0,         \
698
         bias=0.0,         \
699
         scale=1.0,        \
700
         offset=0.0,       \
701
         params=None,      \
702
         owner=None,       \
703
         name=None,        \
704
         prefs=None        \
705
         )
706

707
    .. _Exponential:
708

709
    `function <Exponential._function>` returns exponential transform of `variable <Exponential.variable>`:
710

711
    .. math::
712
         scale * e^{rate*variable+bias} + offset
713

714
    `derivative <Exponential.derivative>` returns the derivative of the Exponential Function:
715

716
    .. math::
717
        scale*rate*(input+bias)*e^{rate*input+bias}
718

719

720
    Arguments
721
    ---------
722

723
    default_variable : number or array : default class_defaults.variable
724
        specifies a template for the value to be transformed.
725

726
    rate : float : default 1.0
727
        specifies a value by which to multiply `variable <Exponential.variable>` before exponentiation.
728

729
    bias : float : default 0.0
730
        specifies a value to add to `variable <Exponential.variable>` after multplying by `rate <Exponential.rate>`
731
        and before exponentiation.
732

733
    scale : float : default 1.0
734
      specifies the value by which the result of the function is multiplied, before `offset <Exponential.offset>` is
735
      added.
736

737
    offset : float : default 0.0
738
      specifies the value added to the result of the function after `scale <Exponential.scale>` has been applied.
739

740
    params : Dict[param keyword: param value] : default None
741
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
742
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
743
        arguments of the constructor.
744

745
    owner : Component
746
        `component <Component>` to which to assign the Function.
747

748
    name : str : default see `name <Function.name>`
749
        specifies the name of the Function.
750

751
    prefs : PreferenceSet or specification dict : default Function.classPreferences
752
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
753

754
    Attributes
755
    ----------
756

757
    variable : number or array
758
        contains value to be transformed.
759

760
    rate : float
761
        value by which `variable <Exponential.variable>` is multiplied before exponentiation;
762
        assigned as *MULTILICATIVE_PARAM* of the Exponential Function.
763

764
    bias : float
765
        value added to `variable <Exponential.variable>` after multiplying by `rate <Exponential.rate>`
766
        and before exponentiation;  assigned as *ADDITIVE_PARAM* of the Exponential Function.
767

768
    range : (0, None)
769
        modified by `scale <Exponential.scale> and/or `offset <Exponential.offset>` if they are specified.
770

771
    scale : float
772
      determines the value by which the result of the function is multiplied, before `offset <Exponential.offset>`
773
      is added.
774

775
    offset : float
776
      determines the value added to the result of the function after `scale <Exponential.scale>` has been applied.
777

778
    owner : Component
779
        `component <Component>` to which the Function has been assigned.
780

781
    name : str
782
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
783
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
784

785
    prefs : PreferenceSet or specification dict : Function.classPreferences
786
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
787
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
788
        for details).
789
    """
790

791
    componentName = EXPONENTIAL_FUNCTION
1✔
792
    default_range = (0, None)
1✔
793

794

795
    class Parameters(DeterministicTransferFunction.Parameters):
1✔
796
        """
797
            Attributes
798
            ----------
799

800
                bias
801
                    see `bias <Exponential.bias>`
802

803
                    :default value: 0.0
804
                    :type: ``float``
805

806
                rate
807
                    see `rate <Exponential.rate>`
808

809
                    :default value: 1.0
810
                    :type: ``float``
811
        """
812
        rate = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
813
        bias = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
814

815
    @check_user_specified
1✔
816
    @beartype
1✔
817
    def __init__(self,
1✔
818
                 default_variable=None,
819
                 rate: Optional[ValidParamSpecType] = None,
820
                 bias: Optional[ValidParamSpecType] = None,
821
                 scale: Optional[ValidParamSpecType] = None,
822
                 offset: Optional[ValidParamSpecType] = None,
823
                 params=None,
824
                 owner=None,
825
                 prefs:  Optional[ValidPrefSet] = None):
826
        super().__init__(
1✔
827
            default_variable=default_variable,
828
            rate=rate,
829
            bias=bias,
830
            scale=scale,
831
            offset=offset,
832
            params=params,
833
            owner=owner,
834
            prefs=prefs
835
        )
836

837
    def _function(self,
1✔
838
                 variable=None,
839
                 context=None,
840
                 params=None,
841
                 ):
842
        """
843

844
        Arguments
845
        ---------
846

847
        variable : number or array : default class_defaults.variable
848
           a single value or array to be exponentiated.
849

850
        params : Dict[param keyword: param value] : default None
851
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
852
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
853
            arguments of the constructor.
854

855
        Returns
856
        -------
857

858
        Exponential transformation of variable : number or array
859

860
        """
861
        rate = self._get_current_parameter_value(RATE, context)
1✔
862
        bias = self._get_current_parameter_value(BIAS, context)
1✔
863
        scale = self._get_current_parameter_value(SCALE, context)
1✔
864
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
865

866
        result = scale * e**(rate * variable + bias) + offset
1✔
867
        return self.convert_output_type(result)
1✔
868

869
    @handle_external_context()
1✔
870
    def derivative(self, input, output=None, context=None):
1✔
871
        """
872
        derivative(input)
873

874
        Arguments
875
        ---------
876

877
        input : number
878
            value of the input to the Exponential transform at which derivative is to be taken.
879

880
        Derivative of `function <Exponential._function>` at **input**.
881

882
        Returns
883
        -------
884
        derivative :  number or array
885
        """
886

887
        rate = self._get_current_parameter_value(RATE, context)
1✔
888
        scale = self._get_current_parameter_value(SCALE, context)
1✔
889
        bias = self._get_current_parameter_value(BIAS, context)
1✔
890

891
        return scale * rate * e**(rate * input + bias)
1✔
892

893
    def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags:frozenset):
1✔
894
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
895
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
896

897
        rate_ptr = ctx.get_param_or_state_ptr(builder, self, RATE, param_struct_ptr=params)
1✔
898
        bias_ptr = ctx.get_param_or_state_ptr(builder, self, BIAS, param_struct_ptr=params)
1✔
899
        scale_ptr = ctx.get_param_or_state_ptr(builder, self, SCALE, param_struct_ptr=params)
1✔
900
        offset_ptr = ctx.get_param_or_state_ptr(builder, self, OFFSET, param_struct_ptr=params)
1✔
901

902
        rate = pnlvm.helpers.load_extract_scalar_array_one(builder, rate_ptr)
1✔
903
        bias = pnlvm.helpers.load_extract_scalar_array_one(builder, bias_ptr)
1✔
904
        scale = pnlvm.helpers.load_extract_scalar_array_one(builder, scale_ptr)
1✔
905
        offset = pnlvm.helpers.load_extract_scalar_array_one(builder, offset_ptr)
1✔
906

907
        exp_f = ctx.get_builtin("exp", [ctx.float_ty])
1✔
908
        val = builder.load(ptri)
1✔
909
        val = builder.fmul(val, rate)
1✔
910
        val = builder.fadd(val, bias)
1✔
911
        val = builder.call(exp_f, [val])
1✔
912

913
        if "derivative" in tags:
1✔
914
            # f'(x) = s*r*e^(r*x + b)
915
            val = builder.fmul(val, scale)
1✔
916
            val = builder.fmul(val, rate)
1✔
917
        else:
918
            # f(x) = s*e^(r*x + b) + o
919
            val = builder.fmul(val, scale)
1✔
920
            val = builder.fadd(val, offset)
1✔
921

922
        builder.store(val, ptro)
1✔
923

924
    def _gen_pytorch_fct(self, device, context=None):
1✔
925
        rate = self._get_pytorch_fct_param_value('rate', device, context)
×
926
        scale = self._get_pytorch_fct_param_value('scale', device, context)
×
927
        bias = self._get_pytorch_fct_param_value('bias', device, context)
×
928

929
        return rate * scale * torch.exp(rate * input + bias)
×
930

931

932
# **********************************************************************************************************************
933
#                                                     Logistic
934
# **********************************************************************************************************************
935

936

937
class Logistic(DeterministicTransferFunction):  # ----------------------------------------------------------------------
1✔
938
    """
939
    Logistic(              \
940
         default_variable, \
941
         gain=1.0,         \
942
         bias=0.0,         \
943
         x_0=0.0,          \
944
         scale=1.0,        \
945
         offset=0.0,       \
946
         params=None,      \
947
         owner=None,       \
948
         name=None,        \
949
         prefs=None        \
950
         )
951

952
    .. _Logistic_Function:
953

954
    `function <Logistic._function>` returns logistic transform of `variable <Logistic.variable>`:
955

956
    .. math::
957
         scale * \\frac{1}{1 + e^{ - gain ( variable + bias - x_{0} )}}  + offset
958

959
    (this is a vertically offset and scaled version of `Tanh`, which is centered on origin).
960

961
    .. _Logistic_Note:
962

963
    .. note::
964
        The `bias <Logistic.bias>` and `x_0 <Logistic.x_0>` Parameters have identical effects, apart from having
965
        opposite signs: `bias <Logistic.bias>` is included to accommodate the convention in the machine learning
966
        community; `x_0 <Logistic.x_0>` is included to match the `standard form of the Logistic Function
967
        <https://en.wikipedia.org/wiki/Logistic_function>`_ (in which `gain <Logistic.gain>` corresponds to
968
        the *k* parameter and `scale <Logistic.scale>` corresponds to the *L* parameter); `offset <Logistic.offset>`
969
        implements a translation of the function along the vertical axis that is *not* modulated by gain.
970

971
    `derivative <Logistic.derivative>` returns the derivative of the Logistic using its **output**:
972

973
    .. math::
974
        scale * gain * output * (1-output)
975

976
    See `note <Logistic_Note>` above for the effects of `scale <Logistic.scale>` and `offset <Logistic.offset>`.
977

978
    Arguments
979
    ---------
980

981
    default_variable : number or array : default class_defaults.variable
982
        specifies a template for the value to be transformed.
983

984
    gain : float : default 1.0
985
        specifies value by which to multiply each element of `variable <Logistic.variable>` after it is
986
        adjusted by `bias <Logistic.bias>` and/or `x_0 <Logistic.x_0>`, but before logistic transformation
987
        (see `note <Logistic_Note>` above).
988

989
    bias : float : default 0.0
990
        specifies value to add to each element of `variable <Logistic.variable>` before applying `gain <Logistic.gain>`;
991
        this argument has an effect identical to x_0, but with the opposite sign (see `note <Logistic_Note>` above).
992

993
    x_0 : float : default 0.0
994
        specifies value to add to each element of `variable <Logistic.variable>` before applying `gain <Logistic.gain>`;
995
        this argument has an effect identical to bias, but with the opposite sign (see `note <Logistic_Note>` above).
996

997
    scale : float : default 1.0
998
      specifies the value by which the result of the function is multiplied, before `offset <Logistic.offset>` is added.
999

1000
    offset : float : default 0.0
1001
      specifies the value added to the result of the function after `scale <Logistic.scale>` has been applied.
1002

1003
    params : Dict[param keyword: param value] : default None
1004
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1005
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1006
        arguments of the constructor.
1007

1008
    owner : Component
1009
        `component <Component>` to which to assign the Function.
1010

1011
    name : str : default see `name <Function.name>`
1012
        specifies the name of the Function.
1013

1014
    prefs : PreferenceSet or specification dict : default Function.classPreferences
1015
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
1016

1017
    Attributes
1018
    ----------
1019

1020
    variable : number or array
1021
        contains value to be transformed.
1022

1023
    gain : float
1024
        value by which to multiply each element of `variable <Logistic.variable>` after it is adjusted by
1025
        `bias <Logistic.bias>` and/or `x_0 <Logistic.x_0>`, but before logistic transformation has been applied
1026
        (see `note <Logistic_Note>` above).
1027

1028
    bias : float
1029
        value to add to each element of `variable <Logistic.variable>` before applying `gain <Logistic.gain>`;
1030
        this argument has an effect identical to x_0, but with the opposite sign (see `note <Logistic_Note>` above).
1031

1032
    x_0 : float
1033
        value to add to each element of `variable <Logistic.variable>` before applying `gain <Logistic.gain>`;
1034
        this argument has an effect identical to bias, but with the opposite sign (see `note <Logistic_Note>` above).
1035

1036
    range : (0, 1)
1037
        modified by `scale <Gaussian.scale> and/or `offset <Gaussian.offset>` if they are specified.
1038

1039
    scale : float
1040
      determines the value by which the result of the function is multiplied, before `offset <Logistic.offset>`
1041
      is added.
1042

1043
    offset : float
1044
      determines the value added to the result of the function after `scale <Logistic.scale>` has been applied.
1045

1046
    owner : Component
1047
        `component <Component>` to which the Function has been assigned.
1048

1049
    name : str
1050
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
1051
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
1052

1053
    prefs : PreferenceSet or specification dict : Function.classPreferences
1054
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
1055
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
1056
        for details).
1057
    """
1058

1059
    componentName = LOGISTIC_FUNCTION
1✔
1060
    # parameter_keywords.update({GAIN, BIAS})
1061
    _model_spec_class_name_is_generic = True
1✔
1062
    default_range = (0, 1)
1✔
1063

1064

1065
    class Parameters(DeterministicTransferFunction.Parameters):
1✔
1066
        """
1067
            Attributes
1068
            ----------
1069

1070
                bias
1071
                    see `bias <Logistic.bias>`
1072

1073
                    :default value: 0.0
1074
                    :type: ``float``
1075

1076
                gain
1077
                    see `gain <Logistic.gain>`
1078

1079
                    :default value: 1.0
1080
                    :type: ``float``
1081

1082
                x_0
1083
                    see `x_0 <Logistic.x_0>`
1084

1085
                    :default value: 0.0
1086
                    :type: ``float``
1087
        """
1088
        gain = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
1089
        x_0 = Parameter(0.0, modulable=True)
1✔
1090
        bias = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
1091

1092
    @check_user_specified
1✔
1093
    @beartype
1✔
1094
    def __init__(self,
1✔
1095
                 default_variable=None,
1096
                 gain: Optional[ValidParamSpecType] = None,
1097
                 x_0=None,
1098
                 bias=None,
1099
                 scale: Optional[ValidParamSpecType] = None,
1100
                 offset: Optional[ValidParamSpecType] = None,
1101
                 params=None,
1102
                 owner=None,
1103
                 prefs:  Optional[ValidPrefSet] = None,
1104
                 **kwargs):
1105
        super().__init__(
1✔
1106
            default_variable=default_variable,
1107
            gain=gain,
1108
            x_0=x_0,
1109
            bias=bias,
1110
            scale=scale,
1111
            offset=offset,
1112
            params=params,
1113
            owner=owner,
1114
            prefs=prefs,
1115
            **kwargs
1116
        )
1117

1118
    def _function(self,
1✔
1119
                 variable=None,
1120
                 context=None,
1121
                 params=None,
1122
                 ):
1123
        """
1124

1125
        Arguments
1126
        ---------
1127

1128
        variable : number or array : default class_defaults.variable
1129
           a single value or array to be transformed.
1130

1131
        params : Dict[param keyword: param value] : default None
1132
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1133
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1134
            arguments of the constructor.
1135

1136
        Returns
1137
        -------
1138

1139
        Logistic transformation of variable : number or array
1140

1141
        """
1142
        gain = self._get_current_parameter_value(GAIN, context)
1✔
1143
        bias = self._get_current_parameter_value(BIAS, context)
1✔
1144
        x_0 = self._get_current_parameter_value(X_0, context)
1✔
1145
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
1146
        scale = self._get_current_parameter_value(SCALE, context)
1✔
1147

1148
        result = scale * (1. / (1 + e**(-gain * (variable + bias - x_0)))) + offset
1✔
1149

1150
        return self.convert_output_type(result)
1✔
1151

1152
    @handle_external_context()
1✔
1153
    def derivative(self, input=None, output=None, context=None):
1✔
1154
        """
1155
        derivative(input=None, output=None)
1156

1157
        Derivative of `function <Logistic._function>` at either **input** or **output**.
1158

1159
        COMMENT:  RESTORE WHEN TEST IN DERIVATIVE IS RESTORED
1160
        Either **input** or **output** must be specified.
1161
        If **output** is not specified, it is computed from  **input**.
1162
        If both are specified, **input** is ignored unless paramValidationPref is set, in which case
1163
        an error is generated if **output** does not correspond to `function <Logistic._function>`\\(**input**).
1164
        COMMENT
1165
        Either **input** or **output** must be specified.
1166
        If **output** is not specified, derivative is computed from **input**.
1167
        If both are specified, **input** is ignored and derivative is computed from **output**
1168
        .. technical_note::
1169
           allowing both to be specified is supported for consistency with `BackPropagation` `LearningFunction`
1170
           which uses output to compute Logistic
1171

1172
        Arguments
1173
        ---------
1174

1175
        input : number
1176
            value of the input to the Logistic transform at which derivative is to be taken.
1177

1178
        output : number
1179
            value of the output of the Logistic transform at which derivative is to be taken.
1180

1181
        Returns
1182
        -------
1183
        derivative  of logistic transform at output :  number or array
1184
        """
1185

1186
        gain = self._get_current_parameter_value(GAIN, context)
1✔
1187
        scale = self._get_current_parameter_value(SCALE, context)
1✔
1188

1189
        # Favor use of output: compute it from input if it is not provided
1190
        if output is None:
1✔
1191
            output = self.function(input, context=context)
1✔
1192

1193
        return gain * scale * output * (1 - output)
1✔
1194

1195
    def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags:frozenset):
1✔
1196
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
1197
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
1198

1199
        gain_ptr = ctx.get_param_or_state_ptr(builder, self, GAIN, param_struct_ptr=params)
1✔
1200
        bias_ptr = ctx.get_param_or_state_ptr(builder, self, BIAS, param_struct_ptr=params)
1✔
1201
        x_0_ptr = ctx.get_param_or_state_ptr(builder, self, X_0, param_struct_ptr=params)
1✔
1202
        scale_ptr = ctx.get_param_or_state_ptr(builder, self, SCALE, param_struct_ptr=params)
1✔
1203
        offset_ptr = ctx.get_param_or_state_ptr(builder, self, OFFSET, param_struct_ptr=params)
1✔
1204

1205
        gain = pnlvm.helpers.load_extract_scalar_array_one(builder, gain_ptr)
1✔
1206
        bias = pnlvm.helpers.load_extract_scalar_array_one(builder, bias_ptr)
1✔
1207
        x_0 = pnlvm.helpers.load_extract_scalar_array_one(builder, x_0_ptr)
1✔
1208
        offset = pnlvm.helpers.load_extract_scalar_array_one(builder, offset_ptr)
1✔
1209
        scale = pnlvm.helpers.load_extract_scalar_array_one(builder, scale_ptr)
1✔
1210
        exp_f = ctx.get_builtin("exp", [ctx.float_ty])
1✔
1211
        val = builder.load(ptri)
1✔
1212

1213
        if "derivative_out" not in tags:
1✔
1214
            val = builder.fadd(val, bias)             # variable + bias
1✔
1215
            val = builder.fsub(val, x_0)              # variable + bias - x_0
1✔
1216
            val = builder.fmul(val, gain)             # gain * (variable + bias - x_0)
1✔
1217
            val = builder.fneg(val)                   # -gain * (variable + bias - x_0)
1✔
1218
            val = builder.call(exp_f, [val])          # e^(-gain * (variable + bias - x_0))
1✔
1219
            val = builder.fadd(val.type(1), val)      # 1 + e^(-gain * (variable + bias - x_0))
1✔
1220
            val = builder.fdiv(val.type(1), val)      # 1 / (1 + e^(-gain * (variable + bias - x_0)))
1✔
1221
            val = builder.fmul(val, scale)            # scale * (1 / (1 + e^(-gain * (variable + bias - x_0)))
1✔
1222
            val = builder.fadd(val, offset)           # scale * (1 / (1 + e^(-gain * (variable + bias - x_0))) + offset
1✔
1223

1224
        if "derivative" in tags or "derivative_out" in tags:
1✔
1225
            # f(x) = g * s * o * (1 - o)
1226
            function_val = val
1✔
1227
            val = builder.fsub(function_val.type(1), function_val)
1✔
1228
            val = builder.fmul(function_val, val)
1✔
1229
            val = builder.fmul(gain, val)
1✔
1230
            val = builder.fmul(scale, val)
1✔
1231

1232
        builder.store(val, ptro)
1✔
1233

1234
    def _gen_pytorch_fct(self, device, context=None):
1✔
1235
        scale = self._get_pytorch_fct_param_value('scale', device, context)
1✔
1236
        gain = self._get_pytorch_fct_param_value('gain', device, context)
1✔
1237
        bias = self._get_pytorch_fct_param_value('bias', device, context)
1✔
1238
        offset = self._get_pytorch_fct_param_value('offset', device, context)
1✔
1239
        return lambda x: scale / (1 + torch.exp(-gain * (x + bias))) + offset
1✔
1240

1241
    def as_mdf_model(self):
1✔
1242
        model = super().as_mdf_model()
1✔
1243

1244
        # x_0 is included in bias in MDF logistic
1245
        self._set_mdf_arg(model, 'bias', np.array(model.args['bias'] - model.args['x_0']))
1✔
1246
        self._set_mdf_arg(model, 'x_0', np.array(0))
1✔
1247

1248
        if model.args['scale'] != 1.0:
1!
1249
            warnings.warn(
×
1250
                f"Scale (set to {model.args['scale']} is not a supported"
1251
                ' parameter for MDF logistic'
1252
            )
1253
        return model
1✔
1254

1255

1256
# **********************************************************************************************************************
1257
#                                                    Tanh
1258
# **********************************************************************************************************************
1259

1260
class Tanh(DeterministicTransferFunction):  # --------------------------------------------------------------------------
1✔
1261
    """
1262
    Tanh(                  \
1263
         default_variable, \
1264
         gain=1.0,         \
1265
         bias=0.0,         \
1266
         x_0=0.0,          \
1267
         scale=1.0,        \
1268
         offset=0.0,       \
1269
         params=None,      \
1270
         owner=None,       \
1271
         name=None,        \
1272
         prefs=None        \
1273
         )
1274

1275
    .. _Tanh_Function:
1276

1277
    `function <Tanh._function>` returns hyperbolic tangent of `variable <Tanh.variable>`:
1278

1279
    .. math::
1280
        \\scale*frac{1 - e^{-2(gain*(variable+bias-x\\_0)+offset)}}{1 + e^{-2(gain*(variable+bias-x\\_0)+offset)}}
1281

1282
    .. note::
1283

1284
       The `Tanh` function is an offset and scaled version of this function.
1285
       The parameters used here have the same meaning as those used for the `Tanh` Function.
1286

1287
    `derivative <Tanh.derivative>` returns the derivative of the hyperbolic tangent at its **input**:
1288

1289
    .. math::
1290
        \\frac{scale*gain}{(\\frac{1+e^{-2(gain*(variable+bias-x\\_0)+offset)}}{2e^{-(gain*(
1291
       variable+bias-x\\_0)+offset)}})^2}
1292

1293
    Arguments
1294
    ---------
1295

1296
    default_variable : number or array : default class_defaults.variable
1297
        specifies template for the value to be transformed.
1298

1299
    gain : float : default 1.0
1300
        specifies value by which to multiply `variable <Tanh.variable>` before Tanh transformation
1301

1302
    bias : float : default 0.0
1303
        specifies value to add to each element of `variable <Tanh.variable>` before applying `gain <Tanh.gain>`
1304
        and before Tanh transformation. This argument is identical to x_0, with the opposite sign.
1305

1306
    x_0 : float : default 0.0
1307
        specifies value to subtract from each element of `variable <Tanh.variable>` before applying `gain <Tanh.gain>`
1308
        and before Tanh transformation. This argument is identical to bias, with the opposite sign.
1309

1310
    scale : float : default 1.0
1311
      specifies the value by which the result of the function is multiplied, before `offset <TanH.offset>` is added.
1312

1313
    offset : float : default 0.0
1314
      specifies the value added to the result of the function after `scale <TanH.scale>` has been applied.
1315

1316
    params : Dict[param keyword: param value] : default None
1317
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1318
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1319
        arguments of the constructor.
1320

1321
    owner : Component
1322
        `component <Component>` to which to assign the Function.
1323

1324
    name : str : default see `name <Function.name>`
1325
        specifies the name of the Function.
1326

1327
    prefs : PreferenceSet or specification dict : default Function.classPreferences
1328
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
1329

1330
    Attributes
1331
    ----------
1332

1333
    variable : number or array
1334
        contains value to be transformed.
1335

1336
    gain : float : default 1.0
1337
        value by which each element of `variable <Tanh.variable>` is multiplied before applying the
1338
        `bias <Tanh.bias>` (if it is specified).
1339

1340
    bias : float : default 0.0
1341
        value added to each element of `variable <Tanh.variable>` before applying the `gain <Tanh.gain>`
1342
        (if it is specified). This attribute is identical to x_0, with the opposite sign.
1343

1344
    x_0 : float : default 0.0
1345
        value subtracted from each element of `variable <Tanh.variable>` before applying the `gain <Tanh.gain>`
1346
        (if it is specified). This attribute is identical to bias, with the opposite sign.
1347

1348
    range : (None, None)
1349
        modified by `scale <TanH.scale> and/or `offset <TanH.offset>` if they are specified.
1350

1351
    scale : float
1352
      determines the value by which the result of the function is multiplied, before `offset <TanH.offset>` is added.
1353

1354
    offset : float
1355
      determines the value added to the result of the function after `scale <TanH.scale>` has been applied.
1356

1357
    owner : Component
1358
        `component <Component>` to which the Function has been assigned.
1359

1360
    name : str
1361
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
1362
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
1363

1364
    prefs : PreferenceSet or specification dict : Function.classPreferences
1365
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
1366
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
1367
        for details).
1368
    """
1369

1370
    componentName = TANH_FUNCTION
1✔
1371
    # parameter_keywords.update({GAIN, BIAS, OFFSET})
1372
    default_range = (-1, 1)
1✔
1373

1374

1375
    class Parameters(DeterministicTransferFunction.Parameters):
1✔
1376
        """
1377
            Attributes
1378
            ----------
1379

1380
                bias
1381
                    see `bias <Tanh.bias>`
1382

1383
                    :default value: 0.0
1384
                    :type: ``float``
1385

1386
                gain
1387
                    see `gain <Tanh.gain>`
1388

1389
                    :default value: 1.0
1390
                    :type: ``float``
1391

1392
                x_0
1393
                    see `x_0 <Tanh.x_0>`
1394

1395
                    :default value: 0.0
1396
                    :type: ``float``
1397
        """
1398
        gain = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
1399
        x_0 = Parameter(0.0, modulable=True)
1✔
1400
        bias = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
1401

1402
    @check_user_specified
1✔
1403
    @beartype
1✔
1404
    def __init__(self,
1✔
1405
                 default_variable=None,
1406
                 gain: Optional[ValidParamSpecType] = None,
1407
                 x_0=None,
1408
                 bias=None,
1409
                 scale: Optional[ValidParamSpecType] = None,
1410
                 offset: Optional[ValidParamSpecType] = None,
1411
                 params=None,
1412
                 owner=None,
1413
                 prefs:  Optional[ValidPrefSet] = None,
1414
                 **kwargs):
1415
        super().__init__(
1✔
1416
            default_variable=default_variable,
1417
            gain=gain,
1418
            x_0=x_0,
1419
            bias=bias,
1420
            scale=scale,
1421
            offset=offset,
1422
            params=params,
1423
            owner=owner,
1424
            prefs=prefs,
1425
            **kwargs
1426
        )
1427

1428
    def _function(self,
1✔
1429
                 variable=None,
1430
                 context=None,
1431
                 params=None,
1432
                 ):
1433
        """
1434

1435
        Arguments
1436
        ---------
1437

1438
        variable : number or array : default class_defaults.variable
1439
           a single value or array to be transformed.
1440

1441
        params : Dict[param keyword: param value] : default None
1442
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1443
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1444
            arguments of the constructor.
1445

1446
        Returns
1447
        -------
1448

1449
        hyperbolic tangent of variable : number or array
1450

1451
        """
1452
        gain = self._get_current_parameter_value(GAIN, context)
1✔
1453
        bias = self._get_current_parameter_value(BIAS, context)
1✔
1454
        x_0 = self._get_current_parameter_value(X_0, context)
1✔
1455
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
1456
        scale = self._get_current_parameter_value(SCALE, context)
1✔
1457

1458
        exponent = -2 * (gain * (variable + bias - x_0) + offset)
1✔
1459
        result = scale * (1 - e**exponent)/ (1 + e**exponent)
1✔
1460

1461
        return self.convert_output_type(result)
1✔
1462

1463

1464
    @handle_external_context()
1✔
1465
    def derivative(self, input, output=None, context=None):
1✔
1466
        """
1467
        derivative(input)
1468

1469
        Derivative of `function <Tanh._function>` at **input**.
1470

1471
        Arguments
1472
        ---------
1473

1474
        input : number
1475
            value of the input to the Tanh transform at which derivative is to be taken.
1476

1477
        Returns
1478
        -------
1479
        derivative :  number or array
1480
        """
1481

1482
        gain = self._get_current_parameter_value(GAIN, context)
1✔
1483
        bias = self._get_current_parameter_value(BIAS, context)
1✔
1484
        x_0 = self._get_current_parameter_value(X_0, context)
1✔
1485
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
1486
        scale = self._get_current_parameter_value(SCALE, context)
1✔
1487

1488
        exponent = -2 * (gain * (input + bias - x_0) + offset)
1✔
1489
        mult = -2 * gain * scale
1✔
1490
        numerator = -2 * e**(exponent)
1✔
1491
        denominator = (1 + e**(exponent))**2
1✔
1492

1493
        return mult * (numerator / denominator)
1✔
1494

1495
    def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags:frozenset):
1✔
1496
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
1497
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
1498

1499
        gain_ptr = ctx.get_param_or_state_ptr(builder, self, GAIN, param_struct_ptr=params)
1✔
1500
        bias_ptr = ctx.get_param_or_state_ptr(builder, self, BIAS, param_struct_ptr=params)
1✔
1501
        x_0_ptr = ctx.get_param_or_state_ptr(builder, self, X_0, param_struct_ptr=params)
1✔
1502
        offset_ptr = ctx.get_param_or_state_ptr(builder, self, OFFSET, param_struct_ptr=params)
1✔
1503
        scale_ptr = ctx.get_param_or_state_ptr(builder, self, SCALE, param_struct_ptr=params)
1✔
1504

1505
        gain = pnlvm.helpers.load_extract_scalar_array_one(builder, gain_ptr)
1✔
1506
        bias = pnlvm.helpers.load_extract_scalar_array_one(builder, bias_ptr)
1✔
1507
        x_0 = pnlvm.helpers.load_extract_scalar_array_one(builder, x_0_ptr)
1✔
1508
        offset = pnlvm.helpers.load_extract_scalar_array_one(builder, offset_ptr)
1✔
1509
        scale = pnlvm.helpers.load_extract_scalar_array_one(builder, scale_ptr)
1✔
1510

1511
        variable = builder.load(ptri)
1✔
1512
        exp_f = ctx.get_builtin("exp", [ctx.float_ty])
1✔
1513

1514
        if "derivative" in tags:
1✔
1515
            exponent = builder.fadd(variable, bias)
1✔
1516
            exponent = builder.fsub(exponent, x_0)
1✔
1517
            exponent = builder.fmul(gain, exponent)
1✔
1518
            exponent = builder.fadd(exponent, offset)
1✔
1519
            exponent = builder.fmul(exponent.type(-2), exponent)
1✔
1520

1521
            mult = builder.fmul(gain, scale)
1✔
1522
            mult = builder.fmul(mult.type(-2), mult)
1✔
1523

1524
            exp_val = builder.call(exp_f, [exponent])
1✔
1525
            numerator = builder.fmul(exp_val.type(-2), exp_val)
1✔
1526

1527
            denominator = builder.fadd(exp_val.type(1), exp_val)
1✔
1528
            denominator = builder.fmul(denominator, denominator)
1✔
1529

1530
            val = builder.fdiv(numerator, denominator)
1✔
1531
            val = builder.fmul(val, mult)
1✔
1532
        else:
1533
            exp_val = builder.fadd(variable, bias)
1✔
1534
            exp_val = builder.fsub(exp_val, x_0)
1✔
1535
            exp_val = builder.fmul(exp_val, gain)
1✔
1536
            exp_val = builder.fadd(exp_val, offset)
1✔
1537
            exp_val = builder.fmul(exp_val.type(-2), exp_val)
1✔
1538

1539
            val = builder.call(exp_f, [exp_val])
1✔
1540
            val1 = builder.fsub(val.type(1), val)
1✔
1541
            val2 = builder.fadd(val.type(1), val)
1✔
1542
            val = builder.fdiv(val1, val2)
1✔
1543
            val = builder.fmul(val, scale)
1✔
1544

1545
        builder.store(val, ptro)
1✔
1546

1547
    def _gen_pytorch_fct(self, device, context=None):
1✔
1548
        gain = self._get_pytorch_fct_param_value('gain', device, context)
1✔
1549
        bias = self._get_pytorch_fct_param_value('bias', device, context)
1✔
1550
        offset = self._get_pytorch_fct_param_value('offset', device, context)
1✔
1551
        # return lambda x: 1 / (1 + torch.exp(-gain * (x + bias) + offset))
1552
        return lambda x: ((torch.exp(-gain * (x + bias) + offset) - torch.exp(-gain * (-x + bias) + offset))
1✔
1553
                          / (torch.exp(-gain * (x + bias) + offset) + torch.exp(-gain * (-x + bias) + offset)))
1554

1555
# **********************************************************************************************************************
1556
#                                                    ReLU
1557
# **********************************************************************************************************************
1558

1559
class ReLU(DeterministicTransferFunction):  # --------------------------------------------------------------------------
1✔
1560
    """
1561
    ReLU(                  \
1562
         default_variable, \
1563
         gain=1.0,         \
1564
         bias=0.0,         \
1565
         leak=0.0,         \
1566
         scale=1.0,        \
1567
         offset=0.0,       \
1568
         params=None,      \
1569
         owner=None,       \
1570
         name=None,        \
1571
         prefs=None        \
1572
         )
1573

1574
    .. _RelU_Function:
1575

1576
    `function <ReLU._function>` returns rectified linear tranform of `variable <ReLU.variable>`:
1577

1578
    .. math::
1579
        x = scale * gain * (variable - bias) + offset
1580

1581
    .. math::
1582
        max(x, leak * x) + offset
1583

1584
    Commonly used by `ReLU <https://en.wikipedia.org/wiki/Rectifier_(neural_networks>`_ units in neural networks.
1585

1586
    `derivative <ReLU.derivative>` returns the derivative of of the rectified linear tranform at its **input**:
1587

1588
    .. math::
1589
        scale * gain\\ if\\ input > 0,\\ scale * gain * leak\\ otherwise
1590

1591
    Arguments
1592
    ---------
1593

1594
    default_variable : number or array : default class_defaults.variable
1595
        specifies a template for the value to be transformed.
1596

1597
    gain : float : default 1.0
1598
        specifies a value by which to multiply `variable <ReLU.variable>` after `bias <ReLU.bias>` is subtracted
1599
        from it.
1600

1601
    bias : float : default 0.0
1602
        specifies a value to subtract from each element of `variable <ReLU.variable>`; functions as threshold.
1603

1604
    leak : float : default 0.0
1605
        specifies a scaling factor between 0 and 1 when (variable - bias) is less than or equal to 0.
1606

1607
    scale : float : default 1.0
1608
      specifies the value by which the result of the function is multiplied, before `offset <ReLU.offset>` is added.
1609

1610
    offset : float : default 0.0
1611
      specifies the value added to the result of the function after `scale <ReLU.scale>` has been applied.
1612

1613
    owner : Component
1614
        `component <Component>` to which to assign the Function.
1615

1616
    params : Dict[param keyword: param value] : default None
1617
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1618
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1619
        arguments of the constructor.
1620

1621
    name : str : default see `name <Function.name>`
1622
        specifies the name of the Function.
1623

1624
    prefs : PreferenceSet or specification dict : default Function.classPreferences
1625
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
1626

1627
    Attributes
1628
    ----------
1629

1630
    variable : number or array
1631
        contains value to be transformed.
1632

1633
    gain : float : default 1.0
1634
        value by which to multiply `variable <ReLU.variable>` after `bias <ReLU.bias>` is subtracted
1635
        from it.
1636

1637
    bias : float : default 0.0
1638
        value to subtract from each element of `variable <ReLU.variable>`; functions as threshold.
1639

1640
    leak : float : default 0.0
1641
        scaling factor between 0 and 1 when (variable - bias) is less than or equal to 0.
1642

1643
    range : (None, None)
1644
        modified by `scale <Gaussian.scale> and/or `offset <ReLU.offset>` if they are specified.
1645

1646
    scale : float : default 1.0
1647
      specifies the value by which the result of the function is multiplied, before `offset <ReLU.offset>` is added.
1648

1649
    offset : float : default 0.0
1650
      specifies the value added to the result of the function after `scale <ReLU.scale>` has been applied.
1651

1652
    owner : Component
1653
        `component <Component>` to which the Function has been assigned.
1654

1655
    name : str
1656
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
1657
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
1658

1659
    prefs : PreferenceSet or specification dict : Function.classPreferences
1660
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
1661
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
1662
        for details).
1663
    """
1664

1665
    componentName = RELU_FUNCTION
1✔
1666
    # parameter_keywords.update({GAIN, BIAS, LEAK})
1667
    default_range = (None, None)
1✔
1668

1669

1670
    class Parameters(DeterministicTransferFunction.Parameters):
1✔
1671
        """
1672
            Attributes
1673
            ----------
1674

1675
                bias
1676
                    see `bias <ReLU.bias>`
1677

1678
                    :default value: 0.0
1679
                    :type: ``float``
1680

1681
                gain
1682
                    see `gain <ReLU.gain>`
1683

1684
                    :default value: 1.0
1685
                    :type: ``float``
1686

1687
                leak
1688
                    see `leak <ReLU.leak>`
1689

1690
                    :default value: 0.0
1691
                    :type: ``float``
1692
        """
1693
        gain = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
1694
        bias = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
1695
        leak = Parameter(0.0, modulable=True)
1✔
1696

1697
    @check_user_specified
1✔
1698
    @beartype
1✔
1699
    def __init__(self,
1✔
1700
                 default_variable=None,
1701
                 gain: Optional[ValidParamSpecType] = None,
1702
                 bias: Optional[ValidParamSpecType] = None,
1703
                 leak: Optional[ValidParamSpecType] = None,
1704
                 scale: Optional[ValidParamSpecType] = None,
1705
                 offset: Optional[ValidParamSpecType] = None,
1706
                 params=None,
1707
                 owner=None,
1708
                 prefs:  Optional[ValidPrefSet] = None,
1709
                 **kwargs):
1710
        super().__init__(
1✔
1711
            default_variable=default_variable,
1712
            gain=gain,
1713
            bias=bias,
1714
            leak=leak,
1715
            scale=scale,
1716
            offset=offset,
1717
            params=params,
1718
            owner=owner,
1719
            prefs=prefs,
1720
            **kwargs
1721
        )
1722

1723
    def _function(self,
1✔
1724
                 variable=None,
1725
                 context=None,
1726
                 params=None,
1727
                 ):
1728
        """
1729

1730
        Arguments
1731
        ---------
1732

1733
        variable : number or array : default class_defaults.variable
1734
           a single value or array to be transformed.
1735
        params : Dict[param keyword: param value] : default None
1736
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1737
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1738
            arguments of the constructor.
1739

1740
        Returns
1741
        -------
1742

1743
        ReLU transformation of variable : number or array
1744
        """
1745
        gain = self._get_current_parameter_value(GAIN, context)
1✔
1746
        bias = self._get_current_parameter_value(BIAS, context)
1✔
1747
        leak = self._get_current_parameter_value(LEAK, context)
1✔
1748
        scale = self._get_current_parameter_value(SCALE, context)
1✔
1749
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
1750

1751
        # KAM modified 2/15/19 to match https://en.wikipedia.org/wiki/Rectifier_(neural_networks)#Leaky_ReLUs
1752
        x = gain * (variable - bias)
1✔
1753
        result = scale * np.maximum(x, leak * x) + offset
1✔
1754

1755
        return self.convert_output_type(result)
1✔
1756

1757
    @handle_external_context()
1✔
1758
    def derivative(self, input=None, output=None, context=None):
1✔
1759
        """
1760
        derivative(input or else output)
1761

1762
        Derivative of `function <ReLU._function>` at **input** or **output**.  If **input** is specified, that
1763
        is used to compute the derivative;  if **input** is not specified, it is inferred from the **output**
1764
        and then used to compute the derivative.
1765

1766
        Arguments
1767
        ---------
1768

1769
        input : number
1770
            value of the input to the ReLU transform at which derivative is to be taken.
1771

1772
        Returns
1773
        -------
1774
        derivative :  number or array
1775
        """
1776

1777
        gain = self._get_current_parameter_value(GAIN, context)
1✔
1778
        leak = self._get_current_parameter_value(LEAK, context)
1✔
1779
        bias = self._get_current_parameter_value(BIAS, context)
1✔
1780
        scale = self._get_current_parameter_value(SCALE, context)
1✔
1781

1782
        if input is not None:
1✔
1783
            # Use input if provided
1784
            variable = np.array(input) - bias
1✔
1785
        else:
1786
            # Infer input from output
1787
            variable = np.array(output) / gain
1✔
1788

1789
        value = np.where(variable > 0, scale * gain, scale * gain * leak)
1✔
1790
        return value
1✔
1791

1792
    def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags:frozenset):
1✔
1793
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
1794
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
1795

1796
        gain_ptr = ctx.get_param_or_state_ptr(builder, self, GAIN, param_struct_ptr=params)
1✔
1797
        bias_ptr = ctx.get_param_or_state_ptr(builder, self, BIAS, param_struct_ptr=params)
1✔
1798
        leak_ptr = ctx.get_param_or_state_ptr(builder, self, LEAK, param_struct_ptr=params)
1✔
1799
        scale_ptr = ctx.get_param_or_state_ptr(builder, self, SCALE, param_struct_ptr=params)
1✔
1800
        offset_ptr = ctx.get_param_or_state_ptr(builder, self, OFFSET, param_struct_ptr=params)
1✔
1801

1802
        gain = pnlvm.helpers.load_extract_scalar_array_one(builder, gain_ptr)
1✔
1803
        bias = pnlvm.helpers.load_extract_scalar_array_one(builder, bias_ptr)
1✔
1804
        leak = pnlvm.helpers.load_extract_scalar_array_one(builder, leak_ptr)
1✔
1805
        scale = pnlvm.helpers.load_extract_scalar_array_one(builder, scale_ptr)
1✔
1806
        offset = pnlvm.helpers.load_extract_scalar_array_one(builder, offset_ptr)
1✔
1807

1808
        # Maxnum for some reason needs full function prototype
1809
        max_f = ctx.get_builtin("maxnum", [ctx.float_ty])
1✔
1810
        var = builder.load(ptri)
1✔
1811
        if "derivative_out" in tags:
1✔
1812
            val = builder.fdiv(var, gain)
1✔
1813
        else:
1814
            val = builder.fsub(var, bias)
1✔
1815

1816
        if "derivative" in tags or "derivative_out" in tags:
1✔
1817
            predicate = builder.fcmp_ordered('>', val, val.type(0))
1✔
1818
            gain = builder.fmul(gain, scale)
1✔
1819
            val = builder.select(predicate, gain, builder.fmul(gain, leak))
1✔
1820
        else:
1821
            val1 = builder.fmul(val, gain)
1✔
1822
            val2 = builder.fmul(val1, leak)
1✔
1823

1824
            val = builder.call(max_f, [val1, val2])
1✔
1825
            val = builder.fmul(val, scale)
1✔
1826
            val = builder.fadd(val, offset)
1✔
1827

1828
        builder.store(val, ptro)
1✔
1829

1830
    def _gen_pytorch_fct(self, device, context=None):
1✔
1831
        gain = self._get_pytorch_fct_param_value('gain', device, context)
×
1832
        bias = self._get_pytorch_fct_param_value('bias', device, context)
×
1833
        leak = self._get_pytorch_fct_param_value('leak', device, context)
×
1834
        return lambda x: (torch.max(input=(x - bias), other=torch.tensor([0], device=device).double()) * gain +
×
1835
                            torch.min(input=(x - bias), other=torch.tensor([0], device=device).double()) * leak)
1836

1837

1838
# **********************************************************************************************************************
1839
#                                                    Gaussian
1840
# **********************************************************************************************************************
1841

1842
class Gaussian(DeterministicTransferFunction):  # ----------------------------------------------------------------------
1✔
1843
    """
1844
    Gaussian(                    \
1845
         default_variable,       \
1846
         standard_deviation=1.0, \
1847
         bias=0.0,               \
1848
         scale=1.0,              \
1849
         offset=0.0,             \
1850
         params=None,            \
1851
         owner=None,             \
1852
         name=None,              \
1853
         prefs=None              \
1854
         )
1855

1856
    .. _Gaussian_Function:
1857

1858
    `function <Gaussian._function>` returns Gaussian transform of `variable <Gaussian.variable>`:
1859

1860
    .. math::
1861
      scale*\\frac{e^{-\\frac{(varible-bias)^{2}}{2\\sigma^{2}}}}{\\sqrt{2\\pi}\\sigma}+offset
1862

1863
    where :math:`\\sigma` = `standard_deviation <Gaussian.standard_deviation>`
1864

1865
    .. note::
1866
        the value returned is deterministic (i.e., the value of the probability density function at variable),
1867
        not a randomly chosen sample from the Gaussian distribution; for the latter, use `GaussianDistort`.
1868

1869
    `derivative <Gaussian.derivative>` returns derivative of the Gaussian transform of `variable <Gaussian.variable>`:
1870

1871
    .. math::
1872

1873
       \\frac{-(variable-bias)*e^{-\\frac{(variable-bias)^{2}}{2\\sigma^{2}}}}{\\sqrt{2\\pi}\\sigma^{3}}
1874

1875
    Arguments
1876
    ---------
1877

1878
    default_variable : number or array : default class_defaults.variable
1879
        specifies a template for the value used as the mean for the Guassian transform.
1880

1881
    standard_deviation : float : default 1.0
1882
        specifies "width" of the Gaussian transform applied to each element of `variable <Gaussian.variable>`.
1883

1884
    bias : float : default 0.0
1885
        value to add to each element of `variable <Gaussian.variable>` before applying Gaussian transform.
1886

1887
    scale : float : default 1.0
1888
      specifies the value by which the result of the function is multiplied, before `offset <Gaussian.offset>` is added.
1889

1890
    offset : float : default 0.0
1891
      specifies the value added to the result of the function after `scale <Gaussian.scale>` has been applied.
1892

1893
    params : Dict[param keyword: param value] : default None
1894
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1895
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1896
        arguments of the constructor.
1897

1898
    owner : Component
1899
        `component <Component>` to which to assign the Function.
1900

1901
    name : str : default see `name <Function.name>`
1902
        specifies the name of the Function.
1903

1904
    prefs : PreferenceSet or specification dict : default Function.classPreferences
1905
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
1906

1907
    Attributes
1908
    ----------
1909

1910
    variable : number or array
1911
        value used as the mean of the Gaussian transform.
1912

1913
    standard_deviation : float : default 1.0
1914
        standard_deviation used for Gaussian transform.
1915

1916
    bias : float : default 0.0
1917
        value added to each element of `variable <Gaussian.variable>` before applying the Gaussian transform.
1918

1919
    range : (None, None)
1920
        modified by `scale <Gaussian.scale> and/or `offset <Gaussian.offset>` if they are specified.
1921

1922
    scale : float
1923
      determines the value by which the result of the function is multiplied, before `offset <Gaussian.offset>`
1924
      is added.
1925

1926
    offset : float
1927
      determines the value added to the result of the function after `scale <Gaussian.scale>` has been applied.
1928

1929
    owner : Component
1930
        `component <Component>` to which the Function has been assigned.
1931

1932
    name : str
1933
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
1934
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
1935

1936
    prefs : PreferenceSet or specification dict : Function.classPreferences
1937
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
1938
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
1939
        for details).
1940
    """
1941

1942
    componentName = GAUSSIAN_FUNCTION
1✔
1943
    # parameter_keywords.update({STANDARD_DEVIATION, BIAS, SCALE, OFFSET})
1944
    default_range = (None, None)
1✔
1945

1946
    class Parameters(TransferFunction.Parameters):
1✔
1947
        """
1948
            Attributes
1949
            ----------
1950

1951
                bias
1952
                    see `bias <Gaussian.bias>`
1953

1954
                    :default value: 0.0
1955
                    :type: ``float``
1956

1957
                standard_deviation
1958
                    see `standard_deviation <Gaussian.standard_deviation>`
1959

1960
                    :default value: 1.0
1961
                    :type: ``float``
1962
        """
1963
        standard_deviation = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
1964
        bias = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
1965

1966
    @check_user_specified
1✔
1967
    @beartype
1✔
1968
    def __init__(self,
1✔
1969
                 default_variable=None,
1970
                 standard_deviation: Optional[ValidParamSpecType] = None,
1971
                 bias: Optional[ValidParamSpecType] = None,
1972
                 scale: Optional[ValidParamSpecType] = None,
1973
                 offset: Optional[ValidParamSpecType] = None,
1974
                 params=None,
1975
                 owner=None,
1976
                 prefs:  Optional[ValidPrefSet] = None,
1977
                 **kwargs):
1978
        super().__init__(
1✔
1979
            default_variable=default_variable,
1980
            standard_deviation=standard_deviation,
1981
            bias=bias,
1982
            scale=scale,
1983
            offset=offset,
1984
            params=params,
1985
            owner=owner,
1986
            prefs=prefs,
1987
            **kwargs
1988
        )
1989

1990
    def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags:frozenset):
1✔
1991
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
1992
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
1993

1994
        standard_deviation_ptr = ctx.get_param_or_state_ptr(builder, self, STANDARD_DEVIATION, param_struct_ptr=params)
1✔
1995
        bias_ptr = ctx.get_param_or_state_ptr(builder, self, BIAS, param_struct_ptr=params)
1✔
1996
        scale_ptr = ctx.get_param_or_state_ptr(builder, self, SCALE, param_struct_ptr=params)
1✔
1997
        offset_ptr = ctx.get_param_or_state_ptr(builder, self, OFFSET, param_struct_ptr=params)
1✔
1998

1999
        standard_deviation = pnlvm.helpers.load_extract_scalar_array_one(builder, standard_deviation_ptr)
1✔
2000
        bias = pnlvm.helpers.load_extract_scalar_array_one(builder, bias_ptr)
1✔
2001
        scale = pnlvm.helpers.load_extract_scalar_array_one(builder, scale_ptr)
1✔
2002
        offset = pnlvm.helpers.load_extract_scalar_array_one(builder, offset_ptr)
1✔
2003

2004
        exp_f = ctx.get_builtin("exp", [ctx.float_ty])
1✔
2005
        sqrt_f = ctx.get_builtin("sqrt", [ctx.float_ty])
1✔
2006

2007
        var = builder.load(ptri)
1✔
2008
        exp_num = builder.fsub(var, bias)
1✔
2009
        exp_num = builder.fmul(exp_num, exp_num)
1✔
2010
        exp_num = pnlvm.helpers.fneg(builder, exp_num)
1✔
2011

2012
        exp_denom = builder.fmul(standard_deviation, standard_deviation)
1✔
2013
        exp_denom = builder.fmul(exp_denom.type(2), exp_denom)
1✔
2014
        exp = builder.fdiv(exp_num, exp_denom)
1✔
2015
        numerator = builder.call(exp_f, [exp])
1✔
2016

2017
        denom = builder.fmul(standard_deviation.type(2 * pi), standard_deviation)
1✔
2018
        denom = builder.call(sqrt_f, [denom])
1✔
2019
        val = builder.fdiv(numerator, denom)
1✔
2020

2021
        val = builder.fmul(scale, val)
1✔
2022
        val = builder.fadd(offset, val)
1✔
2023

2024
        builder.store(val, ptro)
1✔
2025

2026
    def _function(self,
1✔
2027
                 variable=None,
2028
                 context=None,
2029
                 params=None,
2030
                 ):
2031
        """
2032

2033
        Arguments
2034
        ---------
2035

2036
        variable : number or array : default class_defaults.variable
2037
           a single value or array to be distorted by Guassian distribution.
2038

2039
        params : Dict[param keyword: param value] : default None
2040
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2041
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2042
            arguments of the constructor.
2043

2044

2045
        Returns
2046
        -------
2047

2048
        Gaussian transformation of variable : number or array
2049

2050
        """
2051
        standard_deviation = self._get_current_parameter_value(STANDARD_DEVIATION, context)
1✔
2052
        bias = self._get_current_parameter_value(BIAS, context)
1✔
2053
        scale = self._get_current_parameter_value(SCALE, context)
1✔
2054
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
2055

2056
        gaussian = e**(-(variable - bias)**2 / (2 * standard_deviation**2)) / sqrt(2 * pi * standard_deviation)
1✔
2057
        result = scale * gaussian + offset
1✔
2058

2059
        return self.convert_output_type(result)
1✔
2060

2061
    @handle_external_context()
1✔
2062
    def derivative(self, input, output=None, context=None):
1✔
2063
        """
2064
        derivative(input)
2065

2066
        Derivative of `function <Gaussian._function>` at **input**.
2067

2068

2069
        Arguments
2070
        ---------
2071

2072
        input : number
2073
            value of the input of the Gaussian transform at which derivative is to be taken.
2074

2075

2076
        Returns
2077
        -------
2078

2079
        Derivative of Guassian of variable :  number or array
2080

2081
        """
2082
        sigma = self._get_current_parameter_value(STANDARD_DEVIATION, context)
×
2083
        bias = self._get_current_parameter_value(BIAS, context)
×
2084

2085
        adjusted_input = input - bias
×
2086
        result = (-adjusted_input * e**(-(adjusted_input**2 / (2 * sigma**2)))) / sqrt(2 * pi * sigma**3)
×
2087

2088
        return self.convert_output_type(result)
×
2089

2090

2091
# **********************************************************************************************************************
2092
#                                               GaussianDistort
2093
# **********************************************************************************************************************
2094

2095
class GaussianDistort(TransferFunction):  #-----------------------------------------------------------------------------
1✔
2096
    """
2097
    GaussianDistort(       \
2098
         default_variable, \
2099
         variance=1.0,     \
2100
         bias=0.0,         \
2101
         scale=1.0,        \
2102
         offset=0.0,       \
2103
         seed=None,        \
2104
         params=None,      \
2105
         owner=None,       \
2106
         name=None,        \
2107
         prefs=None        \
2108
         )
2109

2110
    .. _GaussianDistort_Function:
2111

2112
    `function <GaussianDistort._function>` returns random value from a Gaussian distribution with
2113
     mean = `variable <GaussianDistort.variable>` and variance = `variance <GaussianDistort.variance>`
2114

2115
    .. note::
2116
        if the Gaussian transform of `variable <GaussianDistort.variable>` is desired (i.e., the value of the
2117
        probability density function at `variable <GaussianDistort.variable>`, not a randomly chosen sample from the
2118
        Gaussian distribution, then use `Gaussian`.
2119

2120
    COMMENT:
2121
    `derivative <Gaussian.derivative>` returns derivative of the Gaussian transform of `variable <Logistic.variable>`:
2122

2123
    .. math::
2124

2125
       \\frac{-(variable-bias)*e^{-\\frac{(variable-bias)^{2}}{2\\sigma^{2}}}}{\\sqrt{2\\pi}\\sigma^{3}}
2126
    COMMENT
2127

2128
    Arguments
2129
    ---------
2130

2131
    default_variable : number or array : default class_defaults.variable
2132
        specifies a template for the value(s) used as the mean of the Guassian distribution from which each sample is
2133
        drawn.
2134

2135
    variance : float : default 1.0
2136
        specifies "width" of the Gaussian distribution around each element of `variable <GaussianDistort.variable>`
2137
        from which sample is drawn.
2138

2139
    bias : float : default 0.0
2140
        specifies value to add to each element of `variable <GaussianDistort.variable>` before drawing sample.
2141

2142
    scale : float : default 1.0
2143
        specifies value by which to multiply each sample.
2144

2145
    offset : float : default 0.0
2146
        specifies value to add to each sample after it is drawn and `scale <GaussianDistort.scale>` is applied
2147

2148
    params : Dict[param keyword: param value] : default None
2149
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2150
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2151
        arguments of the constructor.
2152

2153
    owner : Component
2154
        `component <Component>` to which to assign the Function.
2155

2156
    name : str : default see `name <Function.name>`
2157
        specifies the name of the Function.
2158

2159
    prefs : PreferenceSet or specification dict : default Function.classPreferences
2160
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
2161

2162
    Attributes
2163
    ----------
2164

2165
    variable : number or array
2166
        each element determines mean of the Gaussian distribution from which each sample is drawn.
2167

2168
    variance : float
2169
        determines variance of Gaussian distribution from which each sample is drawn.
2170

2171
    bias : float
2172
        determines value added to each element of `variable <GaussianDistort.variable>` before drawing sample.
2173

2174
    scale : float
2175
        determines value by which each sample is multiplied after it is drawn.
2176

2177
    offset : float
2178
        determines value added to each sample after it is drawn and `scale <GaussianDistort.scale>` is applied
2179

2180
    random_state : numpy.RandomState
2181
        private pseudorandom number generator
2182

2183
    owner : Component
2184
        `component <Component>` to which the Function has been assigned.
2185

2186
    name : str
2187
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
2188
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
2189

2190
    prefs : PreferenceSet or specification dict : Function.classPreferences
2191
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
2192
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
2193
        for details).
2194
    """
2195

2196
    componentName = GAUSSIAN_DISTORT_FUNCTION
1✔
2197
    # parameter_keywords.update({VARIANCE, BIAS, SCALE, OFFSET})
2198

2199
    class Parameters(TransferFunction.Parameters):
1✔
2200
        """
2201
            Attributes
2202
            ----------
2203

2204
                bias
2205
                    see `bias <GaussianDistort.bias>`
2206

2207
                    :default value: 0.0
2208
                    :type: ``float``
2209

2210
                offset
2211
                    see `offset <GaussianDistort.offset>`
2212

2213
                    :default value: 0.0
2214
                    :type: ``float``
2215

2216
                random_state
2217
                    see `random_state <GaussianDistort.random_state>`
2218

2219
                    :default value: None
2220
                    :type: ``numpy.random.RandomState``
2221

2222
                scale
2223
                    see `scale <GaussianDistort.scale>`
2224

2225
                    :default value: 1.0
2226
                    :type: ``float``
2227

2228
                variance
2229
                    see `variance <GaussianDistort.variance>`
2230

2231
                    :default value: 1.0
2232
                    :type: ``float``
2233
        """
2234
        variance = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
2235
        bias = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
2236
        scale = Parameter(1.0, modulable=True)
1✔
2237
        offset = Parameter(0.0, modulable=True)
1✔
2238
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
2239
        seed = Parameter(DEFAULT_SEED(), modulable=True, fallback_value=DEFAULT, setter=_seed_setter)
1✔
2240
        range = (None, None)
1✔
2241

2242
    @check_user_specified
1✔
2243
    @beartype
1✔
2244
    def __init__(self,
1✔
2245
                 default_variable=None,
2246
                 variance: Optional[ValidParamSpecType] = None,
2247
                 bias: Optional[ValidParamSpecType] = None,
2248
                 scale: Optional[ValidParamSpecType] = None,
2249
                 offset: Optional[ValidParamSpecType] = None,
2250
                 seed=None,
2251
                 params=None,
2252
                 owner=None,
2253
                 prefs:  Optional[ValidPrefSet] = None):
2254

2255
        super().__init__(
1✔
2256
            default_variable=default_variable,
2257
            variance=variance,
2258
            bias=bias,
2259
            scale=scale,
2260
            offset=offset,
2261
            seed=seed,
2262
            params=params,
2263
            owner=owner,
2264
            prefs=prefs,
2265
        )
2266

2267
    def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags:frozenset):
1✔
2268
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
2269
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
2270

2271
        variance_ptr = ctx.get_param_or_state_ptr(builder, self, VARIANCE, param_struct_ptr=params)
1✔
2272
        bias_ptr = ctx.get_param_or_state_ptr(builder, self, BIAS, param_struct_ptr=params)
1✔
2273
        scale_ptr = ctx.get_param_or_state_ptr(builder, self, SCALE, param_struct_ptr=params)
1✔
2274
        offset_ptr = ctx.get_param_or_state_ptr(builder, self, OFFSET, param_struct_ptr=params)
1✔
2275

2276
        variance = pnlvm.helpers.load_extract_scalar_array_one(builder, variance_ptr)
1✔
2277
        bias = pnlvm.helpers.load_extract_scalar_array_one(builder, bias_ptr)
1✔
2278
        scale = pnlvm.helpers.load_extract_scalar_array_one(builder, scale_ptr)
1✔
2279
        offset = pnlvm.helpers.load_extract_scalar_array_one(builder, offset_ptr)
1✔
2280

2281
        rvalp = builder.alloca(ptri.type.pointee, name="random_out")
1✔
2282
        rand_state_ptr = ctx.get_random_state_ptr(builder, self, state, params)
1✔
2283
        normal_f = ctx.get_normal_dist_function_by_state(rand_state_ptr)
1✔
2284
        builder.call(normal_f, [rand_state_ptr, rvalp])
1✔
2285

2286
        rval = builder.load(rvalp)
1✔
2287
        rval = builder.fmul(rval, variance)
1✔
2288
        val = builder.load(ptri)
1✔
2289
        val = builder.fadd(val, bias)
1✔
2290
        val = builder.fadd(rval, val)
1✔
2291
        val = builder.fmul(val, scale)
1✔
2292
        val = builder.fadd(offset, val)
1✔
2293

2294
        builder.store(val, ptro)
1✔
2295

2296
    def _function(self,
1✔
2297
                 variable=None,
2298
                 context=None,
2299
                 params=None,
2300
                 ):
2301
        """
2302

2303
        Arguments
2304
        ---------
2305

2306
        variable : number or array : default class_defaults.variable
2307
           a single value or array to be transformed.
2308

2309
        params : Dict[param keyword: param value] : default None
2310
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2311
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2312
            arguments of the constructor.
2313

2314

2315
        Returns
2316
        -------
2317

2318
        Sample from Gaussian distribution for each element of variable : number or array
2319

2320
        """
2321
        variance = self._get_current_parameter_value(VARIANCE, context)
1✔
2322
        bias = self._get_current_parameter_value(BIAS, context)
1✔
2323
        scale = self._get_current_parameter_value(SCALE, context)
1✔
2324
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
2325
        random_state = self._get_current_parameter_value('random_state', context)
1✔
2326

2327
        result = scale * random_state.normal(variable + bias, variance) + offset
1✔
2328

2329
        return self.convert_output_type(result)
1✔
2330

2331
    # def derivative(self, output, input=None, context=None):
2332
    #     """
2333
    #     derivative(output, input):
2334
    #
2335
    #     Derivative of `function <Logistic.function>`:
2336
    #
2337
    #         -input/:math:`{variance^3}*\\sqrt{2\\pi}`
2338
    #
2339
    #
2340
    #     Returns
2341
    #     -------
2342
    #
2343
    #     Derivative of Guassian of variable :  number or array
2344
    #
2345
    #     """
2346
    #     variance = self._get_current_parameter_value(VARIANCE, context)
2347
    #     bias = self._get_current_parameter_value(BIAS, context)
2348
    #     scale = self._get_current_parameter_value(SCALE, context)
2349
    #     offset = self._get_current_parameter_value(OFFSET, context)
2350
    #
2351
    #     # The following doesn't work with autograd (https://github.com/HIPS/autograd/issues/416)
2352
    #     f = scale * np.random.normal(input+bias, variance) + offset
2353
    #
2354
    #     # FIX: SHOULD THIS BE variance**1.5 (since variance = sd**2 and term below is supposed to be sd**3)??
2355
    #     df = -input(variance**3 * np.sqrt(2 * np.pi))
2356
    #
2357
    #     return self.convert_output_type(df*f)
2358

2359

2360
# **********************************************************************************************************************
2361
#                                               BinomialDistort
2362
# **********************************************************************************************************************
2363

2364
class BinomialDistort(TransferFunction):  #-----------------------------------------------------------------------------
1✔
2365
    """
2366
    BinomialDistort(          \
2367
         default_variable,    \
2368
         p=0.05,              \
2369
         seed=None,           \
2370
         params=None,         \
2371
         owner=None,          \
2372
         name=None,           \
2373
         prefs=None           \
2374
         )
2375

2376
    .. _BinomialDistort:
2377

2378
    `function <BinomialDistort._function>` returns `variable <BinomialDistort.variable>` with elements randomly
2379
    zeroed with probability **p**:
2380

2381
    .. math::
2382

2383
       if \\ \\ rand[0,1] > p: output_i=0 \\\\
2384
       else: \\ output_i = variable_i
2385

2386
    `derivative <Binomial.derivative>` returns `variable`
2387

2388
    Arguments
2389
    ---------
2390

2391
    default_variable : number or array : default class_defaults.variable
2392
        specifies a template for the value(s) used as the mean of the Guassian distribution from which each sample is
2393
        drawn.
2394

2395
    p : float : default 0.5
2396
        specifies the probability with which each element of `variable` is replaced with zero.
2397

2398
    params : Dict[param keyword: param value] : default None
2399
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2400
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2401
        arguments of the constructor.
2402

2403
    owner : Component
2404
        `component <Component>` to which to assign the Function.
2405

2406
    name : str : default see `name <Function.name>`
2407
        specifies the name of the Function.
2408

2409
    prefs : PreferenceSet or specification dict : default Function.classPreferences
2410
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
2411

2412
    Attributes
2413
    ----------
2414

2415
    variable : number or array
2416
        each element determines mean of the Gaussian distribution from which each sample is drawn.
2417

2418
    p : float
2419
        the probability with which each element of `variable` is replaced with zero.
2420

2421
    random_state : numpy.RandomState
2422
        private pseudorandom number generator
2423

2424
    owner : Component
2425
        `component <Component>` to which the Function has been assigned.
2426

2427
    name : str
2428
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
2429
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
2430

2431
    prefs : PreferenceSet or specification dict : Function.classPreferences
2432
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
2433
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
2434
        for details).
2435
    """
2436

2437
    componentName = BINOMIAL_DISTORT_FUNCTION
1✔
2438

2439
    classPreferences = {
1✔
2440
        PREFERENCE_SET_NAME: 'BinomialClassPreferences',
2441
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
2442
    }
2443

2444
    class Parameters(TransferFunction.Parameters):
1✔
2445
        """
2446
            Attributes
2447
            ----------
2448
                p
2449
                    see `p <BinomialDistort.p>`
2450

2451
                    :default value: 0.5
2452
                    :type: ``float``
2453

2454
                random_state
2455
                    see `random_state <BinomialDistort.random_state>`
2456

2457
                    :default value: None
2458
                    :type: ``numpy.random.RandomState``
2459

2460
        """
2461
        p = Parameter(0.5, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
2462
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
2463
        seed = Parameter(DEFAULT_SEED(), modulable=True, fallback_value=DEFAULT, setter=_seed_setter)
1✔
2464
        range = (None, None)
1✔
2465

2466
    @check_user_specified
1✔
2467
    @beartype
1✔
2468
    def __init__(self,
1✔
2469
                 default_variable=None,
2470
                 p: Optional[ValidParamSpecType] = None,
2471
                 seed=None,
2472
                 params=None,
2473
                 owner=None,
2474
                 prefs: Optional[ValidPrefSet] = None):
2475

2476
        super().__init__(
1✔
2477
            default_variable=default_variable,
2478
            p=p,
2479
            seed=seed,
2480
            params=params,
2481
            owner=owner,
2482
            prefs=prefs,
2483
        )
2484

2485
    def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags:frozenset):
1✔
2486
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
2487
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
2488

2489
        p_ptr = ctx.get_param_or_state_ptr(builder, self, 'p', param_struct_ptr=params)
1✔
2490
        p = builder.load(p_ptr)
1✔
2491
        mod_p = builder.fsub(p.type(1), p)
1✔
2492
        p_mod_ptr = builder.alloca(mod_p.type)
1✔
2493
        builder.store(mod_p, p_mod_ptr)
1✔
2494

2495
        n_ptr = builder.alloca(ctx.int32_ty)
1✔
2496
        builder.store(n_ptr.type.pointee(1), n_ptr)
1✔
2497

2498
        rand_state_ptr = ctx.get_random_state_ptr(builder, self, state, params)
1✔
2499
        binomial_f = ctx.get_binomial_dist_function_by_state(rand_state_ptr)
1✔
2500

2501
        rvalp = builder.alloca(binomial_f.args[-1].type.pointee, name="random_out")
1✔
2502
        builder.call(binomial_f, [rand_state_ptr, n_ptr, p_mod_ptr, rvalp])
1✔
2503

2504
        val = builder.load(ptri)
1✔
2505
        rval = builder.load(rvalp)
1✔
2506
        rval = builder.uitofp(rval, val.type)
1✔
2507
        val = builder.fmul(val, rval)
1✔
2508

2509
        builder.store(val, ptro)
1✔
2510

2511
    def _function(self,
1✔
2512
                 variable=None,
2513
                 context=None,
2514
                 params=None,
2515
                 ):
2516
        """
2517

2518
        Arguments
2519
        ---------
2520

2521
        variable : number or array : default class_defaults.variable
2522
           a single value or array to be randomly zeroed.
2523

2524
        params : Dict[param keyword: param value] : default None
2525
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2526
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2527
            arguments of the constructor.
2528

2529

2530
        Returns
2531
        -------
2532

2533
        variable with elements zeroed with probability p : number or array
2534

2535
        """
2536
        p = self._get_current_parameter_value('p', context)
1✔
2537
        random_state = self._get_current_parameter_value('random_state', context)
1✔
2538
        result = variable * random_state.binomial(size=len(variable), n=1, p=(1 - p))
1✔
2539
        return self.convert_output_type(result)
1✔
2540

2541
    def _is_identity(self, context=None, defaults=False):
1✔
2542
        if defaults:
×
2543
            p = self.defaults.p
×
2544
        else:
2545
            p = self.parameters.p._get(context)
×
2546
        return p == 0.0
×
2547

2548
    def derivative(self, output, input=None, context=None):
1✔
2549
        raise FunctionError(f"Derivative of BinomialDistort not yet supported.")
2550
    #     """
2551
    #     derivative(input, output):
2552
    #
2553
    #     Derivative of `function <BinomialDistort.function>`:
2554
    #
2555
    #         -input/:math:`{variance^3}*\\sqrt{2\\pi}`
2556
    #
2557
    #
2558
    #     Returns
2559
    #     -------
2560
    #
2561
    #     Derivative of Binomial of variable :  number or array
2562
    #
2563
    #     """
2564
    #     bias = self._get_current_parameter_value(BIAS, context)
2565
    #     scale = self._get_current_parameter_value(SCALE, context)
2566
    #     offset = self._get_current_parameter_value(OFFSET, context)
2567
    #
2568
    #     # The following doesn't work with autograd (https://github.com/HIPS/autograd/issues/416)
2569
    #     f = scale * np.random.normal(input+bias, variance) + offset
2570
    #
2571
    # # FIX: ?WHICH IF EITHER IS CORRECT?:
2572
    # return self._get_current_parameter_value(VARIABLE, context)
2573
    # # return 1.0
2574

2575

2576
# **********************************************************************************************************************
2577
#                                                    Dropout
2578
# **********************************************************************************************************************
2579

2580
class Dropout(TransferFunction):  #
1✔
2581
    # -------------------------------------------------------------------------------------
2582
    """
2583
    Dropout(               \
2584
         default_variable, \
2585
         p=0.5,            \
2586
         params=None,      \
2587
         owner=None,       \
2588
         name=None,        \
2589
         prefs=None        \
2590
         )
2591

2592
    .. _Dropout:
2593

2594
    `function <Dropout._function>` returns `variable <Dropout.variable>` with elements randomly zeroed with
2595
    probability **p** during learning; otherwise functions as `Identity` Function.  During learning, the output
2596
    of the function is scaled by :math:`\\frac{1}{(1-p)}`, which implements the inverse scaling form of `dropout
2597
    <https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html?highlight=dropout>`_ used by by PyTorch.
2598

2599
    .. math::
2600

2601
       if \\ (context.runmode == ContextFlags.LEARNING\\_MODE) \\ and \\ (rand[0,1] > p):  output_i = 0 \\\\
2602
       else: \\ output_i = \\frac{1}{(1-p)}variable_i
2603

2604
    .. _technical_note::
2605
       **learning_only** uses ``context.runmode`` == `ContextFlags.LEARNING_MODE`
2606
       to determine when learning is in effect
2607

2608
    `derivative <Dropout.derivative>` returns `variable`
2609

2610
    Arguments
2611
    ---------
2612

2613
    default_variable : number or array : default class_defaults.variable
2614
        specifies a template for the value to be transformed.
2615

2616
    p : float : default 0.5
2617
        specifies the probability with which each element of `variable` is replaced with zero.
2618

2619
    params : Dict[param keyword: param value] : default None
2620
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2621
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2622
        arguments of the constructor.
2623

2624
    owner : Component
2625
        `component <Component>` to which to assign the Function.
2626

2627
    name : str : default see `name <Function.name>`
2628
        specifies the name of the Function.
2629

2630
    prefs : PreferenceSet or specification dict : default Function.classPreferences
2631
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
2632

2633
    Attributes
2634
    ----------
2635

2636
    variable : number or array
2637
        contains value to be transformed.
2638

2639
    p : float
2640
        the probability with which each element of `variable` is replaced with zero.
2641

2642
    random_state : numpy.RandomState
2643
        private pseudorandom number generator
2644

2645
    owner : Component
2646
        `component <Component>` to which the Function has been assigned.
2647

2648
    name : str
2649
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
2650
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
2651

2652
    prefs : PreferenceSet or specification dict : Function.classPreferences
2653
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
2654
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
2655
        for details).
2656
    """
2657

2658
    componentName = DROPOUT_FUNCTION
1✔
2659

2660
    classPreferences = {
1✔
2661
        PREFERENCE_SET_NAME: 'DropoutClassPreferences',
2662
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
2663
    }
2664

2665
    class Parameters(TransferFunction.Parameters):
1✔
2666
        """
2667
            Attributes
2668
            ----------
2669

2670
                p
2671
                    see `p <Dropout.p>`
2672

2673
                    :default value: 0.5
2674
                    :type: ``float``
2675

2676
                random_state
2677
                    see `random_state <GaussianDistort.random_state>`
2678

2679
                    :default value: None
2680
                    :type: ``numpy.random.RandomState``
2681
        """
2682
        p = Parameter(0.5, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
2683
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
2684
        seed = Parameter(DEFAULT_SEED(), modulable=True, fallback_value=DEFAULT, setter=_seed_setter)
1✔
2685

2686
    @check_user_specified
1✔
2687
    @beartype
1✔
2688
    def __init__(self,
1✔
2689
                 default_variable=None,
2690
                 p: Optional[ValidParamSpecType] = None,
2691
                 params=None,
2692
                 owner=None,
2693
                 prefs: Optional[ValidPrefSet]  = None):
2694
        self.binomial_distort = BinomialDistort(default_variable=default_variable, p=p)
1✔
2695

2696
        super().__init__(
1✔
2697
            default_variable=default_variable,
2698
            p=p,
2699
            params=params,
2700
            owner=owner,
2701
            prefs=prefs,
2702
        )
2703

2704
    def _function(self,
1✔
2705
                 variable=None,
2706
                 context=None,
2707
                 params=None,
2708
                 ):
2709
        """
2710

2711
        Arguments
2712
        ---------
2713

2714
        variable : number or array : default class_defaults.variable
2715
           a single value or array to be randomly zeroed.
2716

2717
        params : Dict[param keyword: param value] : default None
2718
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2719
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2720
            arguments of the constructor.
2721

2722
        Returns
2723
        -------
2724

2725
        During learning, variable with elements zeroed with probability p, else scaled by :math:`frac{1}{(1-p)}`;
2726
        otherwise returns variable : number or array
2727

2728
        """
2729
        p = self._get_current_parameter_value('p', context)
1✔
2730

2731
        if context.runmode != ContextFlags.LEARNING_MODE:
1!
2732
            result = variable
1✔
2733

2734
        else:
2735
            p = p or self.defaults.p
×
2736
            self.binomial_distort.parameters.p.set(p, context)
×
2737
            result = self.binomial_distort(variable) * (1 / (1 - p))
×
2738

2739
        return self.convert_output_type(result)
1✔
2740

2741
    @handle_external_context()
1✔
2742
    def derivative(self, input=None, output=None, context=None):
1✔
2743
        # raise FunctionError(f"Derivative of Dropout not yet supported.")
2744
        """
2745
        derivative(input)
2746

2747
        Derivative of `function <Dropout._function>` at **input**.
2748

2749
        Arguments
2750
        ---------
2751

2752
        input : number or array
2753
            value of the input to the Dropouput function at which derivative is to be taken.
2754

2755
        Returns
2756
        -------
2757

2758
        variable :  number or array
2759

2760
        """
2761
        # FIX: ?WHICH IS CORRECT:
2762
        # return self._get_current_parameter_value(VARIABLE, context)
2763
        return 1.0
×
2764

2765
    def _is_identity(self, context=None, defaults=False):
1✔
2766
        if defaults:
×
2767
            p = self.defaults.p
×
2768
        else:
2769
            p = self.parameters.p._get(context)
×
2770

2771
        return (context.run_mode != ContextFlags.LEARNING_MODE) or (p == 0.0)
×
2772

2773
    def _gen_llvm_transfer(self, builder, index, ctx, vi, vo, params, state, *, tags:frozenset):
1✔
2774
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
2775
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
2776

2777
        val = builder.load(ptri)
1✔
2778
        builder.store(val, ptro)
1✔
2779

2780
    def _gen_pytorch_fct(self, device, context=None):
1✔
2781
        prob = self._get_pytorch_fct_param_value('p')
×
2782
        return lambda x: (torch.dropout(input=x, p=prob, train=False))
×
2783

2784

2785
# **********************************************************************************************************************
2786
#                                                   SoftMax
2787
# **********************************************************************************************************************
2788

2789
softmax_modes = {ALL, ARG_MAX, ARG_MAX_INDICATOR, MAX_VAL, MAX_INDICATOR, PROB, PROB_INDICATOR}
1✔
2790

2791

2792
class SoftMax(TransferFunction):
1✔
2793
    """
2794
    SoftMax(                        \
2795
         default_variable,          \
2796
         gain=1.0,                  \
2797
         mask_threshold=None,       \
2798
         adapt_scale=1,             \
2799
         adapt_base=1,              \
2800
         adapt_entropy_weighting=.1 \
2801
         output=ALL,                \
2802
         params=None,               \
2803
         owner=None,                \
2804
         name=None,                 \
2805
         prefs=None                 \
2806
         )
2807

2808
    .. _SoftMax:
2809

2810
    SoftMax transform of `variable <Softmax.variable>`
2811

2812
    `function <SoftMax._function>` returns SoftMax transform of `variable <Softmax.variable>`:
2813

2814
    .. math::
2815

2816
        \\frac{e^{gain * variable_i}}{\\sum\\limits^{len(variable)}e^{gain * variable}}
2817

2818
    filtered by `output <SoftMax.output>` specification (see `The Softmax function and its derivative
2819
    <http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/>`_ for a nice discussion).
2820

2821
        .. note::
2822
           If `variable <SoftMax.variable>` is all zeros, the SoftMax transform returns all zeros.
2823

2824
    .. _SoftMax_AdaptGain:
2825

2826
    *Thresholding and Adaptive Gain*
2827

2828
    For cases in which SoftMax is used with sparse vectors (e.g., one-hots), the value(s) of the most significant
2829
    entries (e.g., the 1s in a one-hot) can be sensitive to (diminished by) the number of other values in the vector
2830
    (i.e., its length). For example, whereas for ``[1 0]`` the SoftMax is ``[0.73105858 0.26894142]``, for ``[1 0 0 0]``
2831
    it is ``[0.47536689 0.1748777  0.1748777  0.1748777]``. This can be addressed in one of two ways: either by
2832
    thresholding `variable <SoftMax.variable>` before applying the SoftMax function, or by adapting the `gain
2833
    <SoftMax.gain>` parametrically based on the `variable <SoftMax.variable>`:
2834

2835
    - *mask_threshold* -- setting the **mask_threshold** argument to a scalar value causes the `variable
2836
      <SoftMax.variable>` to be thresholded by that value before applying the SoftMax function; Each element in
2837
      `variable <SoftMax.variable>` is first scaled by `gain <SoftMax.gain>`. Then, any elements with an absolute
2838
      value below *mask_threshold* are set to negative infinity (``-inf``), effectively masking them since
2839
      ``exp(-inf) = 0``. The remaining values are then passed through the SoftMax function. This only applies if the
2840
      **gain** argument is specified as a scalar; if it is specified as *ADAPTIVE*, then the **mask_threshold**
2841
      argument is ignored.
2842

2843
    - *ADAPTIVE* -- setting **gain** argument to *ADAPTIVE* causes it to be dynamically adjusted,
2844
      based on the entropy and length of the variable, to keep the mass of the distribution around the highest values
2845
      as consistent as possible over different sized vectors. If *ADAPTIVE* is specified, then the `mask_threshold
2846
      <SoftMax.mask_threshold>` argument is ignored. The gain is adapted by calling the SoftMax function's `adapt_gain
2847
      <SoftMax.adapt_gain>` method. This can be finicky, and may need to be further tuned to the length of `variable
2848
      <SoftMax.variable>`, which can be done using the SoftMax Function's **adapt_scale**, **adapt_base**, and
2849
      **adapt_entropy_weighting** arguments.
2850

2851
    .. _SoftMax_Derivative:
2852

2853
    *Derivative*
2854

2855
    `derivative <SoftMax.derivative>` returns the derivative of the SoftMax.  If *OUTPUT_TYPE* for the SoftMax
2856
    is *ALL*, returns Jacobian matrix (derivative for each element of the output array with respect to each of the
2857
    others):
2858

2859
    .. math::
2860
        D_jS_i = S_i(\\delta_{i,j} - S_j),\\ where\\ \\delta_{i,j}=1\\ if\\ i=j\\ and\\ \\delta_{i,j}=0\\ if\\ i≠j.
2861

2862
    If *OUTPUT_TYPE* is *ARG_MAX*, *ARG_MAX_INDICATOR*, *MAX_VAL*, *MAX_INDICATOR*, returns 1d array of the
2863
    derivatives of the maximum value(s) with respect to the others (calculated as above). If *OUTPUT_TYPE* is *PROB*,
2864
    raises an exception (since it is ambiguous as to which element would have been chosen by the SoftMax function)
2865

2866
    Arguments
2867
    ---------
2868

2869
    default_variable : 1d array : default class_defaults.variable
2870
        specifies a template for the value to be transformed.
2871

2872
    gain : scalar or ADAPTIVE : default 1.0
2873
        specifies the value by which to multiply `variable <Linear.variable>` before SoftMax transformation,
2874
        which functions as the inverse "temperature" of the function.  If it is a scalar, it must be greater
2875
        than zero.  If *ADAPTIVE* is specified, the value is determined dynamically based on the `variable
2876
        <SoftMax.variable>`; see `Thresholding and Adaptive Gain <SoftMax_AdaptGain>` for details).
2877

2878
    mask_threshold : scalar : default None
2879
        specifies whether to mask_threshold the `variable <SoftMax.variable>` before applying the SoftMax function;
2880
        this only applies if `gain <SoftMax.gain>` is specified as a scalar;  otherwise it is ignored
2881
        (see `Thresholding and Adaptive Gain <SoftMax_AdaptGain>` for details).
2882

2883
    adapt_scale : scalar : default 1
2884
        specifies the *scale* parameter using by the `adapt_gain <SoftMax.adapt_gain>` method (see method for details).
2885

2886
    adapt_base : scalar : default 1
2887
        specifies the *base* parameter using by the `adapt_gain <SoftMax.adapt_gain>` method (see method for details).
2888

2889
    adapt_entropy_weighting : default .1
2890
        specifies the *entropy_weighting* parameter using by the `adapt_gain <SoftMax.adapt_gain>` method
2891
        (see method for details).
2892

2893
    output : ALL, ARG_MAX, ARG_MAX_INDICATOR, MAX_VAL, MAX_INDICATOR, or PROB : default ALL
2894
        specifies the format of array returned by `function <SoftMax._function>`
2895
        (see `output <SoftMax.output>` for details).
2896

2897
    per_item : boolean : default True
2898
        for 2d variables, determines whether the SoftMax function will be applied to the entire variable (per_item =
2899
        False), or applied to each item in the variable separately (per_item = True).
2900

2901
    params : Dict[param keyword: param value] : default None
2902
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2903
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2904
        arguments of the constructor.
2905

2906
    owner : Component
2907
        `component <Component>` to which to assign the Function.
2908

2909
    name : str : default see `name <Function.name>`
2910
        specifies the name of the Function.
2911

2912
    prefs : PreferenceSet or specification dict : default Function.classPreferences
2913
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
2914

2915
    Attributes
2916
    ----------
2917

2918
    variable : 1d array
2919
        contains value to be transformed.
2920

2921
    gain : scalar or ADAPTIVE
2922
        determines how `variable <Logistic.variable>` is scaled before the SoftMax transformation, determining the
2923
        "sharpness" of the distribution (it is equivalent to the inverse of the temperature of the SoftMax function);
2924
        if it is 'ADAPTIVE', it is determined dynamically adjusted using the `adapt_gain <SoftMax.adapt_gain>` method
2925
        (see `Thresholding and Adaptive Gain <SoftMax_AdaptGain>` for additional details).
2926

2927
    mask_threshold : scalar or None
2928
        determines whether the `variable <SoftMax.variable>` is thresholded before applying the SoftMax function; if
2929
        it is a scalar, each element of `variable <SoftMax.variable>` is first scaled by ` gain <SoftMax.gain>`. Then
2930
        only elements with an absolute value greater than **mask_threshold** are considered when applying the SoftMax
2931
        function, while all other elements are set to ``-inf`` effectively masking them since ``exp(-inf) = 0``.
2932
        This only applies if `gain <SoftMax.gain>` is specified as a scalar;  otherwise it is ignored
2933
        (see `Thresholding and Adaptive Gain <SoftMax_AdaptGain>` for details).
2934

2935
    adapt_scale : scalar
2936
        determines the *scale* parameter using by the `adapt_gain <SoftMax.adapt_gain>` method (see method for details).
2937

2938
    adapt_base : scalar
2939
        determines the *base* parameter using by the `adapt_gain <SoftMax.adapt_gain>` method (see method for details).
2940

2941
    adapt_entropy_weighting : scalar
2942
        determines the *entropy_weighting* parameter using by the `adapt_gain <SoftMax.adapt_gain>` method
2943
        (see method for details).
2944

2945
    output : ALL, ARG_MAX, ARG_MAX_INDICATOR, MAX_VAL, MAX_INDICATOR, or PROB
2946
        determines how the SoftMax-transformed values of the elements in `variable <SoftMax.variable>` are reported
2947
        in the array returned by `function <SoftMax._function>`:
2948
            * *ALL*: array of all SoftMax-transformed values (the default);
2949
            * *ARG_MAX*: 1 for single element with the maximum SoftMax-transformed value, 0 for all others;
2950
              (one with lowest index of there are multiple maximum values);
2951
            * *ARG_MAX_INDICATOR*: 1 for a single element with the maximum SoftMax-transformed value, 0 for all others;
2952
              (one with lowest index of there are multiple maximum values);
2953
            * *MAX_VAL*: SoftMax-transformed value for the element(s) with the maximum such value, 0 for all others;
2954
            * *MAX_INDICATOR*: 1 for the element(s) with the maximum SoftMax-transformed value, 0 for all others;
2955
            * *PROB*: probabilistically chosen element based on SoftMax-transformed values after setting the
2956
              sum of values to 1 (i.e., their `Luce Ratio <https://en.wikipedia.org/wiki/Luce%27s_choice_axiom>`_),
2957
              0 for all others.
2958

2959
    per_item : boolean : default True
2960
        for 2d variables, determines whether the SoftMax function is applied to the entire variable (per_item =
2961
        False), or applied to each item in the variable separately (per_item = True).
2962

2963
    range : None if `output <SoftMax.output>` in {ARG_MAX, MAX_VAL}, else (0, 1) : default (0, 1)
2964

2965
    owner : Component
2966
        `component <Component>` to which the Function has been assigned.
2967

2968
    name : str
2969
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
2970
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
2971

2972
    prefs : PreferenceSet or specification dict : Function.classPreferences
2973
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
2974
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
2975
        for details).
2976
    """
2977

2978
    componentName = SOFTMAX_FUNCTION
1✔
2979

2980
    class Parameters(TransferFunction.Parameters):
1✔
2981
        """
2982
            Attributes
2983
            ----------
2984

2985
                variable
2986
                    see `variable <SoftMax.variable>`
2987

2988
                    :default value: numpy.array(0.)
2989
                    :type: ``numpy.ndarray``
2990
                    :read only: True
2991

2992
                adapt_scale
2993
                    see `adapt_scale <SoftMax.adapt_scale>`
2994

2995
                    :default value: 1.0
2996
                    :type: ``float``
2997

2998
                adapt_base
2999
                    see `adapt_base <SoftMax.adapt_base>`
3000

3001
                    :default value: 1.0
3002
                    :type: ``float``
3003

3004
                adapt_entropy_weighting
3005
                    see `adapt_entropy_weighting <SoftMax.adapt_entropy_weighting>`
3006

3007
                    :default value: 0.1
3008
                    :type: ``float``
3009

3010
                range
3011
                    see `range <SoftMax.range>`
3012

3013
                    :default value: (0, 1)
3014
                    :type: <class 'tuple'>
3015

3016
                gain
3017
                    see `gain <SoftMax.gain>`
3018

3019
                    :default value: 1.0
3020
                    :type: ``float``
3021

3022
                output
3023
                    see `output <SoftMax.output>`
3024

3025
                    :default value: `ALL`
3026
                    :type: ``str``
3027

3028
                per_item
3029
                    see `per_item <SoftMax.per_item>`
3030

3031
                    :default value: True
3032
                    :type: ``bool``
3033

3034
                mask_threshold
3035
                    see `mask_threshold <SoftMax.mask_threshold>`
3036

3037
                    :default value: None
3038
                    :type: ``float``
3039
        """
3040
        variable = Parameter(np.array([[0.0]]), read_only=True, pnl_internal=True, constructor_argument='default_variable')
1✔
3041
        gain = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
3042
        mask_threshold = Parameter(None, modulable=True)
1✔
3043
        adapt_scale = Parameter(1.0, modulable=True)
1✔
3044
        adapt_base = Parameter(1.0, modulable=True)
1✔
3045
        adapt_entropy_weighting = Parameter(0.95, modulable=True)
1✔
3046
        range = (0, 1)
1✔
3047
        output = ALL
1✔
3048
        per_item = Parameter(True, pnl_internal=True)
1✔
3049
        one_hot_function = Parameter(None, stateful=False, loggable=False)
1✔
3050

3051
        def _validate_gain(self, gain):
1✔
3052
            if is_numeric_scalar(gain):
1!
3053
                if gain <= 0:
1!
3054
                    return 'must be a scalar greater than 0'
×
3055
            elif isinstance(gain, str):
×
3056
                if gain != ADAPTIVE:
×
3057
                    return f'the keyword for adaptive gain is {ADAPTIVE}'
×
3058
            else:
3059
                return f'must be a scalar greater than 0 or the keyword {ADAPTIVE}'
×
3060

3061
        def _validate_mask_threshold(self, mask_threshold):
1✔
3062
            if mask_threshold is not None:
1✔
3063
                if is_numeric_scalar(mask_threshold):
1!
3064
                    if mask_threshold <= 0:
1!
3065
                        return 'must be a scalar greater than 0'
×
3066
                    return None
1✔
3067
                return f'must be a scalar greater than 0'
×
3068

3069
        def _validate_adapt_scale(self, adapt_scale):
1✔
3070
            if is_numeric_scalar(adapt_scale):
1!
3071
                if adapt_scale <= 0:
1!
3072
                    return 'must be a scalar greater than 0'
×
3073
                return None
1✔
3074
            return f'must be a scalar greater than 0'
×
3075

3076
        def _validate_adapt_base(self, adapt_base):
1✔
3077
            if is_numeric_scalar(adapt_base):
1!
3078
                if adapt_base <= 0:
1!
3079
                    return 'must be a scalar greater than 0'
×
3080
                return None
1✔
3081
            return f'must be a scalar greater than 0'
×
3082

3083
        def _validate_adapt_entropy_weighting(self, adapt_entropy_weighting):
1✔
3084
            if is_numeric_scalar(adapt_entropy_weighting):
1!
3085
                if adapt_entropy_weighting <= 0:
1!
3086
                    return 'must be a scalar greater than 0'
×
3087
                return None
1✔
3088
            return f'must be a scalar greater than 0'
×
3089

3090
        def _validate_output(self, output):
1✔
3091
            if output not in softmax_modes:
1!
3092
                return 'not one of {0}'.format(softmax_modes)
×
3093

3094
    @check_user_specified
1✔
3095
    @beartype
1✔
3096
    def __init__(self,
1✔
3097
                 default_variable=None,
3098
                 gain: Optional[ValidParamSpecType] = None,
3099
                 mask_threshold: Optional[ValidParamSpecType] = None,
3100
                 adapt_scale: Optional[ValidParamSpecType] = None,
3101
                 adapt_base: Optional[ValidParamSpecType] = None,
3102
                 adapt_entropy_weighting: Optional[ValidParamSpecType] = None,
3103
                 output=None,
3104
                 per_item=None,
3105
                 params: Optional[Mapping] = None,
3106
                 owner=None,
3107
                 prefs:  Optional[ValidPrefSet] = None):
3108

3109
        try:
1✔
3110
            # needed because one_hot_function is initialized here based
3111
            # on output argument, which may also be passed in params
3112
            output = params['output']
1✔
3113
        except (TypeError, KeyError):
1✔
3114
            pass
1✔
3115

3116
        if output not in {None, ALL}:
1✔
3117
            one_hot_function = OneHot(mode=output)
1✔
3118
        else:
3119
            one_hot_function = None
1✔
3120

3121
        super().__init__(
1✔
3122
            default_variable=default_variable,
3123
            gain=gain,
3124
            mask_threshold=mask_threshold,
3125
            adapt_scale=adapt_scale,
3126
            adapt_base=adapt_base,
3127
            adapt_entropy_weighting=adapt_entropy_weighting,
3128
            per_item=per_item,
3129
            output=output,
3130
            one_hot_function=one_hot_function,
3131
            params=params,
3132
            owner=owner,
3133
            prefs=prefs,
3134
        )
3135

3136
        self._negative_input_warning = False
1✔
3137

3138
    def _parse_one_hot_function_variable(self, variable):
1✔
3139
        if self.defaults.per_item and len(np.shape(variable)) > 1:
1✔
3140
            variable = variable[0]
1✔
3141

3142
        if self.defaults.output in {PROB, PROB_INDICATOR}:
1✔
3143
            prob_dist = np.asarray(variable)
1✔
3144
            # creates probability distribution in shape of variable
3145
            prob_dist = np.ones(variable.shape) / safe_len(prob_dist)
1✔
3146

3147
            variable = np.asarray([variable, prob_dist])
1✔
3148

3149
        return variable
1✔
3150

3151
    def _validate_variable(self, variable, context=None):
1✔
3152
        if variable is None:
1!
3153
            try:
×
3154
                return self.defaults.variable
×
3155
            except AttributeError:
×
3156
                return self.class_defaults.variable
×
3157

3158
        return np.asarray(variable)
1✔
3159

3160
    def apply_softmax(self, input_value, gain, mask_threshold, output_type):
1✔
3161
        # Modulate input_value by gain
3162
        v = gain * input_value
1✔
3163

3164
        # Mask threshold
3165
        if mask_threshold is not None:
1✔
3166
            if np.any(v < 0):
1!
3167
                warnings.warn(f"SoftMax function: mask_threshold is set "
×
3168
                              f"to {mask_threshold} but input_value contains negative values."
3169
                              f"Masking will be applied to the magnitude of the input.")
3170

3171
            v = np.where(np.abs(v) > mask_threshold, v, -np.inf)
1✔
3172

3173
        # Make numerically stable by shifting by max value
3174
        if np.any(v != -np.inf):
1✔
3175
            v = v - np.max(v)
1✔
3176

3177
        # Exponentiate
3178
        v = np.exp(v)
1✔
3179

3180
        # Normalize (to sum to 1)
3181
        if not np.any(v):
1✔
3182
            # If v is all zeros, avoid divide by zero in normalize and return all zeros for softmax
3183
            sm = v
1✔
3184
        else:
3185
            sm = v / np.sum(v)
1✔
3186

3187
        # Generate one-hot encoding based on selected output_type
3188
        if output_type in {ARG_MAX, ARG_MAX_INDICATOR, MAX_VAL, MAX_INDICATOR}:
1✔
3189
            return self.one_hot_function(sm)
1✔
3190
        elif output_type in {PROB, PROB_INDICATOR}:
1✔
3191
            return self.one_hot_function([input_value, sm])
1✔
3192
        else:
3193
            return sm
1✔
3194

3195
    def _function(self,
1✔
3196
                 variable=None,
3197
                 context=None,
3198
                 params=None,
3199
                 ):
3200
        """
3201

3202
        Arguments
3203
        ---------
3204

3205
        variable : 1d array : default class_defaults.variable
3206
           an array to be transformed.
3207

3208
        params : Dict[param keyword: param value] : default None
3209
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
3210
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
3211
            arguments of the constructor.
3212

3213
        Returns
3214
        -------
3215

3216
        SoftMax transformation of variable : number or array
3217

3218
        """
3219
        # Assign the params and return the result
3220
        output_type = self._get_current_parameter_value(OUTPUT_TYPE, context)
1✔
3221
        gain = self._get_current_parameter_value(GAIN, context)
1✔
3222
        mask_threshold = self._get_current_parameter_value('mask_threshold', context)
1✔
3223
        if isinstance(gain, str) and gain == ADAPTIVE:
1!
3224
            gain = self.adapt_gain(variable, context)
×
3225
        per_item = self._get_current_parameter_value(PER_ITEM, context)
1✔
3226

3227
        # Compute softmax and assign to sm
3228
        if per_item and len(np.shape(variable)) > 1:
1✔
3229
            output = []
1✔
3230
            for item in variable:
1✔
3231
                output.append(self.apply_softmax(item, gain, mask_threshold, output_type))
1✔
3232
            output = convert_all_elements_to_np_array(output)
1✔
3233
        else:
3234
            output = self.apply_softmax(variable, gain, mask_threshold, output_type)
1✔
3235

3236
        return self.convert_output_type(output)
1✔
3237

3238
    def adapt_gain(self, v, context)->float:
1✔
3239
        """Compute the softmax gain (inverse temperature) based on the entropy of the distribution of values.
3240
        Uses base, scale, and entropy_weighting parameters of SoftMax function to compute gain:
3241

3242
        .. math:: gain = scale * (base + (entropy\\_weighting * log(entropy(logistic(v)))))
3243
        """
3244
        scale = self._get_current_parameter_value('adapt_scale', context)
×
3245
        base = self._get_current_parameter_value('adapt_base', context)
×
3246
        entropy_weighting = self._get_current_parameter_value('adapt_entropy_weighting', context)
×
3247
        entropy_weighting = np.log(len(v)) * entropy_weighting
×
3248

3249
        v = np.squeeze(v)
×
3250
        gain = scale * (base +
×
3251
                        (entropy_weighting *
3252
                         np.log(
3253
                             -1 * np.sum((1 / (1 + np.exp(-1 * v))) * np.log(1 / (1 + np.exp(-1 * v)))))))
3254
        return gain
×
3255

3256
    @handle_external_context()
1✔
3257
    def derivative(self, input=None, output=None, context=None):
1✔
3258
        """
3259
        derivative(output)
3260

3261
        .. technical note::
3262
           If ARG_MAX or MAX_VAL is specified for the `output <SoftMax.output>` parameter, and there is more than one
3263
           equivalent maximum value, the element with the lowest index is used to compute the derivative (see
3264
           IMPLEMENTATION NOTE below).
3265

3266
        Returns
3267
        -------
3268
        derivative of values returned by SoftMax :  1d or 2d array (depending on *OUTPUT_TYPE* of SoftMax)
3269
        """
3270

3271
        if output is None:
1✔
3272
            output = self.function(input, params={OUTPUT_TYPE: ALL}, context=context)
1✔
3273
        elif np.any(np.equal(0, output)) and context.source == ContextFlags.CONSTRUCTOR:
1!
3274
            # Allow derivative to be computed when output is 0 during initialization
3275
            output = np.where(output, output==0, 1)
×
3276
        else:
3277
            assert not np.any(np.equal(0, output)), \
1✔
3278
                f"Derivative of SoftMax function for '{self.owner.name}' is not defined when output is 0."
3279

3280
        per_item = self._get_current_parameter_value(PER_ITEM, context)
1✔
3281
        if not per_item:
1✔
3282
            output = [output]
1✔
3283

3284
        if np.array(output).ndim == 1:
1✔
3285
            output = np.atleast_2d(output)
1✔
3286

3287
        result = []
1✔
3288
        for sm in output:
1✔
3289
            size = len(sm)
1✔
3290

3291
            output_type = self._get_current_parameter_value(OUTPUT_TYPE, context)
1✔
3292
            if output_type == ALL:
1✔
3293
                # Return full Jacobian matrix of derivatives using Kronecker's delta method:
3294
                derivative = np.empty([size, size])
1✔
3295
                for i, j in np.ndindex(size, size):
1✔
3296
                    if i == j:
1✔
3297
                        d = 1
1✔
3298
                    else:
3299
                        d = 0
1✔
3300
                    derivative[j, i] = sm[i] * (d - sm[j])
1✔
3301
            elif output_type in {ARG_MAX, ARG_MAX_INDICATOR, MAX_VAL, MAX_INDICATOR}:
1✔
3302
                # Return 1d array of derivatives for max element (i.e., the one chosen by SoftMax)
3303
                derivative = np.empty(size)
1✔
3304
                # Get the element of output returned as non-zero (max val) when output_type is not ALL
3305
                # IMPLEMENTATION NOTE:
3306
                #    if there is a tie for max, this chooses the item in sm with the lowest index in sm:
3307
                index_of_max = int(np.where(sm == np.max(sm))[-1][0])
1✔
3308
                #    the following would randomly choose a value in case of a tie,
3309
                #    but may cause problems with compilation:
3310
                # index_of_max = np.where(sm == np.max(sm))[0]
3311
                # if len(index_of_max)>1:
3312
                #     index_of_max = int(np.random.choice(index_of_max))
3313
                max_val = sm[index_of_max]
1✔
3314
                for i in range(size):
1✔
3315
                    if i == index_of_max:
1✔
3316
                        d = 1
1✔
3317
                    else:
3318
                        d = 0
1✔
3319
                    derivative[i] = sm[i] * (d - max_val)
1✔
3320
            else:
3321
                raise FunctionError("Can't assign derivative for SoftMax function{} since OUTPUT_TYPE is PROB "
3322
                                    "(and therefore the relevant element is ambiguous)".format(self.owner_name))
3323

3324
            result.append(derivative)
1✔
3325

3326
        assert per_item or len(result) == 1
1✔
3327
        return result[0] if not per_item or np.array(result).ndim == 3 else result
1✔
3328

3329
    def __gen_llvm_exp_sum(self, builder, index, ctx, vi, gain, exp_sum_ptr):
1✔
3330
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
3331

3332
        exp_f = ctx.get_builtin("exp", [ctx.float_ty])
1✔
3333
        orig_val = builder.load(ptri)
1✔
3334
        val = builder.fmul(orig_val, gain)
1✔
3335
        exp_val = builder.call(exp_f, [val])
1✔
3336

3337
        exp_sum = builder.load(exp_sum_ptr)
1✔
3338
        new_exp_sum = builder.fadd(exp_sum, exp_val)
1✔
3339
        builder.store(new_exp_sum, exp_sum_ptr)
1✔
3340

3341
    def __gen_llvm_exp_div(self, builder, index, ctx, vi, vo, gain, exp_sum):
1✔
3342
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
3343
        ptri = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
3344
        exp_f = ctx.get_builtin("exp", [ctx.float_ty])
1✔
3345
        orig_val = builder.load(ptri)
1✔
3346
        val = builder.fmul(orig_val, gain)
1✔
3347
        val = builder.call(exp_f, [val])
1✔
3348
        val = builder.fdiv(val, exp_sum)
1✔
3349

3350
        builder.store(val, ptro)
1✔
3351

3352
    def __gen_llvm_apply(self, ctx, builder, params, state, arg_in, arg_out, output_type, tags:frozenset):
1✔
3353
        exp_sum_ptr = builder.alloca(ctx.float_ty)
1✔
3354
        builder.store(exp_sum_ptr.type.pointee(0), exp_sum_ptr)
1✔
3355

3356
        gain_ptr = ctx.get_param_or_state_ptr(builder, self, GAIN, param_struct_ptr=params)
1✔
3357
        gain = pnlvm.helpers.load_extract_scalar_array_one(builder, gain_ptr)
1✔
3358

3359
        with pnlvm.helpers.array_ptr_loop(builder, arg_in, "exp_sum_max") as args:
1✔
3360
            self.__gen_llvm_exp_sum(*args, ctx=ctx, vi=arg_in, gain=gain,
1✔
3361
                                    exp_sum_ptr=exp_sum_ptr)
3362

3363
        exp_sum = builder.load(exp_sum_ptr)
1✔
3364

3365
        if output_type == ALL:
1✔
3366
            one_hot_p = ctx.get_param_or_state_ptr(builder, self, 'one_hot_function', param_struct_ptr=params, state_struct_ptr=state)
1✔
3367

3368
            # Derivative first gets the output_type == ALL result even if the selected output type is different.
3369
            assert self.output != output_type or one_hot_p.type.pointee.elements == (), \
1✔
3370
                "OneHot parameter should be empty for output_type == ALL: {}".format(one_hot_p)
3371
            with pnlvm.helpers.array_ptr_loop(builder, arg_in, "exp_div") as args:
1✔
3372
                self.__gen_llvm_exp_div(ctx=ctx, vi=arg_in, vo=arg_out,
1✔
3373
                                        gain=gain, exp_sum=exp_sum, *args)
3374
            return builder
1✔
3375

3376
        one_hot_p, one_hot_s = ctx.get_param_or_state_ptr(builder, self, 'one_hot_function', param_struct_ptr=params, state_struct_ptr=state)
1✔
3377
        one_hot_f = ctx.import_llvm_function(self.one_hot_function, tags=tags)
1✔
3378

3379
        assert one_hot_f.args[3].type == arg_out.type
1✔
3380
        one_hot_out = arg_out
1✔
3381
        one_hot_in = builder.alloca(one_hot_f.args[2].type.pointee)
1✔
3382

3383
        if output_type in {ARG_MAX, ARG_MAX_INDICATOR, MAX_VAL, MAX_INDICATOR}:
1✔
3384
            with pnlvm.helpers.array_ptr_loop(builder, arg_in, "exp_div") as (b, i):
1✔
3385
                self.__gen_llvm_exp_div(ctx=ctx, vi=arg_in, vo=one_hot_in,
1✔
3386
                                        gain=gain, exp_sum=exp_sum, builder=b, index=i)
3387

3388
            builder.call(one_hot_f, [one_hot_p, one_hot_s, one_hot_in, one_hot_out])
1✔
3389

3390
        elif output_type in PROB:
1✔
3391
            one_hot_in_data = builder.gep(one_hot_in, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
3392
            one_hot_in_dist = builder.gep(one_hot_in, [ctx.int32_ty(0), ctx.int32_ty(1)])
1✔
3393

3394
            with pnlvm.helpers.array_ptr_loop(builder, arg_in, "exp_div") as (b, i):
1✔
3395
                self.__gen_llvm_exp_div(ctx=ctx, vi=arg_in, vo=one_hot_in_dist,
1✔
3396
                                        gain=gain, exp_sum=exp_sum, builder=b, index=i)
3397

3398
                dist_in = b.gep(arg_in, [ctx.int32_ty(0), i])
1✔
3399
                dist_out = b.gep(one_hot_in_data, [ctx.int32_ty(0), i])
1✔
3400
                b.store(b.load(dist_in), dist_out)
1✔
3401

3402

3403
            builder.call(one_hot_f, [one_hot_p, one_hot_s, one_hot_in, one_hot_out])
1✔
3404
        else:
3405
            assert False, "Unsupported output in {} for LLVM execution mode: {}".format(self, output_type)
3406

3407
        return builder
1✔
3408

3409
    def _gen_llvm_function_derivative_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
1✔
3410
        assert "derivative" in tags or "derivative_out" in tags
1✔
3411
        assert arg_in.type == arg_out.type
1✔
3412
        forward_tags = tags.difference({"derivative", "derivative_out"})
1✔
3413

3414
        # SoftMax derivative is calculated from the "ALL" results.
3415
        if "derivative_out" in tags:
1✔
3416
            all_out = arg_in
1✔
3417
        else:
3418
            all_out = builder.alloca(arg_out.type.pointee)
1✔
3419
            builder = self._gen_llvm_function_body(ctx, builder, params, state, arg_in, all_out, output_type=ALL, tags=forward_tags)
1✔
3420

3421
        if self.parameters.per_item.get():
1✔
3422
            assert isinstance(arg_in.type.pointee.element, pnlvm.ir.ArrayType)
1✔
3423
            assert isinstance(arg_out.type.pointee.element, pnlvm.ir.ArrayType)
1✔
3424
            for i in range(arg_in.type.pointee.count):
1✔
3425
                inner_all_out = builder.gep(all_out, [ctx.int32_ty(0), ctx.int32_ty(i)])
1✔
3426
                inner_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(i)])
1✔
3427
                builder = self.__gen_llvm_apply_derivative(ctx, builder, params, state, inner_all_out, inner_out, tags=tags)
1✔
3428
            return builder
1✔
3429
        else:
3430
            return self.__gen_llvm_apply_derivative(ctx, builder, params, state, all_out, arg_out, tags=tags)
1✔
3431

3432
    def __gen_llvm_apply_derivative(self, ctx, builder, params, state, all_out, arg_out, *, tags:frozenset):
1✔
3433

3434
        assert self.output in {ARG_MAX, ARG_MAX_INDICATOR, MAX_VAL, MAX_INDICATOR}, (
1✔
3435
            "Derivative of SoftMax is only implemented for ARG_MAX and ARG_MAX_INDICATOR "
3436
            "in LLVM execution mode ({})".format(self.output))
3437

3438
        max_pos_ptr = builder.alloca(ctx.int32_ty)
1✔
3439
        builder.store(max_pos_ptr.type.pointee(-1), max_pos_ptr)
1✔
3440
        max_val_ptr = builder.alloca(arg_out.type.pointee.element)
1✔
3441
        builder.store(max_val_ptr.type.pointee(float("NaN")), max_val_ptr)
1✔
3442

3443
        with pnlvm.helpers.array_ptr_loop(builder, all_out, id="max") as (b, idx):
1✔
3444
            val_ptr = b.gep(all_out, [ctx.int32_ty(0), idx])
1✔
3445
            val = b.load(val_ptr)
1✔
3446
            max_val = b.load(max_val_ptr)
1✔
3447
            new_max = b.fcmp_unordered(">", val, max_val)
1✔
3448
            with b.if_then(new_max):
1✔
3449
                b.store(val, max_val_ptr)
1✔
3450
                b.store(idx, max_pos_ptr)
1✔
3451

3452
        max_val = builder.load(max_val_ptr)
1✔
3453
        max_pos = builder.load(max_pos_ptr)
1✔
3454

3455
        with pnlvm.helpers.array_ptr_loop(builder, all_out, id="derivative") as (b, idx):
1✔
3456
            val_ptr = b.gep(all_out, [ctx.int32_ty(0), idx])
1✔
3457
            val = b.load(val_ptr)
1✔
3458
            is_max_pos = b.icmp_unsigned("==", idx, max_pos)
1✔
3459

3460
            d = b.select(is_max_pos, val.type(1), val.type(0))
1✔
3461
            dv = b.fsub(d, max_val)
1✔
3462
            val = b.fmul(val, dv)
1✔
3463

3464
            out_ptr = b.gep(arg_out, [ctx.int32_ty(0), idx])
1✔
3465
            b.store(val, out_ptr)
1✔
3466

3467
        return builder
1✔
3468

3469
    def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, output_type=None, *, tags:frozenset):
1✔
3470
        output_type = self.output if output_type is None else output_type
1✔
3471
        if "derivative" in tags or "derivative_out" in tags:
1✔
3472
            return self._gen_llvm_function_derivative_body(ctx, builder, params, state, arg_in, arg_out, tags=tags)
1✔
3473

3474
        if self.parameters.per_item.get():
1✔
3475
            assert isinstance(arg_in.type.pointee.element, pnlvm.ir.ArrayType)
1✔
3476
            assert isinstance(arg_out.type.pointee.element, pnlvm.ir.ArrayType)
1✔
3477
            for i in range(arg_in.type.pointee.count):
1✔
3478
                inner_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(i)])
1✔
3479
                inner_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(i)])
1✔
3480
                builder = self.__gen_llvm_apply(ctx, builder, params, state, inner_in, inner_out, output_type, tags=tags)
1✔
3481
            return builder
1✔
3482
        else:
3483
            return self.__gen_llvm_apply(ctx, builder, params, state, arg_in, arg_out, output_type, tags=tags)
1✔
3484

3485
    def _gen_pytorch_fct(self, device, context=None):
1✔
3486
        gain = self._get_pytorch_fct_param_value('gain', device, context)
1✔
3487
        mask_threshold = self._get_pytorch_fct_param_value('mask_threshold', device, context)
1✔
3488

3489
        if isinstance(gain, str) and gain == ADAPTIVE:
1!
3490
            return lambda x: (torch.softmax(self._gen_pytorch_adapt_gain_fct(device, context)(x) * x, -1))
×
3491

3492
        elif mask_threshold is not None:
1!
3493
            def pytorch_thresholded_softmax(_input: torch.Tensor) -> torch.Tensor:
1✔
3494
                v = gain * _input
1✔
3495

3496
                # Apply threshold-based masking
3497
                if mask_threshold is not None:
1!
3498
                    if torch.any(_input < 0) and not self._negative_input_warning:
1✔
3499
                        warnings.warn(f"Softmax function: mask_threshold is set to {mask_threshold}, "
1✔
3500
                                      f"but input contains negative values. "
3501
                                      f"Masking will be applied to the magnitude of the input.")
3502
                        self._negative_input_warning = True
1✔
3503

3504
                    # Create a mask where values below threshold are set to -inf
3505
                    mask = torch.abs(v) > mask_threshold
1✔
3506
                    v = v.masked_fill(~mask, float('-inf'))  # More stable than torch.where()
1✔
3507

3508
                # Handle case where all values are masked (return tensor with gradient support)
3509
                if torch.all(~mask):
1✔
3510
                    return torch.full_like(v, 0.0, requires_grad=True)
1✔
3511

3512
                # Make numerically stable by shifting max value
3513
                max_v = torch.max(v[mask])  # Avoid computing max over -inf
1✔
3514
                v = v - max_v
1✔
3515

3516
                # Compute softmax (PyTorch handles -inf correctly)
3517
                exp_v = torch.exp(v)
1✔
3518
                sm = exp_v / torch.sum(exp_v, dim=-1, keepdim=True)
1✔
3519

3520
                return sm
1✔
3521
            # Return the function
3522
            return pytorch_thresholded_softmax
1✔
3523

3524
        else:
3525
            return lambda x: (torch.softmax(gain * x, -1))
×
3526

3527
    def _gen_pytorch_adapt_gain_fct(self, device, context=None):
1✔
3528
        scale = self._get_pytorch_fct_param_value('adapt_scale', device, context)
×
3529
        base = self._get_pytorch_fct_param_value('adapt_base', device, context)
×
3530
        entropy_weighting = self._get_pytorch_fct_param_value('adapt_entropy_weighting', device, context)
×
3531
        # v = torch.squeeze(v)
3532
        return lambda x : scale * (base +
×
3533
                                   (entropy_weighting * len(x) *
3534
                                    torch.log(-1 * torch.sum((1 / (1 + torch.exp(-1 * x)))
3535
                                                             * torch.log(1 / (1 + torch.exp(-1 * x)))))))
3536

3537

3538
# **********************************************************************************************************************
3539
#                                                    Angle
3540
# **********************************************************************************************************************
3541

3542
# FIX: VALIDATE LEN(VARIABLE)>=2
3543

3544
class Angle(TransferFunction):  # -------------------------------------------------------------------------------------
1✔
3545
    """
3546
    Angle(                 \
3547
         default_variable, \
3548
         params=None,      \
3549
         owner=None,       \
3550
         name=None,        \
3551
         prefs=None        \
3552
         )
3553

3554
    .. _Angle_Function:
3555

3556
    `function <angle._function>` returns Angle transform of vector in `variable <Angle.variable>`:
3557

3558
    COMMENT:
3559
    FIX: WITH PROPER MATHEMATICAL DEFN
3560
    .. math::
3561

3562
        slope * variable + intercept
3563

3564
    `derivative <Angle.derivative>` returns `slope <Angle.slope>`.
3565
    COMMENT
3566

3567
    Arguments
3568
    ---------
3569

3570
    default_variable : 1array : default class_defaults.variable
3571
        specifies a template for the value to be transformed;  length must be at least 2.
3572

3573
    params : Dict[param keyword: param value] : default None
3574
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
3575
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
3576
        arguments of the constructor.
3577

3578
    owner : Component
3579
        `component <Component>` to which to assign the Function.
3580

3581
    name : str : default see `name <Function.name>`
3582
        specifies the name of the Function.
3583

3584
    prefs : PreferenceSet or specification dict : default Function.classPreferences
3585
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
3586

3587
    Attributes
3588
    ----------
3589

3590
    variable : 1d array
3591
        contains value to be transformed.
3592

3593
    owner : Component
3594
        `component <Component>` to which the Function has been assigned.
3595

3596
    name : str
3597
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
3598
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
3599

3600
    prefs : PreferenceSet or specification dict : Function.classPreferences
3601
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
3602
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
3603
        for details).
3604
    """
3605

3606
    componentName = ANGLE_FUNCTION
1✔
3607

3608
    classPreferences = {
1✔
3609
        PREFERENCE_SET_NAME: 'AngleClassPreferences',
3610
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
3611
    }
3612

3613
    _model_spec_class_name_is_generic = True
1✔
3614

3615
    class Parameters(TransferFunction.Parameters):
1✔
3616
        """
3617
            Attributes
3618
            ----------
3619

3620
                variable
3621
                    see `variable <Angle.variable>`
3622

3623
                    :default value: numpy.array([0.,0,])
3624
                    :type: ``numpy.ndarray``
3625
                    :read only: True
3626

3627
        """
3628
        variable = Parameter(np.array([1,1]),
1✔
3629
                             read_only=True,
3630
                             pnl_internal=True,
3631
                             constructor_argument='default_variable')
3632

3633
        def _validate_variable(self, variable):
1✔
3634
            variable = np.squeeze(variable)
1✔
3635
            if variable.ndim != 1 or len(variable) < 2:
1!
3636
                return f"must be list or 1d array of length 2 or greater."
×
3637

3638
    @check_user_specified
1✔
3639
    @beartype
1✔
3640
    def __init__(self,
1✔
3641
                 default_variable=None,
3642
                 params=None,
3643
                 owner=None,
3644
                 prefs:  Optional[ValidPrefSet] = None):
3645

3646
        super().__init__(
1✔
3647
            default_variable=default_variable,
3648
            params=params,
3649
            owner=owner,
3650
            prefs=prefs,
3651
        )
3652

3653
    def _function(self,
1✔
3654
                 variable=None,
3655
                 context=None,
3656
                 params=None,
3657
                 ):
3658
        """
3659

3660
        Arguments
3661
        ---------
3662

3663
        variable : ndarray : default class_defaults.variable
3664
           an array of coordinates on a sphere to be transformed to n+1d angular coordinates;  must be at least 2d.
3665

3666
        params : Dict[param keyword: param value] : default None
3667
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
3668
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
3669
            arguments of the constructor.
3670

3671
        Returns
3672
        -------
3673

3674
        Angle transformation of variable : ndarray of variable.ndim+1
3675

3676
        """
3677
        try:
1✔
3678
            # By default, result should be returned as np.ndarray with same dimensionality as input
3679
            result = self._angle(variable)
1✔
3680
        except TypeError:
×
3681
            if hasattr(variable, "dtype"):
×
3682
                # If variable is an array with mixed sizes or types, try item-by-item operation
3683
                if variable.dtype == object:
×
3684
                    result = np.zeros_like(variable)
×
3685
                    for i, item in enumerate(variable):
×
3686
                        result[i] = self._angle(variable[i])
×
3687
                else:
3688
                    raise FunctionError("Unrecognized type for {} of {} ({})".format(VARIABLE, self.name, variable))
3689
            # KAM 6/28/18: If the variable does not have a "dtype" attr but made it to this line, then it must be of a
3690
            # type that even np does not recognize -- typically a custom OutputPort variable with items of different
3691
            # shapes (e.g. variable = [[0.0], [0.0], array([[0.0, 0.0]])] )
3692
            elif isinstance(variable, list):
×
3693
                result = []
×
3694
                for variable_item in variable:
×
3695
                    result.append(self._angle(variable_item))
×
3696
            else:
3697
                raise FunctionError("Unrecognized type for {} of {} ({})".format(VARIABLE, self.name, variable))
3698

3699
        return self.convert_output_type(result)
1✔
3700

3701
    def _angle(self, value):
1✔
3702
        """Take nd value and return n+1d coordinates for angle on a sphere"""
3703
        value = np.squeeze(value)
1✔
3704
        dim = len(value) + 1
1✔
3705
        angle = np.zeros(dim)
1✔
3706
        sin_value = np.sin(value)
1✔
3707
        cos_value = np.cos(value)
1✔
3708
        angle[0] = cos_value[0]
1✔
3709
        prod_a = np.cumprod(np.flip(sin_value))[:-1]
1✔
3710
        angle[dim - 1] = prod_a[-1]
1✔
3711
        prod_a[-1] = 1.
1✔
3712

3713
        # going down from the top of cumprod we skip: 2 edge values +1 extra for output size
3714
        for j in range(1, dim - 1):
1✔
3715
            angle[j] = prod_a[dim -3 -j] * cos_value[j]
1✔
3716
        return angle
1✔
3717

3718
    def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
1✔
3719
        assert isinstance(arg_in.type.pointee, pnlvm.ir.ArrayType)
1✔
3720
        assert isinstance(arg_out.type.pointee, pnlvm.ir.ArrayType)
1✔
3721
        assert len(arg_in.type.pointee) + 1 == len(arg_out.type.pointee)
1✔
3722

3723
        # The first cos
3724
        res0_ptr = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
3725
        val0_ptr = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
3726
        val0 = builder.load(val0_ptr)
1✔
3727
        cos_f = ctx.get_builtin("cos", [val0.type])
1✔
3728
        cos_val0 = builder.call(cos_f, [val0])
1✔
3729
        builder.store(cos_val0, res0_ptr)
1✔
3730

3731
        # calculate suffix product
3732
        sin_f = ctx.get_builtin("sin", [val0.type])
1✔
3733
        prod_ptr = builder.alloca(val0.type)
1✔
3734
        builder.store(prod_ptr.type.pointee(1.0), prod_ptr)
1✔
3735

3736
        dim_m1 = ctx.int32_ty(len(arg_out.type.pointee) - 1)
1✔
3737
        with pnlvm.helpers.for_loop(builder, dim_m1.type(1), dim_m1, dim_m1.type(1), id="suff_prod") as (b, idx):
1✔
3738
            #revert the index to go from the end
3739
            idx = b.sub(dim_m1, idx)
1✔
3740

3741
            prod = b.load(prod_ptr)
1✔
3742
            val_ptr = b.gep(arg_in, [ctx.int32_ty(0), idx])
1✔
3743
            val = b.load(val_ptr)
1✔
3744

3745
            # calculate suffix product of sin(input)
3746
            val_sin = b.call(sin_f, [val])
1✔
3747
            new_prod = b.fmul(prod, val_sin)
1✔
3748
            b.store(new_prod, prod_ptr)
1✔
3749

3750
            # output value is suffix product * cos(val)
3751
            val_cos = b.call(cos_f, [val])
1✔
3752
            res = b.fmul(prod, val_cos)
1✔
3753
            res_ptr = b.gep(arg_out, [ctx.int32_ty(0), idx])
1✔
3754
            b.store(res, res_ptr)
1✔
3755

3756
        # The last element is just the suffix product * 1
3757
        last_ptr = builder.gep(arg_out, [ctx.int32_ty(0), dim_m1])
1✔
3758
        builder.store(builder.load(prod_ptr), last_ptr)
1✔
3759

3760
        return builder
1✔
3761

3762
    # @handle_external_context()
3763
    # def derivative(self, input=None, output=None, context=None):
3764
    #     """
3765
    #     derivative(input)
3766
    #
3767
    #     Derivative of `function <Angle._function>` at **input**.
3768
    #
3769
    #     Arguments
3770
    #     ---------
3771
    #
3772
    #     input : number
3773
    #         value of the input to the Angle transform at which derivative is to be taken.
3774
    #
3775
    #     Returns
3776
    #     -------
3777
    #
3778
    #     Slope of function :  number or array
3779
    #
3780
    #     """
3781
    #
3782
    #     return self._get_current_parameter_value(SLOPE, context)
3783
    #
3784
    # def _is_identity(self, context=None):
3785
    #     return (
3786
    #         self.parameters.slope._get(context) == 1
3787
    #         and self.parameters.intercept._get(context) == 0
3788
    #     )
3789

3790

3791
# **********************************************************************************************************************
3792
#                                             TransferWithCosts
3793
# **********************************************************************************************************************
3794

3795
# Keywords for TransferWithCosts arguments, cost functions and their parameters ----------------------------------------
3796

3797
# Make accessible externally
3798
__all__.extend(['ENABLED_COST_FUNCTIONS',
1✔
3799
                'INTENSITY_COST',
3800
                'INTENSITY_COST_FUNCTION',
3801
                'INTENSITY_COST_FCT_MULTIPLICATIVE_PARAM',
3802
                'INTENSITY_COST_FCT_ADDITIVE_PARAM',
3803
                'ADJUSTMENT_COST',
3804
                'ADJUSTMENT_COST_FUNCTION',
3805
                'ADJUSTMENT_COST_FCT_MULTIPLICATIVE_PARAM',
3806
                'ADJUSTMENT_COST_FCT_ADDITIVE_PARAM',
3807
                'DURATION_COST',
3808
                'DURATION_COST_FUNCTION',
3809
                'DURATION_COST_FCT_MULTIPLICATIVE_PARAM',
3810
                'DURATION_COST_FCT_ADDITIVE_PARAM',
3811
                'COMBINED_COSTS',
3812
                'COMBINE_COSTS_FUNCTION',
3813
                'COMBINE_COSTS_FCT_MULTIPLICATIVE_PARAM',
3814
                'COMBINE_COSTS_FCT_ADDITIVE_PARAM',
3815
                'costFunctionNames', 'CostFunctions'
3816
                ])
3817

3818
ENABLED_COST_FUNCTIONS = 'enabled_cost_functions'
1✔
3819

3820
# These are assigned to TransferWithCosts Function to make them accesible for modulation
3821
INTENSITY_COST = 'intensity_cost'
1✔
3822
INTENSITY_COST_FUNCTION = 'intensity_cost_fct'
1✔
3823
INTENSITY_COST_FCT_MULTIPLICATIVE_PARAM = 'intensity_cost_fct_mult_param'
1✔
3824
INTENSITY_COST_FCT_ADDITIVE_PARAM = 'intensity_cost_fct_add_param'
1✔
3825

3826
ADJUSTMENT_COST = 'adjustment_cost'
1✔
3827
ADJUSTMENT_COST_FUNCTION = 'adjustment_cost_fct'
1✔
3828
ADJUSTMENT_COST_FCT_MULTIPLICATIVE_PARAM = 'adjustment_cost_fct_mult_param'
1✔
3829
ADJUSTMENT_COST_FCT_ADDITIVE_PARAM = 'adjustment_cost_fct_add_param'
1✔
3830

3831
DURATION_COST = 'duration_cost'
1✔
3832
DURATION_COST_FUNCTION = 'duration_cost_fct'
1✔
3833
DURATION_COST_FCT_MULTIPLICATIVE_PARAM = 'duration_cost_fct_mult_param'
1✔
3834
DURATION_COST_FCT_ADDITIVE_PARAM = 'duration_cost_fct_add_param'
1✔
3835

3836
COMBINED_COSTS = 'combined_costs'
1✔
3837
COMBINE_COSTS_FUNCTION = 'combine_costs_fct'
1✔
3838
COMBINE_COSTS_FCT_MULTIPLICATIVE_PARAM = 'combine_costs_fct_mult_param'
1✔
3839
COMBINE_COSTS_FCT_ADDITIVE_PARAM = 'combine_costs_fct_add_param'
1✔
3840

3841
costFunctionNames = [INTENSITY_COST_FUNCTION,
1✔
3842
                     ADJUSTMENT_COST_FUNCTION,
3843
                     DURATION_COST_FUNCTION,
3844
                     COMBINE_COSTS_FUNCTION]
3845

3846

3847
class CostFunctions(Flag):
1✔
3848
    """Options for selecting constituent cost functions to be used by a `TransferWithCosts` Function.
3849

3850
    These can be used alone or in combination with one another, by enabling or disabling each using the
3851
    `TransferWithCosts` Function's `enable_costs <TransferWithCosts.enable_costs>`,
3852
    `disable_costs <TransferWithCosts.disable_costs>`, `toggle_cost <TransferWithCosts.toggle_cost>` and
3853
    `assign_costs <TransferWithCosts.assign_costs>` methods.
3854

3855
    Attributes
3856
    ----------
3857

3858
    NONE
3859
        `cost <TransferWithCosts.cost>` is not computed.
3860

3861
    INTENSITY
3862
        `duration_cost_fct` is used to calculate a contribution to the `cost <TransferWithCosts.cost>`
3863
        based its current `intensity <TransferWithCosts.intensity>` value.
3864

3865
    ADJUSTMENT
3866
        `adjustment_cost_fct` is used to calculate a contribution to the `cost <TransferWithCosts.cost>`
3867
        based on the change in its `intensity <TransferWithCosts.intensity>` from its last value.
3868

3869
    DURATION
3870
        `duration_cost_fct` is used to calculate a contribitution to the `cost <TransferWithCosts.cost>`
3871
        based on its integral (i.e., it accumulated value over multiple executions).
3872

3873
    ALL
3874
        all of the cost functions are used to calculate `cost <TransferWithCosts.cost>`.
3875

3876
    DEFAULTS
3877
        assign default set of cost functions as `INTENSITY`).
3878

3879
    """
3880
    NONE          = 0
1✔
3881
    INTENSITY     = auto()
1✔
3882
    ADJUSTMENT    = auto()
1✔
3883
    DURATION      = auto()
1✔
3884
    ALL           = INTENSITY | ADJUSTMENT | DURATION
1✔
3885
    DEFAULTS      = NONE
1✔
3886

3887

3888
TRANSFER_FCT = 'transfer_fct'
1✔
3889
INTENSITY_COST_FCT = 'intensity_cost_fct'
1✔
3890
ADJUSTMENT_COST_FCT = 'adjustment_cost_fct'
1✔
3891
DURATION_COST_FCT = 'duration_cost_fct'
1✔
3892
COMBINE_COSTS_FCT = 'combine_costs_fct'
1✔
3893

3894
class TransferWithCosts(TransferFunction):
1✔
3895
    """
3896
    TransferWithCosts(                      \
3897
        default_variable=None,              \
3898
        input_shapes=None,                          \
3899
        transfer_fct=Line                   \
3900
        enabled_cost_functions=None,        \
3901
        intensity_fct=Exponential           \
3902
        adjustment_fct=Linear               \
3903
        duration_fct=SimpleIntegrator       \
3904
        combine_costs_fct=LinearCombination \
3905
        params=None,                        \
3906
        owner=None,                         \
3907
        prefs=None                          \
3908
        )
3909

3910
    .. _TransferWithCosts:
3911

3912
    returns value of `variable <TransferWithCosts.variable>` transformed by `transfer_fct
3913
    <TransferWithCosts.transfer_fct>`, after calling any cost functions that are enabled and assigning
3914
    the result(s) to the corresponding parameter(s), as described below.
3915

3916
    .. _TransferWithCosts_Cost_Functions:
3917

3918
    **Cost Functions**
3919

3920
    The TransferWithCosts function has three individual cost functions that it can execute when its `function
3921
    <TransferWithCosts._function>` is executed, which assign their results to the attributes indicated below:
3922

3923
    * `intensity_cost_fct <TransferWithCosts.intensity_cost_fct>` -> `intensity_cost <TransferWithCosts.intensity_cost>`;
3924
    * `adjustment_cost_fct <TransferWithCosts.adjustment_cost_fct>` -> `adjustment_cost <TransferWithCosts.adjustment_cost>`;
3925
    * `duration_cost_fct <TransferWithCosts.duration_cost_fct>` -> `duration_cost <TransferWithCosts.duration_cost>`;
3926

3927
    Which functions are called is determined by the settings in `enabled_cost_functions
3928
    <TransferWithCosts.enabled_cost_functions>`, that can be initialized in the constructor using the
3929
    **enabled_cost_functions** argument, and later modified using the `enable_costs <TransferWithCosts.enable_costs>`,
3930
    `disable_costs <TransferWithCosts.disable_costs>`, `toggle_cost <TransferWithCosts.toggle_cost>` and
3931
    `assign_costs <TransferWithCosts.assign_costs>` methods.  The value of any cost for which its function has
3932
    *never* been enabled is None;  otherwise, it is the value assigned when it was last enabled and executed
3933
    (see `duration_cost_fct <TransferWithCosts.duration_cost_fct>` for additional details concerning that function).
3934

3935
    If any cost functions are enabled, then the `combine_costs_fct <TransferWithCosts.combine_costs_fct>` function
3936
    is executed, which sums the results of those that are enabled (Hadamard style, if the costs are arrays), and
3937
    stores the result in the `combined_costs <TransferWithCosts.combined_costs>` attribute.  Its value is None if no
3938
    cost functions have ever been enabled;  otherwise it is the value assigned the last time one or more cost functions
3939
    were enabled.
3940

3941
    .. _TransferWithCosts_Modulation_of_Cost_Params:
3942

3943
    **Modulation of Cost Function Parameters**
3944

3945
    The `multiplicative_param <Function_Modulatory_Params>` and `additive_param <Function_Modulatory_Params>` of each
3946
    `cost function <TransferWithCosts_Cost_Functions>` is assigned as a parameter of the TransferWithCost `Function`.
3947
    This makes them accessible for `modulation <ModulatorySignal_Modulation>` when the Function is assigned to a
3948
    `Port` (e.g., as the default `function <ControlSignal._function>` of a `ControlSignal`), or a `Mechanism
3949
    <Mechanism>`.  They can be referred to in the **modulation** argument of a `ModulatorySignal`\\'s constructor
3950
    (see `ModulatorySignal_Types`) using the following keywords:
3951

3952
        *INTENSITY_COST_FCT_MULTIPLICATIVE_PARAM*
3953
        *INTENSITY_COST_FCT_ADDITIVE_PARAM*
3954
        *ADJUSTMENT_COST_FCT_MULTIPLICATIVE_PARAM*
3955
        *ADJUSTMENT_COST_FCT_ADDITIVE_PARAM*
3956
        *DURATION_COST_FCT_MULTIPLICATIVE_PARAM*
3957
        *DURATION_COST_FCT_ADDITIVE_PARAM*
3958
        *COMBINE_COSTS_FCT_MULTIPLICATIVE_PARAM*
3959
        *COMBINE_COSTS_FCT_ADDITIVE_PARAM*
3960
    |
3961
    See `example <ControlSignal_Example_Modulate_Costs>` of how these keywords can be used to
3962
    modulate the parameters of the cost functions of a TransferMechanism assigned to a ControlSignal.
3963

3964
    Arguments
3965
    ---------
3966

3967
    variable : list or 1d array of numbers: Default class_defaults.variable
3968
        specifies shape and default value of the array for variable used by `transfer_fct
3969
        <TransferWithCosts.transfer_fct>`
3970
        on which costs are calculated.
3971

3972
    input_shapes : int : None
3973
        specifies length of the array for `variable <TransferWithCosts.variable>` used by `function
3974
        <TransferWithCosts._function>` and on which costs are calculated;  can be used in place of
3975
        default_value, in which case zeros are assigned as the value(s). An error is generated if both are
3976
        specified but input_shapes != len(default_value).
3977

3978
    transfer_fct : TransferFunction : Linear
3979
        specifies the primary function, used to generate the value it returns.
3980

3981
    enabled_cost_functions : CostFunctions or List[CostFunctions] : None
3982
        specifies the costs to execute when `function <TransferWithCosts._function>` is called, and
3983
        include in the computation of `combined_costs <TransferWithCosts.combined_costs>`.
3984

3985
    intensity_cost_fct : Optional[`TransferFunction`] : default `Exponential`
3986
        specifies the function used to compute the `intensity_cost <TransferWithCosts.intensity_cost>`.
3987

3988
    adjustment_cost_fct : Optional[`TransferFunction`] : default `Linear`
3989
        specifies the function used to compute the `adjustment_cost <TransferWithCosts.adjustment_cost>`.
3990

3991
    duration_cost_fct : `IntegratorFunction` : default `IntegratorFunction`
3992
        specifies the function used to compute the `duration_cost <TransferWithCosts.duration_cost>`.
3993

3994
    combine_costs_fct : function : default `LinearCombination`
3995
        specifies the function used to compute `combined_costs <TransferWithCosts.combined_costs>`.
3996

3997
    params : Dict[param keyword: param value] : default None
3998
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
3999
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
4000
        arguments of the constructor.
4001

4002
    owner : Component
4003
        `component <Component>` to which to assign the Function.
4004

4005
    name : str : default see `name <Function.name>`
4006
        specifies the name of the Function.
4007

4008
    prefs : PreferenceSet or specification dict : default Function.classPreferences
4009
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
4010

4011

4012
    Attributes
4013
    ----------
4014

4015
    variable : 1d array
4016
        value used by `function <TransferWithCosts._function>`, and on which `intensity <TransferWithCosts.intensity>`
4017
        and associated costs are calculated.
4018

4019
    input_shapes : int
4020
        length of array for `variable <TransferWithCosts.variable>`.
4021

4022
    intensity : 1 array
4023
        the result of the transfer_fct <TransferWithCosts.transfer_fct>`, and the value returned by
4024
        `function <TransferWithCosts._function>`.
4025

4026
    function : TransferFunction
4027
        primary function, specified by **transfer_fct** argument of constructor, and also stored in
4028
        `transfer_fct <TransferWithCosts.transfer_fct>`.
4029

4030
    transfer_fct : TransferMechanism
4031
        the TransferWithCosts Function's primary function, used to generate the value it returns;
4032
        same as `function <TransferWithCosts._function>`.
4033

4034
    enabled_cost_functions : CostFunctions or None
4035
        boolean combination of currently enabled CostFunctions;  determines which `cost functions
4036
        <TransferWithCosts_Cost_Functions>` are calculated when `function <TransferWithCosts._function>`
4037
        is called, and are included in the computation of `combined_costs <TransferWithCosts.combined_costs>`
4038
        (see `Cost Functions <TransferWithCosts_Cost_Functions>` for additional details).
4039

4040
    intensity_cost : float or None
4041
        cost computed by `intensity_cost_fct <TransferWithCosts.intensity_cost_fct>` for current `intensity
4042
        <TransferWithCosts.intensity>`.  Value is None if `intensity_cost_fct <TransferWithCosts.intensity_cost_fct>`
4043
        has not been enabled (see `Cost Functions <TransferWithCosts_Cost_Functions>` for additional details).
4044

4045
    intensity_cost_fct : TransferFunction
4046
        calculates `intensity_cost` from the current value of `intensity <TransferWithCosts.intensity>`.
4047
        It can be any `TransferFunction`, or any other function that takes and returns a scalar value.
4048
        The default is `Exponential`.
4049

4050
    intensity_cost_fct_mult_param : value
4051
        references value of the `multiplicative_param <Function_Modulatory_Params>` of `intensity_cost_fct
4052
        <TransferWithCosts.intensity_cost_fct>`.
4053

4054
    intensity_cost_fct_add_param : value
4055
        references value of the `additive_param <Function_Modulatory_Params>` of `intensity_cost_fct
4056
        <TransferWithCosts.intensity_cost_fct>`.
4057

4058
    adjustment_cost : float or None
4059
        cost of change in `intensity <TransferWithCosts.intensity>` from the last time `function
4060
        <TransferWithCosts._function>` was executed.  Value is None if `adjustment_cost_fct
4061
        <TransferWithCosts.adjustment_cost_fct>` has not been enabled (see `Cost Functions
4062
        <TransferWithCosts_Cost_Functions>` for additional details).
4063

4064
    adjustment_cost_fct : TransferFunction
4065
        calculates `adjustment_cost <TransferWithCosts.adjustment_cost>` based on the change in `intensity
4066
        <TransferWithCosts.intensity>` from its value the last time `function <TransferWithCosts._function>` was
4067
        executed. It can be any `TransferFunction`, or any other function that takes and returns a scalar value.
4068

4069
    adjustment_cost_fct_mult_param : value
4070
        references value of the `multiplicative_param <Function_Modulatory_Params>` of `adjustment_cost_fct
4071
        <TransferWithCosts.adjustment_cost_fct>`.
4072

4073
    adjustment_cost_fct_add_param : value
4074
        references value of the `additive_param <Function_Modulatory_Params>` of `adjustment_cost_fct
4075
        <TransferWithCosts.adjustment_cost_fct>`.
4076

4077
    duration_cost : float or None
4078
        integral of `intensity <intensity <TransferWithCosts.intensity>`,  computed by `duration_cost_fct
4079
        <TransferWithCosts.duration_cost_fct>`.  Value is None if `duration_cost_fct
4080
        <TransferWithCosts.duration_cost_fct>` has not been enabled; othewise, the integral of
4081
        `intensity <intensity <TransferWithCosts.intensity>` is only for those executions of `function
4082
        <TransferWithCosts._function>` in which `function <TransferWithCosts.duration_cost_fct>` was enabled.
4083

4084
    duration_cost_fct : IntegratorFunction
4085
        calculates an integral of `intensity <TransferWithCosts.intensity>`.  It can be any `IntegratorFunction`,
4086
        or any other function that takes a list or array of two values and returns a scalar value.
4087

4088
    duration_cost_fct_mult_param : value
4089
        references value of the `multiplicative_param <Function_Modulatory_Params>` of `duration_cost_fct
4090
        <TransferWithCosts.duration_cost_fct>`.
4091

4092
    duration_cost_fct_add_param : value
4093
        references value of the `additive_param <Function_Modulatory_Params>` of `duration_cost_fct
4094
        <TransferWithCosts.duration_cost_fct>`.
4095

4096
    combined_costs : float or None
4097
        combined result of all `cost functions <TransferWithCostss_Cost_Functions>` that are enabled;
4098
        computed by `combined_costs_fct <TransferWithCosts.combined_costs_fct>` for current `intensity
4099
        <TransferWithCosts.intensity>`.  Value is None if no costs have been enabled (see
4100
        `Cost Functions <TransferWithCosts_Cost_Functions>` for additional details).
4101

4102
    combine_costs_fct : function
4103
        combines the results of all `cost functions <TransferWithCostss_Cost_Functions>` that are enabled, and assigns
4104
        the result to `cost <TransferWithCosts.cost>`. It can be any function that takes an array and returns a scalar
4105
        value.
4106

4107
    combined_costs_fct_mult_param : value
4108
        references value of the `multiplicative_param <Function_Modulatory_Params>` of `combined_costs_fct
4109
        <TransferWithCosts.combined_costs_fct>`.
4110

4111
    combined_costs_fct_add_param : value
4112
        references value of the `additive_param <Function_Modulatory_Params>` of `combined_costs_fct
4113
        <TransferWithCosts.combined_costs_fct>`.
4114

4115
    params : Dict[param keyword: param value] : default None
4116
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
4117
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
4118
        arguments of the constructor.
4119

4120
    name : str
4121
        name of the Function.
4122

4123
    owner : Component
4124
        `component <Component>` to which to assign the Function.
4125

4126
    prefs : PreferenceSet or specification dict : default Function.classPreferences
4127
        determines the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
4128
    """
4129

4130
    componentName = TRANSFER_WITH_COSTS_FUNCTION
1✔
4131

4132
    classPreferences = {
1✔
4133
        PREFERENCE_SET_NAME: 'TransferWithCostssClassPreferences',
4134
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
4135
    }
4136

4137
    class Parameters(TransferFunction.Parameters):
1✔
4138
        """
4139
            Attributes
4140
            ----------
4141

4142
                variable
4143
                    see `variable <TransferWithCosts.variable>`
4144

4145
                    :default value: numpy.array([0])
4146
                    :type: ``numpy.ndarray``
4147

4148
                LinearCombination
4149
                    see `LinearCombination <TransferWithCosts.LinearCombination>`
4150

4151
                    :default value: `LinearCombination`
4152
                    :type: `Function`
4153

4154
                SimpleIntegrator
4155
                    see `SimpleIntegrator <TransferWithCosts.SimpleIntegrator>`
4156

4157
                    :default value: `SimpleIntegrator`
4158
                    :type: `Function`
4159

4160
                adjustment_cost
4161
                    see `adjustment_cost <TransferWithCosts.adjustment_cost>`
4162

4163
                    :default value: None
4164
                    :type:
4165

4166
                adjustment_cost_fct
4167
                    see `adjustment_cost_fct <TransferWithCosts.adjustment_cost_fct>`
4168

4169
                    :default value: `Linear`
4170
                    :type: `Function`
4171

4172
                adjustment_cost_fct_add_param
4173
                    see `adjustment_cost_fct_add_param <TransferWithCosts.adjustment_cost_fct_add_param>`
4174

4175
                    :default value: None
4176
                    :type:
4177

4178
                adjustment_cost_fct_mult_param
4179
                    see `adjustment_cost_fct_mult_param <TransferWithCosts.adjustment_cost_fct_mult_param>`
4180

4181
                    :default value: None
4182
                    :type:
4183

4184
                combine_costs_fct
4185
                    see `combine_costs_fct <TransferWithCosts.combine_costs_fct>`
4186

4187
                    :default value: `LinearCombination`
4188
                    :type: `Function`
4189

4190
                combine_costs_fct_add_param
4191
                    see `combine_costs_fct_add_param <TransferWithCosts.combine_costs_fct_add_param>`
4192

4193
                    :default value: None
4194
                    :type:
4195

4196
                combine_costs_fct_mult_param
4197
                    see `combine_costs_fct_mult_param <TransferWithCosts.combine_costs_fct_mult_param>`
4198

4199
                    :default value: None
4200
                    :type:
4201

4202
                combined_costs
4203
                    see `combined_costs <TransferWithCosts.combined_costs>`
4204

4205
                    :default value: None
4206
                    :type:
4207

4208
                duration_cost
4209
                    see `duration_cost <TransferWithCosts.duration_cost>`
4210

4211
                    :default value: None
4212
                    :type:
4213

4214
                duration_cost_fct
4215
                    see `duration_cost_fct <TransferWithCosts.duration_cost_fct>`
4216

4217
                    :default value: `SimpleIntegrator`
4218
                    :type: `Function`
4219

4220
                duration_cost_fct_add_param
4221
                    see `duration_cost_fct_add_param <TransferWithCosts.duration_cost_fct_add_param>`
4222

4223
                    :default value: None
4224
                    :type:
4225

4226
                duration_cost_fct_mult_param
4227
                    see `duration_cost_fct_mult_param <TransferWithCosts.duration_cost_fct_mult_param>`
4228

4229
                    :default value: None
4230
                    :type:
4231

4232
                enabled_cost_functions
4233
                    see `enabled_cost_functions <TransferWithCosts.enabled_cost_functions>`
4234

4235
                    :default value: CostFunctions.INTENSITY
4236
                    :type: `CostFunctions`
4237

4238
                intensity
4239
                    see `intensity <TransferWithCosts.intensity>`
4240

4241
                    :default value: numpy.array([0])
4242
                    :type: ``numpy.ndarray``
4243

4244
                intensity_cost
4245
                    see `intensity_cost <TransferWithCosts.intensity_cost>`
4246

4247
                    :default value: None
4248
                    :type:
4249

4250
                intensity_cost_fct
4251
                    see `intensity_cost_fct <TransferWithCosts.intensity_cost_fct>`
4252

4253
                    :default value: `Exponential`
4254
                    :type: `Function`
4255

4256
                intensity_cost_fct_add_param
4257
                    see `intensity_cost_fct_add_param <TransferWithCosts.intensity_cost_fct_add_param>`
4258

4259
                    :default value: None
4260
                    :type:
4261

4262
                intensity_cost_fct_mult_param
4263
                    see `intensity_cost_fct_mult_param <TransferWithCosts.intensity_cost_fct_mult_param>`
4264

4265
                    :default value: None
4266
                    :type:
4267

4268
                transfer_fct
4269
                    see `transfer_fct <TransferWithCosts.transfer_fct>`
4270

4271
                    :default value: `Linear`
4272
                    :type: `Function`
4273

4274
                transfer_fct_add_param
4275
                    see `transfer_fct_add_param <TransferWithCosts.transfer_fct_add_param>`
4276

4277
                    :default value: None
4278
                    :type:
4279

4280
                transfer_fct_mult_param
4281
                    see `transfer_fct_mult_param <TransferWithCosts.transfer_fct_mult_param>`
4282

4283
                    :default value: None
4284
                    :type:
4285
        """
4286
        variable = Parameter(np.array([0]), history_min_length=1, constructor_argument='default_variable')
1✔
4287

4288
        intensity = Parameter(np.zeros_like(variable.default_value),
1✔
4289
                              history_min_length=1)
4290

4291
        # Create primary functions' modulation params for TransferWithCosts
4292
        transfer_fct = Parameter(Linear, stateful=False)
1✔
4293
        _validate_transfer_fct = get_validator_by_function(is_function_type)
1✔
4294
        transfer_fct_mult_param = FunctionParameter(
1✔
4295
            aliases=MULTIPLICATIVE_PARAM,
4296
            modulation_combination_function=PRODUCT,
4297
            function_name='transfer_fct',
4298
            function_parameter_name=MULTIPLICATIVE_PARAM,
4299
        )
4300
        transfer_fct_add_param = FunctionParameter(
1✔
4301
            aliases=ADDITIVE_PARAM,
4302
            modulation_combination_function=SUM,
4303
            function_name='transfer_fct',
4304
            function_parameter_name=ADDITIVE_PARAM,
4305
        )
4306

4307
        enabled_cost_functions = Parameter(
1✔
4308
            CostFunctions.DEFAULTS,
4309
            valid_types=(CostFunctions, list)
4310
        )
4311

4312
        # Create versions of cost functions' modulation params for TransferWithCosts
4313

4314
        intensity_cost = None
1✔
4315
        intensity_cost_fct = Parameter(Exponential, stateful=False)
1✔
4316
        _validate_intensity_cost_fct = get_validator_by_function(is_function_type)
1✔
4317
        intensity_cost_fct_mult_param = FunctionParameter(
1✔
4318
            modulation_combination_function=PRODUCT,
4319
            function_name='intensity_cost_fct',
4320
            function_parameter_name=MULTIPLICATIVE_PARAM,
4321
        )
4322
        intensity_cost_fct_add_param = FunctionParameter(
1✔
4323
            modulation_combination_function=SUM,
4324
            function_name='intensity_cost_fct',
4325
            function_parameter_name=ADDITIVE_PARAM,
4326
        )
4327

4328
        adjustment_cost = None
1✔
4329
        adjustment_cost_fct = Parameter(Linear, stateful=False)
1✔
4330
        _validate_adjustment_cost_fct = get_validator_by_function(is_function_type)
1✔
4331
        adjustment_cost_fct_mult_param = FunctionParameter(
1✔
4332
            modulation_combination_function=PRODUCT,
4333
            function_name='adjustment_cost_fct',
4334
            function_parameter_name=MULTIPLICATIVE_PARAM,
4335
        )
4336
        adjustment_cost_fct_add_param = FunctionParameter(
1✔
4337
            modulation_combination_function=SUM,
4338
            function_name='adjustment_cost_fct',
4339
            function_parameter_name=ADDITIVE_PARAM,
4340
        )
4341

4342
        duration_cost = None
1✔
4343
        duration_cost_fct = Parameter(SimpleIntegrator, stateful=False)
1✔
4344
        _validate_duration_cost_fct = get_validator_by_function(is_function_type)
1✔
4345
        duration_cost_fct_mult_param = FunctionParameter(
1✔
4346
            modulation_combination_function=PRODUCT,
4347
            function_name='duration_cost_fct',
4348
            function_parameter_name=MULTIPLICATIVE_PARAM,
4349
        )
4350
        duration_cost_fct_add_param = FunctionParameter(
1✔
4351
            modulation_combination_function=SUM,
4352
            function_name='duration_cost_fct',
4353
            function_parameter_name=ADDITIVE_PARAM,
4354
        )
4355

4356
        combined_costs = None
1✔
4357
        combine_costs_fct = Parameter(LinearCombination, stateful=False)
1✔
4358
        _validate_combine_costs_fct = get_validator_by_function(is_function_type)
1✔
4359
        combine_costs_fct_mult_param = FunctionParameter(
1✔
4360
            modulation_combination_function=PRODUCT,
4361
            function_name='combine_costs_fct',
4362
            function_parameter_name=MULTIPLICATIVE_PARAM,
4363
        )
4364
        combine_costs_fct_add_param = FunctionParameter(
1✔
4365
            modulation_combination_function=SUM,
4366
            function_name='combine_costs_fct',
4367
            function_parameter_name=ADDITIVE_PARAM,
4368
        )
4369

4370
    @check_user_specified
1✔
4371
    @beartype
1✔
4372
    def __init__(self,
1✔
4373
                 default_variable=None,
4374
                 input_shapes=None,
4375
                 transfer_fct: Optional[Callable] = None,
4376
                 enabled_cost_functions: Optional[Union[CostFunctions, list]] = None,
4377
                 intensity_cost_fct: Optional[Callable] = None,
4378
                 adjustment_cost_fct: Optional[Callable] = None,
4379
                 duration_cost_fct: Optional[Callable] = None,
4380
                 combine_costs_fct: Optional[Callable] = None,
4381
                 params=None,
4382
                 owner=None,
4383
                 prefs: Optional[ValidPrefSet] = None):
4384

4385
        # if input_shapes:
4386
        #     if default_variable is None:
4387
        #         default_variable = np.zeros(input_shapes)
4388
        #     elif input_shapes != len(default_variable):
4389
        #         raise FunctionError(f"Both {repr(DEFAULT_VARIABLE)} ({default_variable}) and {repr(SIZE)} ({input_shapes}) "
4390
        #                             f"are specified for {self.name} but are {SIZE}!=len({DEFAULT_VARIABLE}).")
4391

4392
        super().__init__(
1✔
4393
            default_variable=default_variable,
4394
            transfer_fct=transfer_fct,
4395
            enabled_cost_functions=enabled_cost_functions,
4396
            intensity_cost_fct=intensity_cost_fct,
4397
            adjustment_cost_fct=adjustment_cost_fct,
4398
            duration_cost_fct=duration_cost_fct,
4399
            combine_costs_fct=combine_costs_fct,
4400
            params=params,
4401
            owner=owner,
4402
            prefs=prefs,
4403
        )
4404

4405
        # # MODIFIED 6/12/19 NEW: [JDC]
4406
        # self._variable_shape_flexibility = DefaultsFlexibility.FLEXIBLE
4407
        # # MODIFIED 6/12/19 END
4408

4409
    def _instantiate_attributes_before_function(self, function=None, context=None):
1✔
4410
        """Instantiate `cost functions <TransferWithCosts_Cost_Functions>` specified in `enabled_cost_functions
4411
        <TransferWithCostss.enabled_cost_functions>`.
4412
        """
4413
        super()._instantiate_attributes_before_function(function=function, context=None)
1✔
4414
        self._instantiate_cost_functions(context=context)
1✔
4415

4416
    def _instantiate_cost_functions(self, context):
1✔
4417
        """Instantiate cost functions and the multiplicative and additive modulatory parameters for them.
4418

4419
        Parse specification of cost functions to enable
4420
        Instantiate cost functions specified in construtor arguments, and enable ones in enabled_cost_functions
4421
        Assign default value for multipicative and additive parameters for each, from the values of those parameters
4422
            on the respective cost functions just instantiated.
4423
        Initialize intensity_cost
4424
        """
4425

4426
        if self.enabled_cost_functions:
1✔
4427
            self.assign_costs(self.enabled_cost_functions)
1✔
4428

4429
        def instantiate_fct(fct_name, fct):
1✔
4430
            if not fct:
1!
4431
                self.toggle_cost(fct_name, OFF)
×
4432
                return None
×
4433
            elif isinstance(fct, Function):
1✔
4434
                return fct
1✔
4435
            elif isinstance(fct, (types.FunctionType, types.MethodType)):
1!
4436
                from psyneulink.core.components.functions.userdefinedfunction import UserDefinedFunction
1✔
4437
                return UserDefinedFunction(#default_variable=function_variable,
1✔
4438
                        custom_function=fct,
4439
                        owner=self,
4440
                        context=context)
4441
            elif issubclass(fct, Function):
×
4442
                return fct()
×
4443
            else:
4444
                raise FunctionError(f"{fct} is not a valid cost function for {fct_name}.")
4445

4446
        self.intensity_cost_fct = instantiate_fct(INTENSITY_COST_FUNCTION, self.intensity_cost_fct)
1✔
4447
        self.adjustment_cost_fct = instantiate_fct(ADJUSTMENT_COST_FUNCTION, self.adjustment_cost_fct)
1✔
4448
        self.duration_cost_fct = instantiate_fct(DURATION_COST_FUNCTION, self.duration_cost_fct)
1✔
4449
        self.combine_costs_fct = instantiate_fct(COMBINE_COSTS_FUNCTION, self.combine_costs_fct)
1✔
4450

4451
        # Initialize intensity attributes
4452
        if self.enabled_cost_functions:
1✔
4453
            # Default cost params
4454
            if self.owner:
1!
4455
                if self.owner.context.initialization_status != ContextFlags.DEFERRED_INIT:
×
4456
                    self.intensity_cost = self.intensity_cost_fct(self.owner.defaults.variable)
×
4457
                else:
4458
                    self.intensity_cost = self.intensity_cost_fct(self.owner.class_defaults.variable)
×
4459
            else:
4460
                self.intensity_cost = self.intensity_cost_fct(self.defaults.variable)
1✔
4461
                self.defaults.intensity_cost = self.intensity_cost
1✔
4462

4463
    def _function(self,
1✔
4464
                 variable=None,
4465
                 params=None,
4466
                 context=None):
4467
        """
4468

4469
        Arguments
4470
        ---------
4471

4472
        variable : number or array : default class_defaults.variable
4473
           a single value or array to be transformed.
4474

4475
        params : Dict[param keyword: param value] : default None
4476
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the function.
4477
            Values specified for parameters in the dictionary override any assigned to those parameters in arguments
4478
            of the constructor.
4479

4480
        Returns
4481
        -------
4482

4483
        transformation of variable using `transfer_fct <TransferWithCostss.transfer_fct>` : number or array
4484

4485
        """
4486

4487
        self._check_args(variable=variable, params=params, context=context)
1✔
4488

4489
        # FIRST, DEAL WITH CURRENT INTENSITY
4490

4491
        # Compute current intensity
4492
        intensity = self.parameters.transfer_fct._get(context)(variable, context=context)
1✔
4493

4494
        # THEN, DEAL WITH COSTS
4495
        # Note: only compute costs that are enabled;  others are left as None, or with their value when last enabled.
4496

4497
        # Get costs for each cost function that is enabled in enabled_cost_functions
4498
        enabled_cost_functions = self.parameters.enabled_cost_functions._get(context)
1✔
4499
        enabled_costs = [] # Used to aggregate costs that are enabled and submit to combine_costs_fct
1✔
4500
        if enabled_cost_functions:
1✔
4501

4502
            # For each cost function that is enabled:
4503
            # - get params for the cost functon using _get_current_parameter_value:
4504
            #   - if TransferWithControl is owned by a Mechanism, get value from ParameterPort for param
4505
            #   - otherwise, get from TransferWithControl modulation parameter (which is also subject to modulation)
4506

4507
            # Compute intensity_cost
4508
            if enabled_cost_functions & CostFunctions.INTENSITY:
1✔
4509
                # Execute intensity_cost function
4510
                intensity_cost = self.intensity_cost_fct(intensity, context=context)
1✔
4511
                self.parameters.intensity_cost._set(intensity_cost, context)
1✔
4512
                enabled_costs.append(intensity_cost)
1✔
4513

4514
            # Compute adjustment_cost
4515
            if enabled_cost_functions & CostFunctions.ADJUSTMENT:
1✔
4516
                # Compute intensity change
4517
                try:
1✔
4518
                    intensity_change = np.abs(intensity - self.parameters.intensity._get(context))
1✔
4519
                except TypeError:
×
4520
                    intensity_change = np.zeros_like(self.parameters_intensity._get(context))
×
4521
                # Execute adjustment_cost function
4522
                adjustment_cost = self.adjustment_cost_fct(intensity_change, context=context)
1✔
4523
                self.parameters.adjustment_cost._set(adjustment_cost, context)
1✔
4524
                enabled_costs.append(adjustment_cost)
1✔
4525

4526
            # Compute duration_cost
4527
            if enabled_cost_functions & CostFunctions.DURATION:
1✔
4528
                # Execute duration_cost function
4529
                duration_cost = self.duration_cost_fct(intensity, context=context)
1✔
4530
                self.parameters.duration_cost._set(duration_cost, context)
1✔
4531
                enabled_costs.append(duration_cost)
1✔
4532

4533
            # Alwasy execute combined_costs_fct if *any* costs are enabled
4534
            # Execute combine_costs function
4535
            combined_costs = self.combine_costs_fct(enabled_costs,
1✔
4536
                                                    context=context)
4537
            self.parameters.combined_costs._set(combined_costs, context)
1✔
4538

4539
        # Store current intensity
4540
        self.parameters.intensity._set(copy_parameter_value(intensity), context)
1✔
4541

4542
        return intensity
1✔
4543

4544
    def _is_identity(self, context=None, defaults=False):
1✔
4545
        transfer_fct = self.parameters.transfer_fct.get()
1✔
4546

4547
        if defaults:
1!
4548
            enabled_cost_functions = self.defaults.enabled_cost_functions
×
4549
        else:
4550
            enabled_cost_functions = self.parameters.enabled_cost_functions.get(context)
1✔
4551

4552
        return transfer_fct._is_identity(context, defaults=defaults) and enabled_cost_functions == CostFunctions.NONE
1✔
4553

4554
    @beartype
1✔
4555
    def assign_costs(self, cost_functions: Union[CostFunctions, list], execution_context=None):
1✔
4556
        """Assigns specified functions; all others are disabled.
4557

4558
        Arguments
4559
        ---------
4560
        cost_functions: CostFunctions or List[CostFunctions]
4561
            `cost function <TransferWithCosts_Cost_Functions>` or list of ones to be used;  all other will be disabled.
4562
        Returns
4563
        -------
4564
        enabled_cost_functions :  boolean combination of CostFunctions
4565
            current value of `enabled_cost_functions <TransferWithCosts.enabled_cost_functions>`.
4566

4567
        """
4568
        if isinstance(cost_functions, CostFunctions):
1!
4569
            cost_functions = [cost_functions]
1✔
4570
        self.parameters.enabled_cost_functions.set(CostFunctions.NONE, execution_context)
1✔
4571
        return self.enable_costs(cost_functions, execution_context)
1✔
4572

4573
    @beartype
1✔
4574
    def enable_costs(self, cost_functions: Union[CostFunctions, list], execution_context=None):
1✔
4575
        """Enable specified `cost functions <TransferWithCosts_Cost_Functions>`;
4576
        settings for all other cost functions are left intact.
4577

4578
        Arguments
4579
        ---------
4580
        cost_functions: CostFunctions or List[CostFunctions]
4581
            `cost function <TransferWithCosts_Cost_Functions>` or list of ones to be enabled,
4582
            in addition to any that are already enabled.
4583
        Returns
4584
        -------
4585
        enabled_cost_functions :  boolean combination of CostFunctions
4586
            current value of `enabled_cost_functions <TransferWithCosts.enabled_cost_functions>`.
4587
        """
4588
        if isinstance(cost_functions, CostFunctions):
1!
4589
            cost_functions = [cost_functions]
×
4590
        enabled_cost_functions = self.parameters.enabled_cost_functions.get(execution_context)
1✔
4591
        for cost_function in cost_functions:
1✔
4592
            enabled_cost_functions |= cost_function
1✔
4593

4594
        self.parameters.enabled_cost_functions.set(enabled_cost_functions, execution_context)
1✔
4595
        return enabled_cost_functions
1✔
4596

4597
    @beartype
1✔
4598
    def disable_costs(self, cost_functions: Union[CostFunctions, list], execution_context=None):
1✔
4599
        """Disable specified `cost functions <TransferWithCosts_Cost_Functions>`;
4600
        settings for all other cost functions are left intact.
4601

4602
        Arguments
4603
        ---------
4604
        cost_functions: CostFunction or List[CostFunctions]
4605
            `cost function <TransferWithCosts_Cost_Functions>` or list of ones to be disabled.
4606
        Returns
4607
        -------
4608
        enabled_cost_functions :  boolean combination of CostFunctions
4609
            current value of `enabled_cost_functions <TransferWithCosts.enabled_cost_functions>`.
4610
        """
4611
        if isinstance(cost_functions, CostFunctions):
×
4612
            cost_functions = [cost_functions]
×
4613
        enabled_cost_functions = self.parameters.enabled_cost_functions.get(execution_context)
×
4614
        for cost_function in cost_functions:
×
4615
            enabled_cost_functions &= ~cost_function
×
4616

4617
        self.parameters.enabled_cost_functions.set(enabled_cost_functions, execution_context)
×
4618
        return enabled_cost_functions
×
4619

4620
    def toggle_cost(self, cost_function_name: Union[str, CostFunctions],
1✔
4621
                    assignment: bool = ON,
4622
                    execution_context=None):
4623
        """Enable/disable a `cost functions <TransferWithCosts_Cost_Functions>`.
4624

4625
        Arguments
4626
        ---------
4627
        cost_function_name : str or CostFunction
4628
            Must be the name of a `cost function <TransferWithCosts_Cost_Functions>` or a value of CostFunction enum.
4629

4630
        Returns
4631
        -------
4632
        enabled_cost_functions :  boolean combination of CostFunctions
4633
            current value of `enabled_cost_functions <TransferWithCosts.enabled_cost_functions>`.
4634

4635
        """
4636
        if cost_function_name in {INTENSITY_COST_FUNCTION, CostFunctions.INTENSITY}:
1✔
4637
            cost_function = CostFunctions.INTENSITY
1✔
4638
            cost_function_name = INTENSITY_COST_FUNCTION
1✔
4639
        elif cost_function_name in {ADJUSTMENT_COST_FUNCTION, CostFunctions.ADJUSTMENT}:
1✔
4640
            cost_function = CostFunctions.ADJUSTMENT
1✔
4641
            cost_function_name = ADJUSTMENT_COST_FUNCTION
1✔
4642
        elif cost_function_name in {DURATION_COST_FUNCTION, CostFunctions.DURATION}:
1!
4643
            cost_function = CostFunctions.DURATION
1✔
4644
            cost_function_name = DURATION_COST_FUNCTION
1✔
4645
        elif cost_function_name == COMBINE_COSTS_FUNCTION:
×
4646
            raise FunctionError("{} cannot be disabled".format(COMBINE_COSTS_FUNCTION))
4647
        else:
4648
            raise FunctionError("toggle_cost: unrecognized cost function: {}".format(cost_function_name))
4649

4650
        enabled_cost_functions = self.parameters.enabled_cost_functions.get(execution_context)
1✔
4651
        if assignment:
1!
4652
            if cost_function_name not in self.parameters.names():
1✔
4653
                raise FunctionError("Unable to toggle {} ON as function assignment is \'None\'".
4654
                                         format(cost_function_name))
4655
            if not enabled_cost_functions:
1✔
4656
                enabled_cost_functions = cost_function
1✔
4657
            else:
4658
                enabled_cost_functions |= cost_function
1✔
4659
        else:
4660
            enabled_cost_functions &= ~cost_function
×
4661

4662
        self.parameters.enabled_cost_functions.set(enabled_cost_functions, execution_context)
1✔
4663
        return enabled_cost_functions
1✔
4664

4665
    def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
1✔
4666
        # Run transfer function first
4667
        transfer_f = self.parameters.transfer_fct
1✔
4668
        trans_f = ctx.import_llvm_function(transfer_f.get())
1✔
4669
        trans_p, trans_s = ctx.get_param_or_state_ptr(builder,
1✔
4670
                                                      self,
4671
                                                      transfer_f.name,
4672
                                                      param_struct_ptr=params,
4673
                                                      state_struct_ptr=state)
4674
        trans_in = arg_in
1✔
4675
        trans_out = arg_out
1✔
4676
        builder.call(trans_f, [trans_p, trans_s, trans_in, trans_out])
1✔
4677

4678
        intensity_ptr = ctx.get_state_space(builder, self, state, self.parameters.intensity)
1✔
4679

4680
        costs = [(self.parameters.intensity_cost_fct, CostFunctions.INTENSITY, self.parameters.intensity_cost),
1✔
4681
                 (self.parameters.adjustment_cost_fct, CostFunctions.ADJUSTMENT, self.parameters.adjustment_cost),
4682
                 (self.parameters.duration_cost_fct, CostFunctions.DURATION, self.parameters.duration_cost)]
4683

4684
        for (func, flag, res_param) in costs:
1✔
4685

4686
            cost_in = trans_out
1✔
4687
            cost_out = ctx.get_state_space(builder, self, state, res_param)
1✔
4688

4689
            # The check for enablement is structural and has to be done in Python.
4690
            # If a cost function is not enabled the cost parameter is None
4691
            if flag in self.parameters.enabled_cost_functions.get():
1✔
4692
                cost_f = ctx.import_llvm_function(func.get())
1✔
4693
                cost_p, cost_s = ctx.get_param_or_state_ptr(builder,
1✔
4694
                                                            self,
4695
                                                            func,
4696
                                                            param_struct_ptr=params,
4697
                                                            state_struct_ptr=state)
4698

4699
                if flag == CostFunctions.ADJUSTMENT:
1✔
4700
                    old_intensity = pnlvm.helpers.load_extract_scalar_array_one(builder, intensity_ptr)
1✔
4701
                    new_intensity = pnlvm.helpers.load_extract_scalar_array_one(builder, trans_out)
1✔
4702
                    adjustment = builder.fsub(new_intensity, old_intensity)
1✔
4703

4704
                    fabs_f = ctx.get_builtin("fabs", [adjustment.type])
1✔
4705
                    adjustment = builder.call(fabs_f, [adjustment])
1✔
4706

4707
                    cost_in = builder.alloca(cost_in.type.pointee)
1✔
4708
                    builder.store(adjustment, builder.gep(cost_in, [ctx.int32_ty(0), ctx.int32_ty(0)]))
1✔
4709

4710
                builder.call(cost_f, [cost_p, cost_s, cost_in, cost_out])
1✔
4711
            else:
4712
                # Intensity is [1] when the cost function is disabled but other cost functions are enabled
4713
                # https://github.com/PrincetonUniversity/PsyNeuLink/issues/2711
4714
                exp_out_len = 0 if self.parameters.enabled_cost_functions.get() == CostFunctions.NONE or flag != CostFunctions.INTENSITY else 1
1✔
4715
                assert len(cost_out.type.pointee) == exp_out_len, "Unexpected out sturct for {}: {}".format(flag, cost_out.type.pointee)
1✔
4716

4717

4718
        # TODO: combine above costs via a call to combine_costs_fct
4719
        # depends on: https://github.com/PrincetonUniversity/PsyNeuLink/issues/2712
4720
        # This function is still used in OCM so track both state and parameters
4721
        combine_p, combine_s = ctx.get_param_or_state_ptr(builder,
1✔
4722
                                                          self,
4723
                                                          self.parameters.combine_costs_fct,
4724
                                                          param_struct_ptr=params,
4725
                                                          state_struct_ptr=state)
4726

4727
        builder.store(builder.load(trans_out), intensity_ptr)
1✔
4728

4729
        return builder
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc