• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 13611321408

25 Feb 2025 01:16AM UTC coverage: 84.017% (-0.02%) from 84.037%
13611321408

push

github

web-flow
Merge pull request #3211 from PrincetonUniversity/devel

Devel

9625 of 12688 branches covered (75.86%)

Branch coverage included in aggregate %.

561 of 630 new or added lines in 23 files covered. (89.05%)

6 existing lines in 4 files now uncovered.

33432 of 38560 relevant lines covered (86.7%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

77.39
/psyneulink/core/components/functions/function.py
1
#
2
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
3
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
4
#     http://www.apache.org/licenses/LICENSE-2.0
5
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
6
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7
# See the License for the specific language governing permissions and limitations under the License.
8
#
9
#
10
# ***********************************************  Function ************************************************************
11

12
"""
13
|
14
Function
15
  * `Function_Base`
16

17
Example function:
18
  * `ArgumentTherapy`
19

20

21
.. _Function_Overview:
22

23
Overview
24
--------
25

26
A Function is a `Component <Component>` that "packages" a function for use by other Components.
27
Every Component in PsyNeuLink is assigned a Function; when that Component is executed, its
28
Function's `function <Function_Base.function>` is executed.  The `function <Function_Base.function>` can be any callable
29
operation, although most commonly it is a mathematical operation (and, for those, almost always uses a call to one or
30
more numpy functions).  There are two reasons PsyNeuLink packages functions in a Function Component:
31

32
* **Manage parameters** -- parameters are attributes of a Function that either remain stable over multiple calls to the
33
  function (e.g., the `gain <Logistic.gain>` or `bias <Logistic.bias>` of a `Logistic` function, or the learning rate
34
  of a learning function); or, if they change, they do so less frequently or under the control of different factors
35
  than the function's variable (i.e., its input).  As a consequence, it is useful to manage these separately from the
36
  function's variable, and not have to provide them every time the function is called.  To address this, every
37
  PsyNeuLink Function has a set of attributes corresponding to the parameters of the function, that can be specified at
38
  the time the Function is created (in arguments to its constructor), and can be modified independently
39
  of a call to its :keyword:`function`. Modifications can be directly (e.g., in a script), or by the operation of other
40
  PsyNeuLink Components (e.g., `ModulatoryMechanisms`) by way of `ControlProjections <ControlProjection>`.
41
..
42
* **Modularity** -- by providing a standard interface, any Function assigned to a Components in PsyNeuLink can be
43
  replaced with other PsyNeuLink Functions, or with user-written custom functions so long as they adhere to certain
44
  standards (the PsyNeuLink `Function API <LINK>`).
45

46
.. _Function_Creation:
47

48
Creating a Function
49
-------------------
50

51
A Function can be created directly by calling its constructor.  Functions are also created automatically whenever
52
any other type of PsyNeuLink Component is created (and its :keyword:`function` is not otherwise specified). The
53
constructor for a Function has an argument for its `variable <Function_Base.variable>` and each of the parameters of
54
its `function <Function_Base.function>`.  The `variable <Function_Base.variable>` argument is used both to format the
55
input to the `function <Function_Base.function>`, and assign its default value.  The arguments for each parameter can
56
be used to specify the default value for that parameter; the values can later be modified in various ways as described
57
below.
58

59
.. _Function_Structure:
60

61
Structure
62
---------
63

64
.. _Function_Core_Attributes:
65

66
*Core Attributes*
67
~~~~~~~~~~~~~~~~~
68

69
Every Function has the following core attributes:
70

71
* `variable <Function_Base.variable>` -- provides the input to the Function's `function <Function_Base.function>`.
72
..
73
* `function <Function_Base.function>` -- determines the computation carried out by the Function; it must be a
74
  callable object (that is, a python function or method of some kind). Unlike other PsyNeuLink `Components
75
  <Component>`, it *cannot* be (another) Function object (it can't be "turtles" all the way down!).
76

77
A Function also has an attribute for each of the parameters of its `function <Function_Base.function>`.
78

79
*Owner*
80
~~~~~~~
81

82
If a Function has been assigned to another `Component`, then it also has an `owner <Function_Base.owner>` attribute
83
that refers to that Component.  The Function itself is assigned as the Component's
84
`function <Component.function>` attribute.  Each of the Function's attributes is also assigned
85
as an attribute of the `owner <Function_Base.owner>`, and those are each associated with with a
86
`parameterPort <ParameterPort>` of the `owner <Function_Base.owner>`.  Projections to those parameterPorts can be
87
used by `ControlProjections <ControlProjection>` to modify the Function's parameters.
88

89

90
COMMENT:
91
.. _Function_Output_Type_Conversion:
92

93
If the `function <Function_Base.function>` returns a single numeric value, and the Function's class implements
94
FunctionOutputTypeConversion, then the type of value returned by its `function <Function>` can be specified using the
95
`output_type` attribute, by assigning it one of the following `FunctionOutputType` values:
96
    * FunctionOutputType.NP_0D_ARRAY: return 0d np.array
97
    * FunctionOutputType.NP_1D_ARRAY: return 1d np.array
98
    * FunctionOutputType.NP_2D_ARRAY: return 2d np.array.
99

100
To implement FunctionOutputTypeConversion, the Function's FUNCTION_OUTPUT_TYPE_CONVERSION parameter must set to True,
101
and function type conversion must be implemented by its `function <Function_Base.function>` method
102
(see `Linear` for an example).
103
COMMENT
104

105
.. _Function_Modulatory_Params:
106

107
*Modulatory Parameters*
108
~~~~~~~~~~~~~~~~~~~~~~~
109

110
Some classes of Functions also implement a pair of modulatory parameters: `multiplicative_param` and `additive_param`.
111
Each of these is assigned the name of one of the function's parameters. These are used by `ModulatorySignals
112
<ModulatorySignal>` to modulate the `function <Port_Base.function>` of a `Port <Port>` and thereby its `value
113
<Port_Base.value>` (see `ModulatorySignal_Modulation` and `figure <ModulatorySignal_Detail_Figure>` for additional
114
details). For example, a `ControlSignal` typically uses the `multiplicative_param` to modulate the value of a parameter
115
of a Mechanism's `function <Mechanism_Base.function>`, whereas a `LearningSignal` uses the `additive_param` to increment
116
the `value <ParamterPort.value>` of the `matrix <MappingProjection.matrix>` parameter of a `MappingProjection`.
117

118
COMMENT:
119
FOR DEVELOPERS:  'multiplicative_param` and `additive_param` are implemented as aliases to the relevant
120
parameters of a given Function, declared in its Parameters subclass declaration of the Function's declaration.
121
COMMENT
122

123

124
.. _Function_Execution:
125

126
Execution
127
---------
128

129
Functions are executable objects that can be called directly.  More commonly, however, they are called when
130
their `owner <Function_Base.owner>` is executed.  The parameters
131
of the `function <Function_Base.function>` can be modified when it is executed, by assigning a
132
`parameter specification dictionary <ParameterPort_Specification>` to the **params** argument in the
133
call to the `function <Function_Base.function>`.
134

135
For `Mechanisms <Mechanism>`, this can also be done by specifying `runtime_params <Composition_Runtime_Params>` in the
136
`Run` method of their `Composition`.
137

138
Class Reference
139
---------------
140

141
"""
142

143
import abc
1✔
144
import inspect
1✔
145
import numbers
1✔
146
import types
1✔
147
import warnings
1✔
148
from enum import Enum, IntEnum
1✔
149

150
import numpy as np
1✔
151
try:
1✔
152
    import torch
1✔
153
except ImportError:
×
154
    torch = None
×
155
from beartype import beartype
1✔
156

157
from psyneulink._typing import Optional, Union, Callable
1✔
158

159
from psyneulink.core.components.component import Component, ComponentError, DefaultsFlexibility
1✔
160
from psyneulink.core.components.shellclasses import Function, Mechanism
1✔
161
from psyneulink.core.globals.context import ContextFlags, handle_external_context
1✔
162
from psyneulink.core.globals.keywords import (
1✔
163
    ARGUMENT_THERAPY_FUNCTION, AUTO_ASSIGN_MATRIX, EXAMPLE_FUNCTION_TYPE, FULL_CONNECTIVITY_MATRIX,
164
    FUNCTION_COMPONENT_CATEGORY, FUNCTION_OUTPUT_TYPE, FUNCTION_OUTPUT_TYPE_CONVERSION, HOLLOW_MATRIX,
165
    IDENTITY_MATRIX, INVERSE_HOLLOW_MATRIX, NAME, PREFERENCE_SET_NAME, RANDOM_CONNECTIVITY_MATRIX, VALUE, VARIABLE,
166
    MODEL_SPEC_ID_MDF_VARIABLE, MatrixKeywordLiteral, ZEROS_MATRIX
167
)
168
from psyneulink.core.globals.mdf import _get_variable_parameter_name
1✔
169
from psyneulink.core.globals.parameters import Parameter, check_user_specified, copy_parameter_value
1✔
170
from psyneulink.core.globals.preferences.basepreferenceset import REPORT_OUTPUT_PREF, ValidPrefSet
1✔
171
from psyneulink.core.globals.preferences.preferenceset import PreferenceEntry, PreferenceLevel
1✔
172
from psyneulink.core.globals.registry import register_category
1✔
173
from psyneulink.core.globals.utilities import (
1✔
174
    NumericCollections,
175
    SeededRandomState,
176
    _get_global_seed,
177
    array_from_matrix_string,
178
    contains_type,
179
    convert_all_elements_to_np_array,
180
    convert_to_np_array,
181
    is_instance_or_subclass,
182
    is_numeric,
183
    is_numeric_scalar,
184
    object_has_single_value,
185
    parameter_spec,
186
    parse_valid_identifier,
187
    random_matrix,
188
    safe_len,
189
    try_extract_0d_array_item,
190
)
191

192
__all__ = [
1✔
193
    'ArgumentTherapy', 'EPSILON', 'Function_Base', 'function_keywords', 'FunctionError', 'FunctionOutputType',
194
    'FunctionRegistry', 'get_param_value_for_function', 'get_param_value_for_keyword', 'is_Function',
195
    'is_function_type', 'PERTINACITY', 'PROPENSITY', 'RandomMatrix'
196
]
197

198
EPSILON = np.finfo(float).eps
1✔
199

200

201
# numeric to allow modulation, invalid to identify unseeded state
202
def DEFAULT_SEED():
1✔
203
    return np.array(-1)
1✔
204

205

206
FunctionRegistry = {}
1✔
207

208
function_keywords = {FUNCTION_OUTPUT_TYPE, FUNCTION_OUTPUT_TYPE_CONVERSION}
1✔
209

210

211
class FunctionError(ComponentError):
1✔
212
    pass
1✔
213

214

215
class FunctionOutputType(IntEnum):
1✔
216
    NP_0D_ARRAY = 0
1✔
217
    NP_1D_ARRAY = 1
1✔
218
    NP_2D_ARRAY = 2
1✔
219
    DEFAULT = 3
1✔
220

221

222
# Typechecking *********************************************************************************************************
223

224
# TYPE_CHECK for Function Instance or Class
225
def is_Function(x):
1✔
226
    if not x:
×
227
        return False
×
228
    elif isinstance(x, Function):
×
229
        return True
×
230
    elif issubclass(x, Function):
×
231
        return True
×
232
    else:
233
        return False
×
234

235

236
def is_function_type(x):
1✔
237
    if callable(x):
1✔
238
        return True
1✔
239
    elif not x:
1!
240
        return False
×
241
    elif isinstance(x, (Function, types.FunctionType, types.MethodType, types.BuiltinFunctionType, types.BuiltinMethodType)):
1!
242
        return True
×
243
    elif isinstance(x, type) and issubclass(x, Function):
1!
244
        return True
×
245
    else:
246
        return False
1✔
247

248
# *******************************   get_param_value_for_keyword ********************************************************
249

250
def get_param_value_for_keyword(owner, keyword):
1✔
251
    """Return the value for a keyword used by a subclass of Function
252

253
    Parameters
254
    ----------
255
    owner : Component
256
    keyword : str
257

258
    Returns
259
    -------
260
    value
261

262
    """
263
    try:
1✔
264
        return owner.function.keyword(owner, keyword)
1✔
265
    except FunctionError as e:
1✔
266
        # assert(False)
267
        # prefs is not always created when this is called, so check
268
        try:
×
269
            owner.prefs
×
270
            has_prefs = True
×
271
        except AttributeError:
×
272
            has_prefs = False
×
273

274
        if has_prefs and owner.prefs.verbosePref:
×
275
            print("{} of {}".format(e, owner.name))
×
276
        # return None
277
        else:
278
            raise FunctionError(e)
279
    except AttributeError:
1✔
280
        # prefs is not always created when this is called, so check
281
        try:
1✔
282
            owner.prefs
1✔
283
            has_prefs = True
1✔
284
        except AttributeError:
1✔
285
            has_prefs = False
1✔
286

287
        if has_prefs and owner.prefs.verbosePref:
1!
288
            print("Keyword ({}) not recognized for {}".format(keyword, owner.name))
×
289
        return None
1✔
290

291

292
def get_param_value_for_function(owner, function):
1✔
293
    try:
×
294
        return owner.function.param_function(owner, function)
×
295
    except FunctionError as e:
×
296
        if owner.prefs.verbosePref:
×
297
            print("{} of {}".format(e, owner.name))
×
298
        return None
×
299
    except AttributeError:
×
300
        if owner.prefs.verbosePref:
×
301
            print("Function ({}) can't be evaluated for {}".format(function, owner.name))
×
302
        return None
×
303

304
# Parameter Mixins *****************************************************************************************************
305

306
# KDM 6/21/18: Below is left in for consideration; doesn't really gain much to justify relaxing the assumption
307
# that every Parameters class has a single parent
308

309
# class ScaleOffsetParamMixin:
310
#     scale = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
311
#     offset = Parameter(1.0, modulable=True, aliases=[ADDITIVE_PARAM])
312

313

314
# Function Definitions *************************************************************************************************
315

316

317
# KDM 8/9/18: below is added for future use when function methods are completely functional
318
# used as a decorator for Function methods
319
# def enable_output_conversion(func):
320
#     @functools.wraps(func)
321
#     def wrapper(*args, **kwargs):
322
#         result = func(*args, **kwargs)
323
#         return convert_output_type(result)
324
#     return wrapper
325

326
# this should eventually be moved to a unified validation method
327
def _output_type_setter(value, owning_component):
1✔
328
    # Can't convert from arrays of length > 1 to number
329
    if (
×
330
        owning_component.defaults.variable is not None
331
        and safe_len(owning_component.defaults.variable) > 1
332
        and owning_component.output_type is FunctionOutputType.NP_0D_ARRAY
333
    ):
334
        raise FunctionError(
335
            f"{owning_component.__class__.__name__} can't be set to return a "
336
            "single number since its variable has more than one number."
337
        )
338

339
    # warn if user overrides the 2D setting for mechanism functions
340
    # may be removed when
341
    # https://github.com/PrincetonUniversity/PsyNeuLink/issues/895 is solved
342
    # properly(meaning Mechanism values may be something other than 2D np array)
343
    try:
×
344
        if (
×
345
            isinstance(owning_component.owner, Mechanism)
346
            and (
347
                value == FunctionOutputType.NP_0D_ARRAY
348
                or value == FunctionOutputType.NP_1D_ARRAY
349
            )
350
        ):
351
            warnings.warn(
×
352
                f'Functions that are owned by a Mechanism but do not return a '
353
                '2D numpy array may cause unexpected behavior if llvm '
354
                'compilation is enabled.'
355
            )
356
    except (AttributeError, ImportError):
×
357
        pass
×
358

359
    return value
×
360

361

362
def _seed_setter(value, owning_component, context, *, compilation_sync):
1✔
363
    if compilation_sync:
1✔
364
        # compilation sync should provide shared memory 0d array with a floating point value.
365
        assert value is not None
1✔
366
        assert value != DEFAULT_SEED()
1✔
367
        assert value.shape == ()
1✔
368

369
        return value
1✔
370

371
    value = try_extract_0d_array_item(value)
1✔
372
    if value is None or value == DEFAULT_SEED():
1✔
373
        value = _get_global_seed()
1✔
374

375
    # Remove any old PRNG state
376
    owning_component.parameters.random_state.set(None, context=context)
1✔
377
    return np.asarray(value)
1✔
378

379

380
def _random_state_getter(self, owning_component, context, modulated=False):
1✔
381

382
    seed_param = owning_component.parameters.seed
1✔
383
    try:
1✔
384
        has_modulation = seed_param.port.has_modulation(context.composition)
1✔
385
    except AttributeError:
1✔
386
        has_modulation = False
1✔
387

388
    # 'has_modulation' indicates that seed has an active modulatory projection
389
    # 'modulated' indicates that the modulated value is requested
390
    if has_modulation and modulated:
1✔
391
        seed_value = [int(owning_component._get_current_parameter_value(seed_param, context).item())]
1✔
392
    else:
393
        seed_value = [int(seed_param._get(context=context))]
1✔
394

395
    if seed_value == [DEFAULT_SEED()]:
1✔
396
        raise FunctionError(
397
            "Invalid seed for {} in context: {} ({})".format(
398
                owning_component, context.execution_id, seed_param
399
            )
400
        )
401

402
    current_state = self.values.get(context.execution_id, None)
1✔
403
    if current_state is None:
1✔
404
        return SeededRandomState(seed_value)
1✔
405
    if current_state.used_seed != seed_value:
1✔
406
        return type(current_state)(seed_value)
1✔
407

408
    return current_state
1✔
409

410

411
def _noise_setter(value, owning_component, context):
1✔
412
    def has_function(x):
1✔
413
        return (
1✔
414
            is_instance_or_subclass(x, (Function_Base, types.FunctionType))
415
            or contains_type(x, (Function_Base, types.FunctionType))
416
        )
417

418
    noise_param = owning_component.parameters.noise
1✔
419
    value_has_function = has_function(value)
1✔
420
    # initial set
421
    if owning_component.is_initializing:
1✔
422
        if value_has_function:
1✔
423
            # is changing a parameter attribute like this ok?
424
            noise_param.stateful = False
1✔
425
    else:
426
        default_value_has_function = has_function(noise_param.default_value)
1✔
427

428
        if default_value_has_function and not value_has_function:
1✔
429
            warnings.warn(
1✔
430
                'Setting noise to a numeric value after instantiation'
431
                ' with a value containing functions will not remove the'
432
                ' noise ParameterPort or make noise stateful.'
433
            )
434
        elif not default_value_has_function and value_has_function:
1✔
435
            warnings.warn(
1✔
436
                'Setting noise to a value containing functions after'
437
                ' instantiation with a numeric value will not create a'
438
                ' noise ParameterPort or make noise stateless.'
439
            )
440

441
    return value
1✔
442

443

444
class Function_Base(Function):
1✔
445
    """
446
    Function_Base(           \
447
         default_variable,   \
448
         params=None,        \
449
         owner=None,         \
450
         name=None,          \
451
         prefs=None          \
452
    )
453

454
    Abstract base class for Function category of Component class
455

456
    COMMENT:
457
        Description:
458
            Functions are used to "wrap" functions used used by other components;
459
            They are defined here (on top of standard libraries) to provide a uniform interface for managing parameters
460
             (including defaults)
461
            NOTE:   the Function category definition serves primarily as a shell, and as an interface to the Function
462
                       class, to maintain consistency of structure with the other function categories;
463
                    it also insures implementation of .function for all Function Components
464
                    (as distinct from other Function subclasses, which can use a FUNCTION param
465
                        to implement .function instead of doing so directly)
466
                    Function Components are the end of the recursive line; as such:
467
                        they don't implement functionParams
468
                        in general, don't bother implementing function, rather...
469
                        they rely on Function_Base.function which passes on the return value of .function
470

471
        Variable and Parameters:
472
        IMPLEMENTATION NOTE:  ** DESCRIBE VARIABLE HERE AND HOW/WHY IT DIFFERS FROM PARAMETER
473
            - Parameters can be assigned and/or changed individually or in sets, by:
474
              - including them in the initialization call
475
              - calling the _instantiate_defaults method (which changes their default values)
476
              - including them in a call the function method (which changes their values for just for that call)
477
            - Parameters must be specified in a params dictionary:
478
              - the key for each entry should be the name of the parameter (used also to name associated Projections)
479
              - the value for each entry is the value of the parameter
480

481
        Return values:
482
            The output_type can be used to specify type conversion for single-item return values:
483
            - it can only be used for numbers or a single-number list; other values will generate an exception
484
            - if self.output_type is set to:
485
                FunctionOutputType.NP_0D_ARRAY, return value is "exposed" as a number
486
                FunctionOutputType.NP_1D_ARRAY, return value is 1d np.array
487
                FunctionOutputType.NP_2D_ARRAY, return value is 2d np.array
488
            - it must be enabled for a subclass by setting params[FUNCTION_OUTPUT_TYPE_CONVERSION] = True
489
            - it must be implemented in the execute method of the subclass
490
            - see Linear for an example
491

492
        MechanismRegistry:
493
            All Function functions are registered in FunctionRegistry, which maintains a dict for each subclass,
494
              a count for all instances of that type, and a dictionary of those instances
495

496
        Naming:
497
            Function functions are named by their componentName attribute (usually = componentType)
498

499
        Class attributes:
500
            + componentCategory: FUNCTION_COMPONENT_CATEGORY
501
            + className (str): kwMechanismFunctionCategory
502
            + suffix (str): " <className>"
503
            + registry (dict): FunctionRegistry
504
            + classPreference (PreferenceSet): BasePreferenceSet, instantiated in __init__()
505
            + classPreferenceLevel (PreferenceLevel): PreferenceLevel.CATEGORY
506

507
        Class methods:
508
            none
509

510
        Instance attributes:
511
            + componentType (str):  assigned by subclasses
512
            + componentName (str):   assigned by subclasses
513
            + variable (value) - used as input to function's execute method
514
            + value (value) - output of execute method
515
            + name (str) - if not specified as an arg, a default based on the class is assigned in register_category
516
            + prefs (PreferenceSet) - if not specified as an arg, default is created by copying BasePreferenceSet
517

518
        Instance methods:
519
            The following method MUST be overridden by an implementation in the subclass:
520
            - execute(variable, params)
521
            The following can be implemented, to customize validation of the function variable and/or params:
522
            - [_validate_variable(variable)]
523
            - [_validate_params(request_set, target_set, context)]
524
    COMMENT
525

526
    Arguments
527
    ---------
528

529
    variable : value : default class_defaults.variable
530
        specifies the format and a default value for the input to `function <Function>`.
531

532
    params : Dict[param keyword: param value] : default None
533
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
534
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
535
        arguments of the constructor.
536

537
    owner : Component
538
        `component <Component>` to which to assign the Function.
539

540
    name : str : default see `name <Function.name>`
541
        specifies the name of the Function.
542

543
    prefs : PreferenceSet or specification dict : default Function.classPreferences
544
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
545

546

547
    Attributes
548
    ----------
549

550
    variable: number
551
        format and default value can be specified by the :keyword:`variable` argument of the constructor;  otherwise,
552
        they are specified by the Function's :keyword:`class_defaults.variable`.
553

554
    function : function
555
        called by the Function's `owner <Function_Base.owner>` when it is executed.
556

557
    value : number
558
        the result returned by calling the Function.
559

560
    COMMENT:
561
    enable_output_type_conversion : Bool : False
562
        specifies whether `function output type conversion <Function_Output_Type_Conversion>` is enabled.
563

564
    output_type : FunctionOutputType : None
565
        used to determine the return type for the `function <Function_Base.function>`;  `functionOuputTypeConversion`
566
        must be enabled and implemented for the class (see `FunctionOutputType <Function_Output_Type_Conversion>`
567
        for details).
568

569
    changes_shape : bool : False
570
        specifies whether the return value of the function is different than the shape of either is outermost dimension
571
        (axis 0) of its  its `variable <Function_Base.variable>`, or any of the items in the next dimension (axis 1).
572
        Used to determine whether the shape of the inputs to the `Component` to which the function is assigned
573
        should be based on the `variable <Function_Base.variable>` of the function or its `value <Function_Base.value>`.
574
    COMMENT
575

576
    owner : Component
577
        `component <Component>` to which the Function has been assigned.
578

579
    name : str
580
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
581
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
582

583
    prefs : PreferenceSet or specification dict : Function.classPreferences
584
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
585
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
586
        for details).
587

588
    """
589

590
    componentCategory = FUNCTION_COMPONENT_CATEGORY
1✔
591
    className = componentCategory
1✔
592
    suffix = " " + className
1✔
593

594
    registry = FunctionRegistry
1✔
595

596
    classPreferenceLevel = PreferenceLevel.CATEGORY
1✔
597

598
    _model_spec_id_parameters = 'args'
1✔
599
    _mdf_stateful_parameter_indices = {}
1✔
600

601
    _specified_variable_shape_flexibility = DefaultsFlexibility.INCREASE_DIMENSION
1✔
602

603
    class Parameters(Function.Parameters):
1✔
604
        """
605
            Attributes
606
            ----------
607

608
                variable
609
                    see `variable <Function_Base.variable>`
610

611
                    :default value: numpy.array([0])
612
                    :type: ``numpy.ndarray``
613
                    :read only: True
614

615
                enable_output_type_conversion
616
                    see `enable_output_type_conversion <Function_Base.enable_output_type_conversion>`
617

618
                    :default value: False
619
                    :type: ``bool``
620

621
                changes_shape
622
                    see `changes_shape <Function_Base.changes_shape>`
623

624
                    :default value: False
625
                    :type: bool
626

627
                output_type
628
                    see `output_type <Function_Base.output_type>`
629

630
                    :default value: FunctionOutputType.DEFAULT
631
                    :type: `FunctionOutputType`
632

633
        """
634
        variable = Parameter(np.array([0]), read_only=True, pnl_internal=True, constructor_argument='default_variable')
1✔
635

636
        output_type = Parameter(
1✔
637
            FunctionOutputType.DEFAULT,
638
            stateful=False,
639
            loggable=False,
640
            pnl_internal=True,
641
            valid_types=FunctionOutputType
642
        )
643
        enable_output_type_conversion = Parameter(False, stateful=False, loggable=False, pnl_internal=True)
1✔
644

645
        changes_shape = Parameter(False, stateful=False, loggable=False, pnl_internal=True)
1✔
646
        def _validate_changes_shape(self, param):
1✔
647
            if not isinstance(param, bool):
1!
648
                return f'must be a bool.'
×
649

650
    # Note: the following enforce encoding as 1D np.ndarrays (one array per variable)
651
    variableEncodingDim = 1
1✔
652

653
    @check_user_specified
1✔
654
    @abc.abstractmethod
1✔
655
    def __init__(
1✔
656
        self,
657
        default_variable,
658
        params,
659
        owner=None,
660
        name=None,
661
        prefs=None,
662
        context=None,
663
        **kwargs
664
    ):
665
        """Assign category-level preferences, register category, and call super.__init__
666

667
        Initialization arguments:
668
        - default_variable (anything): establishes type for the variable, used for validation
669
        Note: if parameter_validation is off, validation is suppressed (for efficiency) (Function class default = on)
670

671
        :param default_variable: (anything but a dict) - value to assign as self.defaults.variable
672
        :param params: (dict) - params to be assigned as instance defaults
673
        :param log: (ComponentLog enum) - log entry types set in self.componentLog
674
        :param name: (string) - optional, overrides assignment of default (componentName of subclass)
675
        :return:
676
        """
677

678
        if self.initialization_status == ContextFlags.DEFERRED_INIT:
1!
679
            self._assign_deferred_init_name(name)
×
680
            self._init_args[NAME] = name
×
681
            return
×
682

683
        register_category(entry=self,
1✔
684
                          base_class=Function_Base,
685
                          registry=FunctionRegistry,
686
                          name=name,
687
                          )
688
        self.owner = owner
1✔
689

690
        super().__init__(
1✔
691
            default_variable=default_variable,
692
            param_defaults=params,
693
            name=name,
694
            prefs=prefs,
695
            **kwargs
696
        )
697

698
    def __call__(self, *args, **kwargs):
1✔
699
        return self.function(*args, **kwargs)
1✔
700

701
    def __deepcopy__(self, memo):
1✔
702
        new = super().__deepcopy__(memo)
1✔
703

704
        if self is not new:
1✔
705
            # ensure copy does not have identical name
706
            register_category(new, Function_Base, new.name, FunctionRegistry)
1✔
707
            if "random_state" in new.parameters:
1✔
708
                # HACK: Make sure any copies are re-seeded to avoid dependent RNG.
709
                # functions with "random_state" param must have "seed" parameter
710
                for ctx in new.parameters.seed.values:
1✔
711
                    new.parameters.seed.set(
1✔
712
                        DEFAULT_SEED(), ctx, skip_log=True, skip_history=True
713
                    )
714

715
        return new
1✔
716

717
    @handle_external_context()
1✔
718
    def function(self,
1✔
719
                 variable=None,
720
                 context=None,
721
                 params=None,
722
                 target_set=None,
723
                 **kwargs):
724

725
        if ContextFlags.COMMAND_LINE in context.source:
1✔
726
            variable = copy_parameter_value(variable)
1✔
727

728
        # IMPLEMENTATION NOTE:
729
        # The following is a convenience feature that supports specification of params directly in call to function
730
        # by moving the to a params dict, which treats them as runtime_params
731
        if kwargs:
1✔
732
            for key in kwargs.copy():
1✔
733
                if key in self.parameters.names():
1✔
734
                    if not params:
1✔
735
                        params = {key: kwargs.pop(key)}
1✔
736
                    else:
737
                        params.update({key: kwargs.pop(key)})
1✔
738

739
        # Validate variable and assign to variable, and validate params
740
        variable = self._check_args(variable=variable,
1✔
741
                                    context=context,
742
                                    params=params,
743
                                    target_set=target_set,
744
                                    )
745
        # Execute function
746
        value = self._function(
1✔
747
            variable=variable, context=context, params=params, **kwargs
748
        )
749
        self.most_recent_context = context
1✔
750
        self.parameters.value._set(value, context=context)
1✔
751
        self._reset_runtime_parameters(context)
1✔
752
        return value
1✔
753

754
    @abc.abstractmethod
1✔
755
    def _function(
1✔
756
        self,
757
        variable=None,
758
        context=None,
759
        params=None,
760

761
    ):
762
        pass
×
763

764
    def _parse_arg_generic(self, arg_val):
1✔
765
        if isinstance(arg_val, list):
×
766
            return np.asarray(arg_val)
×
767
        else:
768
            return arg_val
×
769

770
    def _validate_parameter_spec(self, param, param_name, numeric_only=True):
1✔
771
        """Validates function param
772
        Replace direct call to parameter_spec in tc, which seems to not get called by Function __init__()'s
773
        """
774
        if not parameter_spec(param, numeric_only):
1!
775
            owner_name = 'of ' + self.owner_name if self.owner else ""
×
776
            raise FunctionError(f"{param} is not a valid specification for "
777
                                f"the {param_name} argument of {self.__class__.__name__}{owner_name}.")
778

779
    def _get_current_parameter_value(self, param_name, context=None):
1✔
780
        try:
1✔
781
            param = getattr(self.parameters, param_name)
1✔
782
        except TypeError:
1✔
783
            param = param_name
1✔
784
        except AttributeError:
×
785
            # don't accept strings that don't correspond to Parameters
786
            # on this function
787
            raise
×
788

789
        return super()._get_current_parameter_value(param, context)
1✔
790

791
    def get_previous_value(self, context=None):
1✔
792
        # temporary method until previous values are integrated for all parameters
793
        value = self.parameters.previous_value._get(context)
1✔
794

795
        return value
1✔
796

797
    def convert_output_type(self, value, output_type=None):
1✔
798
        value = convert_all_elements_to_np_array(value)
1✔
799
        if output_type is None:
1✔
800
            if not self.enable_output_type_conversion or self.output_type is None:
1✔
801
                return value
1✔
802
            else:
803
                output_type = self.output_type
1✔
804

805
        # Type conversion (specified by output_type):
806

807
        # MODIFIED 6/21/19 NEW: [JDC]
808
        # Convert to same format as variable
809
        if isinstance(output_type, (list, np.ndarray)):
1✔
810
            shape = np.array(output_type).shape
1✔
811
            return np.array(value).reshape(shape)
1✔
812
        # MODIFIED 6/21/19 END
813

814
        # Convert to 2D array, irrespective of value type:
815
        if output_type is FunctionOutputType.NP_2D_ARRAY:
1✔
816
            # KDM 8/10/18: mimicking the conversion that Mechanism does to its values, because
817
            # this is what we actually wanted this method for. Can be changed to pure 2D np array in
818
            # future if necessary
819

820
            converted_to_2d = np.atleast_2d(value)
1✔
821
            # If return_value is a list of heterogenous elements, return as is
822
            #     (satisfies requirement that return_value be an array of possibly multidimensional values)
823
            if converted_to_2d.dtype == object:
1✔
824
                pass
1✔
825
            # Otherwise, return value converted to 2d np.array
826
            else:
827
                value = converted_to_2d
1✔
828

829
        # Convert to 1D array, irrespective of value type:
830
        # Note: if 2D array (or higher) has more than two items in the outer dimension, generate exception
831
        elif output_type is FunctionOutputType.NP_1D_ARRAY:
1✔
832
            # If variable is 2D
833
            if value.ndim >= 2:
1✔
834
                # If there is only one item:
835
                if len(value) == 1:
1✔
836
                    value = value[0]
1✔
837
                else:
838
                    raise FunctionError(f"Can't convert value ({value}: 2D np.ndarray object "
839
                                        f"with more than one array) to 1D array.")
840
            elif value.ndim == 1:
1!
841
                pass
1✔
842
            elif value.ndim == 0:
×
843
                value = np.atleast_1d(value)
×
844
            else:
845
                raise FunctionError(f"Can't convert value ({value} to 1D array.")
846

847
        # Convert to raw number, irrespective of value type:
848
        # Note: if 2D or 1D array has more than two items, generate exception
849
        elif output_type is FunctionOutputType.NP_0D_ARRAY:
1!
850
            if object_has_single_value(value):
1✔
851
                value = np.asarray(value, dtype=float)
1✔
852
            else:
853
                raise FunctionError(f"Can't convert value ({value}) with more than a single number to a raw number.")
854

855
        return value
1✔
856

857
    @property
1✔
858
    def owner_name(self):
1✔
859
        try:
1✔
860
            return self.owner.name
1✔
861
        except AttributeError:
×
862
            return '<no owner>'
×
863

864
    def _is_identity(self, context=None, defaults=False):
1✔
865
        # should return True in subclasses if the parameters for context are such that
866
        # the Function's output will be the same as its input
867
        # Used to bypass execute when unnecessary
868
        return False
1✔
869

870
    @property
1✔
871
    def _model_spec_parameter_blacklist(self):
1✔
872
        return super()._model_spec_parameter_blacklist.union({
1✔
873
            'multiplicative_param', 'additive_param',
874
        })
875

876
    def _assign_to_mdf_model(self, model, input_id) -> str:
1✔
877
        """Adds an MDF representation of this function to MDF object
878
        **model**, including all necessary auxiliary functions.
879
        **input_id** is the input to the singular MDF function or first
880
        function representing this psyneulink Function, if applicable.
881

882
        Returns:
883
            str: the identifier of the final MDF function representing
884
            this psyneulink Function
885
        """
886
        import modeci_mdf.mdf as mdf
1✔
887

888
        extra_noise_functions = []
1✔
889

890
        self_model = self.as_mdf_model()
1✔
891

892
        def handle_noise(noise):
1✔
893
            if is_instance_or_subclass(noise, Component):
1✔
894
                if inspect.isclass(noise) and issubclass(noise, Component):
1!
895
                    noise = noise()
×
896
                noise_func_model = noise.as_mdf_model()
1✔
897
                extra_noise_functions.append(noise_func_model)
1✔
898
                return noise_func_model.id
1✔
899
            elif isinstance(noise, (list, np.ndarray)):
1!
900
                if noise.ndim == 0:
1!
901
                    return None
1✔
902
                return type(noise)(handle_noise(item) for item in noise)
×
903
            else:
904
                return None
×
905

906
        try:
1✔
907
            noise_val = handle_noise(self.defaults.noise)
1✔
908
        except AttributeError:
1✔
909
            noise_val = None
1✔
910

911
        if noise_val is not None:
1✔
912
            noise_func = mdf.Function(
1✔
913
                id=f'{model.id}_{parse_valid_identifier(self.name)}_noise',
914
                value=MODEL_SPEC_ID_MDF_VARIABLE,
915
                args={MODEL_SPEC_ID_MDF_VARIABLE: noise_val},
916
            )
917
            self._set_mdf_arg(self_model, 'noise', noise_func.id)
1✔
918

919
            model.functions.extend(extra_noise_functions)
1✔
920
            model.functions.append(noise_func)
1✔
921

922
        self_model.id = f'{model.id}_{self_model.id}'
1✔
923
        self._set_mdf_arg(self_model, _get_variable_parameter_name(self), input_id)
1✔
924
        model.functions.append(self_model)
1✔
925

926
        # assign stateful parameters
927
        for name, index in self._mdf_stateful_parameter_indices.items():
1✔
928
            # in this case, parameter gets updated to its function's final value
929
            param = getattr(self.parameters, name)
1✔
930

931
            try:
1✔
932
                initializer_value = self_model.args[param.initializer]
1✔
933
            except KeyError:
1✔
934
                initializer_value = self_model.metadata[param.initializer]
1✔
935

936
            index_str = f'[{index}]' if index is not None else ''
1✔
937

938
            model.parameters.append(
1✔
939
                mdf.Parameter(
940
                    id=param.mdf_name if param.mdf_name is not None else param.name,
941
                    default_initial_value=initializer_value,
942
                    value=f'{self_model.id}{index_str}'
943
                )
944
            )
945

946
        return self_model.id
1✔
947

948
    def as_mdf_model(self):
1✔
949
        import modeci_mdf.mdf as mdf
1✔
950
        import modeci_mdf.functions.standard as mdf_functions
1✔
951

952
        parameters = self._mdf_model_parameters
1✔
953
        metadata = self._mdf_metadata
1✔
954
        stateful_params = set()
1✔
955

956
        # add stateful parameters into metadata for mechanism to get
957
        for name in parameters[self._model_spec_id_parameters]:
1✔
958
            try:
1✔
959
                param = getattr(self.parameters, name)
1✔
960
            except AttributeError:
1✔
961
                continue
1✔
962

963
            if param.initializer is not None:
1✔
964
                stateful_params.add(name)
1✔
965

966
        # stateful parameters cannot show up as args or they will not be
967
        # treated statefully in mdf
968
        for sp in stateful_params:
1✔
969
            del parameters[self._model_spec_id_parameters][sp]
1✔
970

971
        model = mdf.Function(
1✔
972
            id=parse_valid_identifier(self.name),
973
            **parameters,
974
            **metadata,
975
        )
976

977
        try:
1✔
978
            model.value = self.as_expression()
1✔
979
        except AttributeError:
1✔
980
            if self._model_spec_generic_type_name is not NotImplemented:
1✔
981
                typ = self._model_spec_generic_type_name
1✔
982
            else:
983
                try:
1✔
984
                    typ = self.custom_function.__name__
1✔
985
                except AttributeError:
1✔
986
                    typ = type(self).__name__.lower()
1✔
987

988
            if typ not in mdf_functions.mdf_functions:
1✔
989
                warnings.warn(f'{typ} is not an MDF standard function, this is likely to produce an incompatible model.')
1✔
990

991
            model.function = typ
1✔
992

993
        return model
1✔
994

995
    def _get_pytorch_fct_param_value(self, param_name, device, context):
1✔
996
        """Return the current value of param_name for the function
997
         Use default value if not yet assigned
998
         Convert using torch.tensor if val is an array
999
        """
1000
        val = self._get_current_parameter_value(param_name, context=context)
1✔
1001
        if val is None:
1✔
1002
            val = getattr(self.defaults, param_name)
1✔
1003
        if isinstance(val, (str, type(None))):
1✔
1004
            return val
1✔
1005
        elif np.isscalar(np.array(val)):
1!
1006
            return float(val)
×
1007
        try:
1✔
1008
            # return torch.tensor(val, device=device).double()
1009
            return torch.tensor(val, device=device)
1✔
1010
        except Exception as error:
×
1011
            raise FunctionError(f"PROGRAM ERROR: unsupported value of parameter '{param_name}' ({val}) "
1012
                                f"encountered in pytorch_function_creator(): {error.args[0]}")
1013

1014

1015
# *****************************************   EXAMPLE FUNCTION   *******************************************************
1016
PROPENSITY = "propensity"
1✔
1017
PERTINACITY = "pertinacity"
1✔
1018

1019

1020
class ArgumentTherapy(Function_Base):
1✔
1021
    """
1022
    ArgumentTherapy(                   \
1023
         variable,                     \
1024
         propensity=Manner.CONTRARIAN, \
1025
         pertinacity=10.0              \
1026
         params=None,                  \
1027
         owner=None,                   \
1028
         name=None,                    \
1029
         prefs=None                    \
1030
         )
1031

1032
    .. _ArgumentTherapist:
1033

1034
    Return `True` or :keyword:`False` according to the manner of the therapist.
1035

1036
    Arguments
1037
    ---------
1038

1039
    variable : boolean or statement that resolves to one : default class_defaults.variable
1040
        assertion for which a therapeutic response will be offered.
1041

1042
    propensity : Manner value : default Manner.CONTRARIAN
1043
        specifies preferred therapeutic manner
1044

1045
    pertinacity : float : default 10.0
1046
        specifies therapeutic consistency
1047

1048
    params : Dict[param keyword: param value] : default None
1049
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1050
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1051
        arguments of the constructor.
1052

1053
    owner : Component
1054
        `component <Component>` to which to assign the Function.
1055

1056
    name : str : default see `name <Function.name>`
1057
        specifies the name of the Function.
1058

1059
    prefs : PreferenceSet or specification dict : default Function.classPreferences
1060
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
1061

1062

1063
    Attributes
1064
    ----------
1065

1066
    variable : boolean
1067
        assertion to which a therapeutic response is made.
1068

1069
    propensity : Manner value : default Manner.CONTRARIAN
1070
        determines therapeutic manner:  tendency to agree or disagree.
1071

1072
    pertinacity : float : default 10.0
1073
        determines consistency with which the manner complies with the propensity.
1074

1075
    owner : Component
1076
        `component <Component>` to which the Function has been assigned.
1077

1078
    name : str
1079
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
1080
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
1081

1082
    prefs : PreferenceSet or specification dict : Function.classPreferences
1083
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
1084
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
1085
        for details).
1086

1087

1088
    """
1089

1090
    # Function componentName and type (defined at top of module)
1091
    componentName = ARGUMENT_THERAPY_FUNCTION
1✔
1092
    componentType = EXAMPLE_FUNCTION_TYPE
1✔
1093

1094
    classPreferences = {
1✔
1095
        PREFERENCE_SET_NAME: 'ExampleClassPreferences',
1096
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
1097
    }
1098

1099
    class Parameters(Function_Base.Parameters):
1✔
1100
        propensity = None
1✔
1101
        pertinacity = None
1✔
1102

1103
    # Mode indicators
1104
    class Manner(Enum):
1✔
1105
        OBSEQUIOUS = 0
1✔
1106
        CONTRARIAN = 1
1✔
1107

1108
    # Parameter class defaults
1109
    # These are used both to type-cast the params, and as defaults if none are assigned
1110
    #  in the initialization call or later (using either _instantiate_defaults or during a function call)
1111

1112
    @check_user_specified
1✔
1113
    def __init__(self,
1✔
1114
                 default_variable=None,
1115
                 propensity=Manner.CONTRARIAN,
1116
                 pertinacity=10.0,
1117
                 params=None,
1118
                 owner=None,
1119
                 prefs:  Optional[ValidPrefSet] = None):
1120

1121
        super().__init__(
×
1122
            default_variable=default_variable,
1123
            propensity=propensity,
1124
            pertinacity=pertinacity,
1125
            params=params,
1126
            owner=owner,
1127
            prefs=prefs,
1128
        )
1129

1130
    def _validate_variable(self, variable, context=None):
1✔
1131
        """Validates variable and returns validated value
1132

1133
        This overrides the class method, to perform more detailed type checking
1134
        See explanation in class method.
1135
        Note: this method (or the class version) is called only if the parameter_validation attribute is `True`
1136

1137
        :param variable: (anything but a dict) - variable to be validated:
1138
        :param context: (str)
1139
        :return variable: - validated
1140
        """
1141

1142
        if type(variable) == type(self.class_defaults.variable) or \
×
1143
                (isinstance(variable, numbers.Number) and isinstance(self.class_defaults.variable, numbers.Number)):
1144
            return variable
×
1145
        else:
1146
            raise FunctionError(f"Variable must be {type(self.class_defaults.variable)}.")
1147

1148
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
1149
        """Validates variable and /or params and assigns to targets
1150

1151
        This overrides the class method, to perform more detailed type checking
1152
        See explanation in class method.
1153
        Note: this method (or the class version) is called only if the parameter_validation attribute is `True`
1154

1155
        :param request_set: (dict) - params to be validated
1156
        :param target_set: (dict) - destination of validated params
1157
        :return none:
1158
        """
1159

1160
        message = ""
×
1161

1162
        # Check params
1163
        for param_name, param_value in request_set.items():
×
1164

1165
            if param_name == PROPENSITY:
×
1166
                if isinstance(param_value, ArgumentTherapy.Manner):
×
1167
                    # target_set[self.PROPENSITY] = param_value
1168
                    pass  # This leaves param in request_set, clear to be assigned to target_set in call to super below
×
1169
                else:
1170
                    message = "Propensity must be of type Example.Mode"
×
1171
                continue
×
1172

1173
            # Validate param
1174
            if param_name == PERTINACITY:
×
NEW
1175
                if is_numeric_scalar(param_value) and 0 <= param_value <= 10:
×
1176
                    # target_set[PERTINACITY] = param_value
1177
                    pass  # This leaves param in request_set, clear to be assigned to target_set in call to super below
×
1178
                else:
1179
                    message += "Pertinacity must be a number between 0 and 10"
×
1180
                continue
×
1181

1182
        if message:
×
1183
            raise FunctionError(message)
1184

1185
        super()._validate_params(request_set, target_set, context)
×
1186

1187
    def _function(self,
1✔
1188
                 variable=None,
1189
                 context=None,
1190
                 params=None,
1191
                 ):
1192
        """
1193
        Returns a boolean that is (or tends to be) the same as or opposite the one passed in.
1194

1195
        Arguments
1196
        ---------
1197

1198
        variable : boolean : default class_defaults.variable
1199
           an assertion to which a therapeutic response is made.
1200

1201
        params : Dict[param keyword: param value] : default None
1202
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1203
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1204
            arguments of the constructor.
1205

1206

1207
        Returns
1208
        -------
1209

1210
        therapeutic response : boolean
1211

1212
        """
1213
        # Compute the function
1214
        statement = variable
×
1215
        propensity = self._get_current_parameter_value(PROPENSITY, context)
×
1216
        pertinacity = self._get_current_parameter_value(PERTINACITY, context)
×
1217
        whim = np.random.randint(-10, 10)
×
1218

1219
        if propensity == self.Manner.OBSEQUIOUS:
×
1220
            value = whim < pertinacity
×
1221

1222
        elif propensity == self.Manner.CONTRARIAN:
×
1223
            value = whim > pertinacity
×
1224

1225
        else:
1226
            raise FunctionError("This should not happen if parameter_validation == True;  check its value")
1227

1228
        return self.convert_output_type(value)
×
1229

1230

1231

1232
kwEVCAuxFunction = "EVC AUXILIARY FUNCTION"
1✔
1233
kwEVCAuxFunctionType = "EVC AUXILIARY FUNCTION TYPE"
1✔
1234
kwValueFunction = "EVC VALUE FUNCTION"
1✔
1235
CONTROL_SIGNAL_GRID_SEARCH_FUNCTION = "EVC CONTROL SIGNAL GRID SEARCH FUNCTION"
1✔
1236
CONTROLLER = 'controller'
1✔
1237

1238
class EVCAuxiliaryFunction(Function_Base):
1✔
1239
    """Base class for EVC auxiliary functions
1240
    """
1241
    componentType = kwEVCAuxFunctionType
1✔
1242

1243
    class Parameters(Function_Base.Parameters):
1✔
1244
        """
1245
            Attributes
1246
            ----------
1247

1248
                variable
1249
                    see `variable <Function_Base.variable>`
1250

1251
                    :default value: numpy.array([0])
1252
                    :type: numpy.ndarray
1253
                    :read only: True
1254

1255
        """
1256
        variable = Parameter(None, pnl_internal=True, constructor_argument='default_variable')
1✔
1257

1258
    classPreferences = {
1✔
1259
        PREFERENCE_SET_NAME: 'ValueFunctionCustomClassPreferences',
1260
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
1261
       }
1262

1263
    @check_user_specified
1✔
1264
    @beartype
1✔
1265
    def __init__(self,
1✔
1266
                 function,
1267
                 variable=None,
1268
                 params=None,
1269
                 owner=None,
1270
                 prefs:   Optional[ValidPrefSet] = None,
1271
                 context=None):
1272
        self.aux_function = function
×
1273

1274
        super().__init__(default_variable=variable,
×
1275
                         params=params,
1276
                         owner=owner,
1277
                         prefs=prefs,
1278
                         context=context,
1279
                         function=function,
1280
                         )
1281

1282

1283
class RandomMatrix():
1✔
1284
    """Function that returns matrix with random elements distributed uniformly around **center** across **range**.
1285

1286
    The **center** and **range** arguments are passed at construction, and used for all subsequent calls.
1287
    Once constructed, the function must be called with two floats, **sender_size** and **receiver_size**,
1288
    that specify the number of rows and columns of the matrix, respectively.
1289

1290
    Can be used to specify the `matrix <MappingProjection.matrix>` parameter of a `MappingProjection
1291
    <MappingProjection_Matrix_Specification>`, and to specify a default matrix for Projections in the
1292
    construction of a `Pathway` (see `Pathway_Specification_Projections`) or in a call to a Composition's
1293
    `add_linear_processing_pathway<Composition.add_linear_processing_pathway>` method.
1294

1295
    .. technical_note::
1296
       A call to the class calls `random_matrix <Utilities.random_matrix>`, passing **sender_size** and
1297
       **receiver_size** to `random_matrix <Utilities.random_matrix>` as its **num_rows** and **num_cols**
1298
       arguments, respectively, and passing the `center <RandomMatrix.offset>`-0.5 and `range <RandomMatrix.scale>`
1299
       attributes specified at construction to `random_matrix <Utilities.random_matrix>` as its **offset**
1300
       and **scale** arguments, respectively.
1301

1302
    Arguments
1303
    ----------
1304
    center : float
1305
        specifies the value around which the matrix elements are distributed in all calls to the function.
1306
    range : float
1307
        specifies range over which all matrix elements are distributed in all calls to the function.
1308

1309
    Attributes
1310
    ----------
1311
    center : float
1312
        determines the center of the distribution of the matrix elements;
1313
    range : float
1314
        determines the range of the distribution of the matrix elements;
1315
    """
1316

1317
    def __init__(self, center:float=0.0, range:float=1.0):
1✔
1318
        self.center=center
×
1319
        self.range=range
×
1320

1321
    def __call__(self, sender_size:int, receiver_size:int):
1✔
1322
        return random_matrix(sender_size, receiver_size, offset=self.center - 0.5, scale=self.range)
×
1323

1324

1325
def get_matrix(specification, rows=1, cols=1, context=None):
1✔
1326
    """Returns matrix conforming to specification with dimensions = rows x cols or None
1327

1328
     Specification can be a matrix keyword, filler value or np.ndarray
1329

1330
     Specification (validated in _validate_params):
1331
        + single number (used to fill self.matrix)
1332
        + matrix keyword:
1333
            + AUTO_ASSIGN_MATRIX: IDENTITY_MATRIX if it is square, othwerwise FULL_CONNECTIVITY_MATRIX
1334
            + IDENTITY_MATRIX: 1's on diagonal, 0's elsewhere (must be square matrix), otherwise generates error
1335
            + HOLLOW_MATRIX: 0's on diagonal, 1's elsewhere (must be square matrix), otherwise generates error
1336
            + INVERSE_HOLLOW_MATRIX: 0's on diagonal, -1's elsewhere (must be square matrix), otherwise generates error
1337
            + FULL_CONNECTIVITY_MATRIX: all 1's
1338
            + ZERO_MATRIX: all 0's
1339
            + RANDOM_CONNECTIVITY_MATRIX (random floats uniformly distributed between 0 and 1)
1340
            + RandomMatrix (random floats uniformly distributed around a specified center value with a specified range)
1341
        + 2D list or np.ndarray of numbers
1342

1343
     Returns 2D array with length=rows in dim 0 and length=cols in dim 1, or none if specification is not recognized
1344
    """
1345

1346
    # Matrix provided (and validated in _validate_params); convert to array
1347
    if isinstance(specification, (list, np.matrix)):
1✔
1348
        if is_numeric(specification):
1✔
1349
            return convert_to_np_array(specification)
1✔
1350
        else:
1351
            return
1✔
1352
        # MODIFIED 4/9/22 END
1353

1354
    if isinstance(specification, np.ndarray):
1✔
1355
        if specification.ndim == 2:
1✔
1356
            return specification
1✔
1357
        # FIX: MAKE THIS AN np.array WITH THE SAME DIMENSIONS??
1358
        elif specification.ndim < 2:
1✔
1359
            return np.atleast_2d(specification)
×
1360
        else:
1361
            raise FunctionError("Specification of np.array for matrix ({}) is more than 2d".
1362
                                format(specification))
1363

1364
    if specification == AUTO_ASSIGN_MATRIX:
1✔
1365
        if rows == cols:
1✔
1366
            specification = IDENTITY_MATRIX
1✔
1367
        else:
1368
            specification = FULL_CONNECTIVITY_MATRIX
1✔
1369

1370
    if specification == FULL_CONNECTIVITY_MATRIX:
1✔
1371
        return np.full((rows, cols), 1.0)
1✔
1372

1373
    if specification == ZEROS_MATRIX:
1✔
1374
        return np.zeros((rows, cols))
1✔
1375

1376
    if specification == IDENTITY_MATRIX:
1✔
1377
        if rows != cols:
1✔
1378
            raise FunctionError("Sender length ({}) must equal receiver length ({}) to use {}".
1379
                                format(rows, cols, specification))
1380
        return np.identity(rows)
1✔
1381

1382
    if specification == HOLLOW_MATRIX:
1✔
1383
        if rows != cols:
1✔
1384
            raise FunctionError("Sender length ({}) must equal receiver length ({}) to use {}".
1385
                                format(rows, cols, specification))
1386
        return 1 - np.identity(rows)
1✔
1387

1388
    if specification == INVERSE_HOLLOW_MATRIX:
1✔
1389
        if rows != cols:
1✔
1390
            raise FunctionError("Sender length ({}) must equal receiver length ({}) to use {}".
1391
                                format(rows, cols, specification))
1392
        return (1 - np.identity(rows)) * -1
1✔
1393

1394
    if specification == RANDOM_CONNECTIVITY_MATRIX:
1✔
1395
        return np.random.rand(rows, cols)
1✔
1396

1397
    # Function is specified, so assume it uses random.rand() and call with sender_len and receiver_len
1398
    if isinstance(specification, (types.FunctionType, RandomMatrix)):
1!
1399
        return specification(rows, cols)
×
1400

1401
    # (7/12/17 CW) this is a PATCH (like the one in MappingProjection) to allow users to
1402
    # specify 'matrix' as a string (e.g. r = RecurrentTransferMechanism(matrix='1 2; 3 4'))
1403
    if type(specification) == str:
1✔
1404
        try:
1✔
1405
            return array_from_matrix_string(specification)
1✔
1406
        except (ValueError, NameError, TypeError):
1✔
1407
            # np.matrix(specification) will give ValueError if specification is a bad value (e.g. 'abc', '1; 1 2')
1408
            #                          [JDC] actually gives NameError if specification is a string (e.g., 'abc')
1409
            pass
1✔
1410

1411
    # Specification not recognized
1412
    return None
1✔
1413

1414

1415
# Valid types for a matrix specification, note this is does not ensure that ND arrays are 1D or 2D like the
1416
# above code does.
1417
ValidMatrixSpecType = Union[MatrixKeywordLiteral, Callable, str, NumericCollections, np.matrix]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc