• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 7318475939

15 Dec 2023 11:58PM UTC coverage: 84.852% (+0.3%) from 84.589%
7318475939

push

github

web-flow
Merge pull request #2852 from PrincetonUniversity/devel

Devel

13470 of 16609 branches covered (0.0%)

Branch coverage included in aggregate %.

2168 of 2345 new or added lines in 47 files covered. (92.45%)

59 existing lines in 11 files now uncovered.

31551 of 36449 relevant lines covered (86.56%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.06
/psyneulink/core/components/functions/function.py
1
#
2
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
3
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
4
#     http://www.apache.org/licenses/LICENSE-2.0
5
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
6
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7
# See the License for the specific language governing permissions and limitations under the License.
8
#
9
#
10
# ***********************************************  Function ************************************************************
11

12
"""
1✔
13
|
14
Function
15
  * `Function_Base`
16

17
Example function:
18
  * `ArgumentTherapy`
19

20

21
.. _Function_Overview:
22

23
Overview
24
--------
25

26
A Function is a `Component <Component>` that "packages" a function for use by other Components.
27
Every Component in PsyNeuLink is assigned a Function; when that Component is executed, its
28
Function's `function <Function_Base.function>` is executed.  The `function <Function_Base.function>` can be any callable
29
operation, although most commonly it is a mathematical operation (and, for those, almost always uses a call to one or
30
more numpy functions).  There are two reasons PsyNeuLink packages functions in a Function Component:
31

32
* **Manage parameters** -- parameters are attributes of a Function that either remain stable over multiple calls to the
33
  function (e.g., the `gain <Logistic.gain>` or `bias <Logistic.bias>` of a `Logistic` function, or the learning rate
34
  of a learning function); or, if they change, they do so less frequently or under the control of different factors
35
  than the function's variable (i.e., its input).  As a consequence, it is useful to manage these separately from the
36
  function's variable, and not have to provide them every time the function is called.  To address this, every
37
  PsyNeuLink Function has a set of attributes corresponding to the parameters of the function, that can be specified at
38
  the time the Function is created (in arguments to its constructor), and can be modified independently
39
  of a call to its :keyword:`function`. Modifications can be directly (e.g., in a script), or by the operation of other
40
  PsyNeuLink Components (e.g., `ModulatoryMechanisms`) by way of `ControlProjections <ControlProjection>`.
41
..
42
* **Modularity** -- by providing a standard interface, any Function assigned to a Components in PsyNeuLink can be
43
  replaced with other PsyNeuLink Functions, or with user-written custom functions so long as they adhere to certain
44
  standards (the PsyNeuLink `Function API <LINK>`).
45

46
.. _Function_Creation:
47

48
Creating a Function
49
-------------------
50

51
A Function can be created directly by calling its constructor.  Functions are also created automatically whenever
52
any other type of PsyNeuLink Component is created (and its :keyword:`function` is not otherwise specified). The
53
constructor for a Function has an argument for its `variable <Function_Base.variable>` and each of the parameters of
54
its `function <Function_Base.function>`.  The `variable <Function_Base.variable>` argument is used both to format the
55
input to the `function <Function_Base.function>`, and assign its default value.  The arguments for each parameter can
56
be used to specify the default value for that parameter; the values can later be modified in various ways as described
57
below.
58

59
.. _Function_Structure:
60

61
Structure
62
---------
63

64
.. _Function_Core_Attributes:
65

66
*Core Attributes*
67
~~~~~~~~~~~~~~~~~
68

69
Every Function has the following core attributes:
70

71
* `variable <Function_Base.variable>` -- provides the input to the Function's `function <Function_Base.function>`.
72
..
73
* `function <Function_Base.function>` -- determines the computation carried out by the Function; it must be a
74
  callable object (that is, a python function or method of some kind). Unlike other PsyNeuLink `Components
75
  <Component>`, it *cannot* be (another) Function object (it can't be "turtles" all the way down!).
76

77
A Function also has an attribute for each of the parameters of its `function <Function_Base.function>`.
78

79
*Owner*
80
~~~~~~~
81

82
If a Function has been assigned to another `Component`, then it also has an `owner <Function_Base.owner>` attribute
83
that refers to that Component.  The Function itself is assigned as the Component's
84
`function <Component.function>` attribute.  Each of the Function's attributes is also assigned
85
as an attribute of the `owner <Function_Base.owner>`, and those are each associated with with a
86
`parameterPort <ParameterPort>` of the `owner <Function_Base.owner>`.  Projections to those parameterPorts can be
87
used by `ControlProjections <ControlProjection>` to modify the Function's parameters.
88

89

90
COMMENT:
91
.. _Function_Output_Type_Conversion:
92

93
If the `function <Function_Base.function>` returns a single numeric value, and the Function's class implements
94
FunctionOutputTypeConversion, then the type of value returned by its `function <Function>` can be specified using the
95
`output_type` attribute, by assigning it one of the following `FunctionOutputType` values:
96
    * FunctionOutputType.RAW_NUMBER: return "exposed" number;
97
    * FunctionOutputType.NP_1D_ARRAY: return 1d np.array
98
    * FunctionOutputType.NP_2D_ARRAY: return 2d np.array.
99

100
To implement FunctionOutputTypeConversion, the Function's FUNCTION_OUTPUT_TYPE_CONVERSION parameter must set to True,
101
and function type conversion must be implemented by its `function <Function_Base.function>` method
102
(see `Linear` for an example).
103
COMMENT
104

105
.. _Function_Modulatory_Params:
106

107
*Modulatory Parameters*
108
~~~~~~~~~~~~~~~~~~~~~~~
109

110
Some classes of Functions also implement a pair of modulatory parameters: `multiplicative_param` and `additive_param`.
111
Each of these is assigned the name of one of the function's parameters. These are used by `ModulatorySignals
112
<ModulatorySignal>` to modulate the `function <Port_Base.function>` of a `Port <Port>` and thereby its `value
113
<Port_Base.value>` (see `ModulatorySignal_Modulation` and `figure <ModulatorySignal_Detail_Figure>` for additional
114
details). For example, a `ControlSignal` typically uses the `multiplicative_param` to modulate the value of a parameter
115
of a Mechanism's `function <Mechanism_Base.function>`, whereas a `LearningSignal` uses the `additive_param` to increment
116
the `value <ParamterPort.value>` of the `matrix <MappingProjection.matrix>` parameter of a `MappingProjection`.
117

118
COMMENT:
119
FOR DEVELOPERS:  'multiplicative_param` and `additive_param` are implemented as aliases to the relevant
120
parameters of a given Function, declared in its Parameters subclass declaration of the Function's declaration.
121
COMMENT
122

123

124
.. _Function_Execution:
125

126
Execution
127
---------
128

129
Functions are executable objects that can be called directly.  More commonly, however, they are called when
130
their `owner <Function_Base.owner>` is executed.  The parameters
131
of the `function <Function_Base.function>` can be modified when it is executed, by assigning a
132
`parameter specification dictionary <ParameterPort_Specification>` to the **params** argument in the
133
call to the `function <Function_Base.function>`.
134

135
For `Mechanisms <Mechanism>`, this can also be done by specifying `runtime_params <Composition_Runtime_Params>` in the
136
`Run` method of their `Composition`.
137

138
Class Reference
139
---------------
140

141
"""
142

143
import abc
1✔
144
import inspect
1✔
145
import numbers
1✔
146
import types
1✔
147
import warnings
1✔
148
from enum import Enum, IntEnum
1✔
149

150
import numpy as np
1✔
151
try:
1✔
152
    import torch
1✔
NEW
153
except ImportError:
×
NEW
154
    torch = None
×
155
from beartype import beartype
1✔
156

157
from psyneulink._typing import Optional, Union, Callable
1✔
158

159
from psyneulink.core.components.component import Component, ComponentError, DefaultsFlexibility
1✔
160
from psyneulink.core.components.shellclasses import Function, Mechanism
1✔
161
from psyneulink.core.globals.context import ContextFlags, handle_external_context
1✔
162
from psyneulink.core.globals.keywords import (
1✔
163
    ARGUMENT_THERAPY_FUNCTION, AUTO_ASSIGN_MATRIX, EXAMPLE_FUNCTION_TYPE, FULL_CONNECTIVITY_MATRIX,
164
    FUNCTION_COMPONENT_CATEGORY, FUNCTION_OUTPUT_TYPE, FUNCTION_OUTPUT_TYPE_CONVERSION, HOLLOW_MATRIX,
165
    IDENTITY_MATRIX, INVERSE_HOLLOW_MATRIX, NAME, PREFERENCE_SET_NAME, RANDOM_CONNECTIVITY_MATRIX, VALUE, VARIABLE,
166
    MODEL_SPEC_ID_MDF_VARIABLE, MatrixKeywordLiteral, ZEROS_MATRIX
167
)
168
from psyneulink.core.globals.mdf import _get_variable_parameter_name
1✔
169
from psyneulink.core.globals.parameters import Parameter, check_user_specified
1✔
170
from psyneulink.core.globals.preferences.basepreferenceset import REPORT_OUTPUT_PREF, ValidPrefSet
1✔
171
from psyneulink.core.globals.preferences.preferenceset import PreferenceEntry, PreferenceLevel
1✔
172
from psyneulink.core.globals.registry import register_category
1✔
173
from psyneulink.core.globals.utilities import (
1✔
174
    convert_to_np_array, get_global_seed, is_instance_or_subclass, object_has_single_value, parameter_spec, parse_valid_identifier, safe_len,
175
    SeededRandomState, contains_type, is_numeric, NumericCollections,
176
    random_matrix
177
)
178

179
__all__ = [
1✔
180
    'ArgumentTherapy', 'EPSILON', 'Function_Base', 'function_keywords', 'FunctionError', 'FunctionOutputType',
181
    'FunctionRegistry', 'get_param_value_for_function', 'get_param_value_for_keyword', 'is_Function',
182
    'is_function_type', 'PERTINACITY', 'PROPENSITY', 'RandomMatrix'
183
]
184

185
EPSILON = np.finfo(float).eps
1✔
186
# numeric to allow modulation, invalid to identify unseeded state
187
DEFAULT_SEED = -1
1✔
188

189
FunctionRegistry = {}
1✔
190

191
function_keywords = {FUNCTION_OUTPUT_TYPE, FUNCTION_OUTPUT_TYPE_CONVERSION}
1✔
192

193

194
class FunctionError(ComponentError):
1✔
195
    pass
1✔
196

197

198
class FunctionOutputType(IntEnum):
1✔
199
    RAW_NUMBER = 0
1✔
200
    NP_1D_ARRAY = 1
1✔
201
    NP_2D_ARRAY = 2
1✔
202
    DEFAULT = 3
1✔
203

204

205
# Typechecking *********************************************************************************************************
206

207
# TYPE_CHECK for Function Instance or Class
208
def is_Function(x):
1✔
209
    if not x:
×
210
        return False
×
211
    elif isinstance(x, Function):
×
212
        return True
×
213
    elif issubclass(x, Function):
×
214
        return True
×
215
    else:
216
        return False
×
217

218

219
def is_function_type(x):
1✔
220
    if callable(x):
1✔
221
        return True
1✔
222
    elif not x:
1!
223
        return False
×
224
    elif isinstance(x, (Function, types.FunctionType, types.MethodType, types.BuiltinFunctionType, types.BuiltinMethodType)):
1!
225
        return True
×
226
    elif isinstance(x, type) and issubclass(x, Function):
1!
227
        return True
×
228
    else:
229
        return False
1✔
230

231
# *******************************   get_param_value_for_keyword ********************************************************
232

233
def get_param_value_for_keyword(owner, keyword):
1✔
234
    """Return the value for a keyword used by a subclass of Function
235

236
    Parameters
237
    ----------
238
    owner : Component
239
    keyword : str
240

241
    Returns
242
    -------
243
    value
244

245
    """
246
    try:
1✔
247
        return owner.function.keyword(owner, keyword)
1✔
248
    except FunctionError as e:
1!
249
        # assert(False)
250
        # prefs is not always created when this is called, so check
251
        try:
×
252
            owner.prefs
×
253
            has_prefs = True
×
254
        except AttributeError:
×
255
            has_prefs = False
×
256

257
        if has_prefs and owner.prefs.verbosePref:
×
258
            print("{} of {}".format(e, owner.name))
×
259
        # return None
260
        else:
261
            raise FunctionError(e)
262
    except AttributeError:
1✔
263
        # prefs is not always created when this is called, so check
264
        try:
1✔
265
            owner.prefs
1✔
266
            has_prefs = True
1✔
267
        except AttributeError:
1✔
268
            has_prefs = False
1✔
269

270
        if has_prefs and owner.prefs.verbosePref:
1!
271
            print("Keyword ({}) not recognized for {}".format(keyword, owner.name))
×
272
        return None
1✔
273

274

275
def get_param_value_for_function(owner, function):
1✔
276
    try:
×
277
        return owner.function.param_function(owner, function)
×
278
    except FunctionError as e:
×
279
        if owner.prefs.verbosePref:
×
280
            print("{} of {}".format(e, owner.name))
×
281
        return None
×
282
    except AttributeError:
×
283
        if owner.prefs.verbosePref:
×
284
            print("Function ({}) can't be evaluated for {}".format(function, owner.name))
×
285
        return None
×
286

287
# Parameter Mixins *****************************************************************************************************
288

289
# KDM 6/21/18: Below is left in for consideration; doesn't really gain much to justify relaxing the assumption
290
# that every Parameters class has a single parent
291

292
# class ScaleOffsetParamMixin:
293
#     scale = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
294
#     offset = Parameter(1.0, modulable=True, aliases=[ADDITIVE_PARAM])
295

296

297
# Function Definitions *************************************************************************************************
298

299

300
# KDM 8/9/18: below is added for future use when function methods are completely functional
301
# used as a decorator for Function methods
302
# def enable_output_conversion(func):
303
#     @functools.wraps(func)
304
#     def wrapper(*args, **kwargs):
305
#         result = func(*args, **kwargs)
306
#         return convert_output_type(result)
307
#     return wrapper
308

309
# this should eventually be moved to a unified validation method
310
def _output_type_setter(value, owning_component):
1✔
311
    # Can't convert from arrays of length > 1 to number
312
    if (
×
313
        owning_component.defaults.variable is not None
314
        and safe_len(owning_component.defaults.variable) > 1
315
        and owning_component.output_type is FunctionOutputType.RAW_NUMBER
316
    ):
317
        raise FunctionError(
318
            f"{owning_component.__class__.__name__} can't be set to return a "
319
            "single number since its variable has more than one number."
320
        )
321

322
    # warn if user overrides the 2D setting for mechanism functions
323
    # may be removed when
324
    # https://github.com/PrincetonUniversity/PsyNeuLink/issues/895 is solved
325
    # properly(meaning Mechanism values may be something other than 2D np array)
326
    try:
×
327
        if (
×
328
            isinstance(owning_component.owner, Mechanism)
329
            and (
330
                value == FunctionOutputType.RAW_NUMBER
331
                or value == FunctionOutputType.NP_1D_ARRAY
332
            )
333
        ):
334
            warnings.warn(
×
335
                f'Functions that are owned by a Mechanism but do not return a '
336
                '2D numpy array may cause unexpected behavior if llvm '
337
                'compilation is enabled.'
338
            )
339
    except (AttributeError, ImportError):
×
340
        pass
×
341

342
    return value
×
343

344

345
def _seed_setter(value, owning_component, context):
1✔
346
    if value in {None, DEFAULT_SEED}:
1✔
347
        value = get_global_seed()
1✔
348

349
    # Remove any old PRNG state
350
    owning_component.parameters.random_state.set(None, context=context)
1✔
351
    return int(value)
1✔
352

353

354
def _random_state_getter(self, owning_component, context):
1✔
355

356
    seed_param = owning_component.parameters.seed
1✔
357
    try:
1✔
358
        is_modulated = seed_param.port.is_modulated(context)
1✔
359
    except AttributeError:
1✔
360
        is_modulated = False
1✔
361

362
    if is_modulated:
1✔
363
        seed_value = [int(owning_component._get_current_parameter_value(seed_param, context))]
1✔
364
    else:
365
        seed_value = [int(seed_param._get(context=context))]
1✔
366

367
    if seed_value == [DEFAULT_SEED]:
1✔
368
        raise FunctionError(
369
            "Invalid seed for {} in context: {} ({})".format(
370
                owning_component, context.execution_id, seed_param
371
            )
372
        )
373

374
    current_state = self.values.get(context.execution_id, None)
1✔
375
    if current_state is None:
1✔
376
        return SeededRandomState(seed_value)
1✔
377
    if current_state.used_seed != seed_value:
1✔
378
        return type(current_state)(seed_value)
1✔
379

380
    return current_state
1✔
381

382

383
def _noise_setter(value, owning_component, context):
1✔
384
    def has_function(x):
1✔
385
        return (
1✔
386
            is_instance_or_subclass(x, (Function_Base, types.FunctionType))
387
            or contains_type(x, (Function_Base, types.FunctionType))
388
        )
389

390
    noise_param = owning_component.parameters.noise
1✔
391
    value_has_function = has_function(value)
1✔
392
    # initial set
393
    if owning_component.is_initializing:
1✔
394
        if value_has_function:
1✔
395
            # is changing a parameter attribute like this ok?
396
            noise_param.stateful = False
1✔
397
    else:
398
        default_value_has_function = has_function(noise_param.default_value)
1✔
399

400
        if default_value_has_function and not value_has_function:
1✔
401
            warnings.warn(
1✔
402
                'Setting noise to a numeric value after instantiation'
403
                ' with a value containing functions will not remove the'
404
                ' noise ParameterPort or make noise stateful.'
405
            )
406
        elif not default_value_has_function and value_has_function:
1✔
407
            warnings.warn(
1✔
408
                'Setting noise to a value containing functions after'
409
                ' instantiation with a numeric value will not create a'
410
                ' noise ParameterPort or make noise stateless.'
411
            )
412

413
    return value
1✔
414

415

416
class Function_Base(Function):
1✔
417
    """
418
    Function_Base(           \
419
         default_variable,   \
420
         params=None,        \
421
         owner=None,         \
422
         name=None,          \
423
         prefs=None          \
424
    )
425

426
    Implement abstract class for Function category of Component class
427

428
    COMMENT:
429
        Description:
430
            Functions are used to "wrap" functions used used by other components;
431
            They are defined here (on top of standard libraries) to provide a uniform interface for managing parameters
432
             (including defaults)
433
            NOTE:   the Function category definition serves primarily as a shell, and as an interface to the Function
434
                       class, to maintain consistency of structure with the other function categories;
435
                    it also insures implementation of .function for all Function Components
436
                    (as distinct from other Function subclasses, which can use a FUNCTION param
437
                        to implement .function instead of doing so directly)
438
                    Function Components are the end of the recursive line; as such:
439
                        they don't implement functionParams
440
                        in general, don't bother implementing function, rather...
441
                        they rely on Function_Base.function which passes on the return value of .function
442

443
        Variable and Parameters:
444
        IMPLEMENTATION NOTE:  ** DESCRIBE VARIABLE HERE AND HOW/WHY IT DIFFERS FROM PARAMETER
445
            - Parameters can be assigned and/or changed individually or in sets, by:
446
              - including them in the initialization call
447
              - calling the _instantiate_defaults method (which changes their default values)
448
              - including them in a call the function method (which changes their values for just for that call)
449
            - Parameters must be specified in a params dictionary:
450
              - the key for each entry should be the name of the parameter (used also to name associated Projections)
451
              - the value for each entry is the value of the parameter
452

453
        Return values:
454
            The output_type can be used to specify type conversion for single-item return values:
455
            - it can only be used for numbers or a single-number list; other values will generate an exception
456
            - if self.output_type is set to:
457
                FunctionOutputType.RAW_NUMBER, return value is "exposed" as a number
458
                FunctionOutputType.NP_1D_ARRAY, return value is 1d np.array
459
                FunctionOutputType.NP_2D_ARRAY, return value is 2d np.array
460
            - it must be enabled for a subclass by setting params[FUNCTION_OUTPUT_TYPE_CONVERSION] = True
461
            - it must be implemented in the execute method of the subclass
462
            - see Linear for an example
463

464
        MechanismRegistry:
465
            All Function functions are registered in FunctionRegistry, which maintains a dict for each subclass,
466
              a count for all instances of that type, and a dictionary of those instances
467

468
        Naming:
469
            Function functions are named by their componentName attribute (usually = componentType)
470

471
        Class attributes:
472
            + componentCategory: FUNCTION_COMPONENT_CATEGORY
473
            + className (str): kwMechanismFunctionCategory
474
            + suffix (str): " <className>"
475
            + registry (dict): FunctionRegistry
476
            + classPreference (PreferenceSet): BasePreferenceSet, instantiated in __init__()
477
            + classPreferenceLevel (PreferenceLevel): PreferenceLevel.CATEGORY
478

479
        Class methods:
480
            none
481

482
        Instance attributes:
483
            + componentType (str):  assigned by subclasses
484
            + componentName (str):   assigned by subclasses
485
            + variable (value) - used as input to function's execute method
486
            + value (value) - output of execute method
487
            + name (str) - if not specified as an arg, a default based on the class is assigned in register_category
488
            + prefs (PreferenceSet) - if not specified as an arg, default is created by copying BasePreferenceSet
489

490
        Instance methods:
491
            The following method MUST be overridden by an implementation in the subclass:
492
            - execute(variable, params)
493
            The following can be implemented, to customize validation of the function variable and/or params:
494
            - [_validate_variable(variable)]
495
            - [_validate_params(request_set, target_set, context)]
496
    COMMENT
497

498
    Arguments
499
    ---------
500

501
    variable : value : default class_defaults.variable
502
        specifies the format and a default value for the input to `function <Function>`.
503

504
    params : Dict[param keyword: param value] : default None
505
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
506
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
507
        arguments of the constructor.
508

509
    owner : Component
510
        `component <Component>` to which to assign the Function.
511

512
    name : str : default see `name <Function.name>`
513
        specifies the name of the Function.
514

515
    prefs : PreferenceSet or specification dict : default Function.classPreferences
516
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
517

518

519
    Attributes
520
    ----------
521

522
    variable: value
523
        format and default value can be specified by the :keyword:`variable` argument of the constructor;  otherwise,
524
        they are specified by the Function's :keyword:`class_defaults.variable`.
525

526
    function : function
527
        called by the Function's `owner <Function_Base.owner>` when it is executed.
528

529
    COMMENT:
530
    enable_output_type_conversion : Bool : False
531
        specifies whether `function output type conversion <Function_Output_Type_Conversion>` is enabled.
532

533
    output_type : FunctionOutputType : None
534
        used to determine the return type for the `function <Function_Base.function>`;  `functionOuputTypeConversion`
535
        must be enabled and implemented for the class (see `FunctionOutputType <Function_Output_Type_Conversion>`
536
        for details).
537

538
    changes_shape : bool : False
539
        specifies whether the return value of the function is different than the shape of its `variable <Function_Base.variable>.  Used to determine whether the shape of the inputs to the `Component` to which the function is assigned should be based on the `variable <Function_Base.variable>` of the function or its `value <Function.value>`.
540
    COMMENT
541

542
    owner : Component
543
        `component <Component>` to which the Function has been assigned.
544

545
    name : str
546
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
547
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
548

549
    prefs : PreferenceSet or specification dict : Function.classPreferences
550
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
551
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
552
        for details).
553

554
    """
555

556
    componentCategory = FUNCTION_COMPONENT_CATEGORY
1✔
557
    className = componentCategory
1✔
558
    suffix = " " + className
1✔
559

560
    registry = FunctionRegistry
1✔
561

562
    classPreferenceLevel = PreferenceLevel.CATEGORY
1✔
563

564
    _model_spec_id_parameters = 'args'
1✔
565
    _mdf_stateful_parameter_indices = {}
1✔
566

567
    _specified_variable_shape_flexibility = DefaultsFlexibility.INCREASE_DIMENSION
1✔
568

569
    class Parameters(Function.Parameters):
1✔
570
        """
571
            Attributes
572
            ----------
573

574
                variable
575
                    see `variable <Function_Base.variable>`
576

577
                    :default value: numpy.array([0])
578
                    :type: ``numpy.ndarray``
579
                    :read only: True
580

581
                enable_output_type_conversion
582
                    see `enable_output_type_conversion <Function_Base.enable_output_type_conversion>`
583

584
                    :default value: False
585
                    :type: ``bool``
586

587
                changes_shape
588
                    see `changes_shape <Function_Base.changes_shape>`
589

590
                    :default value: False
591
                    :type: bool
592

593
                output_type
594
                    see `output_type <Function_Base.output_type>`
595

596
                    :default value: FunctionOutputType.DEFAULT
597
                    :type: `FunctionOutputType`
598

599
        """
600
        variable = Parameter(np.array([0]), read_only=True, pnl_internal=True, constructor_argument='default_variable')
1✔
601

602
        output_type = Parameter(
1✔
603
            FunctionOutputType.DEFAULT,
604
            stateful=False,
605
            loggable=False,
606
            pnl_internal=True,
607
            valid_types=FunctionOutputType
608
        )
609
        enable_output_type_conversion = Parameter(False, stateful=False, loggable=False, pnl_internal=True)
1✔
610

611
        changes_shape = Parameter(False, stateful=False, loggable=False, pnl_internal=True)
1✔
612
        def _validate_changes_shape(self, param):
1✔
613
            if not isinstance(param, bool):
1!
614
                return f'must be a bool.'
×
615

616
    # Note: the following enforce encoding as 1D np.ndarrays (one array per variable)
617
    variableEncodingDim = 1
1✔
618

619
    @check_user_specified
1✔
620
    @abc.abstractmethod
1✔
621
    def __init__(
1✔
622
        self,
623
        default_variable,
624
        params,
625
        owner=None,
626
        name=None,
627
        prefs=None,
628
        context=None,
629
        **kwargs
630
    ):
631
        """Assign category-level preferences, register category, and call super.__init__
632

633
        Initialization arguments:
634
        - default_variable (anything): establishes type for the variable, used for validation
635
        Note: if parameter_validation is off, validation is suppressed (for efficiency) (Function class default = on)
636

637
        :param default_variable: (anything but a dict) - value to assign as self.defaults.variable
638
        :param params: (dict) - params to be assigned as instance defaults
639
        :param log: (ComponentLog enum) - log entry types set in self.componentLog
640
        :param name: (string) - optional, overrides assignment of default (componentName of subclass)
641
        :return:
642
        """
643

644
        if self.initialization_status == ContextFlags.DEFERRED_INIT:
1!
645
            self._assign_deferred_init_name(name)
×
646
            self._init_args[NAME] = name
×
647
            return
×
648

649
        register_category(entry=self,
1✔
650
                          base_class=Function_Base,
651
                          registry=FunctionRegistry,
652
                          name=name,
653
                          )
654
        self.owner = owner
1✔
655

656
        super().__init__(
1✔
657
            default_variable=default_variable,
658
            param_defaults=params,
659
            name=name,
660
            prefs=prefs,
661
            **kwargs
662
        )
663

664
    def __call__(self, *args, **kwargs):
1✔
665
        return self.function(*args, **kwargs)
1✔
666

667
    def __deepcopy__(self, memo):
1✔
668
        new = super().__deepcopy__(memo)
1✔
669
        # ensure copy does not have identical name
670
        register_category(new, Function_Base, new.name, FunctionRegistry)
1✔
671
        if "random_state" in new.parameters:
1✔
672
            # HACK: Make sure any copies are re-seeded to avoid dependent RNG.
673
            # functions with "random_state" param must have "seed" parameter
674
            for ctx in new.parameters.seed.values:
1✔
675
                new.parameters.seed.set(
1✔
676
                    DEFAULT_SEED, ctx, skip_log=True, skip_history=True
677
                )
678

679
        return new
1✔
680

681
    @handle_external_context()
1✔
682
    def function(self,
1✔
683
                 variable=None,
684
                 context=None,
685
                 params=None,
686
                 target_set=None,
687
                 **kwargs):
688

689
        # IMPLEMENTATION NOTE:
690
        # The following is a convenience feature that supports specification of params directly in call to function
691
        # by moving the to a params dict, which treats them as runtime_params
692
        if kwargs:
1✔
693
            for key in kwargs.copy():
1✔
694
                if key in self.parameters.names():
1✔
695
                    if not params:
1✔
696
                        params = {key: kwargs.pop(key)}
1✔
697
                    else:
698
                        params.update({key: kwargs.pop(key)})
1✔
699

700
        # Validate variable and assign to variable, and validate params
701
        variable = self._check_args(variable=variable,
1✔
702
                                    context=context,
703
                                    params=params,
704
                                    target_set=target_set,
705
                                    )
706
        # Execute function
707
        try:
1✔
708
            value = self._function(variable=variable,
1✔
709
                                   context=context,
710
                                   params=params,
711
                                   **kwargs)
712
        except ValueError as err:
1✔
713
            err_msg = f"Problem with '{self}' in '{self.owner.name if self.owner else self.__class__.__name__}': {err}"
×
714
            raise FunctionError(err_msg) from err
715
        self.most_recent_context = context
1✔
716
        self.parameters.value._set(value, context=context)
1✔
717
        self._reset_runtime_parameters(context)
1✔
718
        return value
1✔
719

720
    @abc.abstractmethod
1✔
721
    def _function(
1✔
722
        self,
723
        variable=None,
724
        context=None,
725
        params=None,
726

727
    ):
728
        pass
×
729

730
    def _parse_arg_generic(self, arg_val):
1✔
731
        if isinstance(arg_val, list):
×
732
            return np.asarray(arg_val)
×
733
        else:
734
            return arg_val
×
735

736
    def _validate_parameter_spec(self, param, param_name, numeric_only=True):
1✔
737
        """Validates function param
738
        Replace direct call to parameter_spec in tc, which seems to not get called by Function __init__()'s
739
        """
740
        if not parameter_spec(param, numeric_only):
1!
741
            owner_name = 'of ' + self.owner_name if self.owner else ""
×
742
            raise FunctionError(f"{param} is not a valid specification for "
743
                                f"the {param_name} argument of {self.__class__.__name__}{owner_name}.")
744

745
    def _get_current_parameter_value(self, param_name, context=None):
1✔
746
        try:
1✔
747
            param = getattr(self.parameters, param_name)
1✔
748
        except TypeError:
1!
749
            param = param_name
1✔
750
        except AttributeError:
×
751
            # don't accept strings that don't correspond to Parameters
752
            # on this function
753
            raise
×
754

755
        return super()._get_current_parameter_value(param, context)
1✔
756

757
    def get_previous_value(self, context=None):
1✔
758
        # temporary method until previous values are integrated for all parameters
759
        value = self.parameters.previous_value._get(context)
1✔
760

761
        return value
1✔
762

763
    def convert_output_type(self, value, output_type=None):
1✔
764
        if output_type is None:
1✔
765
            if not self.enable_output_type_conversion or self.output_type is None:
1✔
766
                return value
1✔
767
            else:
768
                output_type = self.output_type
1✔
769

770
        value = convert_to_np_array(value)
1✔
771

772
        # Type conversion (specified by output_type):
773

774
        # MODIFIED 6/21/19 NEW: [JDC]
775
        # Convert to same format as variable
776
        if isinstance(output_type, (list, np.ndarray)):
1✔
777
            shape = np.array(output_type).shape
1✔
778
            return np.array(value).reshape(shape)
1✔
779
        # MODIFIED 6/21/19 END
780

781
        # Convert to 2D array, irrespective of value type:
782
        if output_type is FunctionOutputType.NP_2D_ARRAY:
1✔
783
            # KDM 8/10/18: mimicking the conversion that Mechanism does to its values, because
784
            # this is what we actually wanted this method for. Can be changed to pure 2D np array in
785
            # future if necessary
786

787
            converted_to_2d = np.atleast_2d(value)
1✔
788
            # If return_value is a list of heterogenous elements, return as is
789
            #     (satisfies requirement that return_value be an array of possibly multidimensional values)
790
            if converted_to_2d.dtype == object:
1✔
791
                pass
1✔
792
            # Otherwise, return value converted to 2d np.array
793
            else:
794
                value = converted_to_2d
1✔
795

796
        # Convert to 1D array, irrespective of value type:
797
        # Note: if 2D array (or higher) has more than two items in the outer dimension, generate exception
798
        elif output_type is FunctionOutputType.NP_1D_ARRAY:
1✔
799
            # If variable is 2D
800
            if value.ndim >= 2:
1✔
801
                # If there is only one item:
802
                if len(value) == 1:
1✔
803
                    value = value[0]
1✔
804
                else:
805
                    raise FunctionError(f"Can't convert value ({value}: 2D np.ndarray object "
806
                                        f"with more than one array) to 1D array.")
807
            elif value.ndim == 1:
1!
808
                pass
1✔
809
            elif value.ndim == 0:
×
810
                value = np.atleast_1d(value)
×
811
            else:
812
                raise FunctionError(f"Can't convert value ({value} to 1D array.")
813

814
        # Convert to raw number, irrespective of value type:
815
        # Note: if 2D or 1D array has more than two items, generate exception
816
        elif output_type is FunctionOutputType.RAW_NUMBER:
1!
817
            if object_has_single_value(value):
1✔
818
                value = float(value)
1✔
819
            else:
820
                raise FunctionError(f"Can't convert value ({value}) with more than a single number to a raw number.")
821

822
        return value
1✔
823

824
    @property
1✔
825
    def owner_name(self):
1✔
826
        try:
1✔
827
            return self.owner.name
1✔
828
        except AttributeError:
×
829
            return '<no owner>'
×
830

831
    def _is_identity(self, context=None, defaults=False):
1✔
832
        # should return True in subclasses if the parameters for context are such that
833
        # the Function's output will be the same as its input
834
        # Used to bypass execute when unnecessary
835
        return False
1✔
836

837
    @property
1✔
838
    def _model_spec_parameter_blacklist(self):
1✔
839
        return super()._model_spec_parameter_blacklist.union({
1✔
840
            'multiplicative_param', 'additive_param',
841
        })
842

843
    def _assign_to_mdf_model(self, model, input_id) -> str:
1✔
844
        """Adds an MDF representation of this function to MDF object
845
        **model**, including all necessary auxiliary functions.
846
        **input_id** is the input to the singular MDF function or first
847
        function representing this psyneulink Function, if applicable.
848

849
        Returns:
850
            str: the identifier of the final MDF function representing
851
            this psyneulink Function
852
        """
853
        import modeci_mdf.mdf as mdf
1✔
854

855
        extra_noise_functions = []
1✔
856

857
        self_model = self.as_mdf_model()
1✔
858

859
        def handle_noise(noise):
1✔
860
            if is_instance_or_subclass(noise, Component):
1✔
861
                if inspect.isclass(noise) and issubclass(noise, Component):
1!
862
                    noise = noise()
×
863
                noise_func_model = noise.as_mdf_model()
1✔
864
                extra_noise_functions.append(noise_func_model)
1✔
865
                return noise_func_model.id
1✔
866
            elif isinstance(noise, (list, np.ndarray)):
1!
867
                return type(noise)(handle_noise(item) for item in noise)
×
868
            else:
869
                return None
1✔
870

871
        try:
1✔
872
            noise_val = handle_noise(self.defaults.noise)
1✔
873
        except AttributeError:
1✔
874
            noise_val = None
1✔
875

876
        if noise_val is not None:
1✔
877
            noise_func = mdf.Function(
1✔
878
                id=f'{model.id}_{parse_valid_identifier(self.name)}_noise',
879
                value=MODEL_SPEC_ID_MDF_VARIABLE,
880
                args={MODEL_SPEC_ID_MDF_VARIABLE: noise_val},
881
            )
882
            self._set_mdf_arg(self_model, 'noise', noise_func.id)
1✔
883

884
            model.functions.extend(extra_noise_functions)
1✔
885
            model.functions.append(noise_func)
1✔
886

887
        self_model.id = f'{model.id}_{self_model.id}'
1✔
888
        self._set_mdf_arg(self_model, _get_variable_parameter_name(self), input_id)
1✔
889
        model.functions.append(self_model)
1✔
890

891
        # assign stateful parameters
892
        for name, index in self._mdf_stateful_parameter_indices.items():
1✔
893
            # in this case, parameter gets updated to its function's final value
894
            param = getattr(self.parameters, name)
1✔
895

896
            try:
1✔
897
                initializer_value = self_model.args[param.initializer]
1✔
898
            except KeyError:
1✔
899
                initializer_value = self_model.metadata[param.initializer]
1✔
900

901
            index_str = f'[{index}]' if index is not None else ''
1✔
902

903
            model.parameters.append(
1✔
904
                mdf.Parameter(
905
                    id=param.mdf_name if param.mdf_name is not None else param.name,
906
                    default_initial_value=initializer_value,
907
                    value=f'{self_model.id}{index_str}'
908
                )
909
            )
910

911
        return self_model.id
1✔
912

913
    def as_mdf_model(self):
1✔
914
        import modeci_mdf.mdf as mdf
1✔
915
        import modeci_mdf.functions.standard as mdf_functions
1✔
916

917
        parameters = self._mdf_model_parameters
1✔
918
        metadata = self._mdf_metadata
1✔
919
        stateful_params = set()
1✔
920

921
        # add stateful parameters into metadata for mechanism to get
922
        for name in parameters[self._model_spec_id_parameters]:
1✔
923
            try:
1✔
924
                param = getattr(self.parameters, name)
1✔
925
            except AttributeError:
1✔
926
                continue
1✔
927

928
            if param.initializer is not None:
1✔
929
                stateful_params.add(name)
1✔
930

931
        # stateful parameters cannot show up as args or they will not be
932
        # treated statefully in mdf
933
        for sp in stateful_params:
1✔
934
            del parameters[self._model_spec_id_parameters][sp]
1✔
935

936
        model = mdf.Function(
1✔
937
            id=parse_valid_identifier(self.name),
938
            **parameters,
939
            **metadata,
940
        )
941

942
        try:
1✔
943
            model.value = self.as_expression()
1✔
944
        except AttributeError:
1✔
945
            if self._model_spec_generic_type_name is not NotImplemented:
1✔
946
                typ = self._model_spec_generic_type_name
1✔
947
            else:
948
                try:
1✔
949
                    typ = self.custom_function.__name__
1✔
950
                except AttributeError:
1✔
951
                    typ = type(self).__name__.lower()
1✔
952

953
            if typ not in mdf_functions.mdf_functions:
1✔
954
                warnings.warn(f'{typ} is not an MDF standard function, this is likely to produce an incompatible model.')
1✔
955

956
            model.function = typ
1✔
957

958
        return model
1✔
959

960
    def _get_pytorch_fct_param_value(self, param_name, device, context):
1✔
961
        """Return the current value of param_name for the function
962
         Use default value if not yet assigned
963
         Convert using torch.tensor if val is an array
964
        """
965
        val = self._get_current_parameter_value(param_name, context=context)
1✔
966
        if val is None:
1✔
967
            val = getattr(self.defaults, param_name)
1✔
968
        if isinstance(val, (str, type(None))):
1✔
969
            return val
1✔
970
        elif np.isscalar(np.array(val)):
1!
NEW
971
            return float(val)
×
972
        try:
1✔
973
            return torch.tensor(val, device=device).double()
1✔
NEW
974
        except Exception:
×
975
            assert False, (f"PROGRAM ERROR: unsupported value of parameter '{param_name}' ({val}) "
976
                           f"encountered in pytorch_function_creator().")
977

978

979
# *****************************************   EXAMPLE FUNCTION   *******************************************************
980
PROPENSITY = "PROPENSITY"
1✔
981
PERTINACITY = "PERTINACITY"
1✔
982

983

984
class ArgumentTherapy(Function_Base):
1✔
985
    """
986
    ArgumentTherapy(                   \
987
         variable,                     \
988
         propensity=Manner.CONTRARIAN, \
989
         pertinacity=10.0              \
990
         params=None,                  \
991
         owner=None,                   \
992
         name=None,                    \
993
         prefs=None                    \
994
         )
995

996
    .. _ArgumentTherapist:
997

998
    Return `True` or :keyword:`False` according to the manner of the therapist.
999

1000
    Arguments
1001
    ---------
1002

1003
    variable : boolean or statement that resolves to one : default class_defaults.variable
1004
        assertion for which a therapeutic response will be offered.
1005

1006
    propensity : Manner value : default Manner.CONTRARIAN
1007
        specifies preferred therapeutic manner
1008

1009
    pertinacity : float : default 10.0
1010
        specifies therapeutic consistency
1011

1012
    params : Dict[param keyword: param value] : default None
1013
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1014
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1015
        arguments of the constructor.
1016

1017
    owner : Component
1018
        `component <Component>` to which to assign the Function.
1019

1020
    name : str : default see `name <Function.name>`
1021
        specifies the name of the Function.
1022

1023
    prefs : PreferenceSet or specification dict : default Function.classPreferences
1024
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
1025

1026

1027
    Attributes
1028
    ----------
1029

1030
    variable : boolean
1031
        assertion to which a therapeutic response is made.
1032

1033
    propensity : Manner value : default Manner.CONTRARIAN
1034
        determines therapeutic manner:  tendency to agree or disagree.
1035

1036
    pertinacity : float : default 10.0
1037
        determines consistency with which the manner complies with the propensity.
1038

1039
    owner : Component
1040
        `component <Component>` to which the Function has been assigned.
1041

1042
    name : str
1043
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
1044
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
1045

1046
    prefs : PreferenceSet or specification dict : Function.classPreferences
1047
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
1048
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
1049
        for details).
1050

1051

1052
    """
1053

1054
    # Function componentName and type (defined at top of module)
1055
    componentName = ARGUMENT_THERAPY_FUNCTION
1✔
1056
    componentType = EXAMPLE_FUNCTION_TYPE
1✔
1057

1058
    classPreferences = {
1✔
1059
        PREFERENCE_SET_NAME: 'ExampleClassPreferences',
1060
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
1061
    }
1062

1063
    # Mode indicators
1064
    class Manner(Enum):
1✔
1065
        OBSEQUIOUS = 0
1✔
1066
        CONTRARIAN = 1
1✔
1067

1068
    # Parameter class defaults
1069
    # These are used both to type-cast the params, and as defaults if none are assigned
1070
    #  in the initialization call or later (using either _instantiate_defaults or during a function call)
1071

1072
    @check_user_specified
1✔
1073
    def __init__(self,
1✔
1074
                 default_variable=None,
1075
                 propensity=10.0,
1076
                 pertincacity=Manner.CONTRARIAN,
1077
                 params=None,
1078
                 owner=None,
1079
                 prefs:  Optional[ValidPrefSet] = None):
1080

1081
        super().__init__(
×
1082
            default_variable=default_variable,
1083
            propensity=propensity,
1084
            pertinacity=pertincacity,
1085
            params=params,
1086
            owner=owner,
1087
            prefs=prefs,
1088
        )
1089

1090
    def _validate_variable(self, variable, context=None):
1✔
1091
        """Validates variable and returns validated value
1092

1093
        This overrides the class method, to perform more detailed type checking
1094
        See explanation in class method.
1095
        Note: this method (or the class version) is called only if the parameter_validation attribute is `True`
1096

1097
        :param variable: (anything but a dict) - variable to be validated:
1098
        :param context: (str)
1099
        :return variable: - validated
1100
        """
1101

1102
        if type(variable) == type(self.class_defaults.variable) or \
×
1103
                (isinstance(variable, numbers.Number) and isinstance(self.class_defaults.variable, numbers.Number)):
1104
            return variable
×
1105
        else:
1106
            raise FunctionError(f"Variable must be {type(self.class_defaults.variable)}.")
1107

1108
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
1109
        """Validates variable and /or params and assigns to targets
1110

1111
        This overrides the class method, to perform more detailed type checking
1112
        See explanation in class method.
1113
        Note: this method (or the class version) is called only if the parameter_validation attribute is `True`
1114

1115
        :param request_set: (dict) - params to be validated
1116
        :param target_set: (dict) - destination of validated params
1117
        :return none:
1118
        """
1119

1120
        message = ""
×
1121

1122
        # Check params
1123
        for param_name, param_value in request_set.items():
×
1124

1125
            if param_name == PROPENSITY:
×
1126
                if isinstance(param_value, ArgumentTherapy.Manner):
×
1127
                    # target_set[self.PROPENSITY] = param_value
1128
                    pass  # This leaves param in request_set, clear to be assigned to target_set in call to super below
×
1129
                else:
1130
                    message = "Propensity must be of type Example.Mode"
×
1131
                continue
×
1132

1133
            # Validate param
1134
            if param_name == PERTINACITY:
×
1135
                if isinstance(param_value, numbers.Number) and 0 <= param_value <= 10:
×
1136
                    # target_set[PERTINACITY] = param_value
1137
                    pass  # This leaves param in request_set, clear to be assigned to target_set in call to super below
×
1138
                else:
1139
                    message += "Pertinacity must be a number between 0 and 10"
×
1140
                continue
×
1141

1142
        if message:
×
1143
            raise FunctionError(message)
1144

1145
        super()._validate_params(request_set, target_set, context)
×
1146

1147
    def _function(self,
1✔
1148
                 variable=None,
1149
                 context=None,
1150
                 params=None,
1151
                 ):
1152
        """
1153
        Returns a boolean that is (or tends to be) the same as or opposite the one passed in.
1154

1155
        Arguments
1156
        ---------
1157

1158
        variable : boolean : default class_defaults.variable
1159
           an assertion to which a therapeutic response is made.
1160

1161
        params : Dict[param keyword: param value] : default None
1162
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1163
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1164
            arguments of the constructor.
1165

1166

1167
        Returns
1168
        -------
1169

1170
        therapeutic response : boolean
1171

1172
        """
1173
        # Compute the function
1174
        statement = variable
×
1175
        propensity = self._get_current_parameter_value(PROPENSITY, context)
×
1176
        pertinacity = self._get_current_parameter_value(PERTINACITY, context)
×
1177
        whim = np.random.randint(-10, 10)
×
1178

1179
        if propensity == self.Manner.OBSEQUIOUS:
×
1180
            value = whim < pertinacity
×
1181

1182
        elif propensity == self.Manner.CONTRARIAN:
×
1183
            value = whim > pertinacity
×
1184

1185
        else:
1186
            raise FunctionError("This should not happen if parameter_validation == True;  check its value")
1187

1188
        return self.convert_output_type(value)
×
1189

1190

1191

1192
kwEVCAuxFunction = "EVC AUXILIARY FUNCTION"
1✔
1193
kwEVCAuxFunctionType = "EVC AUXILIARY FUNCTION TYPE"
1✔
1194
kwValueFunction = "EVC VALUE FUNCTION"
1✔
1195
CONTROL_SIGNAL_GRID_SEARCH_FUNCTION = "EVC CONTROL SIGNAL GRID SEARCH FUNCTION"
1✔
1196
CONTROLLER = 'controller'
1✔
1197

1198
class EVCAuxiliaryFunction(Function_Base):
1✔
1199
    """Base class for EVC auxiliary functions
1200
    """
1201
    componentType = kwEVCAuxFunctionType
1✔
1202

1203
    class Parameters(Function_Base.Parameters):
1✔
1204
        """
1205
            Attributes
1206
            ----------
1207

1208
                variable
1209
                    see `variable <Function_Base.variable>`
1210

1211
                    :default value: numpy.array([0])
1212
                    :type: numpy.ndarray
1213
                    :read only: True
1214

1215
        """
1216
        variable = Parameter(None, pnl_internal=True, constructor_argument='default_variable')
1✔
1217

1218
    classPreferences = {
1✔
1219
        PREFERENCE_SET_NAME: 'ValueFunctionCustomClassPreferences',
1220
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
1221
       }
1222

1223
    @check_user_specified
1✔
1224
    @beartype
1✔
1225
    def __init__(self,
1✔
1226
                 function,
1227
                 variable=None,
1228
                 params=None,
1229
                 owner=None,
1230
                 prefs:   Optional[ValidPrefSet] = None,
1231
                 context=None):
1232
        self.aux_function = function
×
1233

1234
        super().__init__(default_variable=variable,
×
1235
                         params=params,
1236
                         owner=owner,
1237
                         prefs=prefs,
1238
                         context=context,
1239
                         function=function,
1240
                         )
1241

1242

1243
class RandomMatrix():
1✔
1244
    """Function that returns matrix with random elements distributed uniformly around **center** across **range**.
1245

1246
    The **center** and **range** arguments are passed at construction, and used for all subsequent calls.
1247
    Once constructed, the function must be called with two floats, **sender_size** and **receiver_size**,
1248
    that specify the number of rows and columns of the matrix, respectively.
1249

1250
    Can be used to specify the `matrix <MappingProjection.matrix>` parameter of a `MappingProjection
1251
    <MappingProjection_Matrix_Specification>`, and to specify a default matrix for Projections in the
1252
    construction of a `Pathway` (see `Pathway_Specification_Projections`) or in a call to a Composition's
1253
    `add_linear_processing_pathway<Composition.add_linear_processing_pathway>` method.
1254

1255
    .. technical_note::
1256
       A call to the class calls `random_matrix <Utilities.random_matrix>`, passing **sender_size** and
1257
       **receiver_size** to `random_matrix <Utilities.random_matrix>` as its **num_rows** and **num_cols**
1258
       arguments, respectively, and passing the `center <RandomMatrix.offset>`-0.5 and `range <RandomMatrix.scale>`
1259
       attributes specified at construction to `random_matrix <Utilities.random_matrix>` as its **offset**
1260
       and **scale** arguments, respectively.
1261

1262
    Arguments
1263
    ----------
1264
    center : float
1265
        specifies the value around which the matrix elements are distributed in all calls to the function.
1266
    range : float
1267
        specifies range over which all matrix elements are distributed in all calls to the function.
1268

1269
    Attributes
1270
    ----------
1271
    center : float
1272
        determines the center of the distribution of the matrix elements;
1273
    range : float
1274
        determines the range of the distribution of the matrix elements;
1275
    """
1276

1277
    def __init__(self, center:float=0.0, range:float=1.0):
1✔
1278
        self.center=center
×
1279
        self.range=range
×
1280

1281
    def __call__(self, sender_size:int, receiver_size:int):
1✔
1282
        return random_matrix(sender_size, receiver_size, offset=self.center - 0.5, scale=self.range)
×
1283

1284

1285
def get_matrix(specification, rows=1, cols=1, context=None):
1✔
1286
    """Returns matrix conforming to specification with dimensions = rows x cols or None
1287

1288
     Specification can be a matrix keyword, filler value or np.ndarray
1289

1290
     Specification (validated in _validate_params):
1291
        + single number (used to fill self.matrix)
1292
        + matrix keyword:
1293
            + AUTO_ASSIGN_MATRIX: IDENTITY_MATRIX if it is square, othwerwise FULL_CONNECTIVITY_MATRIX
1294
            + IDENTITY_MATRIX: 1's on diagonal, 0's elsewhere (must be square matrix), otherwise generates error
1295
            + HOLLOW_MATRIX: 0's on diagonal, 1's elsewhere (must be square matrix), otherwise generates error
1296
            + INVERSE_HOLLOW_MATRIX: 0's on diagonal, -1's elsewhere (must be square matrix), otherwise generates error
1297
            + FULL_CONNECTIVITY_MATRIX: all 1's
1298
            + ZERO_MATRIX: all 0's
1299
            + RANDOM_CONNECTIVITY_MATRIX (random floats uniformly distributed between 0 and 1)
1300
            + RandomMatrix (random floats uniformly distributed around a specified center value with a specified range)
1301
        + 2D list or np.ndarray of numbers
1302

1303
     Returns 2D array with length=rows in dim 0 and length=cols in dim 1, or none if specification is not recognized
1304
    """
1305

1306
    # Matrix provided (and validated in _validate_params); convert to array
1307
    if isinstance(specification, (list, np.matrix)):
1✔
1308
        if is_numeric(specification):
1✔
1309
            return convert_to_np_array(specification)
1✔
1310
        else:
1311
            return
1✔
1312
        # MODIFIED 4/9/22 END
1313

1314
    if isinstance(specification, np.ndarray):
1✔
1315
        if specification.ndim == 2:
1✔
1316
            return specification
1✔
1317
        # FIX: MAKE THIS AN np.array WITH THE SAME DIMENSIONS??
1318
        elif specification.ndim < 2:
1✔
1319
            return np.atleast_2d(specification)
×
1320
        else:
1321
            raise FunctionError("Specification of np.array for matrix ({}) is more than 2d".
1322
                                format(specification))
1323

1324
    if specification == AUTO_ASSIGN_MATRIX:
1✔
1325
        if rows == cols:
1✔
1326
            specification = IDENTITY_MATRIX
1✔
1327
        else:
1328
            specification = FULL_CONNECTIVITY_MATRIX
1✔
1329

1330
    if specification == FULL_CONNECTIVITY_MATRIX:
1✔
1331
        return np.full((rows, cols), 1.0)
1✔
1332

1333
    if specification == ZEROS_MATRIX:
1✔
1334
        return np.zeros((rows, cols))
1✔
1335

1336
    if specification == IDENTITY_MATRIX:
1✔
1337
        if rows != cols:
1✔
1338
            raise FunctionError("Sender length ({}) must equal receiver length ({}) to use {}".
1339
                                format(rows, cols, specification))
1340
        return np.identity(rows)
1✔
1341

1342
    if specification == HOLLOW_MATRIX:
1✔
1343
        if rows != cols:
1✔
1344
            raise FunctionError("Sender length ({}) must equal receiver length ({}) to use {}".
1345
                                format(rows, cols, specification))
1346
        return 1 - np.identity(rows)
1✔
1347

1348
    if specification == INVERSE_HOLLOW_MATRIX:
1✔
1349
        if rows != cols:
1✔
1350
            raise FunctionError("Sender length ({}) must equal receiver length ({}) to use {}".
1351
                                format(rows, cols, specification))
1352
        return (1 - np.identity(rows)) * -1
1✔
1353

1354
    if specification == RANDOM_CONNECTIVITY_MATRIX:
1✔
1355
        return np.random.rand(rows, cols)
1✔
1356

1357
    # Function is specified, so assume it uses random.rand() and call with sender_len and receiver_len
1358
    if isinstance(specification, (types.FunctionType, RandomMatrix)):
1!
1359
        return specification(rows, cols)
×
1360

1361
    # (7/12/17 CW) this is a PATCH (like the one in MappingProjection) to allow users to
1362
    # specify 'matrix' as a string (e.g. r = RecurrentTransferMechanism(matrix='1 2; 3 4'))
1363
    if type(specification) == str:
1✔
1364
        try:
1✔
1365
            return np.array(np.matrix(specification))
1✔
1366
        except (ValueError, NameError, TypeError):
1✔
1367
            # np.matrix(specification) will give ValueError if specification is a bad value (e.g. 'abc', '1; 1 2')
1368
            #                          [JDC] actually gives NameError if specification is a string (e.g., 'abc')
1369
            pass
1✔
1370

1371
    # Specification not recognized
1372
    return None
1✔
1373

1374

1375
# Valid types for a matrix specification, note this is does not ensure that ND arrays are 1D or 2D like the
1376
# above code does.
1377
ValidMatrixSpecType = Union[MatrixKeywordLiteral, Callable, str, NumericCollections, np.matrix]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc