• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 11992518143

12 Nov 2024 03:50AM UTC coverage: 83.719% (-1.2%) from 84.935%
11992518143

push

github

web-flow
Merge pull request #3071 from PrincetonUniversity/devel

Devel

9406 of 12466 branches covered (75.45%)

Branch coverage included in aggregate %.

3240 of 3767 new or added lines in 77 files covered. (86.01%)

120 existing lines in 26 files now uncovered.

32555 of 37655 relevant lines covered (86.46%)

0.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.38
/psyneulink/core/components/functions/stateful/statefulfunction.py
1
#
2
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
3
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
4
#     http://www.apache.org/licenses/LICENSE-2.0
5
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
6
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7
# See the License for the specific language governing permissions and limitations under the License.
8
#
9
#
10
# *****************************************  STATEFUL FUNCTION *********************************************************
11
"""
12

13
* `StatefulFunction`
14
* `IntegratorFunctions`
15
* `MemoryFunctions`
16

17
"""
18

19
import abc
1✔
20
import collections
1✔
21
import copy
1✔
22
import numbers
1✔
23
import warnings
1✔
24

25
import numpy as np
1✔
26
from beartype import beartype
1✔
27

28
from psyneulink._typing import Mapping, Optional
1✔
29

30
from psyneulink.core import llvm as pnlvm
1✔
31
from psyneulink.core.components.component import DefaultsFlexibility, _has_initializers_setter, ComponentsMeta
1✔
32
from psyneulink.core.components.functions.nonstateful.distributionfunctions import DistributionFunction
1✔
33
from psyneulink.core.components.functions.function import Function_Base, FunctionError, _noise_setter
1✔
34
from psyneulink.core.globals.context import handle_external_context
1✔
35
from psyneulink.core.globals.keywords import STATEFUL_FUNCTION_TYPE, STATEFUL_FUNCTION, NOISE, RATE
1✔
36
from psyneulink.core.globals.parameters import Parameter, check_user_specified
1✔
37
from psyneulink.core.globals.preferences.basepreferenceset import ValidPrefSet
1✔
38
from psyneulink.core.globals.utilities import (
1✔
39
    contains_type,
40
    convert_all_elements_to_np_array,
41
    convert_to_np_array,
42
    fill_array,
43
    iscompatible,
44
    safe_len,
45
)
46

47
__all__ = ['StatefulFunction']
1✔
48

49

50
class StatefulFunction(Function_Base): #  ---------------------------------------------------------------------
1✔
51
    """
52
    StatefulFunction(           \
53
        default_variable=None,  \
54
        initializer,            \
55
        rate=1.0,               \
56
        noise=0.0,              \
57
        params=None,            \
58
        owner=None,             \
59
        prefs=None,             \
60
        )
61

62
    .. _StatefulFunction:
63

64
    Abstract base class for Functions the result of which depend on their `previous_value
65
    <StatefulFunction.previous_value>` attribute.
66

67
    COMMENT:
68
    NARRATIVE HERE THAT EXPLAINS:
69
    A) initializers and stateful_attributes
70
    B) initializer (note singular) is a pre-specified member of initializers
71
       that contains the value with which to initialize previous_value
72
    COMMENT
73

74

75
    Arguments
76
    ---------
77

78
    default_variable : number, list or array : default class_defaults.variable
79
        specifies a template for `variable <StatefulFunction.variable>`.
80

81
    initializer : float, list or 1d array : default 0.0
82
        specifies initial value for `previous_value <StatefulFunction.previous_value>`.  If it is a list or array,
83
        it must be the same length as `variable <StatefulFunction.variable>` (see `initializer
84
        <StatefulFunction.initializer>` for details).
85

86
    rate : float, list or 1d array : default 1.0
87
        specifies value used as a scaling parameter in a subclass-dependent way (see `rate <StatefulFunction.rate>` for
88
        details); if it is a list or array, it must be the same length as `variable <StatefulFunction.default_variable>`.
89

90
    noise : float, function, list or 1d array : default 0.0
91
        specifies random value added in each call to `function <StatefulFunction.function>`; if it is a list or
92
        array, it must be the same length as `variable <StatefulFunction.default_variable>` (see `noise
93
        <StatefulFunction.noise>` for details).
94

95
    params : Dict[param keyword: param value] : default None
96
        a `parameter dictionary <ParameterPort_Specification>` that specifies the parameters for the
97
        function.  Values specified for parameters in the dictionary override any assigned to those parameters in
98
        arguments of the constructor.
99

100
    owner : Component
101
        `component <Component>` to which to assign the Function.
102

103
    name : str : default see `name <Function.name>`
104
        specifies the name of the Function.
105

106
    prefs : PreferenceSet or specification dict : default Function.classPreferences
107
        specifies the `PreferenceSet` for the Function (see `prefs <Function_Base.prefs>` for details).
108

109
    Attributes
110
    ----------
111

112
    variable : number or array
113
        current input value.
114

115
    initializer : float or 1d array
116
        determines initial value assigned to `previous_value <StatefulFunction.previous_value>`. If `variable
117
        <StatefulFunction.variable>` is a list or array, and initializer is a float or has a single element, it is
118
        applied to each element of `previous_value <StatefulFunction.previous_value>`. If initializer is a list or
119
        array,each element is applied to the corresponding element of `previous_value <Integrator.previous_value>`.
120

121
    previous_value : 1d array
122
        last value returned (i.e., for which state is being maintained).
123

124
    initializers : list
125
        stores the names of the initialization attributes for each of the stateful attributes of the function. The
126
        index i item in initializers provides the initialization value for the index i item in `stateful_attributes
127
        <StatefulFunction.stateful_attributes>`.
128

129
    stateful_attributes : list
130
        stores the names of each of the stateful attributes of the function. The index i item in stateful_attributes is
131
        initialized by the value of the initialization attribute whose name is stored in index i of `initializers
132
        <StatefulFunction.initializers>`. In most cases, the stateful_attributes, in that order, are the return values
133
        of the function.
134

135
    .. _Stateful_Rate:
136

137
    rate : float or 1d array
138
        on each call to `function <StatefulFunction.function>`, applied to `variable <StatefulFunction.variable>`,
139
        `previous_value <StatefulFunction.previous_value>`, neither, or both, depending on implementation by
140
        subclass.  If it is a float or has a single value, it is applied to all elements of its target(s);  if it has
141
        more than one element, each element is applied to the corresponding element of its target(s).
142

143
    .. _Stateful_Noise:
144

145
    noise : float, function, list, or 1d array
146
        random value added on each call to `function <StatefulFunction.function>`. If `variable
147
        <StatefulFunction.variable>` is a list or array, and noise is a float or function, it is applied
148
        for each element of `variable <StatefulFunction.variable>`. If noise is a function, it is executed and applied
149
        separately for each element of `variable <StatefulFunction.variable>`.  If noise is a list or array,
150
        it is applied elementwise (i.e., in Hadamard form).
151

152
        .. hint::
153
            To generate random noise that varies for every execution, a probability distribution function should be
154
            used (see `Distribution Functions <DistributionFunction>` for details), that generates a new noise value
155
            from its distribution on each execution. If noise is specified as a float, a function with a fixed
156
            output, or a list or array of either of these, then noise is simply an offset that remains the same
157
            across all executions.
158

159
        .. note::
160
            A ParameterPort for noise will only be generated, and the
161
            noise Parameter itself will only be stateful, if the value
162
            of noise is entirely numeric (contains no functions) at the
163
            time of Mechanism construction.
164

165
    owner : Component
166
        `component <Component>` to which the Function has been assigned.
167

168
    name : str
169
        the name of the Function; if it is not specified in the **name** argument of the constructor, a default is
170
        assigned by FunctionRegistry (see `Registry_Naming` for conventions used for default and duplicate names).
171

172
    prefs : PreferenceSet or specification dict
173
        the `PreferenceSet` for the Function; if it is not specified in the **prefs** argument of the Function's
174
        constructor, a default is assigned using `classPreferences` defined in __init__.py (see `Preferences`
175
        for details).
176
    """
177

178
    componentType = STATEFUL_FUNCTION_TYPE
1✔
179
    componentName = STATEFUL_FUNCTION
1✔
180

181
    # TODO: consider moving this to a Parameter attribute
182
    _mdf_stateful_parameter_indices = {
1✔
183
        'previous_value': None
184
    }
185

186
    class Parameters(Function_Base.Parameters):
1✔
187
        """
188
            Attributes
189
            ----------
190

191
                initializer
192
                    see `initializer <StatefulFunction.initializer>`
193

194
                    :default value: numpy.array([0])
195
                    :type: ``numpy.ndarray``
196

197
                noise
198
                    see `noise <StatefulFunction.noise>`
199

200
                    :default value: 0.0
201
                    :type: ``float``
202

203
                previous_value
204
                    see `previous_value <StatefulFunction.previous_value>`
205

206
                    :default value: numpy.array([0])
207
                    :type: ``numpy.ndarray``
208

209
                rate
210
                    see `rate <StatefulFunction.rate>`
211

212
                    :default value: 1.0
213
                    :type: ``float``
214
        """
215
        noise = Parameter(0.0, modulable=True, setter=_noise_setter)
1✔
216
        rate = Parameter(1.0, modulable=True)
1✔
217
        previous_value = Parameter(np.array([0]), initializer='initializer')
1✔
218
        initializer = Parameter(np.array([0]), pnl_internal=True)
1✔
219
        has_initializers = Parameter(True, setter=_has_initializers_setter, pnl_internal=True)
1✔
220

221
        def _validate_noise(self, noise):
1✔
222
            if (
1!
223
                isinstance(noise, collections.abc.Iterable)
224
                # assume ComponentsMeta are functions
225
                and contains_type(noise, ComponentsMeta)
226
            ):
227
                # TODO: make this validation unnecessary by handling automatically?
228
                return 'functions in a list must be instantiated and have the desired noise variable shape'
×
229

230
    @handle_external_context()
1✔
231
    @check_user_specified
1✔
232
    @beartype
1✔
233
    def __init__(self,
1✔
234
                 default_variable=None,
235
                 rate=None,
236
                 noise=None,
237
                 initializer=None,
238
                 params: Optional[Mapping] = None,
239
                 owner=None,
240
                 prefs:  Optional[ValidPrefSet] = None,
241
                 context=None,
242
                 **kwargs
243
                 ):
244

245
        if not hasattr(self, "initializers"):
1!
246
            self.initializers = ["initializer"]
×
247

248
        if not hasattr(self, "stateful_attributes"):
1!
249
            self.stateful_attributes = ["previous_value"]
×
250

251
        super().__init__(
1✔
252
            default_variable=default_variable,
253
            rate=rate,
254
            initializer=initializer,
255
            noise=noise,
256
            params=params,
257
            owner=owner,
258
            prefs=prefs,
259
            context=context,
260
            **kwargs
261
        )
262

263
    def _validate(self, context=None):
1✔
264
        self._validate_rate(self.defaults.rate)
1✔
265
        self._validate_initializers(self.defaults.variable, context=context)
1✔
266
        super()._validate(context=context)
1✔
267

268
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
269

270
        # Handle list or array for rate specification
271
        if RATE in request_set:
1!
272
            rate = request_set[RATE]
1✔
273

274
            if isinstance(rate, (list, np.ndarray)) and not iscompatible(rate, self.defaults.variable):
1✔
275
                if safe_len(rate) != 1 and safe_len(rate) != np.array(self.defaults.variable).size:
1✔
276
                    # If the variable was not specified, then reformat it to match rate specification
277
                    #    and assign class_defaults.variable accordingly
278
                    # Note: this situation can arise when the rate is parametrized (e.g., as an array) in the
279
                    #       StatefulFunction's constructor, where that is used as specification for a function parameter
280
                    #       (e.g., for an IntegratorMechanism), whereas the input is specified as part of the
281
                    #       object to which the function parameter belongs (e.g., the IntegratorMechanism); in that
282
                    #       case, the StatefulFunction gets instantiated using its class_defaults.variable ([[0]])
283
                    #       before the object itself, thus does not see the array specification for the input.
284
                    if self._variable_shape_flexibility is DefaultsFlexibility.FLEXIBLE:
1✔
285
                        self._instantiate_defaults(variable=np.zeros_like(np.array(rate)), context=context)
1✔
286
                        if self.verbosePref:
1!
NEW
287
                            warnings.warn(f"The length ({safe_len(rate)}) of the array specified for "
×
288
                                          f"the rate parameter ({rate}) of {self.name} must match the length "
289
                                          f"({np.array(self.defaults.variable).size}) of the default input "
290
                                          f"({self.defaults.variable}); the default input has been updated to match.")
291
                    else:
292
                        raise FunctionError(f"The length of the array specified for the rate parameter of {self.name}"
293
                                            f"({safe_len(rate)}) must match the length of the default input "
294
                                            f"({np.array(self.defaults.variable).size}).")
295

296
        super()._validate_params(request_set=request_set,
1✔
297
                                 target_set=target_set,
298
                                 context=context)
299

300
        if NOISE in target_set:
1!
301
            noise = target_set[NOISE]
1✔
302
            if isinstance(noise, DistributionFunction):
1✔
303
                noise.owner = self
1✔
304
                target_set[NOISE] = noise.execute
1✔
305
            self._validate_noise(target_set[NOISE])
1✔
306

307
    def _validate_initializers(self, default_variable, context=None):
1✔
308
        for initial_value_name in self.initializers:
1✔
309

310
            initial_value = self._get_current_parameter_value(initial_value_name, context=context)
1✔
311

312
            if isinstance(initial_value, (list, np.ndarray)):
1!
313
                if len(initial_value) != 1:
1✔
314
                    # np.atleast_2d may not be necessary here?
315
                    if np.shape(np.atleast_2d(initial_value)) != np.shape(np.atleast_2d(default_variable)):
1✔
316
                        raise FunctionError(f"{self.name}'s {initial_value_name} ({initial_value}) is incompatible "
317
                                            f"with its default_variable ({default_variable}).")
318
            elif not isinstance(initial_value, (float, int)):
×
319
                raise FunctionError(f"{self.name}'s {initial_value_name} ({initial_value}) "
320
                                    f"must be a number or a list/array of numbers.")
321

322
    def _validate_rate(self, rate):
1✔
323
        # FIX: CAN WE JUST GET RID OF THIS?
324
        # kmantel: this duplicates much code in _validate_params above, but that calls _instantiate_defaults
325
        # which I don't think is the right thing to do here, but if you don't call it in _validate_params
326
        # then a lot of things don't get instantiated properly
327
        if rate is not None:
1!
328
            if isinstance(rate, list):
1!
329
                rate = np.asarray(rate)
×
330

331
            rate_type_msg = 'The rate parameter of {0} must be a number or an array/list of at most 1d (you gave: {1})'
1✔
332
            if isinstance(rate, np.ndarray):
1!
333
                # kmantel: current test_gating test depends on 2d rate
334
                #   this should be looked at but for now this restriction is removed
335
                # if rate.ndim > 1:
336
                #     raise FunctionError(rate_type_msg.format(self.name, rate))
337
                pass
1✔
UNCOV
338
            elif not isinstance(rate, numbers.Number):
×
339
                raise FunctionError(rate_type_msg.format(self.name, rate))
340

341
            if isinstance(rate, np.ndarray) and not iscompatible(rate, self.defaults.variable):
1✔
342
                if safe_len(rate) != 1 and safe_len(rate) != np.array(self.defaults.variable).size:
1!
343
                    if self._variable_shape_flexibility is DefaultsFlexibility.FLEXIBLE:
×
344
                        self.defaults.variable = np.zeros_like(np.array(rate))
×
345
                        if self.verbosePref:
×
NEW
346
                            warnings.warn(f"The length ({safe_len(rate)}) of the array specified for the rate parameter "
×
347
                                          f"({rate}) of {self.name} must match the length "
348
                                          f"({np.array(self.defaults.variable).size}) of the default input "
349
                                          f"({self.defaults.variable}); the default input has been updated to match.")
350
                        self._instantiate_value()
×
351
                        self._variable_shape_flexibility = DefaultsFlexibility.INCREASE_DIMENSION
×
352
                    else:
353
                        raise FunctionError(f"The length of the array specified for the rate parameter of "
354
                                            f"{safe_len(rate)} ({self.name}) must match the length of the default input "
355
                                            f"({np.array(self.defaults.variable).size}).")
356

357
    # Ensure that the noise parameter makes sense with the input type and shape; flag any noise functions that will
358
    # need to be executed
359
    def _validate_noise(self, noise):
1✔
360
        # Noise must be a scalar, list, array or Distribution Function
361

362
        if isinstance(noise, DistributionFunction):
1!
363
            noise = noise.execute
×
364

365
        if isinstance(noise, (np.ndarray, list)):
1✔
366
            if safe_len(noise) == 1:
1✔
367
                pass
1✔
368
            # Variable is a list/array
369
            elif (not iscompatible(np.atleast_2d(noise), self.defaults.variable)
1✔
370
                  and not iscompatible(np.atleast_1d(noise), self.defaults.variable) and len(noise) > 1):
371
                raise FunctionError(f"Noise parameter ({noise})  for '{self.name}' does not match default variable "
372
                                    f"({self.defaults.variable}); it must be specified as a float, a function, "
373
                                    f"or an array of the appropriate shape "
374
                                    f"({np.shape(np.array(self.defaults.variable))}).",
375
                    component=self)
376
            else:
377
                for i in range(len(noise)):
1✔
378
                    if isinstance(noise[i], DistributionFunction):
1✔
379
                        noise[i] = noise[i].execute
1✔
380
                    if (not np.isscalar(noise[i]) and not callable(noise[i])
1✔
381
                            and not iscompatible(np.atleast_2d(noise[i]), self.defaults.variable[i])
382
                            and not iscompatible(np.atleast_1d(noise[i]), self.defaults.variable[i])):
383
                        raise FunctionError(f"The element '{noise[i]}' specified in 'noise' for {self.name} "
384
                                             f"is not valid; noise must be list or array must be floats or functions.")
385

386
    def _instantiate_attributes_before_function(self, function=None, context=None):
1✔
387
        if not self.parameters.initializer._user_specified:
1✔
388
            new_previous_value = copy.deepcopy(self.defaults.variable)
1✔
389
            fill_array(new_previous_value, 0)
1✔
390
            self._initialize_previous_value(new_previous_value, context)
1✔
391
        self._instantiate_stateful_attributes(self.stateful_attributes, self.initializers, context)
1✔
392
        super()._instantiate_attributes_before_function(function=function, context=context)
1✔
393

394
    def _instantiate_stateful_attributes(self, stateful_attributes:list, initializers:list, context) -> None:
1✔
395
        # use np.broadcast_to to guarantee that all initializer type attributes take on the same shape as variable
396
        if not np.isscalar(self.defaults.variable):
1!
397
            for attr in initializers:
1✔
398
                param = getattr(self.parameters, attr)
1✔
399
                param._set(
1✔
400
                    np.broadcast_to(
401
                        param._get(context),
402
                        self.defaults.variable.shape
403
                    ).copy(),
404
                    context
405
                )
406

407
        # create all stateful attributes and initialize their values to the current values of their
408
        # corresponding initializer attributes
409
        for attr_name in stateful_attributes:
1✔
410
            initializer_value = getattr(self.parameters, getattr(self.parameters, attr_name).initializer)._get(context).copy()
1✔
411
            getattr(self.parameters, attr_name)._set(initializer_value, context)
1✔
412

413
    def _initialize_previous_value(self, initializer, context=None):
1✔
414
        initializer = convert_to_np_array(initializer, dimension=1)
1✔
415

416
        self.defaults.initializer = initializer.copy()
1✔
417
        self.parameters.initializer._set(initializer.copy(), context)
1✔
418

419
        self.defaults.previous_value = initializer.copy()
1✔
420
        self.parameters.previous_value.set(initializer.copy(), context)
1✔
421

422
        return initializer
1✔
423

424
    @handle_external_context()
1✔
425
    def _update_default_variable(self, new_default_variable, context=None):
1✔
426
        if not self.parameters.initializer._user_specified:
1✔
427
            new_default_variable = convert_all_elements_to_np_array(new_default_variable)
1✔
428
            self._initialize_previous_value(np.zeros_like(new_default_variable), context)
1✔
429

430
        super()._update_default_variable(new_default_variable, context=context)
1✔
431

432
    def _parse_value_order(self, **kwargs):
1✔
433
        """
434
            Returns:
435
                tuple: the values of the keyword arguments in the order
436
                in which they appear in this Component's `value
437
                <Component.value>`
438
        """
439
        return tuple(v for k, v in kwargs.items())
×
440

441
    @handle_external_context(fallback_most_recent=True)
1✔
442
    def reset(self, *args, context=None, **kwargs):
1✔
443
        """
444
            Resets `value <StatefulFunction.previous_value>`  and `previous_value <StatefulFunction.previous_value>`
445
            to the specified value(s).
446

447
            If arguments are passed into the reset method, then reset sets each of the attributes in
448
            `stateful_attributes <StatefulFunction.stateful_attributes>` to the value of the corresponding argument.
449
            Next, it sets the `value <StatefulFunction.value>` to a list containing each of the argument values.
450

451
            If reset is called without arguments, then it sets each of the attributes in `stateful_attributes
452
            <StatefulFunction.stateful_attributes>` to the value of the corresponding attribute in `initializers
453
            <StatefulFunction.initializers>`. Next, it sets the `value <StatefulFunction.value>` to a list containing
454
            the values of each of the attributes in `initializers <StatefulFunction.initializers>`.
455

456
            Often, the only attribute in `stateful_attributes <StatefulFunction.stateful_attributes>` is
457
            `previous_value <StatefulFunction.previous_value>` and the only attribute in `initializers
458
            <StatefulFunction.initializers>` is `initializer <StatefulFunction.initializer>`, in which case
459
            the reset method sets `previous_value <StatefulFunction.previous_value>` and `value
460
            <StatefulFunction.value>` to either the value of the argument (if an argument was passed into
461
            reset) or the current value of `initializer <StatefulFunction.initializer>`.
462

463
            For specific types of StatefulFunction functions, the reset method may carry out other
464
            reinitialization steps.
465

466
        """
467
        num_stateful_attrs = len(self.stateful_attributes)
1✔
468
        if num_stateful_attrs >= 2:
1✔
469
            # old args specification can be supported only in subclasses
470
            # that explicitly define an order by overriding reset
471
            if len(args) > 0:
1✔
472
                raise FunctionError(f'{self}.reset has more than one stateful attribute'
473
                                    f' ({self.stateful_attributes}). You must specify reset values by keyword.')
474
            if len(kwargs) != num_stateful_attrs:
1!
475
                type_name = type(self).__name__
×
476
                raise FunctionError(f'StatefulFunction.reset must receive a keyword argument for'
477
                                    f' each item in {type_name}.stateful_attributes in the order in'
478
                                    f' which they appear in {type_name}.value.')
479

480
        if num_stateful_attrs == 1:
1✔
481
            try:
1✔
482
                kwargs[self.stateful_attributes[0]]
1✔
483
            except KeyError:
1✔
484
                try:
1✔
485
                    kwargs[self.stateful_attributes[0]] = args[0]
1✔
486
                except IndexError:
1✔
487
                    kwargs[self.stateful_attributes[0]] = None
1✔
488

489
        invalid_args = []
1✔
490

491
        # iterates in order arguments are sent in function call, so it
492
        # will match their order in value as long as they are listed
493
        # properly in subclass reset method signatures
494
        for attr in kwargs:
1✔
495
            try:
1✔
496
                kwargs[attr]
1✔
497
            except KeyError:
×
498
                kwargs[attr] = None
×
499

500
            if kwargs[attr] is not None:
1✔
501
                # from before: unsure if conversion to 1d necessary
502
                kwargs[attr] = np.atleast_1d(kwargs[attr])
1✔
503
            else:
504
                try:
1✔
505
                    initializer_ref = getattr(self.parameters, attr).initializer
1✔
506
                    if initializer_ref:
1!
507
                        initializer = getattr(self.parameters, initializer_ref)
1✔
508
                    # FIX: ?NEED TO HANDLE initializer IF IT IS A NUMBER?
509
                    if initializer is not None and initializer.port and initializer.port.mod_afferents:
1✔
510
                        # If the initializer is subject to control, get its control_allocation
511
                        initializer_mod_proj = initializer.port.mod_afferents[0]
1✔
512
                        mod_parameter_source = initializer_mod_proj.sender.owner
1✔
513
                        from psyneulink.core.compositions.composition import CompositionInterfaceMechanism
1✔
514
                        from psyneulink.core.components.mechanisms.modulatory.control.controlmechanism \
1✔
515
                            import ControlMechanism
516
                        if isinstance(mod_parameter_source, CompositionInterfaceMechanism):
1✔
517
                            ctl_sig,_,_  = mod_parameter_source._get_source_of_modulation_for_parameter_CIM(
1✔
518
                                initializer_mod_proj.sender)
519
                        elif isinstance(mod_parameter_source, ControlMechanism):
1✔
520
                            ctl_sig = mod_parameter_source.control_signals[0]
1✔
521
                        else:
522
                            assert False, f"Cannot reset {self.name} because " \
523
                                          f"the source of modulation is not of correct type."
524
                        kwargs[attr] = ctl_sig.parameters.value.get(context)
1✔
525
                    else:
526
                        # Otherwise, just use the default (or user-assigned) initializer
527
                        kwargs[attr] = self._get_current_parameter_value(initializer, context=context)
1✔
528

529
                except AttributeError:
×
530
                    invalid_args.append(attr)
×
531

532
        if len(invalid_args) > 0:
1✔
533
            raise FunctionError(f'Arguments {invalid_args} to reset are invalid because they do'
534
                                f" not correspond to any of {self}'s stateful_attributes.")
535

536
        # rebuilding value rather than simply returning reinitialization_values in case any of the stateful
537
        # attrs are modified during assignment
538
        value = []
1✔
539
        for attr, v in kwargs.items():
1✔
540
            # FIXME: HACK: Do not reinitialize random_state
541
            if attr != "random_state":
1!
542
                getattr(self.parameters, attr).set(kwargs[attr],
1✔
543
                                                   context, override=True)
544
                value.append(getattr(self.parameters, attr)._get(context))
1✔
545

546
        self.parameters.value.set(value, context, override=True)
1✔
547
        return value
1✔
548

549
    def _gen_llvm_function_reset(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
1✔
550
        assert "reset" in tags
1✔
551
        for a in self.stateful_attributes:
1✔
552
            initializer = getattr(self.parameters, a).initializer
1✔
553
            source_ptr = ctx.get_param_or_state_ptr(builder, self, initializer, param_struct_ptr=params)
1✔
554
            dest_ptr = ctx.get_param_or_state_ptr(builder, self, a, state_struct_ptr=state)
1✔
555
            builder.store(builder.load(source_ptr), dest_ptr)
1✔
556

557
        return builder
1✔
558

559
    @abc.abstractmethod
1✔
560
    def _function(self, *args, **kwargs):
1✔
561
        raise FunctionError("StatefulFunction is not meant to be called explicitly")
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc