• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.08
/psyneulink/core/components/functions/nonstateful/transformfunctions.py
1
#
2
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
3
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
4
#     http://www.apache.org/licenses/LICENSE-2.0
5
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
6
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7
# See the License for the specific language governing permissions and limitations under the License.
8
#
9
#
10
# *****************************************  COMBINATION FUNCTIONS  ****************************************************
11

12
"""
13
* `Concatenate`
14
* `Rearrange`
15
* `Reduce`
16
* `LinearCombination`
17
* `CombineMeans`
18
* `MatrixTransform`
19
* `PredictionErrorDeltaFunction`
20

21
Overview
22
--------
23

24
Functions that combine multiple items with the same shape, yielding a result with a single item that has the same
25
shape as the individual items.
26

27
All Transformfunctions must have two attributes - **multiplicative_param** and **additive_param** -
28
each of which is assigned the name of one of the function's parameters;
29
this is for use by ModulatoryProjections (and, in particular, GatingProjections,
30
when the TransformFunction is used as the function of an InputPort or OutputPort).
31

32

33
"""
34

35
import numbers
1✔
36
import types
1✔
37
import warnings
1✔
38

39
import numpy as np
1✔
40

41
try:
1✔
42
    import torch
1✔
43
except ImportError:
×
44
    torch = None
×
45
from beartype import beartype
1✔
46

47
from psyneulink._typing import Optional, Union, Literal
1✔
48

49
from psyneulink.core import llvm as pnlvm
1✔
50
from psyneulink.core.components.functions import function
1✔
51
from psyneulink.core.components.functions.function import (
1✔
52
    Function_Base, FunctionError, FunctionOutputType, function_keywords, get_matrix)
53
from psyneulink.core.components.shellclasses import Projection
1✔
54
from psyneulink.core.globals.keywords import (
1✔
55
    ADDITIVE_PARAM, ARRANGEMENT, COMBINATION_FUNCTION_TYPE, COMBINE_MEANS_FUNCTION, CONCATENATE_FUNCTION,
56
     CROSS_ENTROPY, DEFAULT_VARIABLE, DOT_PRODUCT, EXPONENTS,
57
     HAS_INITIALIZERS, HOLLOW_MATRIX, IDENTITY_MATRIX, LINEAR_COMBINATION_FUNCTION, L0,
58
     MATRIX, MATRIX_KEYWORD_NAMES, MATRIX_TRANSFORM_FUNCTION,  MULTIPLICATIVE_PARAM, NORMALIZE,
59
     OFFSET, OPERATION, PREDICTION_ERROR_DELTA_FUNCTION, PRODUCT,
60
     REARRANGE_FUNCTION, RECEIVER, REDUCE_FUNCTION, SCALE, SUM, WEIGHTS, PREFERENCE_SET_NAME)
61
from psyneulink.core.globals.utilities import (
1✔
62
    convert_all_elements_to_np_array, convert_to_np_array, is_numeric, is_matrix_keyword, is_numeric_scalar,
63
    np_array_less_than_2d, ValidParamSpecType)
64
from psyneulink.core.globals.context import ContextFlags, handle_external_context
1✔
65
from psyneulink.core.globals.parameters import Parameter, check_user_specified, copy_parameter_value
1✔
66
from psyneulink.core.globals.preferences.basepreferenceset import \
1✔
67
    REPORT_OUTPUT_PREF, ValidPrefSet, PreferenceEntry, PreferenceLevel
68

69
__all__ = ['TransformFunction', 'Concatenate', 'CombineMeans', 'Rearrange', 'Reduce',
1✔
70
           'LinearCombination', 'MatrixTransform', 'PredictionErrorDeltaFunction']
71

72
class TransformFunction(Function_Base):
1✔
73
    """Function that combines multiple items, yielding a result with the same shape as its operands
74

75
    All Transformfunctions must have two attributes - multiplicative_param and additive_param -
76
        each of which is assigned the name of one of the function's parameters;
77
        this is for use by ModulatoryProjections (and, in particular, GatingProjections,
78
        when the TransformFunction is used as the function of an InputPort or OutputPort).
79

80
    """
81
    componentType = COMBINATION_FUNCTION_TYPE
1✔
82

83
    class Parameters(Function_Base.Parameters):
1✔
84
        """
85
            Attributes
86
            ----------
87

88
                variable
89
                    see `variable <TransformFunction.variable>`
90

91
                    :default value: numpy.array([0])
92
                    :type: ``numpy.ndarray``
93
                    :read only: True
94
        """
95
        # variable = np.array([0, 0])
96
        variable = Parameter(np.array([0]), read_only=True, pnl_internal=True, constructor_argument='default_variable')
1✔
97

98
    def _gen_llvm_load_param(self, ctx, builder, params, param_name, index, default):
1✔
99
        param_ptr = ctx.get_param_or_state_ptr(builder, self, param_name, param_struct_ptr=params)
1✔
100
        param_type = param_ptr.type.pointee
1✔
101
        if isinstance(param_type, pnlvm.ir.LiteralStructType):
1✔
102
            assert len(param_type) == 0
1✔
103
            return ctx.float_ty(default)
1✔
104

105
        elif isinstance(param_type, pnlvm.ir.ArrayType):
1✔
106
            index = ctx.int32_ty(0) if len(param_type) == 1 else index
1✔
107
            param_ptr = builder.gep(param_ptr, [ctx.int32_ty(0), index])
1✔
108

109
        return builder.load(param_ptr)
1✔
110

111
    def _gen_llvm_function_body(self, ctx, builder, params, _, arg_in, arg_out, *, tags:frozenset):
1✔
112
        # Sometimes we arg_out to 2d array
113
        arg_out = pnlvm.helpers.unwrap_2d_array(builder, arg_out)
1✔
114

115
        with pnlvm.helpers.array_ptr_loop(builder, arg_out, "linear") as args:
1✔
116
            self._gen_llvm_combine(ctx=ctx, vi=arg_in, vo=arg_out, params=params, *args)
1✔
117
        return builder
1✔
118

119

120
class Concatenate(TransformFunction):  # ------------------------------------------------------------------------
1✔
121
    """
122
    Concatenate(                                   \
123
         default_variable=class_defaults.variable, \
124
         scale=1.0,                                \
125
         offset=0.0,                               \
126
         params=None,                              \
127
         owner=None,                               \
128
         prefs=None,                               \
129
    )
130

131
    .. _Concatenate:
132

133
    Concatenates items in outer dimension (axis 0) of `variable <Concatenate.variable>` into a single array,
134
    optionally scaling and/or adding an offset to the result after concatenating.
135

136
    `function <Concatenate.function>` returns a 1d array with length equal to the sum of the lengths of the items
137
    in `variable <Concatenate.variable>`.
138

139
    `derivative <Concatenate.derivative>` returns `scale <Concatenate.slope>`.
140

141

142
    Arguments
143
    ---------
144

145
    default_variable : list or np.array : default class_defaults.variable
146
        specifies a template for the value to be transformed and its default value;  all entries must be numeric.
147

148
    scale : float
149
        specifies a value by which to multiply each element of the output of `function <Concatenate.function>`
150
        (see `scale <Concatenate.scale>` for details)
151

152
    offset : float
153
        specifies a value to add to each element of the output of `function <Concatenate.function>`
154
        (see `offset <Concatenate.offset>` for details)
155

156
    params : Dict[param keyword: param value] : default None
157
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
158
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
159
        arguments of the constructor.
160

161
    owner : Component
162
        `component <Component>` to which to assign the Function.
163

164
    name : str : default see `name <Function.name>`
165
        specifies the name of the Function.
166

167
    prefs : PreferenceSet or specification dict : default Function.classPreferences
168
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
169

170
    Attributes
171
    ----------
172

173
    default_variable : list or np.array
174
        contains template of array(s) to be concatenated.
175

176
    scale : float
177
        value is applied multiplicatively to each element of the concatenated, before  applying the `offset
178
        <Concatenate.offset>` (if it is specified).
179

180
    offset : float
181
        value is added to each element of the concatentated array, after `scale <Concatenate.scale>` has been
182
        applied (if it is specified).
183

184
    owner : Component
185
        `component <Component>` to which the Function has been assigned.
186

187
    name : str
188
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
189
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
190

191
    prefs : PreferenceSet or specification dict : Function.classPreferences
192
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
193
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences` for
194
        details).
195
    """
196
    componentName = CONCATENATE_FUNCTION
1✔
197

198

199
    class Parameters(TransformFunction.Parameters):
1✔
200
        """
201
            Attributes
202
            ----------
203

204
                changes_shape
205
                    see `changes_shape <Function_Base.changes_shape>`
206

207
                    :default value: True
208
                    :type: bool
209

210
                offset
211
                    see `offset <Concatenate.offset>`
212

213
                    :default value: 0.0
214
                    :type: ``float``
215

216
                scale
217
                    see `scale <Concatenate.scale>`
218

219
                    :default value: 1.0
220
                    :type: ``float``
221
        """
222
        scale = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
223
        offset = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
224
        changes_shape = Parameter(True, stateful=False, loggable=False, pnl_internal=True)
1✔
225

226
    @check_user_specified
1✔
227
    @beartype
1✔
228
    def __init__(self,
1✔
229
                 default_variable=None,
230
                 scale: Optional[ValidParamSpecType] = None,
231
                 offset: Optional[ValidParamSpecType] = None,
232
                 params=None,
233
                 owner=None,
234
                 prefs:  Optional[ValidPrefSet] = None):
235

236
        super().__init__(
1✔
237
            default_variable=default_variable,
238
            scale=scale,
239
            offset=offset,
240
            params=params,
241
            owner=owner,
242
            prefs=prefs,
243
        )
244

245
    def _validate_variable(self, variable, context=None):
1✔
246
        """Insure that list or array is 1d and that all elements are numeric
247

248
        Args:
249
            variable:
250
            context:
251
        """
252
        variable = super()._validate_variable(variable=variable, context=context)
1✔
253
        if not is_numeric(variable):
1✔
254
            raise FunctionError("All elements of {} must be scalar values".
255
                                format(self.__class__.__name__))
256
        return variable
1✔
257

258
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
259
        """Validate scale and offset parameters
260

261
        Check that SCALE and OFFSET are scalars.
262
        """
263

264
        super()._validate_params(request_set=request_set,
1✔
265
                                 target_set=target_set,
266
                                 context=context)
267

268
        if SCALE in target_set and target_set[SCALE] is not None:
1!
269
            scale = target_set[SCALE]
1✔
270
            if not is_numeric_scalar(scale):
1✔
271
                raise FunctionError("{} param of {} ({}) must be a scalar".format(SCALE, self.name, scale))
272

273
        if OFFSET in target_set and target_set[OFFSET] is not None:
1!
274
            offset = target_set[OFFSET]
1✔
275
            if not is_numeric_scalar(offset):
1✔
276
                raise FunctionError("{} param of {} ({}) must be a scalar".format(OFFSET, self.name, offset))
277

278
    def _function(self,
1✔
279
                 variable=None,
280
                 context=None,
281
                 params=None,
282
                 ):
283
        """Use numpy hstack to concatenate items in outer dimension (axis 0) of variable.
284

285
        Arguments
286
        ---------
287

288
        variable : list or np.array : default class_defaults.variable
289
           a list or np.array of numeric values.
290

291
        params : Dict[param keyword: param value] : default None
292
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
293
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
294
            arguments of the constructor.
295

296

297
        Returns
298
        -------
299

300
        Concatenated array of items in variable : array
301
            in an array that is one dimension less than `variable <Concatenate.variable>`.
302

303
        """
304
        scale = self._get_current_parameter_value(SCALE, context)
1✔
305
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
306

307
        result = np.hstack(variable) * scale + offset
1✔
308

309
        return self.convert_output_type(result)
1✔
310

311
    @handle_external_context()
1✔
312
    def derivative(self, input=None, output=None, covariates=None, context=None):
1✔
313
        """
314
        derivative(input)
315

316
        Derivative of `function <Concatenate._function>` at **input**.
317

318
        Arguments
319
        ---------
320

321
        input : number
322
            value of the input to the function at which derivative is to be taken.
323

324
        covariates : 2d np.array : default class_defaults.variable[1:]
325
            the input(s) to the Concatenate function other than the one for which the derivative is being
326
            computed;  these are ignored and are accepted for consistency with other functions.
327

328
        Returns
329
        -------
330

331
        Scale of function :  number or array
332

333
        """
334

335
        return self._get_current_parameter_value(SCALE, context)
×
336

337
    def _gen_pytorch_fct(self, device, context=None):
1✔
338
        scale = self._get_pytorch_fct_param_value('scale', device, context)
×
339
        offset = self._get_pytorch_fct_param_value('offset', device, context)
×
340
        # return lambda x: torch.concatenate(tuple(x)) * scale + offset
341
        return lambda x: torch.hstack(tuple(x)) * scale + offset
×
342

343

344
class Rearrange(TransformFunction):  # ------------------------------------------------------------------------
1✔
345
    """
346
    Rearrange(                                     \
347
         default_variable=class_defaults.variable, \
348
         arrangement=None,                         \
349
         scale=1.0,                                \
350
         offset=0.0,                               \
351
         params=None,                              \
352
         owner=None,                               \
353
         prefs=None,                               \
354
    )
355

356
    .. _Rearrange:
357

358
    Rearranges items in outer dimension (axis 0) of `variable <Rearrange.variable>`, as specified by **arrangement**,
359
    optionally scaling and/or adding an offset to the result after concatenating.
360

361
    .. _Rearrange_Arrangement:
362

363
    The **arrangement** argument specifies how to rearrange the items of `variable <Rearrange.variable>`, possibly
364
    concatenating subsets of them into single 1d arrays.  The specification must be an integer, a tuple of integers,
365
    or a list containing either or both.  Each integer must be an index of an item in the outer dimension (axis 0) of
366
    `variable <Rearrange.variable>`.  Items referenced in a tuple are concatenated in the order specified into a single
367
    1d array, and that 1d array is included in the resulting 2d array in the order it appears in **arrangement**.
368
    If **arrangement** is specified, then only the items of `variable <Rearrange.variable>` referenced in the
369
    specification are included in the result; if **arrangement** is not specified, all of the items of `variable
370
    <Rearrange.variable>` are concatenated into a single 1d array (i.e., it functions identically to `Concatenate`).
371

372
    `function <Rearrange.function>` returns a 2d array with the items of `variable` rearranged
373
    (and possibly concatenated) as specified by **arrangement**.
374

375
    Examples
376
    --------
377

378
    >>> r = Rearrange(arrangement=[(1,2),(0)])
379
    >>> print(r(np.array([[0,0],[1,1],[2,2]])))
380
    [array([1., 1., 2., 2.]) array([0., 0.])]
381

382
    >>> r = Rearrange()
383
    >>> print(r(np.array([[0,0],[1,1],[2,2]])))
384
    [0. 0. 1. 1. 2. 2.]
385

386

387
    Arguments
388
    ---------
389

390
    default_variable : list or np.array : default class_defaults.variable
391
        specifies a template for the value to be transformed and its default value;  all entries must be numeric.
392

393
    arrangement : int, tuple, or list : default None
394
        specifies ordering of items in `variable <Rearrange.variable>` and/or ones to concatenate.
395
        (see `above <Rearrange_Arrangement>` for details).
396

397
    scale : float
398
        specifies a value by which to multiply each element of the output of `function <Rearrange.function>`
399
        (see `scale <Rearrange.scale>` for details).
400

401
    offset : float
402
        specifies a value to add to each element of the output of `function <Rearrange.function>`
403
        (see `offset <Rearrange.offset>` for details).
404

405
    params : Dict[param keyword: param value] : default None
406
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
407
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
408
        arguments of the constructor.
409

410
    owner : Component
411
        `component <Component>` to which to assign the Function.
412

413
    name : str : default see `name <Function.name>`
414
        specifies the name of the Function.
415

416
    prefs : PreferenceSet or specification dict : default Function.classPreferences
417
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
418

419
    Attributes
420
    ----------
421

422
    default_variable : list or np.array
423
        contains template of array(s) to be concatenated.
424

425
    arrangement : list of one or more tuples
426
        determines ordering of items in `variable <Rearrange.variable>` and/or ones to concatenate
427
        (see `above <Rearrange_Arrangement>` for additional details).
428

429
    scale : float
430
        value is applied multiplicatively to each element of the concatenated, before  applying the `offset
431
        <Rearrange.offset>` (if it is specified).
432

433
    offset : float
434
        value is added to each element of the concatentated array, after `scale <Rearrange.scale>` has been
435
        applied (if it is specified).
436

437
    owner : Component
438
        `component <Component>` to which the Function has been assigned.
439

440
    name : str
441
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
442
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
443

444
    prefs : PreferenceSet or specification dict : Function.classPreferences
445
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
446
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
447
        for details).
448
    """
449
    componentName = REARRANGE_FUNCTION
1✔
450

451
    class Parameters(TransformFunction.Parameters):
1✔
452
        """
453
            Attributes
454
            ----------
455

456
                arrangement
457
                    see `arrangement <Rearrange_Arrangement>`
458

459
                    :default value: None
460
                    :type:
461

462
                offset
463
                    see `offset <Rearrange.offset>`
464

465
                    :default value: 0.0
466
                    :type: ``float``
467

468
                scale
469
                    see `scale <Rearrange.scale>`
470

471
                    :default value: 1.0
472
                    :type: ``float``
473
        """
474
        arrangement = Parameter(None, modulable=False)
1✔
475
        scale = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
476
        offset = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
477

478
    @check_user_specified
1✔
479
    @beartype
1✔
480
    def __init__(self,
1✔
481
                 default_variable=None,
482
                 scale: Optional[ValidParamSpecType] = None,
483
                 offset: Optional[ValidParamSpecType] = None,
484
                 arrangement:Optional[Union[int, tuple, list]]=None,
485
                 params=None,
486
                 owner=None,
487
                 prefs:  Optional[ValidPrefSet] = None):
488

489
        super().__init__(
1✔
490
            default_variable=default_variable,
491
            arrangement=arrangement,
492
            scale=scale,
493
            offset=offset,
494
            params=params,
495
            owner=owner,
496
            prefs=prefs,
497
        )
498

499
    def _validate_variable(self, variable, context=None):
1✔
500
        """Insure that all elements are numeric and that list or array is at least 2d
501
        """
502
        variable = super()._validate_variable(variable=variable, context=context)
1✔
503
        if not is_numeric(variable):
1✔
504
            raise FunctionError(
505
                    f"All elements of {repr(DEFAULT_VARIABLE)} for {self.__class__.__name__} must be scalar values.")
506

507
        if self.parameters.variable._user_specified and np.array(variable).ndim<2:
1✔
508
            raise FunctionError(f"{repr(DEFAULT_VARIABLE)} for {self.__class__.__name__} must be at least 2d.")
509

510
        return variable
1✔
511

512
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
513
        """Validate arrangement, scale and offset parameters"""
514

515
        super()._validate_params(request_set=request_set,
1✔
516
                                 target_set=target_set,
517
                                 context=context)
518

519
        if ARRANGEMENT in target_set and target_set[ARRANGEMENT] is not None:
1✔
520

521
            # If default_varilable was specified by user, validate indices in arrangement
522
            owner_str = ''
1✔
523
            if self.owner:
1!
524
                owner_str = f' of {self.owner.name}'
×
525
            for i in self._indices:
1✔
526
                if not isinstance(i, int):
1✔
527
                    raise FunctionError(f"Index specified in {repr(ARRANGEMENT)} arg for "
528
                                        f"{self.name}{owner_str} ({repr(i)}) is not an int.")
529
                if self.parameters.variable._user_specified:
1✔
530
                    try:
1✔
531
                        self.parameters.variable.default_value[i]
1✔
532
                    except IndexError:
1✔
533
                        raise FunctionError(f"Index ({i}) specified in {repr(ARRANGEMENT)} arg for "
534
                                            f"{self.name}{owner_str} is out of bounds for its {repr(DEFAULT_VARIABLE)} "
535
                                            f"arg (max index = {len(self.parameters.variable.default_value) - 1}).")
536

537
        # Check that SCALE and OFFSET are scalars.
538
        if SCALE in target_set and target_set[SCALE] is not None:
1!
539
            scale = target_set[SCALE]
1✔
540
            if not is_numeric_scalar(scale):
1✔
541
                raise FunctionError("{} param of {} ({}) must be a scalar".format(SCALE, self.name, scale))
542

543
        if OFFSET in target_set and target_set[OFFSET] is not None:
1!
544
            offset = target_set[OFFSET]
1✔
545
            if not is_numeric_scalar(offset):
1✔
546
                raise FunctionError("{} param of {} ({}) must be a scalar".format(OFFSET, self.name, offset))
547

548
    def _instantiate_attributes_before_function(self, function=None, context=None):
1✔
549
        """Insure all items of arrangement are tuples and compatibility with default_variable
550

551
        If arrangement is specified, convert all items to tuples
552
        If default_variable is NOT specified, assign with length in outer dimension = max index in arragnement
553
        If default_variable IS _user_specified, compatiblility with arrangement is checked in _validate_params
554
        """
555

556
        arrangement = self.parameters.arrangement.get()
1✔
557

558
        if arrangement is not None:
1✔
559
            # Insure that all items are tuples
560
            self.parameters.arrangement.set([item if isinstance(item,tuple) else tuple([item]) for item in arrangement])
1✔
561

562
        if not self.parameters.variable._user_specified:
1✔
563
            # Reshape variable.default_value to match maximum index specified in arrangement
564
            self.parameters.variable.default_value = np.zeros((max(self._indices) + 1, 1))
1✔
565

566
        super()._instantiate_attributes_before_function(function, context)
1✔
567

568
    @property
1✔
569
    def _indices(self):
1✔
570
        arrangement = list(self.parameters.arrangement.get())
1✔
571
        items = [list(item) if isinstance(item, tuple) else [item] for item in arrangement]
1✔
572
        indices = []
1✔
573
        for item in items:
1✔
574
            indices.extend(item)
1✔
575
        return indices
1✔
576

577
    def _function(self,
1✔
578
                 variable=None,
579
                 context=None,
580
                 params=None,
581
                 ):
582
        """Rearrange items in outer dimension (axis 0) of variable according to `arrangement <Rearrange.arrangement>`.
583

584
        Arguments
585
        ---------
586

587
        variable : list or np.array : default class_defaults.variable
588
           a list or np.array of numeric values.
589

590
        params : Dict[param keyword: param value] : default None
591
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
592
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
593
            arguments of the constructor.
594

595
        Returns
596
        -------
597

598
        Rearranged items of outer dimension (axis 0) of **variable** : array
599
            in a 2d array.
600
        """
601
        variable = np.atleast_2d(variable)
1✔
602

603
        scale = self._get_current_parameter_value(SCALE, context)
1✔
604
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
605
        arrangement = self.parameters.arrangement.get(context)
1✔
606

607
        if arrangement is None:
1!
608
            result = np.hstack(variable) * scale + offset
×
609

610
        else:
611
            try:
1✔
612
                result = []
1✔
613
                for item in arrangement:
1✔
614
                    stack = []
1✔
615
                    for index in item:
1✔
616
                        stack.append(variable[index])
1✔
617
                    result.append(np.hstack(tuple(stack)))
1✔
618
                result = convert_to_np_array(result) * scale + offset
1✔
619
            except IndexError:
×
620
                assert False, f"PROGRAM ERROR: Bad index specified in {repr(ARRANGEMENT)} arg -- " \
621
                    f"should have been caught in _validate_params or _instantiate_attributes_before_function"
622

623
        return self.convert_output_type(result, FunctionOutputType.NP_2D_ARRAY)
1✔
624

625

626
class Reduce(TransformFunction):  # ------------------------------------------------------------------------
1✔
627
    # FIX: CONFIRM THAT 1D KWEIGHTS USES EACH ELEMENT TO SCALE CORRESPONDING VECTOR IN VARIABLE
628
    # FIX  CONFIRM THAT LINEAR TRANSFORMATION (OFFSET, SCALE) APPLY TO THE RESULTING ARRAY
629
    # FIX: CONFIRM RETURNS LIST IF GIVEN LIST, AND SIMLARLY FOR NP.ARRAY
630
    """
631
    Reduce(                                       \
632
         default_variable=class_defaults.variable, \
633
         weights=None,                            \
634
         exponents=None,                          \
635
         operation=SUM,                           \
636
         scale=1.0,                               \
637
         offset=0.0,                              \
638
         params=None,                             \
639
         owner=None,                              \
640
         prefs=None,                              \
641
    )
642

643
    .. _Reduce:
644

645
    Combines values in each of one or more arrays into a single value for each array, with optional weighting and/or
646
    exponentiation of each item within an array prior to combining, and scaling and/or offset of result after combining.
647

648
    `function <Reduce.function>` returns an array of scalar values, one for each array in `variable <Reduce.variable>`.
649

650
    COMMENT:
651
        IMPLEMENTATION NOTE: EXTEND TO MULTIDIMENSIONAL ARRAY ALONG ARBITRARY AXIS
652
    COMMENT
653

654
    Arguments
655
    ---------
656

657
    default_variable : list or np.array : default class_defaults.variable
658
        specifies a template for the value to be transformed and its default value;  all entries must be numeric.
659

660
    weights : 1d or 2d np.array : default None
661
        specifies values used to multiply the elements of each array in `variable  <LinearCombination.variable>`.
662
        If it is 1d, its length must equal the number of items in `variable <LinearCombination.variable>`;
663
        if it is 2d, the length of each item must be the same as those in `variable <LinearCombination.variable>`,
664
        and there must be the same number of items as there are in `variable <LinearCombination.variable>`
665
        (see `weights <LinearCombination.weights>` for details)
666

667
    exponents : 1d or 2d np.array : default None
668
        specifies values used to exponentiate the elements of each array in `variable  <LinearCombination.variable>`.
669
        If it is 1d, its length must equal the number of items in `variable <LinearCombination.variable>`;
670
        if it is 2d, the length of each item must be the same as those in `variable <LinearCombination.variable>`,
671
        and there must be the same number of items as there are in `variable <LinearCombination.variable>`
672
        (see `exponents <LinearCombination.exponents>` for details)
673

674
    operation : SUM or PRODUCT : default SUM
675
        specifies whether to sum or multiply the elements in `variable <Reduce.function.variable>` of
676
        `function <Reduce.function>`.
677

678
    scale : float
679
        specifies a value by which to multiply each element of the output of `function <Reduce.function>`
680
        (see `scale <Reduce.scale>` for details)
681

682
    offset : float
683
        specifies a value to add to each element of the output of `function <Reduce.function>`
684
        (see `offset <Reduce.offset>` for details)
685

686
    params : Dict[param keyword: param value] : default None
687
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
688
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
689
        arguments of the constructor.
690

691
    owner : Component
692
        `component <Component>` to which to assign the Function.
693

694
    name : str : default see `name <Function.name>`
695
        specifies the name of the Function.
696

697
    prefs : PreferenceSet or specification dict : default Function.classPreferences
698
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
699

700
    Attributes
701
    ----------
702

703
    default_variable : list or np.array
704
        contains array(s) to be reduced.
705

706
    operation : SUM or PRODUCT
707
        determines whether elements of each array in `variable <Reduce.function.variable>` of
708
        `function <Reduce.function>` are summmed or multiplied.
709

710
    scale : float
711
        value is applied multiplicatively to each element of the array after applying the `operation <Reduce.operation>`
712
        (see `scale <Reduce.scale>` for details);  this done before applying the `offset <Reduce.offset>`
713
        (if it is specified).
714

715
    offset : float
716
        value is added to each element of the array after applying the `operation <Reduce.operation>`
717
        and `scale <Reduce.scale>` (if it is specified).
718

719
    owner : Component
720
        `component <Component>` to which the Function has been assigned.
721

722
    name : str
723
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
724
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
725

726
    prefs : PreferenceSet or specification dict : Function.classPreferences
727
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
728
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
729
        for details).
730
    """
731
    componentName = REDUCE_FUNCTION
1✔
732

733

734
    class Parameters(TransformFunction.Parameters):
1✔
735
        """
736
            Attributes
737
            ----------
738

739
                exponents
740
                    see `exponents <Reduce.exponents>`
741

742
                    :default value: None
743
                    :type:
744

745
                changes_shape
746
                    see `changes_shape <Function_Base.changes_shape>`
747

748
                    :default value: True
749
                    :type: bool
750

751
                offset
752
                    see `offset <Reduce.offset>`
753

754
                    :default value: 0.0
755
                    :type: ``float``
756

757
                operation
758
                    see `operation <Reduce.operation>`
759

760
                    :default value: `SUM`
761
                    :type: ``str``
762

763
                scale
764
                    see `scale <Reduce.scale>`
765

766
                    :default value: 1.0
767
                    :type: ``float``
768

769
                weights
770
                    see `weights <Reduce.weights>`
771

772
                    :default value: None
773
                    :type:
774
        """
775
        weights = None
1✔
776
        exponents = None
1✔
777
        operation = SUM
1✔
778
        scale = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
779
        offset = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
780
        changes_shape = Parameter(True, stateful=False, loggable=False, pnl_internal=True)
1✔
781

782
        def _validate_scale(self, scale):
1✔
783
            if not is_numeric_scalar(scale):
1✔
784
                return "scale must be a scalar"
1✔
785

786
        def _validate_offset(self, offset):
1✔
787
            if not is_numeric_scalar(offset):
1✔
788
                return "vector offset is not supported"
1✔
789

790

791
    @check_user_specified
1✔
792
    @beartype
1✔
793
    def __init__(self,
1✔
794
                 # weights:  Optional[ValidParamSpecType] = None,
795
                 # exponents:  Optional[ValidParamSpecType] = None,
796
                 weights=None,
797
                 exponents=None,
798
                 default_variable=None,
799
                 operation: Optional[Literal['sum', 'product']] = None,
800
                 scale: Optional[ValidParamSpecType] = None,
801
                 offset: Optional[ValidParamSpecType] = None,
802
                 params=None,
803
                 owner=None,
804
                 prefs:  Optional[ValidPrefSet] = None):
805

806
        super().__init__(
1✔
807
            default_variable=default_variable,
808
            weights=weights,
809
            exponents=exponents,
810
            operation=operation,
811
            scale=scale,
812
            offset=offset,
813
            params=params,
814
            owner=owner,
815
            prefs=prefs,
816
        )
817

818
    def _validate_variable(self, variable, context=None):
1✔
819
        """Insure that list or array is 1d and that all elements are numeric
820

821
        Args:
822
            variable:
823
            context:
824
        """
825
        variable = super()._validate_variable(variable=variable, context=context)
1✔
826
        if not is_numeric(variable):
1!
827
            if self.owner:
×
828
                err_msg = f"{self.__class__.__name__} function of {repr(self.owner.name)} " \
×
829
                          f"passed variable ({variable}) with non-scalar element."
830
            else:
831
                err_msg = f"All elements of variable ({variable}) for {self.__class__.__name__} must be scalar values."
×
832
            raise FunctionError(err_msg)
833
        return variable
1✔
834

835
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
836
        """Validate weghts, exponents, scale and offset parameters
837

838
        Check that WEIGHTS and EXPONENTS are lists or np.arrays of numbers with length equal to variable.
839
        Check that SCALE and OFFSET are scalars.
840

841
        Note: the checks of compatibility with variable are only performed for validation calls during execution
842
              (i.e., from check_args(), since during initialization or COMMAND_LINE assignment,
843
              a parameter may be re-assigned before variable assigned during is known
844
        """
845

846
        super()._validate_params(request_set=request_set,
1✔
847
                                 target_set=target_set,
848
                                 context=context)
849

850
        if WEIGHTS in target_set and target_set[WEIGHTS] is not None:
1✔
851
            self._validate_parameter_spec(target_set[WEIGHTS], WEIGHTS, numeric_only=True)
1✔
852
            target_set[WEIGHTS] = np.atleast_1d(target_set[WEIGHTS])
1✔
853
            if context.execution_phase & (ContextFlags.EXECUTING | ContextFlags.LEARNING):
1!
854
                if len(target_set[WEIGHTS]) != len(self.defaults.variable):
×
855
                    raise FunctionError("Number of weights ({0}) is not equal to number of elements in variable ({1})".
856
                                        format(len(target_set[WEIGHTS]), len(self.defaults.variable)))
857

858
        if EXPONENTS in target_set and target_set[EXPONENTS] is not None:
1✔
859
            self._validate_parameter_spec(target_set[EXPONENTS], EXPONENTS, numeric_only=True)
1✔
860
            target_set[EXPONENTS] = np.atleast_1d(target_set[EXPONENTS])
1✔
861
            if context.execution_phase & (ContextFlags.EXECUTING | ContextFlags.LEARNING):
1!
862
                if len(target_set[EXPONENTS]) != len(self.defaults.variable):
×
863
                    raise FunctionError("Number of exponents ({0}) does not equal number of elements in variable ({1})".
864
                                        format(len(target_set[EXPONENTS]), len(self.defaults.variable)))
865

866
        if SCALE in target_set and target_set[SCALE] is not None:
1!
867
            scale = target_set[SCALE]
1✔
868
            if not is_numeric_scalar(scale):
1✔
869
                raise FunctionError("{} param of {} ({}) must be a scalar".format(SCALE, self.name, scale))
870

871
        if OFFSET in target_set and target_set[OFFSET] is not None:
1!
872
            offset = target_set[OFFSET]
1✔
873
            if not is_numeric_scalar(offset):
1✔
874
                raise FunctionError("{} param of {} ({}) must be a scalar".format(OFFSET, self.name, offset))
875

876
    def _function(self,
1✔
877
                 variable=None,
878
                 context=None,
879
                 params=None,
880
                 ):
881
        """
882

883
        Arguments
884
        ---------
885

886
        variable : list or np.array : default class_defaults.variable
887
           a list or np.array of numeric values.
888

889
        params : Dict[param keyword: param value] : default None
890
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
891
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
892
            arguments of the constructor.
893

894

895
        Returns
896
        -------
897

898
        Sum or product of arrays in variable : array
899
            in an array that is one dimension less than `variable <Reduce.variable>`.
900

901

902
        """
903
        weights = self._get_current_parameter_value(WEIGHTS, context)
1✔
904
        exponents = self._get_current_parameter_value(EXPONENTS, context)
1✔
905
        operation = self._get_current_parameter_value(OPERATION, context)
1✔
906
        scale = self._get_current_parameter_value(SCALE, context)
1✔
907
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
908

909
        # FIX FOR EFFICIENCY: CHANGE THIS AND WEIGHTS TO TRY/EXCEPT // OR IS IT EVEN NECESSARY, GIVEN VALIDATION ABOVE??
910
        # Apply exponents if they were specified
911
        if exponents is not None:
1✔
912
            # Avoid divide by zero warning:
913
            #    make sure there are no zeros for an element that is assigned a negative exponent
914
            # Allow during initialization because 0s are common in default_variable argument
915
            if self.is_initializing:
1✔
916
                with np.errstate(divide='raise'):
1✔
917
                    try:
1✔
918
                        variable = variable ** exponents
1✔
919
                    except FloatingPointError:
×
920
                        variable = np.ones_like(variable)
×
921
            else:
922
                # if this fails with FloatingPointError it should not be caught outside of initialization
923
                variable = variable ** exponents
1✔
924

925
        # Apply weights if they were specified
926
        if weights is not None:
1✔
927
            variable = variable * weights
1✔
928

929
        # Calculate using relevant aggregation operation and return
930
        if operation == SUM:
1✔
931
            # result = np.sum(np.atleast_2d(variable), axis=0) * scale + offset
932
            result = np.sum(np.atleast_2d(variable), axis=1) * scale + offset
1✔
933
        elif operation == PRODUCT:
1✔
934
            result = np.prod(np.atleast_2d(variable), axis=1) * scale + offset
1✔
935
        else:
936
            raise FunctionError("Unrecognized operator ({0}) for Reduce function".
937
                                format(self._get_current_parameter_value(OPERATION, context)))
938

939
        return self.convert_output_type(result)
1✔
940

941
    def _get_input_struct_type(self, ctx):
1✔
942
        # FIXME: Workaround a special case of simple array.
943
        #        It should just pass through to modifiers, which matches what
944
        #        single element 2d array does
945
        default_var = np.atleast_2d(self.defaults.variable)
1✔
946
        return ctx.convert_python_struct_to_llvm_ir(default_var)
1✔
947

948
    def _gen_llvm_combine(self, builder, index, ctx, vi, vo, params):
1✔
949
        scale = self._gen_llvm_load_param(ctx, builder, params, SCALE, index, 1.0)
1✔
950
        offset = self._gen_llvm_load_param(ctx, builder, params, OFFSET, index, -0.0)
1✔
951

952
        # assume operation does not change dynamically
953
        operation = self.parameters.operation.get()
1✔
954
        if operation == SUM:
1✔
955
            val = ctx.float_ty(-0.0)
1✔
956
            comb_op = "fadd"
1✔
957
        elif operation == PRODUCT:
1✔
958
            val = ctx.float_ty(1.0)
1✔
959
            comb_op = "fmul"
1✔
960
        else:
961
            assert False, "Unknown operation: {}".format(operation)
962

963
        val_p = builder.alloca(val.type, name="reduced_result")
1✔
964
        builder.store(val, val_p)
1✔
965

966
        pow_f = ctx.get_builtin("pow", [ctx.float_ty])
1✔
967

968
        vi = builder.gep(vi, [ctx.int32_ty(0), index])
1✔
969
        with pnlvm.helpers.array_ptr_loop(builder, vi, "reduce") as (b, idx):
1✔
970
            ptri = b.gep(vi, [ctx.int32_ty(0), idx])
1✔
971
            in_val = b.load(ptri)
1✔
972

973
            exponent = self._gen_llvm_load_param(ctx, b, params, EXPONENTS,
1✔
974
                                                 index, 1.0)
975
            # Vector of vectors (even 1-element vectors)
976
            if isinstance(exponent.type, pnlvm.ir.ArrayType):
1✔
977
                assert len(exponent.type) == 1 # FIXME: Add support for matrix weights
1✔
978
                exponent = b.extract_value(exponent, [0])
1✔
979
            # FIXME: Remove this micro-optimization,
980
            #        it should be handled by the compiler
981
            if not isinstance(exponent, pnlvm.ir.Constant) or exponent.constant != 1.0:
1✔
982
                in_val = b.call(pow_f, [in_val, exponent])
1✔
983

984
            # Try per element weights first
985
            weight = self._gen_llvm_load_param(ctx, b, params, WEIGHTS,
1✔
986
                                               idx, 1.0)
987

988
            # Vector of vectors (even 1-element vectors)
989
            if isinstance(weight.type, pnlvm.ir.ArrayType):
1✔
990
                weight = self._gen_llvm_load_param(ctx, b, params, WEIGHTS,
1✔
991
                                                   index, 1.0)
992
                assert len(weight.type) == 1 # FIXME: Add support for matrix weights
1✔
993
                weight = b.extract_value(weight, [0])
1✔
994

995
            in_val = b.fmul(in_val, weight)
1✔
996

997
            val = b.load(val_p)
1✔
998
            val = getattr(b, comb_op)(val, in_val)
1✔
999
            b.store(val, val_p)
1✔
1000

1001
        val = b.load(val_p)
1✔
1002
        val = builder.fmul(val, scale)
1✔
1003
        val = builder.fadd(val, offset)
1✔
1004

1005
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
1006
        builder.store(val, ptro)
1✔
1007

1008

1009
class LinearCombination(
1✔
1010
    TransformFunction):  # ------------------------------------------------------------------------
1011
    """
1012
    LinearCombination(     \
1013
         default_variable, \
1014
         weights=None,     \
1015
         exponents=None,   \
1016
         operation=SUM,    \
1017
         scale=None,       \
1018
         offset=None,      \
1019
         params=None,      \
1020
         owner=None,       \
1021
         name=None,        \
1022
         prefs=None        \
1023
         )
1024

1025
    .. _LinearCombination:
1026

1027
    Linearly combine arrays of values, optionally weighting and/or exponentiating each array prior to combining,
1028
    and scaling and/or offsetting after combining.
1029

1030
    `function <LinearCombination.function>` combines the arrays in the outermost dimension (axis 0) of `variable
1031
    <LinearCombination.variable>` either additively or multiplicatively (as specified by `operation
1032
    <LinearCombination.operation>`), applying `weights <LinearCombination.weights>` and/or `exponents
1033
    <LinearCombination.exponents>` (if specified) to each array prior to combining them, and applying `scale
1034
    <LinearCombination.scale>` and/or `offeset <LinearCombination.offset>` (if specified) to the result after
1035
    combining, and returns an array of the same length as the operand arrays.
1036

1037
    COMMENT:
1038
        Description:
1039
            Combine corresponding elements of arrays in variable arg, using arithmetic operation determined by OPERATION
1040
            Use optional SCALE and OFFSET parameters to linearly transform the resulting array
1041
            Returns a list or 1D array of the same length as the individual ones in the variable
1042

1043
            Notes:
1044
            * If variable contains only a single array, it is simply linearly transformed using SCALE and OFFSET
1045
            * If there is more than one array in variable, they must all be of the same length
1046
            * WEIGHTS and EXPONENTS can be:
1047
                - 1D: each array in variable is scaled by the corresponding element of WEIGHTS or EXPONENTS
1048
                - 2D: each array in variable is scaled by (Hadamard-wise) corresponding array of WEIGHTS or EXPONENTS
1049
        Initialization arguments:
1050
         - variable (value, np.ndarray or list): values to be combined;
1051
             can be a list of lists, or a 1D or 2D np.array;  a 1D np.array is always returned
1052
             if it is a list, it must be a list of numbers, lists, or np.arrays
1053
             all items in the list or 2D np.array must be of equal length
1054
             + WEIGHTS (list of numbers or 1D np.array): multiplies each item of variable before combining them
1055
                  (default: [1,1])
1056
             + EXPONENTS (list of numbers or 1D np.array): exponentiates each item of variable before combining them
1057
                  (default: [1,1])
1058
         - params (dict) can include:
1059
             + WEIGHTS (list of numbers or 1D np.array): multiplies each variable before combining them (default: [1,1])
1060
             + OFFSET (value): added to the result (after the arithmetic operation is applied; default is 0)
1061
             + SCALE (value): multiples the result (after combining elements; default: 1)
1062
             + OPERATION (Operation Enum) - method used to combine terms (default: SUM)
1063
                  SUM: element-wise sum of the arrays in variable
1064
                  PRODUCT: Hadamard Product of the arrays in variable
1065

1066
        LinearCombination.function returns combined values:
1067
        - single number if variable was a single number
1068
        - list of numbers if variable was list of numbers
1069
        - 1D np.array if variable was a single np.variable or np.ndarray
1070
    COMMENT
1071

1072
    Arguments
1073
    ---------
1074

1075
    variable : 1d or 2d np.array : default class_defaults.variable
1076
        specifies a template for the arrays to be combined.  If it is 2d, all items must have the same length.
1077

1078
    weights : scalar or 1d or 2d np.array : default None
1079
        specifies values used to multiply the elements of each array in **variable**.
1080
        If it is 1d, its length must equal the number of items in `variable <LinearCombination.variable>`;
1081
        if it is 2d, the length of each item must be the same as those in `variable <LinearCombination.variable>`,
1082
        and there must be the same number of items as there are in `variable <LinearCombination.variable>`
1083
        (see `weights <LinearCombination.weights>` for details of how weights are applied).
1084

1085
    exponents : scalar or 1d or 2d np.array : default None
1086
        specifies values used to exponentiate the elements of each array in `variable  <LinearCombination.variable>`.
1087
        If it is 1d, its length must equal the number of items in `variable <LinearCombination.variable>`;
1088
        if it is 2d, the length of each item must be the same as those in `variable <LinearCombination.variable>`,
1089
        and there must be the same number of items as there are in `variable <LinearCombination.variable>`
1090
        (see `exponents <LinearCombination.exponents>` for details of how exponents are applied).
1091

1092
    operation : SUM, PRODUCT or CROSS_ENTROPY : default SUM
1093
        specifies whether the `function <LinearCombination.function>` takes the elementwise (Hadamarad)
1094
        sum, product or cross entropy of the arrays in `variable  <LinearCombination.variable>`.
1095

1096
    scale : float or np.ndarray : default None
1097
        specifies a value by which to multiply each element of the result of `function <LinearCombination.function>`
1098
        (see `scale <LinearCombination.scale>` for details)
1099

1100
    offset : float or np.ndarray : default None
1101
        specifies a value to add to each element of the result of `function <LinearCombination.function>`
1102
        (see `offset <LinearCombination.offset>` for details)
1103

1104
    params : Dict[param keyword: param value] : default None
1105
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1106
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1107
        arguments of the constructor.
1108

1109
    owner : Component
1110
        `component <Component>` to which to assign the Function.
1111

1112
    name : str : default see `name <Function.name>`
1113
        specifies the name of the Function.
1114

1115
    prefs : PreferenceSet or specification dict : default Function.classPreferences
1116
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
1117

1118
    Attributes
1119
    ----------
1120

1121
    variable : 1d or 2d np.array
1122
        contains the arrays to be combined by `function <LinearCombination>`.  If it is 1d, the array is simply
1123
        linearly transformed by and `scale <LinearCombination.scale>` and `offset <LinearCombination.scale>`.
1124
        If it is 2d, the arrays (all of which must be of equal length) are weighted and/or exponentiated as
1125
        specified by `weights <LinearCombination.weights>` and/or `exponents <LinearCombination.exponents>`
1126
        and then combined as specified by `operation <LinearCombination.operation>`.
1127

1128
    weights : scalar or 1d or 2d np.array
1129
        if it is a scalar, the value is used to multiply all elements of all arrays in `variable
1130
        <LinearCombination.variable>`; if it is a 1d array, each element is used to multiply all elements in the
1131
        corresponding array of `variable <LinearCombination.variable>`;  if it is a 2d array, then each array is
1132
        multiplied elementwise (i.e., the Hadamard Product is taken) with the corresponding array of `variable
1133
        <LinearCombinations.variable>`. All `weights` are applied before any exponentiation (if it is specified).
1134

1135
    exponents : scalar or 1d or 2d np.array
1136
        if it is a scalar, the value is used to exponentiate all elements of all arrays in `variable
1137
        <LinearCombination.variable>`; if it is a 1d array, each element is used to exponentiate the elements of the
1138
        corresponding array of `variable <LinearCombinations.variable>`;  if it is a 2d array, the element of each
1139
        array is used to exponentiate the corresponding element of the corresponding array of `variable
1140
        <LinearCombination.variable>`. In either case, all exponents are applied after application of the `weights
1141
        <LinearCombination.weights>` (if any are specified).
1142

1143
    operation : SUM or PRODUCT
1144
        determines whether the `function <LinearCombination.function>` takes the elementwise (Hadamard) sum,
1145
        product, or cross entropy of the arrays in `variable  <LinearCombination.variable>`.
1146

1147
    scale : float or np.ndarray
1148
        value is applied multiplicatively to each element of the array after applying the
1149
        `operation <LinearCombination.operation>` (see `scale <LinearCombination.scale>` for details);
1150
        this done before applying the `offset <LinearCombination.offset>` (if it is specified).
1151

1152
    offset : float or np.ndarray
1153
        value is added to each element of the array after applying the `operation <LinearCombination.operation>`
1154
        and `scale <LinearCombination.scale>` (if it is specified).
1155

1156
    COMMENT:
1157
    function : function
1158
        applies the `weights <LinearCombination.weights>` and/or `exponents <LinearCombinations.weights>` to the
1159
        arrays in `variable <LinearCombination.variable>`, then takes their sum or product (as specified by
1160
        `operation <LinearCombination.operation>`), and finally applies `scale <LinearCombination.scale>` and/or
1161
        `offset <LinearCombination.offset>`.
1162

1163
    enable_output_type_conversion : Bool : False
1164
        specifies whether `function output type conversion <Function_Output_Type_Conversion>` is enabled.
1165

1166
    output_type : FunctionOutputType : None
1167
        used to specify the return type for the `function <Function_Base.function>`;  `functionOuputTypeConversion`
1168
        must be enabled and implemented for the class (see `FunctionOutputType <Function_Output_Type_Conversion>`
1169
        for details).
1170
    COMMENT
1171

1172
    owner : Component
1173
        `component <Component>` to which the Function has been assigned.
1174

1175
    name : str
1176
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
1177
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
1178

1179
    prefs : PreferenceSet or specification dict : Function.classPreferences
1180
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
1181
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
1182
        for details).
1183
    """
1184

1185
    componentName = LINEAR_COMBINATION_FUNCTION
1✔
1186

1187
    classPreferences = {
1✔
1188
        PREFERENCE_SET_NAME: 'LinearCombinationCustomClassPreferences',
1189
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
1190
    }
1191

1192
    class Parameters(TransformFunction.Parameters):
1✔
1193
        """
1194
            Attributes
1195
            ----------
1196

1197
                exponents
1198
                    see `exponents <LinearCombination.exponents>`
1199

1200
                    :default value: None
1201
                    :type:
1202

1203
                offset
1204
                    see `offset <LinearCombination.offset>`
1205

1206
                    :default value: 0.0
1207
                    :type: ``float``
1208

1209
                operation
1210
                    see `operation <LinearCombination.operation>`
1211

1212
                    :default value: `SUM`
1213
                    :type: ``str``
1214

1215
                scale
1216
                    see `scale <LinearCombination.scale>`
1217

1218
                    :default value: 1.0
1219
                    :type: ``float``
1220

1221
                weights
1222
                    see `weights <LinearCombination.weights>`
1223

1224
                    :default value: None
1225
                    :type:
1226
        """
1227
        operation = SUM
1✔
1228

1229
        weights = Parameter(None, modulable=True)
1✔
1230
        exponents = Parameter(None, modulable=True)
1✔
1231
        scale = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
1232
        offset = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
1233

1234
    @check_user_specified
1✔
1235
    @beartype
1✔
1236
    def __init__(self,
1✔
1237
                 default_variable=None,
1238
                 # weights:  Optional[ValidParamSpecType] = None,
1239
                 # exponents:  Optional[ValidParamSpecType] = None,
1240
                 weights=None,
1241
                 exponents=None,
1242
                 operation: Optional[Literal['sum', 'product', 'cross-entropy']] = None,
1243
                 scale=None,
1244
                 offset=None,
1245
                 params=None,
1246
                 owner=None,
1247
                 prefs:  Optional[ValidPrefSet] = None):
1248

1249
        super().__init__(
1✔
1250
            default_variable=default_variable,
1251
            weights=weights,
1252
            exponents=exponents,
1253
            operation=operation,
1254
            scale=scale,
1255
            offset=offset,
1256
            params=params,
1257
            owner=owner,
1258
            prefs=prefs,
1259
        )
1260

1261
    def _validate_variable(self, variable, context=None):
1✔
1262
        """Insure that all items of list or np.array in variable are of the same length
1263

1264
        Args:
1265
            variable:
1266
            context:
1267
        """
1268
        variable = super()._validate_variable(variable=variable, context=context)
1✔
1269
        # FIX: CONVERT TO AT LEAST 1D NP ARRAY IN INIT AND EXECUTE, SO ALWAYS NP ARRAY
1270
        # FIX: THEN TEST THAT SHAPES OF EVERY ELEMENT ALONG AXIS 0 ARE THE SAME
1271
        # FIX; PUT THIS IN DOCUMENTATION
1272
        if isinstance(variable, (list, np.ndarray)):
1!
1273
            if isinstance(variable, np.ndarray) and not variable.ndim:
1!
1274
                return variable
×
1275
            length = 0
1✔
1276
            for i in range(len(variable)):
1✔
1277
                if i == 0:
1✔
1278
                    continue
1✔
1279
                if isinstance(variable[i - 1], numbers.Number):
1✔
1280
                    old_length = 1
1✔
1281
                else:
1282
                    old_length = len(variable[i - 1])
1✔
1283
                if variable[i] is None:
1!
1284
                    owner_str = f"'{self.owner.name}' " if self.owner else ''
×
1285
                    raise FunctionError(f"One of the elements of variable for {self.__class__.__name__} function "
1286
                                        f"of {owner_str}is None; variable: {variable}.")
1287
                elif isinstance(variable[i], numbers.Number):
1✔
1288
                    new_length = 1
1✔
1289
                else:
1290
                    new_length = len(variable[i])
1✔
1291
                if old_length != new_length:
1!
1292
                    owner_str = f"'{self.owner.name }' " if self.owner else ''
×
1293
                    raise FunctionError(f"Length of all arrays in variable for {self.__class__.__name__} function "
1294
                                        f"of {owner_str}must be the same; variable: {variable}.")
1295
        return variable
1✔
1296

1297
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
1298
        """Validate weghts, exponents, scale and offset parameters
1299

1300
        Check that WEIGHTS and EXPONENTS are lists or np.arrays of numbers with length equal to variable
1301
        Check that SCALE and OFFSET are either scalars or np.arrays of numbers with length and shape equal to variable
1302

1303
        Note: the checks of compatibility with variable are only performed for validation calls during execution
1304
              (i.e., from check_args(), since during initialization or COMMAND_LINE assignment,
1305
              a parameter may be re-assigned before variable assigned during is known
1306
        """
1307

1308
        # FIX: MAKE SURE THAT IF OPERATION IS SUBTRACT OR DIVIDE, THERE ARE ONLY TWO VECTORS
1309

1310
        super()._validate_params(request_set=request_set,
1✔
1311
                                 target_set=target_set,
1312
                                 context=context)
1313

1314
        if WEIGHTS in target_set and target_set[WEIGHTS] is not None:
1✔
1315
            self._validate_parameter_spec(target_set[WEIGHTS], WEIGHTS, numeric_only=True)
1✔
1316
            if context.execution_phase & (ContextFlags.EXECUTING | ContextFlags.LEARNING):
1!
1317
                if np.array(target_set[WEIGHTS]).shape != self.defaults.variable.shape:
×
1318
                    raise FunctionError("Number of weights ({0}) is not equal to number of items in variable ({1})".
1319
                                        format(len(target_set[WEIGHTS]), len(self.defaults.variable)))
1320

1321
        if EXPONENTS in target_set and target_set[EXPONENTS] is not None:
1✔
1322
            self._validate_parameter_spec(target_set[EXPONENTS], EXPONENTS, numeric_only=True)
1✔
1323
            if context.execution_phase & (ContextFlags.PROCESSING | ContextFlags.LEARNING):
1!
1324
                if np.array(target_set[EXPONENTS]).shape != self.defaults.variable.shape:
×
1325
                    raise FunctionError("Number of exponents ({0}) does not equal number of items in variable ({1})".
1326
                                        format(len(target_set[EXPONENTS]), len(self.defaults.variable)))
1327

1328
        if SCALE in target_set and target_set[SCALE] is not None:
1!
1329
            scale = target_set[SCALE]
1✔
1330
            if isinstance(scale, numbers.Number):
1!
1331
                pass
×
1332
            elif isinstance(scale, np.ndarray):
1!
1333
                target_set[SCALE] = np.array(scale)
1✔
1334
            if context.execution_phase & (ContextFlags.PROCESSING | ContextFlags.LEARNING):
1!
1335
                if not is_numeric_scalar(scale):
×
1336
                    err_msg = "Scale is using Hadamard modulation but its shape and/or size (scale shape: {}, size:{})" \
×
1337
                              " do not match the variable being modulated (variable shape: {}, size: {})". \
1338
                        format(scale.shape, scale.size, self.defaults.variable.shape,
1339
                               self.defaults.variable.size)
1340
                    if len(self.defaults.variable.shape) == 0:
×
1341
                        raise FunctionError(err_msg)
1342
                    if (scale.shape != self.defaults.variable.shape) and \
×
1343
                            (scale.shape != self.defaults.variable.shape[1:]):
1344
                        raise FunctionError(err_msg)
1345

1346
        if OFFSET in target_set and target_set[OFFSET] is not None:
1!
1347
            offset = target_set[OFFSET]
1✔
1348
            if isinstance(offset, numbers.Number):
1!
1349
                pass
×
1350
            elif isinstance(offset, np.ndarray):
1!
1351
                target_set[OFFSET] = np.array(offset)
1✔
1352

1353
            if context.execution_phase & (ContextFlags.PROCESSING | ContextFlags.LEARNING):
1!
1354
                if not is_numeric_scalar(offset):
×
1355
                    err_msg = "Offset is using Hadamard modulation but its shape and/or size (offset shape: {}, size:{})" \
×
1356
                              " do not match the variable being modulated (variable shape: {}, size: {})". \
1357
                        format(offset.shape, offset.size, self.defaults.variable.shape,
1358
                               self.defaults.variable.size)
1359
                    if len(self.defaults.variable.shape) == 0:
×
1360
                        raise FunctionError(err_msg)
1361
                    if (offset.shape != self.defaults.variable.shape) and \
×
1362
                            (offset.shape != self.defaults.variable.shape[1:]):
1363
                        raise FunctionError(err_msg)
1364

1365
                        # if not operation:
1366
                        #     raise FunctionError("Operation param missing")
1367
                        # if not operation == self.Operation.SUM and not operation == self.Operation.PRODUCT:
1368
                        #     raise FunctionError("Operation param ({0}) must be Operation.SUM or Operation.PRODUCT".
1369
                        #     format(operation))
1370

1371
    def _function(self,
1✔
1372
                 variable=None,
1373
                 context=None,
1374
                 params=None,
1375
                 ):
1376
        """
1377

1378
        Arguments
1379
        ---------
1380

1381
        variable : 1d or 2d np.array : default class_defaults.variable
1382
           a single numeric array, or multiple arrays to be combined; if it is 2d, all arrays must have the same length.
1383

1384
        params : Dict[param keyword: param value] : default None
1385
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1386
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1387
            arguments of the constructor.
1388

1389

1390
        Returns
1391
        -------
1392

1393
        combined array : 1d array
1394
            the result of linearly combining the arrays in `variable <LinearCombination.variable>`.
1395

1396
        """
1397
        weights = self._get_current_parameter_value(WEIGHTS, context)
1✔
1398
        exponents = self._get_current_parameter_value(EXPONENTS, context)
1✔
1399
        # if self.initialization_status == ContextFlags.INITIALIZED:
1400
        #     if weights is not None and weights.shape != variable.shape:
1401
        #         weights = weights.reshape(variable.shape)
1402
        #     if exponents is not None and exponents.shape != variable.shape:
1403
        #         exponents = exponents.reshape(variable.shape)
1404
        operation = self._get_current_parameter_value(OPERATION, context)
1✔
1405
        scale = self._get_current_parameter_value(SCALE, context)
1✔
1406
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
1407

1408
        # QUESTION:  WHICH IS LESS EFFICIENT:
1409
        #                A) UNECESSARY ARITHMETIC OPERATIONS IF SCALE AND/OR OFFSET ARE 1.0 AND 0, RESPECTIVELY?
1410
        #                   (DOES THE COMPILER KNOW NOT TO BOTHER WITH MULT BY 1 AND/OR ADD 0?)
1411
        #                B) EVALUATION OF IF STATEMENTS TO DETERMINE THE ABOVE?
1412
        # IMPLEMENTATION NOTE:  FOR NOW, ASSUME B) ABOVE, AND ASSIGN DEFAULT "NULL" VALUES TO offset AND scale
1413
        if offset is None:
1!
1414
            offset = 0.0
×
1415

1416
        if scale is None:
1!
1417
            scale = 1.0
×
1418

1419
        # IMPLEMENTATION NOTE: CONFIRM: SHOULD NEVER OCCUR, AS _validate_variable NOW ENFORCES 2D np.ndarray
1420
        # If variable is 0D or 1D:
1421
        if np_array_less_than_2d(variable):
1✔
1422
            return self.convert_output_type((variable * scale) + offset)
1✔
1423

1424
        # FIX FOR EFFICIENCY: CHANGE THIS AND WEIGHTS TO TRY/EXCEPT // OR IS IT EVEN NECESSARY, GIVEN VALIDATION ABOVE??
1425
        # Apply exponents if they were specified
1426
        if exponents is not None:
1✔
1427
            # Avoid divide by zero warning:
1428
            #    make sure there are no zeros for an element that is assigned a negative exponent
1429
            # Allow during initialization because 0s are common in default_variable argument
1430
            if self.is_initializing:
1✔
1431
                with np.errstate(divide='raise'):
1✔
1432
                    try:
1✔
1433
                        variable = variable ** exponents
1✔
1434
                    except FloatingPointError:
1✔
1435
                        variable = np.ones_like(variable)
1✔
1436
            else:
1437
                # if this fails with FloatingPointError it should not be caught outside of initialization
1438
                variable = variable ** exponents
1✔
1439

1440
        # Apply weights if they were specified
1441
        if weights is not None:
1✔
1442
            variable = variable * weights
1✔
1443

1444
        # CW 3/19/18: a total hack, e.g. to make scale=[4.] turn into scale=4. Used b/c the `scale` ParameterPort
1445
        # changes scale's format: e.g. if you write c = pnl.LinearCombination(scale = 4), print(c.scale) returns [4.]
1446
        # Don't use try_extract_0d_array_item because that will only
1447
        # handle 0d arrays, not 1d.
1448
        try:
1✔
1449
            scale = scale.item()
1✔
1450
        except (AttributeError, ValueError):
1✔
1451
            pass
1✔
1452
        try:
1✔
1453
            offset = offset.item()
1✔
1454
        except (AttributeError, ValueError):
1✔
1455
            pass
1✔
1456

1457
        # CALCULATE RESULT USING RELEVANT COMBINATION OPERATION AND MODULATION
1458
        if operation == SUM:
1✔
1459
            combination = np.sum(variable, axis=0)
1✔
1460
        elif operation == PRODUCT:
1✔
1461
            combination = np.prod(variable, axis=0)
1✔
1462
        elif operation == CROSS_ENTROPY:
1✔
1463
            v1 = variable[0]
1✔
1464
            v2 = variable[1]
1✔
1465
            both_zero = np.logical_and(v1 == 0, v2 == 0)
1✔
1466
            combination = v1 * np.where(both_zero, 0.0, np.log(v2, where=np.logical_not(both_zero)))
1✔
1467
        else:
1468
            raise FunctionError("Unrecognized operator ({0}) for LinearCombination function".
1469
                                format(operation.self.Operation.SUM))
1470
        if isinstance(scale, numbers.Number):
1✔
1471
            # scalar scale
1472
            product = combination * scale
1✔
1473
        else:
1474
            # Hadamard scale
1475
            product = np.prod([combination, scale], axis=0)
1✔
1476

1477
        if isinstance(offset, numbers.Number):
1✔
1478
            # scalar offset
1479
            result = product + offset
1✔
1480
        else:
1481
            # Hadamard offset
1482
            result = np.sum([product, offset], axis=0)
1✔
1483

1484
        return self.convert_output_type(result)
1✔
1485

1486
    @handle_external_context()
1✔
1487
    def derivative(self, input=None, output=None, covariates=None, context=None):
1✔
1488
        """
1489
        derivative(input)
1490

1491
        Derivative of `function <LinearCombination._function>` at **input**.
1492

1493
        Arguments
1494
        ---------
1495

1496
        output : 1d np.array : default class_defaults.variable[0]
1497
            value of the input to the Linear transform at which derivative is to be taken.
1498
           a single numeric array or multiple arrays being combined, and at which derivative is to be taken.
1499

1500
           .. technical_note::
1501
              output arg is used for consistency with other derivatives used by BackPropagation, and is ignored.
1502

1503
        covariates : 2d np.array : default class_defaults.variable[1:]
1504
            the input(s) to the LinearCombination function other than the one for which the derivative is being
1505
            computed;  these are used to calculate the Jacobian of the LinearCombination function.
1506

1507
        Returns
1508
        -------
1509

1510
        Scale :  number (if input is 1d) or array (if input is 2d)
1511

1512
        """
1513
        if covariates is None or self.operation == SUM:
1✔
1514
            jacobian = self._get_current_parameter_value(SCALE, context)
1✔
1515
        else:
1516
            jacobian = np.prod(np.vstack(covariates), axis=0)  * self._get_current_parameter_value(SCALE, context)
1✔
1517

1518
        return np.eye(len(output)) * jacobian
1✔
1519

1520
    def _get_input_struct_type(self, ctx):
1✔
1521
        # FIXME: Workaround a special case of simple array.
1522
        #        It should just pass through to modifiers, which matches what
1523
        #        single element 2d array does
1524
        default_var = np.atleast_2d(self.defaults.variable)
1✔
1525
        return ctx.convert_python_struct_to_llvm_ir(default_var)
1✔
1526

1527
    def _gen_llvm_combine(self, builder, index, ctx, vi, vo, params):
1✔
1528
        scale = self._gen_llvm_load_param(ctx, builder, params, SCALE, index, 1.0)
1✔
1529
        offset = self._gen_llvm_load_param(ctx, builder, params, OFFSET, index, -0.0)
1✔
1530

1531
        # assume operation does not change dynamically
1532
        operation = self.parameters.operation.get()
1✔
1533
        if operation == SUM:
1✔
1534
            val = ctx.float_ty(-0.0)
1✔
1535
            comb_op = "fadd"
1✔
1536
        elif operation == PRODUCT:
1!
1537
            val = ctx.float_ty(1.0)
1✔
1538
            comb_op = "fmul"
1✔
1539
        elif operation == CROSS_ENTROPY:
×
1540
            raise FunctionError(f"LinearCombination Function does not (yet) support CROSS_ENTROPY operation.")
1541
            # FIX: THIS NEEDS TO BE REPLACED TO GENERATE A VECTOR WITH HADAMARD CROSS-ENTROPY OF vi AND vo
1542
            # ptr1 = builder.gep(vi, [index])
1543
            # ptr2 = builder.gep(vo, [index])
1544
            # val1 = builder.load(ptr1)
1545
            # val2 = builder.load(ptr2)
1546
            # log_f = ctx.get_builtin("log", [ctx.float_ty])
1547
            # log = builder.call(log_f, [val2])
1548
            # prod = builder.fmul(val1, log)
1549
        else:
1550
            assert False, "Unknown operation: {}".format(operation)
1551

1552
        val_p = builder.alloca(val.type, name="combined_result")
1✔
1553
        builder.store(val, val_p)
1✔
1554

1555
        pow_f = ctx.get_builtin("pow", [ctx.float_ty])
1✔
1556

1557
        with pnlvm.helpers.array_ptr_loop(builder, vi, "combine") as (b, idx):
1✔
1558
            ptri = b.gep(vi, [ctx.int32_ty(0), idx, index])
1✔
1559
            in_val = b.load(ptri)
1✔
1560

1561
            exponent = self._gen_llvm_load_param(ctx, b, params, EXPONENTS,
1✔
1562
                                                 idx, 1.0)
1563
            # Vector of vectors (even 1-element vectors)
1564
            if isinstance(exponent.type, pnlvm.ir.ArrayType):
1✔
1565
                assert len(exponent.type) == 1 # FIXME: Add support for matrix weights
1✔
1566
                exponent = b.extract_value(exponent, [0])
1✔
1567
            # FIXME: Remove this micro-optimization,
1568
            #        it should be handled by the compiler
1569
            if not isinstance(exponent, pnlvm.ir.Constant) or exponent.constant != 1.0:
1✔
1570
                in_val = b.call(pow_f, [in_val, exponent])
1✔
1571

1572
            weight = self._gen_llvm_load_param(ctx, b, params, WEIGHTS,
1✔
1573
                                               idx, 1.0)
1574
            # Vector of vectors (even 1-element vectors)
1575
            if isinstance(weight.type, pnlvm.ir.ArrayType):
1✔
1576
                assert len(weight.type) == 1 # FIXME: Add support for matrix weights
1✔
1577
                weight = b.extract_value(weight, [0])
1✔
1578

1579
            in_val = b.fmul(in_val, weight)
1✔
1580

1581
            val = b.load(val_p)
1✔
1582
            val = getattr(b, comb_op)(val, in_val)
1✔
1583
            b.store(val, val_p)
1✔
1584

1585
        val = builder.load(val_p)
1✔
1586
        val = builder.fmul(val, scale)
1✔
1587
        val = builder.fadd(val, offset)
1✔
1588

1589
        ptro = builder.gep(vo, [ctx.int32_ty(0), index])
1✔
1590
        builder.store(val, ptro)
1✔
1591

1592
    def _gen_pytorch_fct(self, device, context=None):
1✔
1593
        weights = self._get_pytorch_fct_param_value('weights', device, context)
1✔
1594
        if weights is not None:
1✔
1595
            weights = torch.tensor(weights, device=device).double()
1✔
1596
        # Note: the first dimension of x is batch, aggregate over the second dimension
1597
        if self.operation == SUM:
1✔
1598
            if weights is not None:
1✔
1599
                return lambda x: torch.sum(x * weights, 1)
1✔
1600
            else:
1601
                return lambda x: torch.sum(x, 1)
1✔
1602
        elif self.operation == PRODUCT:
1!
1603
            if weights is not None:
1!
1604
                return lambda x: torch.prod(x * weights, 1)
×
1605
            else:
1606
                return lambda x: torch.prod(x, 1)
1✔
1607
        else:
1608
            from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError
×
1609
            raise AutodiffCompositionError(f"The 'operation' parameter of {function.componentName} is not supported "
1610
                                           f"by AutodiffComposition; use 'SUM' or 'PRODUCT' if possible.")
1611

1612

1613
# **********************************************************************************************************************
1614
#                                                 MatrixTransform
1615
# **********************************************************************************************************************
1616

1617
class MatrixTransform(TransformFunction):  # -------------------------------------------------------------------------------
1✔
1618
    """
1619
    MatrixTransform(            \
1620
         default_variable,      \
1621
         matrix=None,           \
1622
         operation=DOT_PRODUCT, \
1623
         normalize=False,       \
1624
         params=None,           \
1625
         owner=None,            \
1626
         name=None,             \
1627
         prefs=None             \
1628
         )
1629

1630
    .. _MatrixTransform:
1631

1632
    Matrix transform of `variable <MatrixTransform.variable>`.
1633

1634
    `function <MatrixTransform._function>` returns a matrix transform of `variable <MatrixTransform.variable>`
1635
     based on the **operation** argument.
1636

1637
    **operation** = *DOT_PRODUCT*:
1638

1639
        Returns the dot (inner) product of `variable <MatrixTransform.variable>` and `matrix <MatrixTransform.matrix>`:
1640

1641
        .. math::
1642
            {variable} \\bullet |matrix|
1643

1644
        If **normalize** =True, the result is normalized by the product of the norms of the variable and matrix:
1645

1646
        .. math::
1647
            \\frac{variable \\bullet matrix}{\\|variable\\| \\cdot \\|matrix\\|}
1648

1649
        .. note::
1650
           For **normalize** =True, the result is the same as the cosine of the angle between pairs of vectors.
1651

1652
    **operation** = *L0*:
1653

1654
        Returns the absolute value of the difference between `variable <MatrixTransform.variable>` and `matrix
1655
        <MatrixTransform.matrix>`:
1656

1657
        .. math::
1658
            |variable - matrix|
1659

1660
        If **normalize** =True, the result is normalized by the norm of the sum of differences between the variable and
1661
        matrix, which is then subtracted from 1:
1662

1663
        .. math::
1664
            1 - \\frac{|variable - matrix|}{\\|variable - matrix\\|}
1665

1666
        .. note::
1667
           For **normalize** =True, the result has the same effect as the normalized *DOT_PRODUCT* operation,
1668
           with more similar pairs of vectors producing larger values (closer to 1).
1669

1670
        .. warning::
1671
           For **normalize** =False, the result is smaller (closer to 0) for more similar pairs of vectors,
1672
           which is **opposite** the effect of the *DOT_PRODUCT* and normalized *L0* operations.  If the desired
1673
           result is that more similar pairs of vectors produce larger values, set **normalize** =True or
1674
           use the *DOT_PRODUCT* operation.
1675

1676

1677
    COMMENT:  [CONVERT TO FIGURE]
1678
        ----------------------------------------------------------------------------------------------------------
1679
        MATRIX FORMAT <shape: (3,5)>
1680
                                         INDICES:
1681
                                     Output elements:
1682
                              0       1       2       3       4
1683
                         0  [0,0]   [0,1]   [0,2]   [0,3]   [0,4]
1684
        Input elements:  1  [1,0]   [1,1]   [1,2]   [1,3]   [1,4]
1685
                         2  [2,0]   [2,1]   [2,2]   [2,3]   [2,4]
1686

1687
        matrix.shape => (input/rows, output/cols)
1688

1689
        ----------------------------------------------------------------------------------------------------------
1690
        ARRAY FORMAT
1691
                                                                            INDICES
1692
                                          [ [      Input 0 (row0)       ], [       Input 1 (row1)      ]... ]
1693
                                          [ [ out0,  out1,  out2,  out3 ], [ out0,  out1,  out2,  out3 ]... ]
1694
        matrix[input/rows, output/cols]:  [ [ row0,  row0,  row0,  row0 ], [ row1,  row1,  row1,  row1 ]... ]
1695
                                          [ [ col0,  col1,  col2,  col3 ], [ col0,  col1,  col2,  col3 ]... ]
1696
                                          [ [[0,0], [0,1], [0,2], [0,3] ], [[1,0], [1,1], [1,2], [1,3] ]... ]
1697

1698
        ----------------------------------------------------------------------------------------------------------
1699
    COMMENT
1700

1701

1702
    Arguments
1703
    ---------
1704

1705
    variable : list or 1d array : default class_defaults.variable
1706
        specifies a template for the value to be transformed; length must equal the number of rows of `matrix
1707
        <MatrixTransform.matrix>`.
1708

1709
    matrix : number, list, 1d or 2d np.ndarray, function, or matrix keyword : default IDENTITY_MATRIX
1710
        specifies matrix used to transform `variable <MatrixTransform.variable>`
1711
        (see `matrix <MatrixTransform.matrix>` for specification details).
1712

1713
        When MatrixTransform is the `function <Projection_Base.function>` of a projection:
1714

1715
            - the matrix specification must be compatible with the variables of the `sender <Projection_Base.sender>`
1716
              and `receiver <Projection_Base.receiver>`
1717

1718
            - a matrix keyword specification generates a matrix based on the sender and receiver shapes
1719

1720
        When MatrixTransform is instantiated on its own, or as the function of a `Mechanism <Mechanism>` or `Port`:
1721

1722
            - the matrix specification must be compatible with the function's own `variable <MatrixTransform.variable>`
1723

1724
            - if matrix is not specified, a square identity matrix is generated based on the number of columns in
1725
              `variable <MatrixTransform.variable>`
1726

1727
            - matrix keywords are not valid matrix specifications
1728

1729
    operation : DOT_PRODUCT or L0 : default DOT_PRODUCT
1730
        specifies whether to take the dot product or difference of `variable <MatrixTransform.variable>`
1731
        and `matrix <MatrixTransform.matrix>`.
1732

1733
    normalize : bool : default False
1734
        specifies whether to normalize the result of `function <LinearCombination.function>` by dividing it by the
1735
        norm of `variable <MatrixTransform.variable>` x the norm of `matrix <MatrixTransform.matrix>`;  this cannot
1736
        be used if `variable <MatrixTransform.variable>` is a scalar (i.e., has only one element), and **operation**
1737
        is set to *L0* (since it is not needed, and can produce a divide by zero error).
1738

1739
    params : Dict[param keyword: param value] : default None
1740
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
1741
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
1742
        arguments of the constructor.
1743

1744
    owner : Component
1745
        `component <Component>` to which to assign the Function.
1746

1747
    name : str : default see `name <Function.name>`
1748
        specifies the name of the Function.
1749

1750
    prefs : PreferenceSet or specification dict : default Function.classPreferences
1751
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
1752

1753
    Attributes
1754
    ----------
1755

1756
    variable : 1d array
1757
        contains value to be transformed.
1758

1759
    matrix : 2d array
1760
        matrix used to transform `variable <MatrixTransform.variable>`.
1761
        Can be specified as any of the following:
1762
            * number - used as the filler value for all elements of the :keyword:`matrix` (call to np.fill);
1763
            * list of arrays, 2d array - assigned as the value of :keyword:`matrix`;
1764
            * matrix keyword - see `MatrixKeywords` for list of options.
1765
        Rows correspond to elements of the input array (outer index), and
1766
        columns correspond to elements of the output array (inner index).
1767

1768
    operation : DOT_PRODUCT or L0 : default DOT_PRODUCT
1769
        determines whether dot product or difference of `variable <MatrixTransform.variable>` and `matrix
1770
        <MatrixTransform.matrix>` is taken.  If the length of `variable <MatrixTransform.variable>` is greater
1771
        than 1 and L0 is specified, the `variable <MatrixTransform.variable>` array is subtracted from each
1772
        array of `matrix <MatrixTransform.matrix>` and the resulting array is summed, to produce the corresponding
1773
        element of the array returned by the function.
1774

1775
    normalize : bool
1776
        determines whether the result of `function <LinearCombination.function>` is normalized, by dividing it by the
1777
        norm of `variable <MatrixTransform.variable>` x the norm of `matrix <MatrixTransform.matrix>`.
1778

1779
    owner : Component
1780
        `component <Component>` to which the Function has been assigned.
1781

1782
    name : str
1783
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
1784
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
1785

1786
    prefs : PreferenceSet or specification dict : Function.classPreferences
1787
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
1788
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `PreferenceSet`
1789
        for details).
1790
    """
1791

1792
    componentName = MATRIX_TRANSFORM_FUNCTION
1✔
1793

1794
    DEFAULT_FILLER_VALUE = 0
1✔
1795

1796
    _model_spec_generic_type_name = 'onnx::MatMul'
1✔
1797

1798
    class Parameters(TransformFunction.Parameters):
1✔
1799
        """
1800
            Attributes
1801
            ----------
1802

1803
                matrix
1804
                    see `matrix <MatrixTransform.matrix>`
1805

1806
                    :default value: None
1807
                    :type:
1808

1809
                operation
1810
                    see `operation <MatrixTransform.operation>`
1811

1812
                    :default value: DOT_PRODUCT
1813
                    :type: bool
1814

1815
                normalize
1816
                    see `normalize <MatrixTransform.normalize>`
1817

1818
                    :default value: False
1819
                    :type: bool
1820
        """
1821
        variable = Parameter(np.array([0]), read_only=True, pnl_internal=True, constructor_argument='default_variable', mdf_name='A')
1✔
1822
        matrix = Parameter(None, modulable=True, mdf_name='B')
1✔
1823
        operation = Parameter(DOT_PRODUCT, stateful=False)
1✔
1824
        normalize = Parameter(False)
1✔
1825

1826
    @check_user_specified
1✔
1827
    @beartype
1✔
1828
    def __init__(self,
1✔
1829
                 default_variable=None,
1830
                 matrix=None,
1831
                 operation=None,
1832
                 normalize=None,
1833
                 params=None,
1834
                 owner=None,
1835
                 prefs:  Optional[ValidPrefSet] = None):
1836

1837
        # Note: this calls _validate_variable and _validate_params which are overridden below;
1838
        #       the latter implements the matrix if required
1839
        # super(MatrixTransform, self).__init__(default_variable=default_variable,
1840
        super().__init__(
1✔
1841
            default_variable=default_variable,
1842
            matrix=matrix,
1843
            operation=operation,
1844
            normalize=normalize,
1845
            params=params,
1846
            owner=owner,
1847
            prefs=prefs,
1848
        )
1849

1850
        self.parameters.matrix.set(
1✔
1851
            self.instantiate_matrix(self.parameters.matrix.get()),
1852
            skip_log=True,
1853
        )
1854

1855

1856
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
1857
        """Validate params and assign to targets
1858

1859
        This overrides the class method, to perform more detailed type checking (see explanation in class method).
1860
        Note: this method (or the class version) is called only if the parameter_validation attribute is `True`
1861

1862
        :param request_set: (dict) - params to be validated
1863
        :param target_set: (dict) - destination of validated params
1864
        :param context: (str)
1865
        :return none:
1866
        """
1867

1868
        super()._validate_params(request_set, target_set, context)
1✔
1869

1870
        param_set = target_set
1✔
1871
        # proxy for checking whether the owner is a projection
1872
        if hasattr(self.owner, 'receiver'):
1✔
1873
            sender = self.defaults.variable
1✔
1874
            sender_len = np.size(np.atleast_2d(self.defaults.variable), 1)
1✔
1875

1876
            # Check for and validate receiver first, since it may be needed to validate and/or construct the matrix
1877
            # First try to get receiver from specification in params
1878
            if RECEIVER in param_set:
1!
1879
                self.receiver = param_set[RECEIVER]
×
1880
                # Check that specification is a list of numbers or an array
1881
                if ((isinstance(self.receiver, list) and all(
×
1882
                        isinstance(elem, numbers.Number) for elem in self.receiver)) or
1883
                        isinstance(self.receiver, np.ndarray)):
1884
                    self.receiver = np.atleast_1d(self.receiver)
×
1885
                else:
1886
                    raise FunctionError("receiver param ({0}) for {1} must be a list of numbers or an np.array".
1887
                                        format(self.receiver, self.name))
1888
            # No receiver, so use sender as template (assuming square -- e.g., IDENTITY -- matrix)
1889
            else:
1890
                if (self.owner and self.owner.prefs.verbosePref) or self.prefs.verbosePref:
1!
1891
                    print("Identity matrix requested but 'receiver' not specified; sender length ({0}) will be used".
×
1892
                          format(sender_len))
1893
                self.receiver = param_set[RECEIVER] = sender
1✔
1894

1895
            receiver_len = len(self.receiver)
1✔
1896

1897
            # Check rest of params
1898
            message = ""
1✔
1899
            for param_name, param_value in param_set.items():
1✔
1900

1901
                # receiver param already checked above
1902
                if param_name == RECEIVER:
1✔
1903
                    continue
1✔
1904

1905
                # Not currently used here
1906
                if param_name in function_keywords:
1✔
1907
                    continue
1✔
1908

1909
                if param_name == HAS_INITIALIZERS:
1✔
1910
                    continue
1✔
1911

1912
                # matrix specification param
1913
                elif param_name == MATRIX:
1!
1914

1915
                    # A number (to be used as a filler), so OK
1916
                    if isinstance(param_value, numbers.Number):
1!
1917
                        continue
×
1918

1919
                    # np.matrix or np.ndarray provided, so validate that it is numeric and check dimensions
1920
                    elif isinstance(param_value, (list, np.ndarray, np.matrix)):
1!
1921
                        # get dimensions specified by:
1922
                        #   variable (sender): width/cols/outer index
1923
                        #   kwReceiver param: height/rows/inner index
1924

1925
                        weight_matrix = np.atleast_2d(param_value)
1✔
1926
                        if 'U' in repr(weight_matrix.dtype):
1✔
1927
                            raise FunctionError("Non-numeric entry in MATRIX "
1928
                                                "specification ({}) for the {} "
1929
                                                "function of {}".format(param_value,
1930
                                                                        self.name,
1931
                                                                        self.owner_name))
1932

1933
                        if weight_matrix.ndim != 2:
1✔
1934
                            raise FunctionError("The matrix provided for the {} function of {} must be 2d (it is {}d".
1935
                                                format(weight_matrix.ndim, self.name, self.owner_name))
1936

1937
                        matrix_rows = weight_matrix.shape[0]
1✔
1938
                        matrix_cols = weight_matrix.shape[1]
1✔
1939

1940
                        # Check that number of rows equals length of sender vector (variable)
1941
                        if matrix_rows != sender_len:
1✔
1942
                            raise FunctionError("The number of rows ({}) of the "
1943
                                                "matrix provided for {} function "
1944
                                                "of {} does not equal the length "
1945
                                                "({}) of the sender vector "
1946
                                                "(variable)".format(matrix_rows,
1947
                                                                    self.name,
1948
                                                                    self.owner_name,
1949
                                                                    sender_len))
1950

1951
                    # Auto, full or random connectivity matrix requested (using keyword):
1952
                    # Note:  assume that these will be properly processed by caller
1953
                    #        (e.g., MappingProjection._instantiate_receiver)
1954
                    elif is_matrix_keyword(param_value):
×
1955
                        continue
×
1956

1957
                    # Identity matrix requested (using keyword), so check send_len == receiver_len
1958
                    elif param_value in {IDENTITY_MATRIX, HOLLOW_MATRIX}:
×
1959
                        # Receiver length doesn't equal sender length
1960
                        if not (self.receiver.shape == sender.shape and self.receiver.size == sender.size):
×
1961
                            # if self.owner.prefs.verbosePref:
1962
                            #     print ("Identity matrix requested, but length of receiver ({0})"
1963
                            #            " does not match length of sender ({1});  sender length will be used".
1964
                            #            format(receiver_len, sender_len))
1965
                            # # Set receiver to sender
1966
                            # param_set[kwReceiver] = sender
1967
                            raise FunctionError("{} requested for the {} function of {}, "
1968
                                                "but length of receiver ({}) does not match length of sender ({})".
1969
                                                format(param_value, self.name, self.owner_name, receiver_len,
1970
                                                       sender_len))
1971
                        continue
×
1972

1973
                    # list used to describe matrix, so convert to 2D array and pass to validation of matrix below
1974
                    elif isinstance(param_value, list):
×
1975
                        try:
×
1976
                            param_value = np.atleast_2d(param_value)
×
1977
                        except (ValueError, TypeError) as error_msg:
×
1978
                            raise FunctionError(
1979
                                "Error in list specification ({}) of matrix for the {} function of {}: {})".
1980
                                    # format(param_value, self.__class__.__name__, error_msg))
1981
                                    format(param_value, self.name, self.owner_name, error_msg))
1982

1983
                    # string used to describe matrix, so convert to np.array and pass to validation of matrix below
1984
                    elif isinstance(param_value, str):
×
1985
                        try:
×
1986
                            param_value = np.atleast_2d(param_value)
×
1987
                        except (ValueError, TypeError) as error_msg:
×
1988
                            raise FunctionError("Error in string specification ({}) of the matrix "
1989
                                                "for the {} function of {}: {})".
1990
                                                # format(param_value, self.__class__.__name__, error_msg))
1991
                                                format(param_value, self.name, self.owner_name, error_msg))
1992

1993
                    # function so:
1994
                    # - assume it uses random.rand()
1995
                    # - call with two args as place markers for cols and rows
1996
                    # -  validate that it returns an array
1997
                    elif isinstance(param_value, types.FunctionType):
×
1998
                        test = param_value(1, 1)
×
1999
                        if not isinstance(test, np.ndarray):
×
2000
                            raise FunctionError("A function is specified for the matrix of the {} function of {}: {}) "
2001
                                                "that returns a value ({}) that is not an array".
2002
                                                # format(param_value, self.__class__.__name__, test))
2003
                                                format(self.name, self.owner_name, param_value, test))
2004

2005
                    elif param_value is None:
×
2006
                        raise FunctionError("TEMP ERROR: param value is None.")
2007

2008
                    else:
2009
                        raise FunctionError("Value of {} param ({}) for the {} function of {} "
2010
                                            "must be a matrix, a number (for filler), or a matrix keyword ({})".
2011
                                            format(param_name,
2012
                                                   param_value,
2013
                                                   self.name,
2014
                                                   self.owner_name,
2015
                                                   MATRIX_KEYWORD_NAMES))
2016
                else:
2017
                    continue
×
2018

2019
        # owner is a mechanism, state
2020
        # OR function was defined on its own (no owner)
2021
        else:
2022
            if MATRIX in param_set:
1!
2023
                param_value = param_set[MATRIX]
1✔
2024

2025
                # numeric value specified; verify that it is compatible with variable
2026
                if isinstance(param_value, (float, list, np.ndarray, np.matrix)):
1✔
2027
                    param_size = np.size(np.atleast_2d(param_value), 0)
1✔
2028
                    param_shape = np.shape(np.atleast_2d(param_value))
1✔
2029
                    variable_size = np.size(np.atleast_2d(self.defaults.variable),1)
1✔
2030
                    variable_shape = np.shape(np.atleast_2d(self.defaults.variable))
1✔
2031
                    if param_size != variable_size:
1✔
2032
                        raise FunctionError("Specification of matrix and/or default_variable for {} is not valid. The "
2033
                                            "shapes of variable {} and matrix {} are not compatible for multiplication".
2034
                                            format(self.name, variable_shape, param_shape))
2035

2036
                # keyword matrix specified - not valid outside of a projection
2037
                elif is_matrix_keyword(param_value):
1✔
2038
                    raise FunctionError("{} is not a valid specification for the matrix parameter of {}. Keywords "
2039
                                        "may only be used to specify the matrix parameter of a Projection's "
2040
                                        "MatrixTransform function. When the MatrixTransform function is implemented in a "
2041
                                        "mechanism, such as {}, the correct matrix cannot be determined from a "
2042
                                        "keyword. Instead, the matrix must be fully specified as a float, list, "
2043
                                        "np.ndarray".
2044
                                        format(param_value, self.name, self.owner.name))
2045

2046
                # The only remaining valid option is matrix = None (sorted out in instantiate_attribs_before_fn)
2047
                elif param_value is not None:
1✔
2048
                    raise FunctionError("Value of the matrix param ({}) for the {} function of {} "
2049
                                        "must be a matrix, a number (for filler), or a matrix keyword ({})".
2050
                                        format(param_value,
2051
                                               self.name,
2052
                                               self.owner_name,
2053
                                               MATRIX_KEYWORD_NAMES))
2054

2055
    def _instantiate_attributes_before_function(self, function=None, context=None):
1✔
2056
        # replicates setting of receiver in _validate_params
2057
        if isinstance(self.owner, Projection):
1✔
2058
            self.receiver = copy_parameter_value(self.defaults.variable)
1✔
2059

2060
        matrix = self.parameters.matrix._get(context)
1✔
2061

2062
        if matrix is None and not hasattr(self.owner, "receiver"):
1✔
2063
            variable_length = np.size(np.atleast_2d(self.defaults.variable), 1)
1✔
2064
            matrix = np.identity(variable_length)
1✔
2065
        self.parameters.matrix._set(self.instantiate_matrix(matrix), context)
1✔
2066

2067
    def instantiate_matrix(self, specification, context=None):
1✔
2068
        """Implements matrix indicated by specification
2069

2070
         Specification is derived from MATRIX param (passed to self.__init__ or self._function)
2071

2072
         Specification (validated in _validate_params):
2073
            + single number (used to fill self.matrix)
2074
            + matrix keyword (see get_matrix)
2075
            + 2D list or np.ndarray of numbers
2076

2077
        :return matrix: (2D list)
2078
        """
2079
        from psyneulink.core.components.projections.projection import Projection
1✔
2080
        if isinstance(self.owner, Projection):
1✔
2081
            # Matrix provided (and validated in _validate_params); convert to array
2082
            if isinstance(specification, np.matrix):
1!
2083
                return np.array(specification)
×
2084

2085
            sender = copy_parameter_value(self.defaults.variable)
1✔
2086
            sender_len = sender.shape[0]
1✔
2087
            try:
1✔
2088
                receiver = self.receiver
1✔
2089
            except:
×
2090
                raise FunctionError("Can't instantiate matrix specification ({}) for the {} function of {} "
2091
                                    "since its receiver has not been specified".
2092
                                    format(specification, self.name, self.owner_name))
2093
                # receiver = sender
2094
            receiver_len = receiver.shape[0]
1✔
2095

2096
            matrix = get_matrix(specification, rows=sender_len, cols=receiver_len, context=context)
1✔
2097

2098
            # This should never happen (should have been picked up in validate_param or above)
2099
            if matrix is None:
1✔
2100
                raise FunctionError("MATRIX param ({}) for the {} function of {} must be a matrix, a function "
2101
                                    "that returns one, a matrix specification keyword ({}), or a number (filler)".
2102
                                    format(specification, self.name, self.owner_name, MATRIX_KEYWORD_NAMES))
2103
            else:
2104
                return matrix
1✔
2105
        else:
2106
            return np.array(specification)
1✔
2107

2108

2109
    def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
1✔
2110
        # Restrict to 1d arrays
2111
        if self.defaults.variable.ndim != 1:
1✔
2112
            warnings.warn("Shape mismatch: {} (in {}) got 2D input: {}".format(
1✔
2113
                          self, self.owner, self.defaults.variable),
2114
                          pnlvm.PNLCompilerWarning)
2115
            arg_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
2116
        if self.defaults.value.ndim != 1:
1✔
2117
            warnings.warn("Shape mismatch: {} (in {}) has 2D output: {}".format(
1✔
2118
                          self, self.owner, self.defaults.value),
2119
                          pnlvm.PNLCompilerWarning)
2120
            arg_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
2121

2122
        matrix = ctx.get_param_or_state_ptr(builder, self, MATRIX, param_struct_ptr=params, state_struct_ptr=state)
1✔
2123
        normalize = ctx.get_param_or_state_ptr(builder, self, NORMALIZE, param_struct_ptr=params)
1✔
2124

2125
        # Convert array pointer to pointer to the fist element
2126
        matrix = builder.gep(matrix, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
2127
        vec_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
2128
        vec_out = builder.gep(arg_out, [ctx.int32_ty(0), ctx.int32_ty(0)])
1✔
2129

2130
        input_length = ctx.int32_ty(arg_in.type.pointee.count)
1✔
2131
        output_length = ctx.int32_ty(arg_out.type.pointee.count)
1✔
2132

2133
        # if normalize:
2134
        #     if vec_in is not zeros:
2135
        #     # FIX: NORMALIZE vec_in and matrix here
2136
        #         vec_in_sum = fsum(builder, vec_in)
2137
        #         vec_in = fdiv(builder, vec_in, vec_in_sum)
2138
        #     if matrix is not zeros:
2139
        #     # FIX: NORMALIZE matrix here
2140

2141
        builtin = ctx.import_llvm_function("__pnl_builtin_vxm")
1✔
2142
        builder.call(builtin, [vec_in, matrix, input_length, output_length, vec_out])
1✔
2143
        return builder
1✔
2144

2145
    def _gen_pytorch_fct(self, device, context=None):
1✔
2146
        operation = self._get_pytorch_fct_param_value('operation', device, context)
1✔
2147
        normalize = self._get_pytorch_fct_param_value('normalize', device, context)
1✔
2148

2149
        def dot_product_with_normalization(vector, matrix):
1✔
2150
            if torch.any(vector):
1!
2151
                vector = vector / torch.norm(vector)
1✔
2152
            if torch.any(matrix):
1!
2153
                matrix = matrix / torch.norm(matrix)
1✔
2154
            return torch.matmul(vector, matrix)
1✔
2155

2156
        def diff_with_normalization(vector, matrix):
1✔
2157
            normalize = torch.sum(torch.abs(vector - matrix))
×
NEW
2158
            return torch.sum((1 - torch.abs(vector - matrix) / normalize), axis=0, keepdim=True)
×
2159

2160
        if operation is DOT_PRODUCT:
1!
2161
            if normalize:
1✔
2162
                return dot_product_with_normalization
1✔
2163
            else:
2164
                return lambda x, y : torch.matmul(x, y)
1✔
2165

2166
        elif operation is L0:
×
2167
            if normalize:
×
2168
                return diff_with_normalization
×
2169
            else:
2170
                return lambda x, y: torch.sum(torch.abs(x - y),axis=0)
×
2171

2172
        else:
2173
            from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError
×
2174
            raise AutodiffCompositionError(f"The 'operation' parameter of {function.componentName} is not supported "
2175
                                           f"by AutodiffComposition; use 'DOT_PRODUCT' or 'L0'.")
2176

2177
    def _function(self,
1✔
2178
                 variable=None,
2179
                 context=None,
2180
                 params=None):
2181
        """
2182

2183
        Arguments
2184
        ---------
2185
        variable : list or 1d array
2186
            array to be transformed;  length must equal the number of rows of `matrix <MatrixTransform.matrix>`.
2187

2188
        params : Dict[param keyword: param value] : default None
2189
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2190
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2191
            arguments of the constructor.
2192

2193
        Returns
2194
        ---------
2195

2196
        dot product of or difference between variable and matrix : 1d array
2197
            length of the array returned equals the number of columns of `matrix <MatrixTransform.matrix>`.
2198

2199
        """
2200
        vector = np.array(variable)
1✔
2201
        matrix = self._get_current_parameter_value(MATRIX, context)
1✔
2202
        operation = self._get_current_parameter_value(OPERATION, context)
1✔
2203
        normalize = self._get_current_parameter_value(NORMALIZE, context)
1✔
2204

2205
        if operation == DOT_PRODUCT:
1✔
2206
            if normalize:
1✔
2207
                if np.any(vector):
1✔
2208
                    vector = vector / np.linalg.norm(vector)
1✔
2209
                if np.any(matrix):
1✔
2210
                    # FIX: the axis along which norming is carried out should probably be a parameter
2211
                    #      Also need to deal with column- (or row-) wise zeros which cause div by zero
2212
                    #      Replace columns (if norming axis 0) or rows (if norming axis 1) of zeros with 1's
2213
                    # matrix = matrix / np.linalg.norm(matrix,axis=-1,keepdims=True)
2214
                    matrix = matrix / np.linalg.norm(matrix, axis=0, keepdims=True)
1✔
2215
            result = np.dot(vector, matrix)
1✔
2216

2217
        elif operation == L0:
1!
2218
            if normalize:
1!
2219
                normalization = np.sum(np.abs(vector - matrix)) or 1
1✔
2220
                result = np.sum((1 - (np.abs(vector - matrix)) / normalization),axis=0)
1✔
2221
            else:
2222
                result = np.sum((np.abs(vector - matrix)),axis=0)
×
2223

2224
        return self.convert_output_type(result)
1✔
2225

2226
    @staticmethod
1✔
2227
    def keyword(obj, keyword):
1✔
2228

2229
        from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
1✔
2230
        rows = None
1✔
2231
        cols = None
1✔
2232
        # use of variable attribute here should be ok because it's using it as a format/type
2233
        if isinstance(obj, MappingProjection):
1!
2234
            if isinstance(obj.sender.defaults.value, numbers.Number):
1!
2235
                rows = 1
×
2236
            else:
2237
                rows = len(obj.sender.defaults.value)
1✔
2238
            if isinstance(obj.receiver.defaults.variable, numbers.Number):
1!
2239
                cols = 1
×
2240
            else:
2241
                cols = obj.receiver.socket_width
1✔
2242
        matrix = get_matrix(keyword, rows, cols)
1✔
2243

2244
        if matrix is None:
1✔
2245
            raise FunctionError("Unrecognized keyword ({}) specified for the {} function of {}".
2246
                                format(keyword, obj.name, obj.owner_name))
2247
        else:
2248
            return matrix
1✔
2249

2250
    def param_function(owner, function):
1✔
2251
        sender_len = len(owner.sender.defaults.value)
×
2252
        receiver_len = len(owner.receiver.defaults.variable)
×
2253
        return function(sender_len, receiver_len)
×
2254

2255
    def _is_identity(self, context=None, defaults=False):
1✔
2256
        if defaults:
1✔
2257
            matrix = self.defaults.matrix
1✔
2258
        else:
2259
            matrix = self.parameters.matrix._get(context)
1✔
2260

2261
        # if matrix is not an np array with at least one dimension,
2262
        # this isn't an identity matrix
2263
        try:
1✔
2264
            size = matrix.shape[0]
1✔
2265
        except (AttributeError, IndexError):
×
2266
            return False
×
2267

2268
        # check if the matrix is the same as the identity matrix
2269
        # note that we can use the first dimension size to create the identity matrix
2270
        # because if the matrix is not square, this comparison will fail anyway
2271
        identity_matrix = np.identity(size)
1✔
2272
        # numpy has deprecated == comparisons of arrays
2273
        try:
1✔
2274
            return np.array_equal(matrix, identity_matrix)
1✔
2275
        except TypeError:
×
2276
            return matrix == identity_matrix
×
2277

2278
# def is_matrix_spec(m):
2279
#     if m is None:
2280
#         return True
2281
#     if isinstance(m, (list, np.ndarray, types.FunctionType)):
2282
#         return True
2283
#     if m in MATRIX_KEYWORD_VALUES:
2284
#         return True
2285
#     return False
2286

2287

2288

2289
class CombineMeans(TransformFunction):  # ------------------------------------------------------------------------
1✔
2290
    # FIX: CONFIRM THAT 1D KWEIGHTS USES EACH ELEMENT TO SCALE CORRESPONDING VECTOR IN VARIABLE
2291
    # FIX  CONFIRM THAT LINEAR TRANSFORMATION (OFFSET, SCALE) APPLY TO THE RESULTING ARRAY
2292
    # FIX: CONFIRM RETURNS LIST IF GIVEN LIST, AND SIMLARLY FOR NP.ARRAY
2293
    """
2294
    CombineMeans(            \
2295
         default_variable, \
2296
         weights=None,     \
2297
         exponents=None,   \
2298
         operation=SUM,    \
2299
         scale=None,       \
2300
         offset=None,      \
2301
         params=None,      \
2302
         owner=None,       \
2303
         name=None,        \
2304
         prefs=None        \
2305
         )
2306

2307
    .. _CombineMeans:
2308

2309
    Calculate and combine mean(s) for arrays of values, optionally weighting and/or exponentiating each mean prior to
2310
    combining, and scaling and/or offsetting after combining.
2311

2312
    `function <CombineMeans.function>` takes the mean of each array in the outermost dimension (axis 0) of `variable
2313
    <CombineMeans.variable>`, and combines them either additively or multiplicatively (as specified by `operation
2314
    <CombineMeans.operation>`), applying `weights <LinearCombination.weights>` and/or `exponents
2315
    <LinearCombination.exponents>` (if specified) to each mean prior to combining them, and applying `scale
2316
    <LinearCombination.scale>` and/or `offeset <LinearCombination.offset>` (if specified) to the result after combining,
2317
    and returns a scalar value.
2318

2319
    COMMENT:
2320
        Description:
2321
            Take means of elements of each array in variable arg,
2322
                and combine using arithmetic operation determined by OPERATION
2323
            Use optional SCALE and OFFSET parameters to linearly transform the resulting array
2324
            Returns a scalar
2325

2326
            Notes:
2327
            * WEIGHTS and EXPONENTS can be:
2328
                - 1D: each array in variable is scaled by the corresponding element of WEIGHTS or EXPONENTS
2329
                - 2D: each array in variable is scaled by (Hadamard-wise) corresponding array of WEIGHTS or EXPONENTS
2330
        Initialization arguments:
2331
         - variable (value, np.ndarray or list): values to be combined;
2332
             can be a list of lists, or a 1D or 2D np.array;  a scalar is always returned
2333
             if it is a list, it must be a list of numbers, lists, or np.arrays
2334
             if WEIGHTS or EXPONENTS are specified, their length along the outermost dimension (axis 0)
2335
                 must equal the number of items in the variable
2336
         - params (dict) can include:
2337
             + WEIGHTS (list of numbers or 1D np.array): multiplies each item of variable before combining them
2338
                  (default: [1,1])
2339
             + EXPONENTS (list of numbers or 1D np.array): exponentiates each item of variable before combining them
2340
                  (default: [1,1])
2341
             + OFFSET (value): added to the result (after the arithmetic operation is applied; default is 0)
2342
             + SCALE (value): multiples the result (after combining elements; default: 1)
2343
             + OPERATION (Operation Enum) - method used to combine the means of the arrays in variable (default: SUM)
2344
                  SUM: sum of the means of the arrays in variable
2345
                  PRODUCT: product of the means of the arrays in variable
2346

2347
        CombineMeans.function returns a scalar value
2348
    COMMENT
2349

2350
    Arguments
2351
    ---------
2352

2353
    variable : 1d or 2d np.array : default class_defaults.variable
2354
        specifies a template for the arrays to be combined.  If it is 2d, all items must have the same length.
2355

2356
    weights : 1d or 2d np.array : default None
2357
        specifies values used to multiply the elements of each array in `variable  <CombineMeans.variable>`.
2358
        If it is 1d, its length must equal the number of items in `variable <CombineMeans.variable>`;
2359
        if it is 2d, the length of each item must be the same as those in `variable <CombineMeans.variable>`,
2360
        and there must be the same number of items as there are in `variable <CombineMeans.variable>`
2361
        (see `weights <CombineMeans.weights>` for details)
2362

2363
    exponents : 1d or 2d np.array : default None
2364
        specifies values used to exponentiate the elements of each array in `variable  <CombineMeans.variable>`.
2365
        If it is 1d, its length must equal the number of items in `variable <CombineMeans.variable>`;
2366
        if it is 2d, the length of each item must be the same as those in `variable <CombineMeans.variable>`,
2367
        and there must be the same number of items as there are in `variable <CombineMeans.variable>`
2368
        (see `exponents <CombineMeans.exponents>` for details)
2369

2370
    operation : SUM or PRODUCT : default SUM
2371
        specifies whether the `function <CombineMeans.function>` takes the sum or product of the means of the arrays in
2372
        `variable  <CombineMeans.variable>`.
2373

2374
    scale : float or np.ndarray : default None
2375
        specifies a value by which to multiply the result of `function <CombineMeans.function>`
2376
        (see `scale <CombineMeans.scale>` for details)
2377

2378
    offset : float or np.ndarray : default None
2379
        specifies a value to add to the result of `function <CombineMeans.function>`
2380
        (see `offset <CombineMeans.offset>` for details)
2381

2382
    params : Dict[param keyword: param value] : default None
2383
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2384
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2385
        arguments of the constructor.
2386

2387
    owner : Component
2388
        `component <Component>` to which to assign the Function.
2389

2390
    name : str : default see `name <Function.name>`
2391
        specifies the name of the Function.
2392

2393
    prefs : PreferenceSet or specification dict : default Function.classPreferences
2394
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
2395

2396
    Attributes
2397
    ----------
2398

2399
    variable : 1d or 2d np.array
2400
        contains the arrays to be combined by `function <CombineMeans>`.  If it is 1d, the array is simply
2401
        linearly transformed by and `scale <CombineMeans.scale>` and `offset <CombineMeans.scale>`.
2402
        If it is 2d, the arrays (all of which must be of equal length) are weighted and/or exponentiated as
2403
        specified by `weights <CombineMeans.weights>` and/or `exponents <CombineMeans.exponents>`
2404
        and then combined as specified by `operation <CombineMeans.operation>`.
2405

2406
    weights : 1d or 2d np.array : default NOne
2407
        if it is 1d, each element is used to multiply all elements in the corresponding array of
2408
        `variable <CombineMeans.variable>`;    if it is 2d, then each array is multiplied elementwise
2409
        (i.e., the Hadamard Product is taken) with the corresponding array of `variable <CombineMeanss.variable>`.
2410
        All :keyword:`weights` are applied before any exponentiation (if it is specified).
2411

2412
    exponents : 1d or 2d np.array : default None
2413
        if it is 1d, each element is used to exponentiate the elements of the corresponding array of
2414
        `variable <CombineMeans.variable>`;  if it is 2d, the element of each array is used to exponentiate
2415
        the corresponding element of the corresponding array of `variable <CombineMeans.variable>`.
2416
        In either case, exponentiating is applied after application of the `weights <CombineMeans.weights>`
2417
        (if any are specified).
2418

2419
    operation : SUM or PRODUCT : default SUM
2420
        determines whether the `function <CombineMeans.function>` takes the elementwise (Hadamard) sum or
2421
        product of the arrays in `variable  <CombineMeans.variable>`.
2422

2423
    scale : float or np.ndarray : default None
2424
        value is applied multiplicatively to each element of the array after applying the
2425
        `operation <CombineMeans.operation>` (see `scale <CombineMeans.scale>` for details);
2426
        this done before applying the `offset <CombineMeans.offset>` (if it is specified).
2427

2428
    offset : float or np.ndarray : default None
2429
        value is added to each element of the array after applying the `operation <CombineMeans.operation>`
2430
        and `scale <CombineMeans.scale>` (if it is specified).
2431

2432
    COMMENT:
2433
    function : function
2434
        applies the `weights <CombineMeans.weights>` and/or `exponents <CombineMeanss.weights>` to the
2435
        arrays in `variable <CombineMeans.variable>`, then takes their sum or product (as specified by
2436
        `operation <CombineMeans.operation>`), and finally applies `scale <CombineMeans.scale>` and/or
2437
        `offset <CombineMeans.offset>`.
2438

2439
    enable_output_type_conversion : Bool : False
2440
        specifies whether `function output type conversion <Function_Output_Type_Conversion>` is enabled.
2441

2442
    output_type : FunctionOutputType : None
2443
        used to specify the return type for the `function <Function_Base.function>`;  `functionOuputTypeConversion`
2444
        must be enabled and implemented for the class (see `FunctionOutputType <Function_Output_Type_Conversion>`
2445
        for details).
2446
    COMMENT
2447

2448
    owner : Component
2449
        `component <Component>` to which the Function has been assigned.
2450

2451
    name : str
2452
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
2453
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
2454

2455
    prefs : PreferenceSet or specification dict : Function.classPreferences
2456
        the `PreferenceSet` for function; if it is not specified in the **prefs** argument of the Function's
2457
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
2458
        for details).
2459
    """
2460

2461
    componentName = COMBINE_MEANS_FUNCTION
1✔
2462

2463
    classPreferences = {
1✔
2464
        PREFERENCE_SET_NAME: 'CombineMeansCustomClassPreferences',
2465
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
2466
    }
2467

2468
    class Parameters(TransformFunction.Parameters):
1✔
2469
        """
2470
            Attributes
2471
            ----------
2472

2473
                exponents
2474
                    see `exponents <CombineMeans.exponents>`
2475

2476
                    :default value: None
2477
                    :type:
2478

2479
                offset
2480
                    see `offset <CombineMeans.offset>`
2481

2482
                    :default value: 0.0
2483
                    :type: ``float``
2484

2485
                operation
2486
                    see `operation <CombineMeans.operation>`
2487

2488
                    :default value: `SUM`
2489
                    :type: ``str``
2490

2491
                scale
2492
                    see `scale <CombineMeans.scale>`
2493

2494
                    :default value: 1.0
2495
                    :type: ``float``
2496

2497
                weights
2498
                    see `weights <CombineMeans.weights>`
2499

2500
                    :default value: None
2501
                    :type:
2502
        """
2503
        weights = None
1✔
2504
        exponents = None
1✔
2505
        operation = SUM
1✔
2506
        scale = Parameter(1.0, modulable=True, aliases=[MULTIPLICATIVE_PARAM])
1✔
2507
        offset = Parameter(0.0, modulable=True, aliases=[ADDITIVE_PARAM])
1✔
2508

2509
    @check_user_specified
1✔
2510
    @beartype
1✔
2511
    def __init__(self,
1✔
2512
                 default_variable=None,
2513
                 # weights: Optional[ValidParamSpecType] = None,
2514
                 # exponents: Optional[ValidParamSpecType] = None,
2515
                 weights=None,
2516
                 exponents=None,
2517
                 operation: Optional[Literal['sum', 'product']] = None,
2518
                 scale=None,
2519
                 offset=None,
2520
                 params=None,
2521
                 owner=None,
2522
                 prefs:  Optional[ValidPrefSet] = None):
2523

2524
        super().__init__(
1✔
2525
            default_variable=default_variable,
2526
            weights=weights,
2527
            exponents=exponents,
2528
            operation=operation,
2529
            scale=scale,
2530
            offset=offset,
2531
            params=params,
2532
            owner=owner,
2533
            prefs=prefs,
2534
        )
2535

2536
        if self.weights is not None:
1!
2537
            self.weights = np.atleast_2d(self.weights).reshape(-1, 1)
×
2538
        if self.exponents is not None:
1!
2539
            self.exponents = np.atleast_2d(self.exponents).reshape(-1, 1)
×
2540

2541
    def _validate_variable(self, variable, context=None):
1✔
2542
        """Insure that all items of variable are numeric
2543
        """
2544
        variable = super()._validate_variable(variable=variable, context=context)
1✔
2545
        # if any(not is_numeric(item) for item in variable):
2546
        #     raise FunctionError("All items of the variable for {} must be numeric".format(self.componentName))
2547
        return variable
1✔
2548

2549
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
2550
        """Validate weights, exponents, scale and offset parameters
2551

2552
        Check that WEIGHTS and EXPONENTS are lists or np.arrays of numbers with length equal to variable
2553
        Check that SCALE and OFFSET are either scalars or np.arrays of numbers with length and shape equal to variable
2554

2555
        Note: the checks of compatibility with variable are only performed for validation calls during execution
2556
              (i.e., from check_args(), since during initialization or COMMAND_LINE assignment,
2557
              a parameter may be re-assigned before variable assigned during is known
2558
        """
2559

2560
        # FIX: MAKE SURE THAT IF OPERATION IS SUBTRACT OR DIVIDE, THERE ARE ONLY TWO VECTORS
2561

2562
        super()._validate_params(request_set=request_set,
1✔
2563
                                 target_set=target_set,
2564
                                 context=context)
2565

2566
        if WEIGHTS in target_set and target_set[WEIGHTS] is not None:
1!
2567
            target_set[WEIGHTS] = np.atleast_2d(target_set[WEIGHTS]).reshape(-1, 1)
×
2568
            if context.execution_phase & (ContextFlags.PROCESSING | ContextFlags.LEARNING):
×
2569
                if len(target_set[WEIGHTS]) != len(self.defaults.variable):
×
2570
                    raise FunctionError("Number of weights ({0}) is not equal to number of items in variable ({1})".
2571
                                        format(len(target_set[WEIGHTS]), len(self.defaults.variable.shape)))
2572

2573
        if EXPONENTS in target_set and target_set[EXPONENTS] is not None:
1!
2574
            target_set[EXPONENTS] = np.atleast_2d(target_set[EXPONENTS]).reshape(-1, 1)
×
2575
            if context.execution_phase & (ContextFlags.PROCESSING | ContextFlags.LEARNING):
×
2576
                if len(target_set[EXPONENTS]) != len(self.defaults.variable):
×
2577
                    raise FunctionError("Number of exponents ({0}) does not equal number of items in variable ({1})".
2578
                                        format(len(target_set[EXPONENTS]), len(self.defaults.variable.shape)))
2579

2580
        if SCALE in target_set and target_set[SCALE] is not None:
1!
2581
            scale = target_set[SCALE]
1✔
2582
            if isinstance(scale, numbers.Number):
1!
2583
                pass
×
2584
            elif isinstance(scale, np.ndarray):
1✔
2585
                target_set[SCALE] = np.array(scale)
1✔
2586
            else:
2587
                raise FunctionError("{} param of {} ({}) must be a scalar or an np.ndarray".
2588
                                    format(SCALE, self.name, scale))
2589
            if context.execution_phase & (ContextFlags.PROCESSING | ContextFlags.LEARNING):
1!
2590
                if (isinstance(scale, np.ndarray) and
×
2591
                        (scale.size != self.defaults.variable.size or
2592
                                 scale.shape != self.defaults.variable.shape)):
2593
                    raise FunctionError("Scale is using Hadamard modulation "
2594
                                        "but its shape and/or size (shape: {}, size:{}) "
2595
                                        "do not match the variable being modulated (shape: {}, size: {})".
2596
                                        format(scale.shape, scale.size, self.defaults.variable.shape,
2597
                                               self.defaults.variable.size))
2598

2599
        if OFFSET in target_set and target_set[OFFSET] is not None:
1!
2600
            offset = target_set[OFFSET]
1✔
2601
            if isinstance(offset, numbers.Number):
1!
2602
                pass
×
2603
            elif isinstance(offset, np.ndarray):
1✔
2604
                target_set[OFFSET] = np.array(offset)
1✔
2605
            else:
2606
                raise FunctionError("{} param of {} ({}) must be a scalar or an np.ndarray".
2607
                                    format(OFFSET, self.name, offset))
2608
            if context.execution_phase & (ContextFlags.PROCESSING | ContextFlags.LEARNING):
1!
2609
                if (isinstance(offset, np.ndarray) and
×
2610
                        (offset.size != self.defaults.variable.size or
2611
                                 offset.shape != self.defaults.variable.shape)):
2612
                    raise FunctionError("Offset is using Hadamard modulation "
2613
                                        "but its shape and/or size (shape: {}, size:{}) "
2614
                                        "do not match the variable being modulated (shape: {}, size: {})".
2615
                                        format(offset.shape, offset.size, self.defaults.variable.shape,
2616
                                               self.defaults.variable.size))
2617

2618
                    # if not operation:
2619
                    #     raise FunctionError("Operation param missing")
2620
                    # if not operation == self.Operation.SUM and not operation == self.Operation.PRODUCT:
2621
                    #     raise FunctionError("Operation param ({0}) must be Operation.SUM or Operation.PRODUCT".
2622
                    #     format(operation))
2623

2624
    def _function(self,
1✔
2625
                 variable=None,
2626
                 context=None,
2627
                 params=None,
2628
                 ):
2629
        """
2630

2631
        Arguments
2632
        ---------
2633

2634
        variable : 1d or 2d np.array : default class_defaults.variable
2635
           a single numeric array, or multiple arrays to be combined; if it is 2d, all arrays must have the same length.
2636

2637
        params : Dict[param keyword: param value] : default None
2638
            a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
2639
            function.  Values specified for parameters in the dictionary override any assigned to those parameters in
2640
            arguments of the constructor.
2641

2642

2643
        Returns
2644
        -------
2645

2646
        combined means : number
2647
            the result of taking the means of each array in `variable <CombineMeans.variable>` and combining them.
2648

2649
        """
2650
        exponents = self._get_current_parameter_value(EXPONENTS, context)
1✔
2651
        weights = self._get_current_parameter_value(WEIGHTS, context)
1✔
2652
        operation = self._get_current_parameter_value(OPERATION, context)
1✔
2653
        offset = self._get_current_parameter_value(OFFSET, context)
1✔
2654
        scale = self._get_current_parameter_value(SCALE, context)
1✔
2655

2656
        # QUESTION:  WHICH IS LESS EFFICIENT:
2657
        #                A) UNECESSARY ARITHMETIC OPERATIONS IF SCALE AND/OR OFFSET ARE 1.0 AND 0, RESPECTIVELY?
2658
        #                   (DOES THE COMPILER KNOW NOT TO BOTHER WITH MULT BY 1 AND/OR ADD 0?)
2659
        #                B) EVALUATION OF IF STATEMENTS TO DETERMINE THE ABOVE?
2660
        # IMPLEMENTATION NOTE:  FOR NOW, ASSUME B) ABOVE, AND ASSIGN DEFAULT "NULL" VALUES TO offset AND scale
2661
        if offset is None:
1!
2662
            offset = 0.0
×
2663

2664
        if scale is None:
1!
2665
            scale = 1.0
×
2666

2667
        # IMPLEMENTATION NOTE: CONFIRM: SHOULD NEVER OCCUR, AS _validate_variable NOW ENFORCES 2D np.ndarray
2668
        # If variable is 0D or 1D:
2669
        # if np_array_less_than_2d(variable):
2670
        #     return (variable * scale) + offset
2671

2672
        means = convert_all_elements_to_np_array([np.mean(item) for item in variable])
1✔
2673

2674
        # FIX FOR EFFICIENCY: CHANGE THIS AND WEIGHTS TO TRY/EXCEPT // OR IS IT EVEN NECESSARY, GIVEN VALIDATION ABOVE??
2675
        # Apply exponents if they were specified
2676
        if exponents is not None:
1!
2677
            # Avoid divide by zero warning:
2678
            #    make sure there are no zeros for an element that is assigned a negative exponent
2679
            if (self.is_initializing and
×
2680
                    any(not any(i) and j < 0 for i, j in zip(variable, exponents))):
2681
                means = np.ones_like(means)
×
2682
            else:
2683
                means = means ** exponents
×
2684

2685
        # Apply weights if they were specified
2686
        if weights is not None:
1!
2687
            means = means * weights
×
2688

2689
        # CALCULATE RESULT USING RELEVANT COMBINATION OPERATION AND MODULATION
2690

2691
        if operation == SUM:
1!
2692
            result = np.sum(means, axis=0) * scale + offset
1✔
2693

2694
        elif operation == PRODUCT:
×
2695
            result = np.prod(means, axis=0) * scale + offset
×
2696

2697
        else:
2698
            raise FunctionError("Unrecognized operator ({0}) for CombineMeans function".
2699
                                format(self._get_current_parameter_value(OPERATION, context)))
2700

2701
        return self.convert_output_type(result)
1✔
2702

2703
    @property
1✔
2704
    def offset(self):
1✔
2705
        if not hasattr(self, '_offset'):
×
2706
            return None
×
2707
        else:
2708
            return self._offset
×
2709

2710
    @offset.setter
1✔
2711
    def offset(self, val):
1✔
2712
        self._offset = val
×
2713

2714
    @property
1✔
2715
    def scale(self):
1✔
2716
        if not hasattr(self, '_scale'):
×
2717
            return None
×
2718
        else:
2719
            return self._scale
×
2720

2721
    @scale.setter
1✔
2722
    def scale(self, val):
1✔
2723
        self._scale = val
×
2724

2725

2726
GAMMA = 'gamma'
1✔
2727

2728

2729
class PredictionErrorDeltaFunction(TransformFunction):
1✔
2730
    """
2731
    Calculate temporal difference prediction error.
2732

2733
    `function <PredictionErrorDeltaFunction.function>` returns the prediction error using arrays in `variable
2734
    <PredictionErrorDeltaFunction.variable>`:
2735

2736
    .. math::
2737
        \\delta(t) = r(t) + \\gamma sample(t) - sample(t - 1)
2738

2739
    """
2740
    componentName = PREDICTION_ERROR_DELTA_FUNCTION
1✔
2741

2742
    classPreferences = {
1✔
2743
        PREFERENCE_SET_NAME: 'PredictionErrorDeltaCustomClassPreferences',
2744
        REPORT_OUTPUT_PREF: PreferenceEntry(False, PreferenceLevel.INSTANCE),
2745
    }
2746

2747
    class Parameters(TransformFunction.Parameters):
1✔
2748
        """
2749
            Attributes
2750
            ----------
2751

2752
                variable
2753
                    see `variable <PredictionErrorDeltaFunction.variable>`
2754

2755
                    :default value: numpy.array([[1], [1]])
2756
                    :type: ``numpy.ndarray``
2757

2758
                gamma
2759
                    see `gamma <PredictionErrorDeltaFunction.gamma>`
2760

2761
                    :default value: 1.0
2762
                    :type: ``float``
2763
        """
2764
        variable = Parameter(np.array([[1], [1]]), pnl_internal=True, constructor_argument='default_variable')
1✔
2765
        gamma = Parameter(1.0, modulable=True)
1✔
2766

2767
    @check_user_specified
1✔
2768
    @beartype
1✔
2769
    def __init__(self,
1✔
2770
                 default_variable=None,
2771
                 gamma: Optional[float] = None,
2772
                 params=None,
2773
                 owner=None,
2774
                 prefs:  Optional[ValidPrefSet] = None):
2775

2776
        super().__init__(
1✔
2777
            default_variable=default_variable,
2778
            gamma=gamma,
2779
            params=params,
2780
            owner=owner,
2781
            prefs=prefs,
2782
        )
2783

2784
    def _validate_variable(self, variable, context=None):
1✔
2785
        """
2786
        Insure that all items of variable are numeric
2787

2788
        Parameters
2789
        ----------
2790
        variable
2791
        context
2792

2793
        Returns
2794
        -------
2795
        variable if all items are numeric
2796
        """
2797
        variable = super()._validate_variable(variable=variable, context=context)
1✔
2798

2799
        if isinstance(variable, (list, np.ndarray)):
1!
2800
            if isinstance(variable, np.ndarray) and not variable.ndim:
1!
2801
                return variable
×
2802
            length = 0
1✔
2803
            for i in range(1, len(variable)):
1✔
2804
                if i == 0:
1!
2805
                    continue
×
2806
                if isinstance(variable[i - 1], numbers.Number):
1!
2807
                    old_length = 1
×
2808
                else:
2809
                    old_length = len(variable[i - 1])
1✔
2810
                if isinstance(variable[i], numbers.Number):
1!
2811
                    new_length = 1
×
2812
                else:
2813
                    new_length = len(variable[i])
1✔
2814
                if old_length != new_length:
1✔
2815
                    raise FunctionError("Length of all arrays in variable {} "
2816
                                        "for {} must be the same".format(variable,
2817
                                                                         self.__class__.__name__))
2818
        return variable
1✔
2819

2820
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
2821
        """
2822
        Checks that WEIGHTS is a list or np.array of numbers with length equal
2823
        to variable.
2824

2825
        Note: the checks of compatibility with variable are only performed for
2826
        validation calls during execution (i.e. from `check_args()`), since
2827
        during initialization or COMMAND_LINE assignment, a parameter may be
2828
        re-assigned before variable assigned during is known
2829

2830
        Parameters
2831
        ----------
2832
        request_set
2833
        target_set
2834
        context
2835

2836
        Returns
2837
        -------
2838
        None
2839
        """
2840
        super()._validate_params(request_set,
1✔
2841
                                 target_set=target_set,
2842
                                 context=context)
2843

2844
        if GAMMA in target_set and target_set[GAMMA] is not None:
1!
2845
            self._validate_parameter_spec(target_set[GAMMA], GAMMA, numeric_only=True)
1✔
2846

2847
        if WEIGHTS in target_set and target_set[WEIGHTS] is not None:
1!
2848
            self._validate_parameter_spec(target_set[WEIGHTS], WEIGHTS, numeric_only=True)
×
2849
            target_set[WEIGHTS] = np.atleast_2d(target_set[WEIGHTS]).reshape(-1, 1)
×
2850
            if context.execution_phase & (ContextFlags.EXECUTING):
×
2851
                if len(target_set[WEIGHTS]) != len(
×
2852
                        self.defaults.variable):
2853
                    raise FunctionError("Number of weights {} is not equal to "
2854
                                        "number of items in variable {}".format(
2855
                        len(target_set[WEIGHTS]),
2856
                        len(self.defaults.variable.shape)))
2857

2858
    def _function(self,
1✔
2859
                 variable=None,
2860
                 context=None,
2861
                 params=None,
2862
                 ):
2863
        """
2864

2865
        Arguments
2866
        ----------
2867
        variable : 2d np.array : default class_defaults.variable
2868
            a 2d array representing the sample and target values to be used to
2869
            calculate the temporal difference delta values. Both arrays must
2870
            have the same length
2871

2872
        params : Dict[param keyword, param value] : default None
2873
            a `parameter dictionary <ParameterPort_Specification>` that
2874
            specifies the parameters for the function. Values specified for
2875
            parameters in the dictionary override any assigned to those
2876
            parameters in arguments of the constructor.
2877

2878

2879
        Returns
2880
        -------
2881
        delta values : 1d np.array
2882

2883
        """
2884
        gamma = self._get_current_parameter_value(GAMMA, context).item()
1✔
2885
        sample = variable[0]
1✔
2886
        reward = variable[1]
1✔
2887
        delta = np.zeros(sample.shape)
1✔
2888

2889
        for t in range(1, len(sample)):
1✔
2890
            delta[t] = reward[t] + gamma * sample[t] - sample[t - 1]
1✔
2891

2892
        return self.convert_output_type(delta)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc