• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.57
/psyneulink/library/components/mechanisms/modulatory/learning/EMstoragemechanism.py
1
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
2
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
3
#     http://www.apache.org/licenses/LICENSE-2.0
4
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
5
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6
# See the License for the specific language governing permissions and limitations under the License.
7

8

9
# *************************************  EMStorageMechanism **********************************************
10

11
"""
12

13
Contents
14
--------
15

16
  * `EMStorageMechanism_Overview`
17
    - `EMStorageMechanism_Memory`
18
    - `EMStorageMechanism_Entry`
19
    - `EMStorageMechanism_Fields`
20
  * `EMStorageMechanism_Creation`
21
  * `EMStorageMechanism_Structure`
22
  * `EMStorageMechanism_Execution`
23
  * `EMStorageMechanism_Class_Reference`
24

25

26
.. _EMStorageMechanism_Overview:
27

28
Overview
29
--------
30

31
An EMStorageMechanism is a subclass of `LearningMechanism`, modified for use in an `EMComposition` to store a new
32
entry in its `memory <EMComposition.memory>` attribute each time it executes.
33

34
.. _EMStorageMechanism_Memory:
35

36
# FIX: NEEDS EDITING:
37

38
* **Memory** -- the `memory <EMComposition.memory>` attribute of an `EMComposition` is a list of entries, each of
39
    which is a 2d np.array with a shape that corresponds to the `memory_matrix <EMStorageMechanism.memory_matrix>`
40
    attribute of the EMStorageMechanism that stores it.  Each entry is stored in the `memory <EMComposition.memory>`
41
    attribute of the EMComposition as a row or column of the `matrix <MappingProjection.matrix>` parameter of the
42
    `MappingProjections <MappingProjection>` to which the `LearningProjections <LearningProjection>` of the
43
    EMStorageMechanism project.  The `memory <EMComposition.memory>` attribute of the EMComposition is used by its
44
    `controller <EMComposition.controller>` to generate the `memory <EMMemoryMechanism.memory>` attribute of an
45
    `EMMemoryMechanism` that is used to retrieve entries from the `memory <EMComposition.memory>` attribute of the
46
    EMComposition.
47

48
.. _EMStorageMechanism_Entry:
49

50
* **Entry** -- an entry is a 2d np.array with a shape that corresponds to the `memory_matrix
51
    <EMStorageMechanism.memory_matrix>` attribute of the EMStorageMechanism that stores it.  Each entry is stored in the
52
    `memory <EMComposition.memory>` attribute of the EMComposition as a row or column of the `matrix
53
    <MappingProjection.matrix>` parameter of the `MappingProjections <MappingProjection>` to which the
54
    `LearningProjections <LearningProjection>` of the EMStorageMechanism project.  The `memory
55
    <EMComposition.memory>` attribute of the EMComposition is used by its `controller <EMComposition.controller>` to
56
    generate the `memory <EMMemoryMechanism.memory>` attribute of an `EMMemoryMechanism` that is used to retrieve
57
    entries from the `memory <EMComposition.memory>` attribute of the EMComposition.
58

59
.. _EMStorageMechanism_Fields:
60

61
* **Fields** -- an entry is composed of one or more fields, each of which is a 1d np.array with a length that
62
    corresponds to the number of `fields <EMStorageMechanism_Fields>` of the EMStorageMechanism that stores it.  Each
63
    field is stored in the `memory <EMComposition.memory>` attribute of the EMComposition as a row or column of the
64
    `matrix <MappingProjection.matrix>` parameter of the `MappingProjections <MappingProjection>` to which the
65
    `LearningProjections <LearningProjection>` of the EMStorageMechanism project.  The `memory
66
    <EMComposition.memory>` attribute of the EMComposition is used by its `controller <EMComposition.controller>` to
67
    generate the `memory <EMMemoryMechanism.memory>` attribute of an `EMMemoryMechanism` that is used to retrieve
68
    entries from the `memory <EMComposition.memory>` attribute of the EMComposition.
69

70
.. _EMStorageMechanism_Creation:
71

72
Creating an EMStorageMechanism
73
--------------------------------------------
74

75
An EMStorageMechanism can be created directly by calling its constructor, but most commonly it is created
76
automatically when an `EMComposition` is created, as its `learning_mechanism <EMComposition.learning_mechanism>`
77
used to store entries in its `memory <EMComposition.memory>` of the EMComposition. The `memory_matrix` must be
78
specified (as a template for the shape of the entries to be stored, and of the `matrix <MappingProjection.matrix>`
79
parameters to which they are assigned. It must also have at least one, and usually several `fields
80
<EMStorageMechanism.fields>` specifications that identify the `OutputPort`\\s of the `ProcessingMechanism`\\s from
81
which it receives its `fields <EMStorageMechanism_Fields>`, and a `field_types <EMStorageMechanism.field_types>`
82
specification that indicates whether each `field is a key or a value field <EMStorageMechanism_Fields>`.
83

84
.. _EMStorageMechanism_Structure:
85

86
Structure
87
---------
88

89
An EMStorageMechanism differs from a standard `LearningMechanism` in the following ways:
90

91
  * it has no `input_source <LearningMechanism.input_source>`, `output_source <LearningMechanism.output_source>`,
92
    or `error_source <LearningMechanism.error_source>` attributes;  instead, it has the `fields
93
    <EMStorageMechanism.fields>` and `field_types <EMStorageMechanism.field_types>` attributes described below.
94

95
  * its `fields <EMStorageMechanism.fields>` attribute has as many *FIELDS* `field <EMStorage_mechanism.fields>`
96
    as there are `fields <EMStorageMechanism_Fields>` of an entry in its `memory_matrix
97
    <EMStorageMechanism.memory_matrix>` attribute;  these are listed in its `fields <EMStorageMechanism.fields>`
98
    attribute and serve as the `InputPort`\\s for the EMStorageMechanism;  each receives a `MappingProjection` from
99
    the `OutputPort` of a `ProcessingMechanism`, the activity of which constitutes the corresponding `field
100
    <EMStorageMechanism_Fields>` of the `entry <EMStorageMechanism_Entry>` to be stored in its `memory_matrix
101
    <EMStorageMechanism.memory_matrix>` attribute.
102

103
  * it has a `field_types <EMStorageMechanism.field_types>` attribute that specifies whether each `field
104
    <EMStorageMechanism_Fields>` is a `key or a value field <EMStorageMechanism_Fields>`.
105

106
  * it has a `field_weights <EMStorageMechanism.field_weights>` attribute that specifies whether each `field
107
    <EMStorageMechanism_Fields>` each norms for each field are weighted before deteterming the weakest `entry
108
    <EMStorageMechanism_Entry>` in `memory_matrix <EMStorageMechanism.memory_matrix>`.
109

110
  * it has a `memory_matrix <EMStorageMechanism.memory_matrix>` attribute that represents the full memory that the
111
    EMStorageMechanism is used to update.
112

113
  * it has a `concatenation_node <EMStorageMechanism.concatenation_node>` attribute used to access the concatenated
114
    inputs to the `key <EMStorageMechanism.key>` fields of the `entry <EMStorageMechanism_Entry>` to be stored in its
115
    `memory_matrix <EMStorageMechanism.memory_matrix>` attribute.
116

117
  * it has a several *LEARNING_SIGNAL* `OutputPorts <OutputPort>` that each send a `LearningProjection` to the `matrix
118
    <MappingProjection.matrix>` parameter of a 'MappingProjection` that constitutes a `field <EMStorageMechanism_Fields>`
119
    of the `memory_matrix <EMStorageMechanism.memory_matrix>` attribute.
120

121
  * its `function <EMStorageMechanism.function>` is an `EMStorage` `LearningFunction`, that takes as its `variable
122
    <Function_Base.variable>` a list or 1d np.array with a length of the corresponding  *ACTIVATION_INPUT* InputPort;
123
    and it returns a `learning_signal <LearningMechanism.learning_signal>` (a weight matrix assigned to one of the
124
    Mechanism's *LEARNING_SIGNAL* OutputPorts), but no `error_signal <LearningMechanism.error_signal>`.
125

126
  * the default form of `modulation <ModulatorySignal_Modulation>` for its `learning_signals
127
    <LearningMechanism.learning_signals>` is *OVERRIDE*, so that the `matrix <MappingProjection.matrix>` parameter of
128
    the `MappingProjection` to which the `LearningProjection` projects is replaced by the `value
129
    <LearningProjection.value>` of the `learning_signal <LearningMechanism.learning_signal>`.
130

131
  * its `decay_rate <EMStorageMechanism.decay_rate>`, a float in the interval [0,1] that is used to decay
132
    `memory_matrix <EMStorageMechanism.memory_matrix>` before an `entry <EMStorageMechanism_Entry>` is stored.
133

134
  * its `storage_prob <EMStorageMechanism.storage_prob>`, a float in the interval [0,1] is used in place of a
135
    LearningMechanism's `storage_prob <LearningMechanism.storage_prob>` to determine the probability that the
136
    Mechanism will store its `variable <EMStorageMechanism.variable>` in its `memory_matrix
137
    <EMStorageMechanism.memory_matrix>` attribute each time it executes.
138

139
.. _EMStorageMechanism_Execution:
140

141
Execution
142
---------
143

144
An EMStorageMechanism executes after all of the other Mechanisms in the `EMComposition` to which it belongs have
145
executed.  It executes in the same manner as standard `LearningMechanism`, however instead of modulating
146
the `matrix <MappingProjection.matrix>` Parameter of a `MappingProjection`, it replaces a row or column in each of
147
the `matrix <MappingProjection.matrix>` Parameters of the `MappingProjections <MappingProjection>` to which its
148
`LearningProjections <LearningProjection>` project with an item of its `variable <EMStorageMechanism.variable>` that
149
represents the corresponding `field <EMStorageMechanism.fields>`. The entry replaced is the one that has the lowest
150
norm computed across all `fields <EMSorageMechanism_Fields>` of the `entry <EMStorageMechanism_Entry>` weighted by the
151
corresponding items of `field_weights <EMStorageMechanism.field_weights>` if that is specified.
152

153

154
.. _EMStorageMechanism_Class_Reference:
155

156
Class Reference
157
---------------
158

159
"""
160

161
import numpy as np
1✔
162
import re
1✔
163
from beartype import beartype
1✔
164

165
from psyneulink._typing import Optional, Union, Callable, Literal
1✔
166

167
from psyneulink.core.components.component import parameter_keywords
1✔
168
from psyneulink.core.components.functions.nonstateful.learningfunctions import EMStorage
1✔
169
from psyneulink.core.components.mechanisms.mechanism import Mechanism
1✔
170
from psyneulink.core.components.mechanisms.modulatory.learning.learningmechanism import \
1✔
171
    LearningMechanism, LearningMechanismError, LearningTiming, LearningType
172
from psyneulink.core.components.projections.projection import Projection, projection_keywords
1✔
173
from psyneulink.core.components.ports.parameterport import ParameterPort
1✔
174
from psyneulink.core.components.ports.outputport import OutputPort
1✔
175
from psyneulink.core.globals.context import ContextFlags
1✔
176
from psyneulink.core.globals.keywords import \
1✔
177
    (ADDITIVE, EM_STORAGE_MECHANISM, LEARNING, LEARNING_PROJECTION, LEARNING_SIGNALS, MULTIPLICATIVE,
178
     MULTIPLICATIVE_PARAM, MODULATION, NAME, OVERRIDE, OWNER_VALUE, PROJECTIONS, REFERENCE_VALUE, VARIABLE)
179
from psyneulink.core.globals.parameters import Parameter, ParameterNoValueError, check_user_specified, FunctionParameter, copy_parameter_value
1✔
180
from psyneulink.core.globals.preferences.basepreferenceset import ValidPrefSet
1✔
181
from psyneulink.core.globals.preferences.preferenceset import PreferenceLevel
1✔
182
from psyneulink.core.globals.utilities import convert_all_elements_to_np_array, is_numeric, all_within_range
1✔
183

184
__all__ = [
1✔
185
    'EMStorageMechanism', 'EMStorageMechanismError',
186
]
187

188
# Parameters:
189

190
parameter_keywords.update({LEARNING_PROJECTION, LEARNING})
1✔
191
projection_keywords.update({LEARNING_PROJECTION, LEARNING})
1✔
192

193
MEMORY_MATRIX = 'memory_matrix'
1✔
194
FIELDS = 'fields'
1✔
195
FIELD_TYPES = 'field_types'
1✔
196

197
def _memory_matrix_getter(owning_component=None, context=None)->list:
1✔
198
    """Return list of memories in which rows (outer dimension) are memories for each field.
199
    These are derived from `matrix <MappingProjection.matrix>` parameter of the `afferent
200
    <Mechanism_Base.afferents>` MappingProjections to each of the `retrieved_nodes <EMComposition.retrieved_nodes>`.
201
    """
202
    if owning_component.is_initializing:
1✔
203
        try:
1✔
204
            if owning_component.learning_signals is None or owning_component.input_ports is None:
1!
NEW
205
                return None
×
NEW
206
        except ParameterNoValueError:
×
UNCOV
207
            return None
×
208

209
    num_fields = len(owning_component.input_ports)
1✔
210

211
    # Get learning_signals that project to retrieved_nodes
212
    num_learning_signals = len(owning_component.learning_signals)
1✔
213
    learning_signals_for_retrieved = owning_component.learning_signals[num_learning_signals - num_fields:]
1✔
214

215
    # Get memory from learning_signals that project to retrieved_nodes
216
    if owning_component.is_initializing:
1✔
217
        # If initializing, learning_signals are still MappingProjections used to specify them, so get from them
218
        memory = [retrieved_learning_signal.parameters.matrix._get(context)
1✔
219
                  for retrieved_learning_signal in learning_signals_for_retrieved]
220
    else:
221
        # Otherwise, get directly from the learning_signals
222
        memory = [retrieved_learning_signal.efferents[0].receiver.owner.parameters.matrix._get(context)
1✔
223
                  for retrieved_learning_signal in learning_signals_for_retrieved]
224

225
    # Get memory capacity from first length of first matrix (can use full set since might be ragged array)
226
    memory_capacity = len(memory[0])
1✔
227

228
    # Reorganize memory so that each row is an entry and each column is a field
229
    return convert_all_elements_to_np_array([
1✔
230
        [memory[j][i] for j in range(num_fields)]
231
        for i in range(memory_capacity)
232
    ])
233

234

235
class EMStorageMechanismError(LearningMechanismError):
1✔
236
    pass
1✔
237

238

239
class EMStorageMechanism(LearningMechanism):
1✔
240
    """
241
    EMStorageMechanism(                       \
242
        variable,                             \
243
        fields,                               \
244
        field_types,                          \
245
        memory_matrix,                        \
246
        function=EMStorage,                   \
247
        storage_prob=1.0,                     \
248
        decay_rate=0.0,                       \
249
        learning_signals,                     \
250
        modulation=OVERRIDE,                  \
251
        params=None,                          \
252
        name=None,                            \
253
        prefs=None)
254

255
    Implements a `LearningMechanism` that modifies the `matrix <MappingProjection.matrix>` parameters of
256
    `MappingProjections <MappingProjection>` that implement its `memory_matrix <EMStorageMechanism.memory_matrix>`.
257

258
    Arguments
259
    ---------
260

261
    variable : List or 2d np.array : default None
262
        each item of the 2d array specifies the shape of the corresponding `field <EMStorageMechanism_Fields>` of
263
        an `entry <EMStorageMechanism_Entry>`, that must be compatible (in number and type) with the `value
264
        <InputPort.value>` of the corresponding item of its `fields <EMStorageMechanism.fields>`
265
        attribute (see `variable <EMStorageMechanism.variable>` for additional details).
266

267
    fields : List[OutputPort, Mechanism, Projection, tuple[str, Mechanism, Projection] or dict] : default None
268
        specifies the `OutputPort`\\(s), the `value <OutputPort.value>`\\s of which are used as the
269
        corresponding `fields <EMStorageMechanism_Fields>` of the `memory_matrix <EMStorageMechanism.memory_matrix>`;
270
        used to construct the Mechanism's `InputPorts <InputPort>`; must be the same length as `variable
271
        <EMStorageMechanism.variable>`.
272

273
    field_types : List[int] : default None
274
        specifies whether each item of `variable <EMStorageMechanism.variable>` corresponds to a `key or value field
275
        <EMStorageMechanism_Fields>` (see `field_types <EMStorageMechanism.field_types>` for additional details);
276
        must contain only 1's (for keys) and 0's (for values), with the same number of these as there are items in
277
        the `variable <EMStorageMechanism.variable>` and `fields <EMStorageMechanism.fields>` arguments.
278

279
    field_weights : List[float] : default None
280
        specifies whether norms for each field are weighted before determining the weakest `entry
281
        <EMStorageMechanism_Entry>` in `memory_matrix <EMStorageMechanism.memory_matrix>`. If None (the default),
282
        the norm of each `entry <EMStorageMechanism_Entry>` is calculated across all fields at once; if specified,
283
        it must contain only floats from 0 to 1, and be the same length as the `fields <EMStorageMechanism.fields>`
284
        argument (see `field_weights <EMStorageMechanism.field_types>` for additional details).
285

286
    concatenation_node : OutputPort or Mechanism : default None
287
        specifies the `OutputPort` or `Mechanism` in which the `value <OutputPort.value>` of the `key fields
288
        <EMStorageMechanism_Fields>` are concatenated (see `concatenate keys <EMComposition_Concatenate_Queries>`
289
        for additional details).
290

291
    memory_matrix : List or 2d np.array : default None
292
        specifies the shape of the `memory <EMStorageMechanism_Memory>` used to store an `entry
293
        <EMStorageMechanism_Entry>` (see `memory_matrix <EMStorageMechanism.memory_matrix>` for additional details).
294

295
    function : LearningFunction or function : default EMStorage
296
        specifies the function used to assign each item of the `variable <EMStorageMechanism.variable>` to the
297
        corresponding `field <EMStorageMechanism_Fields>` of the `memory_matrix <EMStorageMechanism.memory_matrix>`.
298
        It must take as its `variable <EMStorage.variable>` argument a list or 1d array of numeric values
299
        (the "activity vector"); a ``memory_matrix`` argument that is a 2d array to which
300
        the `variable <EMStorageMechanism.variable>` is assigned; ``axis`` and ``storage_location`` arguments that
301
        determine where in ``memory_matrix`` the `variable <EMStorageMechanism.variable>` is stored; and optional
302
        ``storage_prob`` and ``decay_rate`` arguments that determine the probability with which storage occurs and
303
        the rate at which the `memory_matrix <EMStorageMechanism.memory_matrix>` decays, respectively.  The function
304
        must return a list, 2d np.array for the corresponding `field <EMStorageMechanism_Fields>` of the
305
        `memory_matrix <EMStorageMechanism.memory_matrix>` that is updated (see `EMStorage` for additional details).
306

307
    learning_signals : List[ParameterPort, Projection, tuple[str, Projection] or dict] : default None
308
        specifies the `ParameterPort`\\(s) for the `matrix <MappingProjection.matrix>` parameter of the
309
        `MappingProjection>`\\s that implement the `memory <EMStorageMechanism_Memory>` in which the `entry
310
        <EMStorageMechanism_Entry>` is stored; there must the same number of these as `fields
311
        <EMStorageMechanism.fields>`, and they must be specified in the sqme order.
312

313
    modulation : str : default OVERRIDE
314
        specifies form of `modulation <ModulatorySignal_Modulation>` that `learning_signals
315
        <EMStorageMechanism.learning_signals>` use to modify the `matrix <MappingProjection.matrix>` parameter of the
316
        `MappingProjections <MappingProjection>` that implement the `memory <EMStorageMechanism_Memory>` in which
317
        `entries <EMStorageMechanism_Entry>` is stored (see `modulation <EMStorageMechanism_Modulation>` for additional
318
        details).
319

320
    storage_prob : float : default None
321
        specifies the probability with which the current entry is stored in the EMSorageMechanism's `memory_matrix
322
        <EMStorageMechanism.memory_matrix>` (see `storage_prob <EMStorageMechanism.storage_prob>` for details).
323

324
    decay_rate : float : default 0.0
325
        specifies the rate at which `entries <EMStorageMechanism_Entry>` in the `memory_matrix
326
        <EMStorageMechanism.memory_matrix>` decay (see `decay_rate <EMStorageMechanism.decay_rate>` for additional
327
        details).
328

329
    Attributes
330
    ----------
331

332
    # FIX: FINISH EDITING:
333

334
    variable : 2d np.array
335

336
        each item of the 2d array is used as a template for the shape of each the `fields
337
        <EMStorageMechanism_Fields>` that  comprise and `entry <EMStorageMechanism_Entry>` in the `memory_matrix
338
        <EMStorageMechanism.memory_matrix>`, and that must be compatible (in number and type) with the `value
339
        <OutputPort.value>` of the item specified the corresponding itme of its `fields <EMStorageMechanism.fields>`
340
        attribute. The values of the `variable <EMStorageMechanism.variable>` are assigned to the `memory_matrix
341
        <EMStorageMechanism.memory_matrix>` by the `function <EMStorageMechanism.function>`.
342

343
    fields : List[OutputPort, Mechanism, Projection, tuple[str, Mechanism, Projection] or dict] : default None
344
        the `OutputPort`\\(s) used to get the value for each `field <EMStorageMechanism_Fields>` of
345
        an `entry <EMStorageMechanism_Entry>` of the `memory_matrix <EMStorageMechanism.memory_matrix>` attribute.
346

347
    field_types : List[int or tuple[slice]]
348
        contains a list of indicators of whether each item of `variable <EMStorageMechanism.variable>`
349
        and the corresponding `fields <EMStorageMechanism.fields>` are key (1) or value (0) fields.
350
        (see `fields <EMStorageMechanism_Fields>` for additional details).
351

352
    field_weights : List[float] or None
353
        determines whether norms for each field are weighted before identifying the weakest `entry
354
        <EMStorageMechanism_Entry>` in `memory_matrix <EMStorageMechanism.memory_matrix>`. If is None (the default),
355
        the norm of each `entry <EMStorageMechanism_Entry>` is calculated across all fields at once; if specified,
356
        it must contain only floats from 0 to 1, and be the same length as the `fields <EMStorageMechanism.fields>`
357
        argument (see `field_weights <EMStorageMechanism.field_types>` for additional details).
358

359
    learned_projections : List[MappingProjection]
360
        list of the `MappingProjections <MappingProjection>`, the `matrix <MappingProjection.matrix>` Parameters of
361
        which are modified by the EMStorageMechanism.
362

363
    function : LearningFunction or function : default EMStorage
364
        the function used to assign the value of each `field <EMStorageMechanism.fields>` to the corresponding entry
365
        in `memory_matrix <EMStorageMechanism.memory_matrix>`.  It must take as its `variable <EMSorage.variable>`
366
        argument a list or 1d array of numeric values (an `entry
367
        <EMStorage.entry`) and return a list, 2d np.array assigned to
368
        the corresponding `field <EMStorageMechanism_Fields>` of the
369
        `memory_matrix <EMStorageMechanism.memory_matrix>`.
370

371
    storage_prob : float
372
        specifies the probability with which the current entry is stored in the EMSorageMechanism's `memory_matrix
373
        <EMStorageMechanism.memory_matrix>`.
374

375
    decay_rate : float : default 0.0
376
        determines the rate at which `entries <EMStorageMechanism_Entry>` in the `memory_matrix
377
        <EMStorageMechanism.memory_matrix>` decay;  the decay rate is applied to `memory_matrix
378
        <EMStorageMechanism.memory_matrix>` before it is updated with the new `entry <EMStorageMechanism_Entry>`.
379

380
    learning_signals : List[LearningSignal]
381
        list of all of the `LearningSignals <LearningSignal>` for the EMStorageMechanism, each of which
382
        sends a `LearningProjection` to the `ParameterPort`\\(s) for the `MappingProjections
383
        <MappingProjection>` that implement the `memory <EMStorageMechanism_Memory>` in which the `entry
384
        <EMStorageMechanism_Entry>` is stored.  The `value <LearningSignal.value>` of each LearningSignal is
385
        used by its `LearningProjection` to modify the `matrix <MappingProjection.matrix>` parameter of the
386
        MappingProjection to which that projects.
387

388
    learning_projections : List[LearningProjection]
389
        list of all of the LearningProjections <LearningProjection>` from the EMStorageMechanism, listed
390
        in the order of the `LearningSignals <LearningSignal>` to which they belong (that is, in the order they are
391
        listed in the `learning_signals <EMStorageMechanism.learning_signals>` attribute).
392

393
    modulation : str
394
        determines form of `modulation <ModulatorySignal_Modulation>` that `learning_signals
395
        <EMStorageMechanism.learning_signals>` use to modify the `matrix <MappingProjection.matrix>` parameter of the
396
        `MappingProjections <MappingProjection>` that implement the `memory <EMStorageMechanism_Memory>` in which
397
        `entries <EMStorageMechanism_Entry>` is stored.  *OVERRIDE* (the default) insure that entries are stored
398
        exactly as specified by the `value <OutputPort.value>` of the `fields <EMStorageMechanism.fields>` of the
399
        `entry <EMStorageMechanism_Entry>`;  other values can have unpredictable consequences
400
        (see `ModulatorySignal_Types for additional details)
401

402
    output_ports : ContentAddressableList[OutputPort]
403
        list of the EMStorageMechanism's `OutputPorts <OutputPort>`, beginning with its
404
        `learning_signals <EMStorageMechanism.learning_signals>`, and followed by any additional
405
        (user-specified) `OutputPorts <OutputPort>`.
406

407
    output_values : 2d np.array
408
        the first items are the `value <OutputPort.value>`\\(s) of the LearningMechanism's `learning_signal
409
        <EMStorageMechanism.learning_signal>`\\(s), followed by the `value <OutputPort.value>`(s)
410
        of any additional (user-specified) OutputPorts.
411

412
    """
413

414
    componentType = EM_STORAGE_MECHANISM
1✔
415
    className = componentType
1✔
416
    suffix = " " + className
1✔
417

418
    class Parameters(LearningMechanism.Parameters):
1✔
419
        """
420
            Attributes
421
            ----------
422

423
                concatenation_node
424
                    see `concatenation_node <EMStorageMechanism.concatenation_node>`
425

426
                    :default value: None
427
                    :type: ``Mechanism or OutputPort``
428
                    :read only: True
429

430
                decay_rate
431
                    see `decay_rate <EMStorageMechanism.decay_rate>`
432

433
                    :default value: 0.0
434
                    :type: ``float``
435

436
                fields
437
                    see `fields <EMStorageMechanism.fields>`
438

439
                    :default value: None
440
                    :type: ``list``
441
                    :read only: True
442

443
                field_types
444
                    see `field_types <EMStorageMechanism.field_types>`
445

446
                    :default value: None
447
                    :type: ``list``
448
                    :read only: True
449

450
                field_weights
451
                    see `field_weights <EMStorageMechanism.field_weights>`
452

453
                    :default value: None
454
                    :type: ``list or np.ndarray``
455

456
                memory_matrix
457
                    see `memory_matrix <EMStorageMechanism.memory_matrix>`
458

459
                    :default value: None
460
                    :type: ``np.ndarray``
461
                    :read only: True
462

463
                function
464
                    see `function <EMStorageMechanism.function>`
465

466
                    :default value: `EMStorage`
467
                    :type: `Function`
468

469
                input_ports
470
                    see `fields <EMStorageMechanism.fields>`
471

472
                    :default value: None
473
                    :type: ``list``
474
                    :read only: True
475

476
                learning_signals
477
                    see `learning_signals <EMStorageMechanism.learning_signals>`
478

479
                    :default value: []
480
                    :type: ``List[MappingProjection or ParameterPort]``
481
                    :read only: True
482

483
                modulation
484
                    see `modulation <EMStorageMechanism.modulation>`
485

486
                    :default value: OVERRIDE
487
                    :type: ModulationParam
488
                    :read only: True
489

490
                output_ports
491
                    see `learning_signals <EMStorageMechanism.learning_signals>`
492

493
                    :default value: None
494
                    :type: ``list``
495
                    :read only: True
496

497
                storage_prob
498
                    see `storage_prob <EMStorageMechanism.storage_prob>`
499

500
                    :default value: 1.0
501
                    :type: ``float``
502

503
        """
504
        input_ports = Parameter([], # FIX: SHOULD BE ABLE TO UE THIS WITH 'fields' AS CONSTRUCTOR ARGUMENT
1✔
505
                                stateful=False,
506
                                loggable=False,
507
                                read_only=True,
508
                                structural=True,
509
                                parse_spec=True,
510
                                constructor_argument='fields',
511
                                )
512
        fields = Parameter(
1✔
513
            [], stateful=False, loggable=False, read_only=True, structural=True
514
        )
515
        field_types = Parameter([],stateful=False,
1✔
516
                                loggable=False,
517
                                read_only=True,
518
                                structural=True,
519
                                parse_spec=True,
520
                                dependiencies='fields')
521
        field_weights = Parameter(None,
1✔
522
                                  modulable=True,
523
                                  stateful=True,
524
                                  loggable=True,
525
                                  dependiencies='fields')
526
        concatenation_node = Parameter(None,
1✔
527
                                       stateful=False,
528
                                       loggable=False,
529
                                       read_only=True,
530
                                       structural=True)
531
        function = Parameter(EMStorage, stateful=False, loggable=False)
1✔
532
        # storage_prob = Parameter(1.0, modulable=True, stateful=True)
533
        storage_prob = FunctionParameter(1.0,
1✔
534
                                         function_name='function',
535
                                         function_parameter_name='storage_prob',
536
                                         primary=True,
537
                                         modulable=True,
538
                                         aliases=[MULTIPLICATIVE_PARAM],
539
                                         stateful=True)
540
        decay_rate = Parameter(0.0, modulable=True, stateful=True)
1✔
541
        memory_matrix = Parameter(None, getter=_memory_matrix_getter, read_only=True, structural=True)
1✔
542
        modulation = OVERRIDE
1✔
543
        output_ports = Parameter([],
1✔
544
                                 stateful=False,
545
                                 loggable=False,
546
                                 read_only=True,
547
                                 structural=True,
548
                                 # constructor_argument='learning_signals'
549
                                 )
550
        learning_signals = Parameter([],
1✔
551
                                     stateful=False,
552
                                     loggable=False,
553
                                     read_only=True,
554
                                     structural=True)
555
        learning_type = LearningType.UNSUPERVISED
1✔
556
        # learning_type = LearningType.SUPERVISED
557
        # learning_timing = LearningTiming.LEARNING_PHASE
558
        learning_timing = LearningTiming.EXECUTION_PHASE
1✔
559

560
    def _validate_field_types(self, field_types):
1✔
561
        if not len(field_types) or len(field_types) != len(self.input_ports):
×
562
            return f"must be specified with a number of items equal to " \
×
563
                   f"the number of fields specified {len(self.input_ports)}"
564
        if not all(item in {1,0} for item in field_types):
×
565
            return f"must be a list of 1s (for keys) and 0s (for values)."
×
566

567
    def _validate_field_weights(self, field_weights):
1✔
568
        if not field_weights or len(field_weights) != len(self.input_ports):
×
569
            return f"must be specified with a number of items equal to " \
×
570
                   f"the number of fields specified {len(self.input_ports)}"
571
        if not all(isinstance(item, (int, float)) and (0 <= item  <= 1) for item in field_weights):
×
572
            return f"must be a list floats from 0 to 1."
×
573

574
    def _validate_storage_prob(self, storage_prob):
1✔
575
        storage_prob = float(storage_prob)
×
576
        if not all_within_range(storage_prob, 0, 1):
×
577
            return f"must be a float in the interval [0,1]."
×
578

579
    def _validate_decay_rate(self, decay_rate):
1✔
580
        decay_rate = float(decay_rate)
×
581
        if not all_within_range(decay_rate, 0, 1):
×
582
            return f"must be a float in the interval [0,1]."
×
583

584

585
    classPreferenceLevel = PreferenceLevel.TYPE
1✔
586

587
    @check_user_specified
1✔
588
    @beartype
1✔
589
    def __init__(self,
1✔
590
                 default_variable: Union[list, np.ndarray],
591
                 fields: Union[list, tuple, dict, OutputPort, Mechanism, Projection] = None,
592
                 field_types: list = None,
593
                 field_weights: Optional[Union[list, np.ndarray]] = None,
594
                 concatenation_node: Optional[Union[OutputPort, Mechanism]] = None,
595
                 memory_matrix: Union[list, np.ndarray] = None,
596
                 function: Optional[Callable] = EMStorage,
597
                 learning_signals: Union[list, dict, ParameterPort, Projection, tuple] = None,
598
                 modulation: Optional[Literal[OVERRIDE, ADDITIVE, MULTIPLICATIVE]] = OVERRIDE,
599
                 decay_rate: Optional[Union[int, float, np.ndarray]] = 0.0,
600
                 storage_prob: Optional[Union[int, float, np.ndarray]] = 1.0,
601
                 params=None,
602
                 name=None,
603
                 prefs: Optional[ValidPrefSet] = None,
604
                 **kwargs
605
                 ):
606

607
        super().__init__(default_variable=default_variable,
1✔
608
                         fields=fields,
609
                         field_types=field_types,
610
                         concatenation_node=concatenation_node,
611
                         memory_matrix=memory_matrix,
612
                         function=function,
613
                         learning_signals=learning_signals,
614
                         modulation=modulation,
615
                         decay_rate=decay_rate,
616
                         storage_prob=storage_prob,
617
                         field_weights=field_weights,
618
                         params=params,
619
                         name=name,
620
                         prefs=prefs,
621
                         **kwargs)
622

623
    def _validate_variable(self, variable, context=None):
1✔
624
        """Validate that variable has only one item: activation_input.
625
        """
626

627
        # Skip LearningMechanism._validate_variable in call to super(), as it requires variable to have 3 items
628
        variable = super(LearningMechanism, self)._validate_variable(variable, context)
1✔
629

630
        # Items in variable should be 1d and have numeric values
631
        if not (all(np.array(variable)[i].ndim == 1 for i in range(len(variable))) and is_numeric(variable)):
1✔
632
            raise EMStorageMechanismError(f"Variable for {self.name} ({variable}) must be "
633
                                          f"a list or 2d np.array containing 1d arrays with only numbers.")
634
        return variable
1✔
635

636
    def _validate_params(self, request_set, target_set=None, context=None):
1✔
637
        """Validate relationship of matrix, fields and field_types arguments"""
638

639
        # Ensure that the shape of variable is equivalent to an entry in memory_matrix
640
        if MEMORY_MATRIX in request_set:
1!
641
            memory_matrix = request_set[MEMORY_MATRIX]
1✔
642
            # Items in variable should have the same shape as memory_matrix
643
            if memory_matrix[0].shape != np.array(self.variable).shape:
1✔
644
                raise EMStorageMechanismError(f"The 'variable' arg for {self.name} ({self.variable}) must be "
645
                                              f"a list or 2d np.array containing entries that have the same shape "
646
                                              f"({memory_matrix.shape}) as an entry (row) in 'memory_matrix' arg.")
647

648
        # Ensure the number of fields is equal to the number of items in variable
649
        if FIELDS in request_set:
1!
650
            fields = request_set[FIELDS]
1✔
651
            if len(fields) != len(self.variable):
1✔
652
                raise EMStorageMechanismError(f"The 'fields' arg for {self.name} ({fields}) must have the same "
653
                                              f"number of items as its variable arg ({len(self.variable)}).")
654

655
        # Ensure the number of field_types is equal to the number of fields
656
        if FIELD_TYPES in request_set:
1!
657
            field_types = request_set[FIELD_TYPES]
1✔
658
            if len(field_types) != len(fields):
1✔
659
                raise EMStorageMechanismError(f"The 'field_types' arg for {self.name} ({field_types}) must have "
660
                                              f"the same number of items as its 'fields' arg ({len(fields)}).")
661

662
        num_keys = len([i for i in field_types if i==1])
1✔
663
        concatenate_queries = 'concatenation_node' in request_set and request_set['concatenation_node'] is not None
1✔
664

665
        # Ensure the number of learning_signals is equal to the number of fields + number of keys
666
        if LEARNING_SIGNALS in request_set:
1!
667
            learning_signals = request_set[LEARNING_SIGNALS]
1✔
668
            if concatenate_queries:
1✔
669
                num_match_fields = 1
1✔
670
            else:
671
                num_match_fields = num_keys
1✔
672
            if len(learning_signals) != num_match_fields + len(fields):
1✔
673
                raise EMStorageMechanismError(f"The number of 'learning_signals' ({len(learning_signals)}) specified "
674
                                              f"for  {self.name} must be the same as the number of items "
675
                                              f"in its variable ({len(self.variable)}).")
676

677
        # Ensure shape of learning_signals matches shapes of matrices for match nodes (i.e., either keys or concatenate)
678
        key_indices = [i for i, field_type in enumerate(field_types) if field_type == 1]
1✔
679
        for i, learning_signal in enumerate(learning_signals[:num_match_fields]):
1✔
680
            learning_signal_shape = learning_signal.parameters.matrix._get(context).shape
1✔
681
            if concatenate_queries:
1✔
682
                memory_matrix_field_shape = np.array([np.concatenate(row, dtype=object).flatten()
1✔
683
                                                      for row in memory_matrix[:,0:num_keys]]).T.shape
684
            else:
685
                memory_matrix_field_shape = np.array(memory_matrix[:,key_indices[i]].tolist()).T.shape
1✔
686
            assert learning_signal_shape == memory_matrix_field_shape, \
1✔
687
                f"The shape ({learning_signal_shape}) of the matrix for the Projection {learning_signal.name} " \
688
                f"used to specify learning signal {i} of {self.name} does not match the shape " \
689
                f"of the corresponding field {i} of its 'memory_matrix' {memory_matrix_field_shape})."
690
        # Ensure shape of learning_signals matches shapes of matrices for retrieval nodes (i.e., all input fields)
691
        for i, learning_signal in enumerate(learning_signals[num_match_fields:]):
1✔
692
            learning_signal_shape = learning_signal.parameters.matrix._get(context).shape
1✔
693
            memory_matrix_field_shape = np.array(memory_matrix[:,i].tolist()).shape
1✔
694
            assert learning_signal_shape == memory_matrix_field_shape, \
1✔
695
                f"The shape ({learning_signal.shape}) of the matrix for the Projection {learning_signal.name} " \
696
                f"used to specify learning signal {i} of {self.name} does not match the shape " \
697
                f"of the corresponding field {i} of its 'memory_matrix' {memory_matrix.shape})."
698

699
    def _instantiate_input_ports(self, input_ports=None, reference_value=None, context=None):
1✔
700
        """Override LearningMechanism to instantiate an InputPort for each field"""
701
        input_ports = [{NAME: f"QUERY_INPUT_{i}" if self.field_types[i] == 1 else f"VALUE_INPUT_{i}",
1✔
702
                        VARIABLE: self.variable[i],
703
                        PROJECTIONS: field}
704
                       for i, field in enumerate(self.input_ports)]
705
        return super()._instantiate_input_ports(input_ports=input_ports, context=context)
1✔
706

707
    def _instantiate_output_ports(self, output_ports=None, reference_value=None, context=None):
1✔
708
        learning_signal_dicts = []
1✔
709
        for i, learning_signal in enumerate(self.learning_signals):
1✔
710
            learning_signal_dicts.append({NAME: f"STORE TO {learning_signal.receiver.owner.name} MATRIX",
1✔
711
                                          VARIABLE: (OWNER_VALUE, i),
712
                                          REFERENCE_VALUE: self.value[i],
713
                                          MODULATION: self.modulation,
714
                                          PROJECTIONS: learning_signal.parameter_ports['matrix']})
715
        self.parameters.learning_signals._set(learning_signal_dicts, context)
1✔
716

717
        learning_signals = super()._instantiate_output_ports(context=context)
1✔
718

719

720
    def _parse_function_variable(self, variable, context=None):
1✔
721
        # Function expects a single field (one item of Mechanism's variable) at a time
722
        if self.initialization_status == ContextFlags.INITIALIZING:
1✔
723
            # During initialization, Mechanism's variable is its default_variable,
724
            # which has all field's worth of input, so need get a single one here
725
            return variable[0]
1✔
726
        # During execution, _execute passes only a entry (item of variable) at a time,
727
        #    so can just pass that along here
728
        return variable
1✔
729

730
    def _execute(self,
1✔
731
                 variable=None,
732
                 context=None,
733
                 runtime_params=None):
734
        """Execute EMStorageMechanism.function and return learning_signals
735

736
        For each node in query_input_nodes and value_input_nodes,
737
        assign its value to afferent weights of corresponding retrieved_node.
738
        - memory = matrix of entries made up vectors for each field in each entry (row)
739
        - memory_full_vectors = matrix of entries made up of vectors concatentated across all fields (used for norm)
740
        - entry_to_store = query_input or value_input to store
741
        - field_memories = weights of Projections for each field
742

743
        DIVISION OF LABOR BETWEEN MECHANISM AND FUNCTION:
744
        EMStorageMechanism._execute:
745
         - compute norms to find weakest entry in memory
746
         - compute storage_prob to determine whether to store current entry in memory
747
         - call function for each LearningSignal to decay existing memory and assign input to weakest entry
748
        EMStorage function:
749
         - decay existing memories
750
         - assign input to weakest entry (given index passed from EMStorageMechanism)
751

752
        :return: List[2d np.array] self.learning_signal
753
        """
754

755
        # FIX: SET LEARNING MODE HERE FOR SHOW_GRAPH
756

757
        decay_rate = self.parameters.decay_rate._get(context)      # modulable, so use getter
1✔
758
        storage_prob = self.parameters.storage_prob._get(context)  # modulable, so use getter
1✔
759
        field_weights = self.parameters.field_weights._get(context)  # modulable, so use getter
1✔
760
        concatenation_node = self.concatenation_node
1✔
761
        num_match_fields = 1 if concatenation_node else len([i for i in self.field_types if i==1])
1✔
762

763
        memory = self.parameters.memory_matrix._get(context)
1✔
764
        if memory is None or self.is_initializing:
1✔
765
            if self.is_initializing:
1!
766
                # Return existing matrices for field_memories  # FIX: THE FOLLOWING DOESN'T TEST FUNCTION:
767
                return convert_all_elements_to_np_array([
1✔
768
                    learning_signal.receiver.path_afferents[0].parameters.matrix._get(context)
769
                    for learning_signal in self.learning_signals
770
                ])
771
            # Raise exception if not initializing and memory is not specified
772
            else:
773
                owner_string = ""
×
774
                if self.owner:
×
775
                    owner_string = " of " + self.owner.name
×
776
                raise EMStorageMechanismError(f"Call to {self.__class__.__name__} function {owner_string} "
777
                                              f"must include '{MEMORY_MATRIX}' in params arg.")
778

779
        # Get least used slot (i.e., weakest memory = row of matrix with lowest weights) computed across all fields
780
        field_norms = np.empty((len(memory),len(memory[0])))
1✔
781
        for row in range(len(memory)):
1✔
782
            for col in range(len(memory[0])):
1✔
783
                field_norms[row][col] = np.linalg.norm(memory[row][col])
1✔
784
        if field_weights is not None:
1!
785
            field_norms *= field_weights
×
786
        row_norms = np.sum(field_norms, axis=1)
1✔
787
        # IMPLEMENTATION NOTE:
788
        #  the following will give the lowest index in case of a tie;
789
        #  this means that if memory is initialized with all zeros,
790
        #  it will be occupied in row order
791
        idx_of_weakest_memory = np.argmin(row_norms)
1✔
792

793
        value = []
1✔
794
        for i, field_projection in enumerate([learning_signal.efferents[0].receiver.owner
1✔
795
                                            for learning_signal in self.learning_signals]):
796
            if i < num_match_fields:
1✔
797
                # For match matrices,
798
                #   get entry to store from variable of Projection matrix (memory_field)
799
                #   to match_node in which memory will be stored (this is to accomodate concatenation_node)
800
                axis = 0
1✔
801
                entry_to_store = field_projection.parameters.variable._get(context)
1✔
802
                if concatenation_node is None:
1✔
803
                    assert np.all(entry_to_store == variable[i]),\
1✔
804
                        f"PROGRAM ERROR: misalignment between inputs and fields for storing them"
805
            else:
806
                # For retrieval matrices,
807
                #    get entry to store from variable (which has inputs to all fields)
808
                axis = 1
1✔
809
                entry_to_store = variable[i - num_match_fields]
1✔
810
            # Get matrix containing memories for the field from the Projection
811
            field_memory_matrix = field_projection.parameters.matrix._get(context)
1✔
812

813
            # pass in field_projection matrix to EMStorage function
814
            res = super(LearningMechanism, self)._execute(
1✔
815
                variable=entry_to_store,
816
                memory_matrix=copy_parameter_value(field_memory_matrix),
817
                axis=axis,
818
                storage_location=idx_of_weakest_memory,
819
                storage_prob=storage_prob,
820
                decay_rate=decay_rate,
821
                context=context,
822
                runtime_params=runtime_params
823
            )
824
            value.append(res)
1✔
825
            # assign modified field_memory_matrix back
826
            field_projection.parameters.matrix._set(res, context)
1✔
827
        return convert_all_elements_to_np_array(value)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc