• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.02
/psyneulink/library/compositions/grucomposition/grucomposition.py
1
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
2
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
3
#     http://www.apache.org/licenses/LICENSE-2.0
4
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
5
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6
# See the License for the specific language governing permissions and limitations under the License.
7

8

9
# ********************************************* GRUComposition *************************************************
10

11
"""
12
Contents
13
--------
14

15
  * `GRUComposition_Overview`
16
  * `GRUComposition_Creation`
17
  * `GRUComposition_Structure`
18
  * `GRUComposition_Execution`
19
     - `Processing <GRUComposition_Processing>`
20
     - `Learning <GRUComposition_Learning>`
21
  * `GRUComposition_Examples`
22
  * `GRUComposition_Class_Reference`
23

24
.. _GRUComposition_Overview:
25

26
Overview
27
--------
28

29
The GRUComposition a subclass of `AutodiffComposition` that implements a single-layered gated recurrent network,
30
which uses a set of `GatingMechanisms <GatingMechanism>` to implement gates that  modulate the flow of information
31
through its `hidden_layer_node <GRUComposition.hidden_layer_node>`. This implements the exact same computations as
32
a PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module, which is used to implement
33
it when its `learn <GRUComposition.learn>` method is called.  When it is executed in Python model, it functions
34
in the same way as a `GRUCell <https://pytorch.org/docs/stable/generated/torch.nn.GRUCell.html>`_ module, processing
35
its input one stimulus at a time.  However, when used for `learning <GRUComposition_Learning>`, it is executed as
36
a PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module, so that it can used to
37
process an entire sequence of stimuli at once, and learn to predict the next stimulus in the sequence.
38

39
.. _GRUComposition_Creation:
40

41
Creation
42
--------
43

44
An GRUComposition is created by calling its constructor.  When it's `learn <AutoDiffComposition.learn>`
45
method is called, it automatically creates a PytorchGRUCompositionWrapper that implements the GRUComposition
46
using the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module, that is trained
47
using PyTorch. Its constructor takes the following arguments that are in addition to or handled differently
48
than `AutodiffComposition`:
49

50
**input_size** (int) specifies the length of the input array to the GRUComposition, and the size
51
of the `input_node <GRUComposition.input_node>`, which can be different than **hidden_size**.
52

53
**hidden_size** (int) specifies the length of the internal ("hidden") state of the GRUComposition,
54
and the size of the `hidden_layer_node <GRUComposition.hidden_layer_node>` and all nodes other
55
than the `input_node<GRUComposition.input_node>`, which can be different than **input_size**.
56

57
**bias** (bool) specifies whether the GRUComposition includes `BIAS <NodeRole.BIAS>` `Nodes <Composition_Nodes>`
58
and, correspondingly, the `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module uses
59
bias vectors in its computations.
60

61
.. _GRUComposition_Learning_Arguments:
62

63
**enable_learning** (bool) specifies whether learning is enabled for the GRUComposition;  if it is false,
64
no learning will occur, even when its `learn <AutodiffComposition.learn>` method is called.
65

66
**learning_rate** (bool or float): specifies the default learning_rate for the parameters of the Pytorch `GRU
67
<https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module that are not specified for individual
68
parameters in the **optimizer_params** argument of the AutodiffComposition's constructor in the call to its `learn
69
<AutodiffComposition.learn>` method. If it is an int or a float, that is used as the default learning rate for the
70
GRUComposition; if it is None or True, the GRUComposition's default `learning_rate <GRUComposition.learning_rate>`
71
(.001) is used; if it is False, then learning will occur only for parameters for which an explicit learning_rate
72
has been specified in the **optimizer_params** argument of the GRUComposition's constructor
73
COMMENT: FIX CORRECT?
74
or in the call to its `learn <AutodiffComposition.learn>` method
75
COMMENT
76

77
.. _GRUComposition_Individual_Learning_Rates:
78

79
**optimizer_params** (dict): used to specify parameter-specific learning rates, which supercede the value of the
80
GRUCompositon's `learning_rate <GRUComposition.learning_rate>`. Keys of the dict must reference parameters of the
81
`GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module, and values their learning_rates,
82
as described below.
83

84
  **Keys** for specifying individual parameters in the **optimizer_params** dict:
85

86
    - *`w_ih`*: learning rate for the ``weight_ih_l0`` parameter of the PyTorch `GRU
87
      <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module that corresponds to the weights of the
88
      efferent projections from the `input_node <GRUComposition.input_node>` of the GRUComposition: `wts_in
89
      <GRUComposition.wts_in>`, `wts_iu <GRUComposition.wts_iu>`, and `wts_ir <GRUComposition.wts_ir>`; its value
90
      is stored in the `w_ih_learning_rate <GRUComposition.w_ih_learning_rate>` attribute of the GRUComposition;
91

92
    - *`w_hh`*: learning rate for the ``weight_hh_l0`` parameter of the PyTorch `GRU
93
      <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module that corresponds to the weights of the
94
      efferent projections from the `hidden_layer_node <GRUComposition.hidden_layer_node>` of the GRUComposition:
95
      `wts_hn <GRUComposition.wts_hn>`, `wts_hu <GRUComposition.wts_hu>`, `wts_hr <GRUComposition.wts_hr>`; its
96
      value is stored in the `w_hh_learning_rate <GRUComposition.w_hh_learning_rate>` attribute of the GRUComposition;
97

98
    - *`b_ih`*: learning rate for the ``bias_ih_l0`` parameter of the PyTorch `GRU
99
      <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module that corresponds to the biases of the
100
      efferent projections from the `input_node <GRUComposition.input_node>` of the GRUComposition: `bias_ir
101
      <GRUComposition.bias_ir>`, `bias_iu <GRUComposition.bias_iu>`, `bias_in <GRUComposition.bias_in>`; its value
102
      is stored in the `b_ih_learning_rate <GRUComposition.b_ih_learning_rate>` attribute of the GRUComposition;
103

104
    - *`b_hh`*: learning rate for the ``bias_hh_l0`` parameter of the PyTorch `GRU
105
      <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module that corresponds to the biases of the
106
      efferent projections from the `hidden_layer_node <GRUComposition.hidden_layer_node>` of the GRUComposition:
107
      `bias_hr <GRUComposition.bias_hr>`, `bias_hu <GRUComposition.bias_hu>`, `bias_hn <GRUComposition.bias_hn>`; its
108
      value is stored in the `b_hh_learning_rate <GRUComposition.b_hh_learning_rate>` attribute of theGRUComposition.
109

110
  **Values** for specifying an individual parameter's learning_rate in the **optimizer_params** dict
111

112
    - *int or float*: the value is used as the learning_rate;
113

114
    - *True or None*: the value of the GRUComposition's `learning_rate <GRUComposition.learning_rate>` is used;
115

116
    - *False*: the parameter is not learned.
117

118

119
.. _GRUComposition_Structure:
120

121
Structure
122
---------
123

124
The GRUComposition assigns a node to each of the computations of the PyTorch `GRU
125
<https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module, and a Projetion to each of
126
its weight and bias parameters, as shown in the figure below:
127

128
.. figure:: _static/GRUComposition_fig.svg
129
   :alt: GRU Composition
130
   :width: 400
131
   :align: center
132

133
   **Structure of a GRUComposition** -- can be seen in more detail using the Composition's s `show_graph
134
   <ShowGraph.show_graph>` method with its **show_node_structure** argument set to ``True`` or ``ALL``;
135
   can also be seen with biases added by setting the **show_bias** argument to ``True`` in the constructor.
136

137
The `input_node <GRUComposition.input_node>` receives the input to the GRUComposition, and passes it to the
138
`hidden_layer_node <GRUComposition.hidden_layer_node>`, that implements the recurrence and integration function of
139
a GRU.  The `reset_node <GRUComposition.reset_node>` gates the input to the `new_node<GRUComposition.new_node>`. The
140
`update_node <GRUComposition.update_node>` gates the input to the `hidden_layer_node<GRUComposition.hidden_layer_node>`
141
from the `new_node <GRUComposition.new_node>` (current input) and the prior state of the `hidden_layer_node
142
<GRUComposition.hidden_layer_node>` (i.e., the input it receives from its recurrent Projection).  The `output_node
143
<GRUComposition.output_node>` receives the output of current state of the `hidden_layer_node
144
<GRUComposition.hidden_layer_node>` that is provided as the output of the GRUComposition.  The `reset_gate
145
<GRUComposition.reset_gate>` and `update_node <GRUComposition.update_node>` are `GatingMechanisms <GatingMechanism>`,
146
while the other nodes are all `Processing Mechanisms <ProcessingMechanism>`.
147

148
.. note::
149
   The GRUComposition is limited to a single layer GRU at present, thus its ``num_layers`` argument is not
150
   implemented.  Similarly, ``dropout`` and ``bidirectional`` arguments are not yet implemented.  These will
151
   be added in a future version.
152

153
COMMENT:
154
FIX: ADD EXPLANATION OF THE FOLLOWING
155
.. technical_note::
156
   gru_mech
157
   target_node
158
   PytorchGRUProjectionWrappers for nested case
159
COMMENT
160

161
.. _GRUComposition_Execution:
162

163
Execution
164
---------
165

166
.. _GRUComposition_Processing:
167

168
*Processing*
169
~~~~~~~~~~~~
170

171
The GRUComposition implements the following computations by its `reset <GRUComposition.reset_node>`, `update
172
<GRUComposition.update_node>`, `new <GRUComposition.new_node>`, and `hidden_layer <GRUComposition.hidden_layer_node>`
173
`Nodes <Composition_Nodes>` when it is executed:
174

175
    `reset <GRUComposition.reset_gate>`\\(t) = `Logistic`\\[(`wts_ir <GRUComposition.wts_ir>` *
176
    `input <GRUComposition.input_node>`) + `bias_ir <GRUComposition.bias_ir>` +
177
    (`wts_hr <GRUComposition.wts_hr>` * `hidden_layer <GRUComposition.hidden_layer_node>`\\(t-1)) +
178
    `bias_hr <GRUComposition.bias_hr>`)]
179

180
    `update <GRUComposition.update_node>`\\(t) = `Logistic`\\[(`wts_iu <GRUComposition.wts_iu>` *
181
    `input <GRUComposition.input_node>`) + `bias_iu <GRUComposition.bias_iu>` + (`wts_hu <GRUComposition.wts_hu>` *
182
    `hidden_layer <GRUComposition.hidden_layer_node>`\\(t-1)) + `bias_hu <GRUComposition.bias_hu>`]
183

184
    `new <GRUComposition.new_node>`\\(t) = :math:`tanh`\\[(`wts_in <GRUComposition.wts_in>` *
185
    `input <GRUComposition.input_node>`) + `bias_in <GRUComposition.bias_in>` +
186
    (`reset <GRUComposition.reset_gate>`\\(t) * (`wts_hn <GRUComposition.wts_hn>` *
187
    `hidden_layer <GRUComposition.hidden_layer_node>`\\(t-1) + `bias_hn <GRUComposition.bias_hn>`)]
188

189
    `hidden_layer <GRUComposition.hidden_layer_node>`\\(t) = [(1 - `update <GRUComposition.update_node>`\\(t)) *
190
    `new <GRUComposition.new_node>`\\(t)] + [`update <GRUComposition.update_node>`\\(t) * `hidden_layer
191
    <GRUComposition.hidden_layer_node>`\\(t-1)]
192

193
COMMENT:
194
where:
195
    r(t) = reset gate
196

197
    z(t) = update gate
198

199
    n(t) = new gate
200

201
    h(t) = hidden layer
202

203
    x(t) = input
204

205
    W_ir, W_iz, W_in, W_hr, W_hz, W_hn = input, update, and reset weights
206

207
    b_ir, b_iz, b_in, b_hr, b_hz, b_hn = input, update, and reset biases
208
COMMENT
209

210
This corresponds to the computations of the `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module:
211

212
.. math::
213

214
   &reset = Logistic(wts\\_ir \\cdot input + bias\\_ir + wts\\_hr \\cdot hidden + bias\\_hr)
215

216
   &update = Logistic(wts\\_iu \\cdot input + bias\\_iu + wts\\_hu \\cdot hidden + bias\\_hu)
217

218
   &new = Tanh(wts\\_in \\cdot input + bias\\_in + reset \\cdot (wts\\_hn \\cdot hidden + bias\\_hn))
219

220
   &hidden = (1 - update) \\odot new + update \\odot hidden
221

222
where :math:`\\cdot` is the dot product, :math:`\\odot` is the Hadamard product, and all values are for the
223
current execution of the Composition *(t)* except for hidden, which uses the value from the prior execution *(t-1)*
224
(see `Cycles <Composition_Cycle>` for handling of recurrence and cycles).
225

226

227
.. technical_note::
228
    The `full Composition <GRUComposition_Structure>` is executed when its `run <Composition.run>` method is
229
    called with **execution_mode** set to `ExecutionMode.Python`, or if ``torch_available`` is False.  Otherwise, and
230
    always in a call to `learn <AutodiffComposition.learn>`, the GRUComposition is executed using the PyTorch `GRU
231
    <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module with values of the individual
232
    computations copied back to Nodes of the full GRUComposition at times determined by the value of the
233
    `synch_node_values_with_torch <AutodiffComposition.synch_node_values_with_torch>` option.
234

235

236
.. _GRUComposition_Learning:
237

238
*Learning*
239
~~~~~~~~~~
240

241
Learning is executed using the `learn` method in same way as a standard `AutodiffComposition`.  For learning to
242
occur the following conditions must obtain:
243

244
  - `enable_learning <GRUComposition.enable_learning>` must be set to `True` (the default);
245

246
  - GRUCompositions's `learning_rate <GRUComposition.learning_rate>` must not be False and/or the
247
    `learning_rate of individual parameters <GRUComposition_Individual_Learning_Rates>` must not all be False;
248

249
  - **execution_mode** argument of the `learn <AutodiffComposition.learn>` method must `ExecutionMode.PyTorch`
250
    (the default).
251

252
  .. note:: Because a GRUComposition uses the PyTorch `GRU
253
     <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module to implement its computations during
254
     learning, its `learn <AutodiffComposition.learn>` method can only be called with the **execution_mode**
255
     argument set to `ExecutionMode.PyTorch` (the default).
256

257
The GRUComposition uses the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
258
to implement its computations during learning. After learning, the values of the module's parameters are copied
259
to the weight `matrices <MappingProjection.matrix>` of the corresponding `MappingProjections <MappingProjection>`,
260
and results of computations are copied to the `values <Mechanism_Base.value>` of the corresponding `Nodes
261
<Composition_Nodes>` in the GRUComposition at times determined by the value of the `synch_node_values_with_torch
262
<AutodiffComposition.synch_node_values_with_torch>` option.
263

264
COMMENT:
265
.. _GRUComposition_Examples:
266

267
Examples
268
--------
269

270
The following are examples of how to configure and initialize a GRUComposition:
271
COMMENT
272

273
.. _GRUComposition_Class_Reference:
274

275
Class Reference
276
---------------
277
"""
278
import numpy as np
1✔
279
import warnings
1✔
280
from typing import Union
1✔
281
# from sympy.stats import Logistic
282
from collections import deque
1✔
283

284
import psyneulink.core.scheduling.condition as conditions
1✔
285
from psyneulink.core.components.functions.nonstateful.transformfunctions import LinearCombination
1✔
286
from psyneulink.core.components.functions.nonstateful.transferfunctions import Linear, Logistic, Tanh
1✔
287
from psyneulink.core.components.functions.nonstateful.transformfunctions import MatrixTransform
1✔
288
from psyneulink.core.components.functions.function import (
1✔
289
    DEFAULT_SEED, get_matrix, _random_state_getter, _seed_setter)
290
from psyneulink.core.components.ports.inputport import InputPort
1✔
291
from psyneulink.core.components.ports.outputport import OutputPort
1✔
292
from psyneulink.core.compositions.composition import CompositionError, NodeRole
1✔
293
from psyneulink.library.compositions.autodiffcomposition import AutodiffComposition, torch_available
1✔
294
from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism
1✔
295
from psyneulink.core.components.mechanisms.modulatory.control.gating.gatingmechanism import GatingMechanism
1✔
296
from psyneulink.core.components.ports.modulatorysignals.gatingsignal import GatingSignal
1✔
297
from psyneulink.core.components.projections.projection import DuplicateProjectionError
1✔
298
from psyneulink.core.components.projections.modulatory.gatingprojection import GatingProjection
1✔
299
from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
1✔
300
from psyneulink.core.globals.context import Context, ContextFlags, handle_external_context
1✔
301
from psyneulink.core.globals.parameters import Parameter, check_user_specified
1✔
302
from psyneulink.core.globals.keywords import (
1✔
303
    CONTEXT, FULL_CONNECTIVITY_MATRIX, GRU_COMPOSITION, IDENTITY_MATRIX, OUTCOME, SUM)
304
from psyneulink.core import llvm as pnlvm
1✔
305
from psyneulink.core.llvm import ExecutionMode
1✔
306

307
__all__ = ['GRUComposition', 'GRUCompositionError',
1✔
308
           'INPUT_NODE', 'HIDDEN_LAYER', 'RESET_NODE',
309
           'UPDATE_NODE', 'NEW_NODE', 'OUTPUT_NODE', 'GRU_INTERNAL_STATE_NAMES', 'GRU_NODE', 'GRU_TARGET_NODE']
310

311
# Node names
312
INPUT_NODE = 'INPUT'
1✔
313
NEW_NODE = 'NEW'
1✔
314
RESET_NODE = 'RESET'
1✔
315
UPDATE_NODE = 'UPDATE'
1✔
316
HIDDEN_LAYER = 'HIDDEN\nLAYER'
1✔
317
OUTPUT_NODE = 'OUTPUT'
1✔
318
GRU_INTERNAL_STATE_NAMES = [NEW_NODE, RESET_NODE, UPDATE_NODE, HIDDEN_LAYER]
1✔
319
GRU_NODE = 'PYTORCH GRU NODE'
1✔
320
GRU_TARGET_NODE = 'GRU TARGET NODE'
1✔
321

322
class GRUCompositionError(CompositionError):
1✔
323
    pass
1✔
324

325

326
class GRUComposition(AutodiffComposition):
1✔
327
    """
328
    GRUComposition(                         \
329
        name="GRU_Composition"              \
330
        input_size=1,                       \
331
        hidden_size=1,                      \
332
        bias=False                          \
333
        enable_learning=True                \
334
        learning_rate=.01                   \
335
        optimizer_params=None               \
336
        )
337

338
    Subclass of `AutodiffComposition` that implements a single-layered gated recurrent network.
339

340
    See `GRUComposition_Structure` and technical_note under under `GRUComposition_Execution`
341
    for a description of when the full Composition is constructed and used for execution
342
    vs. when the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_
343
    module is used.
344

345
    Note: all exposed methods, attributes and `Parameters <Parameter>`) of the GRUComposition are
346
          PsyNeuLink elements; all PyTorch-specific elements belong to `pytorch_representation
347
          <AutodiffComposition.pytorch_representation>` which, for a GRUComposition, is of class
348
          `PytorchGRUCompositionWrapper`.
349

350
    Constructor takes the following arguments in addition to those of `AutodiffComposition`:
351

352
    Arguments
353
    ---------
354

355
    input_size : int : default 1
356
        specifies the length of the input array to the GRUComposition, and the size of the `input_node
357
        <GRUComposition.input_node>`.
358

359
    hidden_size : int : default 1
360
        specifies the length of the internal state of the GRUComposition, and the size of the `hidden_layer_node
361
        <GRUComposition.hidden_layer_node>` and all nodes other than the `input_node<GRUComposition.input_node>`.
362

363
    bias : bool : default False
364
        specifies whether the GRUComposition uses bias vectors in its computations.
365

366
    COMMENT:
367
    num_layers : int : default 1
368
     batch_first : bool : default False
369
     dropout : float : default 0.0
370
     bidirectional : bool : default False
371
    COMMENT
372

373
    enable_learning : bool : default True
374
        specifies whether learning is enabled for the GRUComposition (see `Learning Arguments
375
        <GRUComposition_Learning_Arguments>` for additional details).
376

377
    learning_rate : float : default .001
378
        specifies the learning_rate for the GRUComposition (see `Learning Arguments
379
        <GRUComposition_Learning_Arguments>` for additional details).
380

381
    optimizer_params : Dict[str: value]
382
        specifies parameters for the optimizer used for learning by the GRUComposition
383
        (see `Learning Arguments <GRUComposition_Learning_Arguments>` for details of specification).
384

385
    Attributes
386
    ----------
387

388
    input_size : int
389
        determines the length of the input array to the GRUComposition and size of the `input_node
390
        <GRUComposition.input_node>`.
391

392
    hidden_size : int
393
        determines the size of the `hidden_layer_node` and all other `INTERNAL` `Nodes <Composition_Nodes>`
394
        of the GRUComposition.
395

396
    bias : bool
397
        determines whether the GRUComposition uses bias vectors in its computations.
398

399
    COMMENT:
400
    num_layers : int : default 1
401
     batch_first : bool : default False
402
     dropout : float : default 0.0
403
     bidirectional : bool : default False
404
    COMMENT
405

406
    enable_learning : bool
407
        determines whether learning is enabled for the GRUComposition
408
        (see `Learning Arguments <GRUComposition_Learning_Arguments>` for additional details).
409

410
    learning_rate : float
411
        determines the default learning_rate for the parameters of the Pytorch `GRU
412
        <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module that are not specified
413
        for individual parameters in the **optimizer_params** argument of the AutodiffComposition's
414
        constructor in the call to its `learn <GRUComposition.learn>` method (see `Learning Arguments
415
        <GRUComposition_Learning_Arguments>` for additional details).
416

417
    w_ih_learning_rate : flot or bool
418
        determines the learning rate specifically for the weights of the `efferent projections
419
        <Mechanism_Base.efferents>` from the `input_node <GRUComposition.input_node>`
420
        of the GRUComposition: `wts_in <GRUComposition.wts_in>`, `wts_iu <GRUComposition.wts_iu>`,
421
        and `wts_ir <GRUComposition.wts_ir>`; corresponds to the ``weight_ih_l0`` parameter of the
422
        PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
423
        (see `Learning Arguments <GRUComposition_Learning_Arguments>` for additional details).
424

425
    w_hh_learning_rate : float or bool
426
        determines the learning rate specifically for the weights of the `efferent projections
427
        <Mechanism_Base.efferents>` from the `hidden_layer_node <GRUComposition.hidden_layer_node>`
428
        of the GRUComposition: `wts_hn <GRUComposition.wts_hn>`, `wts_hu <GRUComposition.wts_hu>`,
429
        `wts_hr <GRUComposition.wts_hr>`; corresponds to the ``weight_hh_l0`` parameter of the
430
        PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
431
         (see `Learning Arguments <GRUComposition_Learning_Arguments>` for additional details).
432

433
    b_ih_learning_rate : float or bool
434
        determines the learning rate specifically for the biases influencing the `efferent projections
435
        <Mechanism_Base.efferents>` from the `input_node <GRUComposition.input_node>` of the GRUComposition:
436
        `bias_ir <GRUComposition.bias_ir>`, `bias_iu <GRUComposition.bias_iu>`, `bias_in <GRUComposition.bias_in>`;
437
        corresponds to the ``bias_ih_l0`` parameter of the PyTorch `GRU module (see `Learning Arguments
438
        <GRUComposition_Learning_Arguments>` for additional details).
439

440
    b_hh_learning_rate : float or bool
441
        determines the learning rate specifically for the biases influencing the `efferent projections
442
        <Mechanism_Base.efferents>` from the `hidden_layer_node <GRUComposition.hidden_layer_node>` of
443
        the GRUComposition: `bias_hr <GRUComposition.bias_hr>`, `bias_hu <GRUComposition.bias_hu>`,
444
        `bias_hn <GRUComposition.bias_hn>`; corresponds to the ``bias_hh_l0`` parameter of the PyTorch
445
        `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module (see `Learning Arguments
446
        <GRUComposition_Learning_Arguments>` for additional details).
447

448
    input_node : ProcessingMechanism
449
        `INPUT <NodeRole.INPUT>` `Node <Composition_Nodes>` that receives the input to the GRUComposition and passes
450
        it to the `hidden_layer_node <GRUComposition.hidden_layer_node>`; corresponds to input *(i)* of the PyTorch
451
        `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module.
452

453
    new_node : ProcessingMechanism
454
        `ProcessingMechanism` that provides the `hidden_layer_node <GRUComposition.hidden_layer_node>`
455
        with the input from the `input_node <GRUComposition.input_node>`, gated by the `reset_node
456
        <GRUComposition.reset_node>`; corresponds to new gate *(n)* of the PyTorch `GRU
457
        <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module.
458

459
    hidden_layer_node : ProcessingMechanism
460
        `ProcessingMechanism` that implements the recurrent layer of the GRUComposition; corresponds to
461
        hidden layer *(h)* of the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module.
462

463
    reset_node : GatingMechanism
464
        `GatingMechanism` that that gates the input to the `new_node <GRUComposition.new_node>`; corresponds to reset
465
        gate *(r)* of the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module.
466

467
    update_node : GatingMechanism
468
        `GatingMechanism` that gates the inputs to the hidden layer from the `new_node <GRUComposition.new_node>`
469
        and the prior state of the `hidden_layer_node <GRUComposition.hidden_layer_node>` itself (i.e., the input
470
        it receives from its recurrent Projection); corresponds to update gate *(z)* of the PyTorch `GRU
471
        <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module.
472

473
    output_node : ProcessingMechanism
474
        `OUTPUT <NodeRole.INPUT>` `Node <Composition_Nodes>` that receives the output of the `hidden_layer_node
475
        <GRUComposition.hidden_layer_node>`; corresponds to result of the PyTorch `GRU
476
        <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module.
477

478
    learnable_projections : List[MappingProjection]
479
        list of the `MappingProjections <MappingProjection>` in the GRUComposition that have
480
        `matrix <MappingProjection.matrix>` parameters that can be learned; these correspond to the learnable
481
        parameters of the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module.
482

483
    wts_in : MappingProjection
484
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights") that projects
485
        from the `input_node <GRUComposition.input_node>` to the `new_node <GRUComposition.new_node>`; corresponds to
486
        :math:`W_{in}` term in the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module's
487
        computation (see `GRUComposition_Structure` for additional information).
488

489
    wts_iu : MappingProjection
490
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights") that projects
491
        from the `input_node <GRUComposition.input_node>` to the `update_node <GRUComposition.update_node>`; corresponds
492
        to :math:`W_{iz}` term in the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_
493
        module's computation (see `GRUComposition_Structure` for additional information).
494

495
    wts_ir : MappingProjection
496
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights") that projects
497
        from the `input_node <GRUComposition.input_node>` to the `reset_node <GRUComposition.reset_node>`; corresponds
498
        to :math:`W_{ir}` term in the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_
499
        module's computation (see `GRUComposition_Structure` for additional information).
500

501
    wts_nh : MappingProjection
502
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights") that projects
503
        from the `new_node <GRUComposition.new_node>` to the `hidden_layer_node <GRUComposition.hidden_layer_node>`.
504
        (see `GRUComposition_Structure` for additional information).
505

506
    wts_hr : MappingProjection
507
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights")
508
        that projects from the `hidden_layer_node <GRUComposition.hidden_layer_node>` to the
509
        `reset_node <GRUComposition.reset_node>`; corresponds to :math:`W_{hr}` term in the PyTorch
510
        `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module's computation
511
        (see `GRUComposition_Structure` for additional information).
512

513
    wts_hu : MappingProjection
514
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights")
515
        that projects from the `hidden_layer_node <GRUComposition.hidden_layer_node>` to the
516
        `update_node <GRUComposition.update_node>`; corresponds to :math:`W_{hz}` in the PyTorch
517
        `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module's computation
518
        (see `GRUComposition_Structure` for additional information).
519

520
    wts_hn : MappingProjection
521
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights")
522
        that projects from the `hidden_layer_node <GRUComposition.hidden_layer_node>` to the `new_node
523
        <GRUComposition.new_node>`; corresponds to :math:`W_{hn}` in the PyTorch
524
        `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module's computation
525
        (see `GRUComposition_Structure` for additional information).
526

527
    wts_hh : MappingProjection
528
        `MappingProjection` with fixed `matrix <MappingProjection.matrix>` ("connection weights") that projects
529
        from the `hidden_layer_node <GRUComposition.hidden_layer_node>` to itself (i.e., the recurrent Projection).
530
        (see `GRUComposition_Structure` for additional information).
531

532
    wts_ho : MappingProjection
533
        `MappingProjection` with fixed `matrix <MappingProjection.matrix>` ("connection weights") that projects from
534
        the `hidden_layer_node <GRUComposition.hidden_layer_node>` to the `output_node <GRUComposition.output_node>`.
535
        (see `GRUComposition_Structure` for additional information).
536

537
    reset_gate : GatingProjection
538
        `GatingProjection` that gates the input to the `new_node <GRUComposition.new_node>` from the `input_node
539
        <GRUComposition.input_node>`; its `value <GatingProjection.value>` is used in the Hadamard product with
540
        the input to produce the new (external) input to the `hidden_layer_node <GRUComposition.hidden_layer_node>`.
541
        (see `GRUComposition_Structure` for additional information).
542

543
    new_gate : GatingProjection
544
        `GatingProjection` that gates the input to the `hidden_layer_node <GRUComposition.hidden_layer_node>` from the
545
        `new_node <GRUComposition.new_node>`; its `value <GatingProjection.value>` is used in the Hadamard product
546
        with the (external) input to the `hidden_layer_node <GRUComposition.hidden_layer_node>` from the `new_node
547
        <GRUComposition.new_node>`, which determines how much of the `hidden_layer_node
548
        <GRUComposition.hidden_layer_node>`\\'s new state is determined by the external input vs. its prior state
549
        (see `GRUComposition_Structure` for additional information).
550

551
    recurrent_gate : GatingProjection
552
        `GatingProjection` that gates the input to the `hidden_layer_node <GRUComposition.hidden_layer_node>` from its
553
        recurrent projection (`wts_hh <GRUComposition.wts_hh>`); its `value <GatingProjection.value>` is used in the
554
        in the Hadamard product with the recurrent input to the `hidden_layer_node <GRUComposition.hidden_layer_node>`,
555
        which determines how much of the `hidden_layer_node <GRUComposition.hidden_layer_node>`\\'s
556
        new state is determined by its prior state vs.its external input
557
        (see `GRUComposition_Structure` for additional information).
558

559
    bias_ir_node : ProcessingMechanism
560
        `BIAS` `Node <Composition_Nodes>`, the Projection from which (`bias_ir <GRUComposition.bias_ir>`) provides the
561
        the bias to weights (`wts_ir <GRUComposition.wts_ir>`) from the `input_node <GRUComposition.input_node>` to the
562
        `reset_node <GRUComposition.reset_node>` (see `GRUComposition_Structure` for additional information).
563

564
    bias_iu_node : ProcessingMechanism
565
        `BIAS` `Node <Composition_Nodes>`, the Projection from which (`bias_iu <GRUComposition.bias_iu>`) provides
566
        the the bias to weights (`wts_iu <GRUComposition.wts_iu>`) from the `input_node <GRUComposition.input_node>`
567
        to the `update_node <GRUComposition.update_node>` (see `GRUComposition_Structure` for additional information).
568

569
    bias_in_node : ProcessingMechanism
570
        `BIAS` `Node <Composition_Nodes>`, the Projection from which (`bias_in <GRUComposition.bias_in>`) provides the
571
        the bias to weights (`wts_in <GRUComposition.wts_in>`) from the `input_node <GRUComposition.input_node>` to the
572
        `new_node <GRUComposition.new_node>` (see `GRUComposition_Structure` for additional information).
573

574
    bias_hr_node : ProcessingMechanism
575
        `BIAS` `Node <Composition_Nodes>`, the Projection from which (`bias_hr <GRUComposition.bias_hr>`) provides the
576
        the bias to weights (`wts_hr <GRUComposition.wts_hr>`) from the `hidden_layer_node
577
        <GRUComposition.hidden_layer_node>` to the `reset_node <GRUComposition.reset_node>`
578
        (see `GRUComposition_Structure` for additional information).
579

580
    bias_hu_node : ProcessingMechanism
581
        `BIAS` `Node <Composition_Nodes>`, the Projection from which (`bias_hu <GRUComposition.bias_hu>`) provides the
582
        the bias to weights (`wts_hu <GRUComposition.wts_hu>`) from the `hidden_layer_node
583
        <GRUComposition.hidden_layer_node>` to the `update_node <GRUComposition.update_node>`
584
        (see `GRUComposition_Structure` for additional information).
585

586
    bias_hn_node : ProcessingMechanism
587
        `BIAS` `Node <Composition_Nodes>`, the Projection from which (`bias_hn <GRUComposition.bias_hn>`) provides the
588
        the bias to weights (`wts_hn <GRUComposition.wts_hn>`) from the `hidden_layer_node
589
        <GRUComposition.hidden_layer_node>` to the `new_node <GRUComposition.new_node>`
590
        (see `GRUComposition_Structure` for additional information).
591

592
    biases : List[MappingProjection]
593
        list of the `MappingProjections <MappingProjection>` from the `BIAS <NodeRole.BIAS>` `Nodes of
594
        the GRUComposition, all of which have `matrix <MappingProjection.matrix>` parameters if `bias
595
        <GRUComposition.bias>` is True; these correspond to the learnable biases of the PyTorch `GRU
596
        <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
597
        (see `GRUComposition_Structure` for additional information).
598

599
    bias_ir : MappingProjection
600
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights") that provides
601
        the bias to the weights, `wts_ir <GRUComposition.wts_ir>`, from the `input_node <GRUComposition.input_node>`
602
        to the `reset_node <GRUComposition.reset_node>`; corresponds to the :math:`b_ir` bias parameter of the
603
        PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
604
        (see `GRUComposition_Structure` for additional information).
605

606
    bias_iu : ProcessingMechanism
607
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights") that provides
608
        the bias to the weights, `wts_iu <GRUComposition.wts_iu>`, from the `input_node <GRUComposition.input_node>`
609
        to the `update_node <GRUComposition.update_node>`; corresponds to the :math:`b_iz` bias parameter of the
610
        PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
611
        (see `GRUComposition_Structure` for additional information).
612

613
    bias_in : ProcessingMechanism
614
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights") that provides
615
        the bias to the weights, `wts_in <GRUComposition.wts_in>`, from the `input_node <GRUComposition.input_node>`
616
        to the `new_node <GRUComposition.new_node>`; corresponds to the :math:`b_in` bias parameter of the
617
        PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
618
        (see `GRUComposition_Structure` for additional information).
619

620
    bias_hr : ProcessingMechanism
621
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights")
622
        that provides the bias to the weights, `wts_hr <GRUComposition.wts_hr>`, from the `hidden_layer_node
623
        <GRUComposition.hidden_layer_node>` to the `reset_node <GRUComposition.reset_node>`;
624
        corresponds to the :math:`b_hr` bias parameter of the PyTorch `GRU
625
        <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
626
        (see `GRUComposition_Structure` for additional information).
627

628
    bias_hu : ProcessingMechanism
629
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights") that provides
630
        the bias to the weights, `wts_hu <GRUComposition.wts_hu>`, from the `hidden_layer_node
631
        <GRUComposition.hidden_layer_node>` to the `update_node <GRUComposition.update_node>`;
632
        corresponds to the :math:`b_hz` bias parameter of the PyTorch `GRU
633
        <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
634
        (see `GRUComposition_Structure` for additional information).
635

636
    bias_hn : ProcessingMechanism
637
        `MappingProjection` with learnable `matrix <MappingProjection.matrix>` ("connection weights") that provides
638
        the bias to the weights, `wts_hn <GRUComposition.wts_hn>`, from the `hidden_layer_node
639
        <GRUComposition.hidden_layer_node>` to the `new_node <GRUComposition.new_node>`; corresponds to the :math:`b_hn`
640
        bias parameter of the PyTorch `GRU <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html>`_ module
641
        (see `GRUComposition_Structure` for additional information).
642
    """
643

644
    componentCategory = GRU_COMPOSITION
1✔
645

646
    if torch_available:
1!
647
        from psyneulink.library.compositions.grucomposition.pytorchGRUwrappers import \
1✔
648
            PytorchGRUCompositionWrapper, PytorchGRUMechanismWrapper
649
        pytorch_composition_wrapper_type = PytorchGRUCompositionWrapper
1✔
650
        pytorch_mechanism_wrapper_type = PytorchGRUMechanismWrapper
1✔
651

652
    class Parameters(AutodiffComposition.Parameters):
1✔
653
        """
654
            Attributes
655
            ----------
656

657
                bias
658
                    see `bias <GRUComposition.bias>`
659

660
                    :default value: False
661
                    :type: ``bool``
662

663
                enable_learning
664
                    see `enable_learning <GRUComposition.enable_learning>`
665

666
                    :default value: True
667
                    :type: ``bool``
668

669
                gru_mech
670
                    see `gru_mech <GRUComposition.gru_mech>`
671

672
                    :default value: None
673
                    :type: ``ProcessingMechanism``
674

675
                hidden_biases_learning_rate
676
                    see `hidden_biases_learning_rate <GRUComposition.hidden_biases_learning_rate>`
677

678
                    :default value: True
679
                    :type: ``bool``
680

681
                hidden_size
682
                    see `hidden_size <GRUComposition.hidden_size>`
683

684
                    :default value: 1
685
                    :type: ``int``
686

687
                hidden_state
688

689
                    :default value: None
690
                    :type: ``ndarray``
691

692
                hidden_weights_learning_rate
693
                    see `hidden_weights_learning_rate <GRUComposition.hidden_weights_learning_rate>`
694

695
                    :default value: True
696
                    :type: ``bool``
697

698
                input_biases_learning_rate
699
                    see `input_biases_learning_rate <GRUComposition.input_weights_learning_rate>`
700

701
                    :default value: True
702
                    :type: ``bool``
703

704
                input_size
705
                    see `input_size <GRUComposition.input_size>`
706

707
                    :default value: 1
708
                    :type: ``int``
709

710
                input_weights_learning_rate
711
                    see `input_weights_learning_rate <GRUComposition.input_weights_learning_rate>`
712

713
                    :default value: True
714
                    :type: ``bool``
715

716
                learning_rate
717
                    see `learning_results <GRUComposition.learning_rate>`
718

719
                    :default value: []
720
                    :type: ``list``
721

722
                random_state
723
                    see `random_state <NormalDist.random_state>`
724

725
                    :default value: None
726
                    :type: ``numpy.random.RandomState``
727

728
        """
729
        input_size = Parameter(1, structural=True, stateful=False)
1✔
730
        hidden_size = Parameter(1, structural=True, stateful=False)
1✔
731
        bias = Parameter(False, structural=True, stateful=False)
1✔
732
        gru_mech = Parameter(None, structural=True, stateful=False)
1✔
733
        enable_learning = Parameter(True, structural=True)
1✔
734
        learning_rate = Parameter(.001, modulable=True)
1✔
735
        input_weights_learning_rate = Parameter(True, structural=True)
1✔
736
        hidden_weights_learning_rate = Parameter(True, structural=True)
1✔
737
        input_biases_learning_rate = Parameter(True, structural=True)
1✔
738
        hidden_biases_learning_rate = Parameter(True, structural=True)
1✔
739
        random_state = Parameter(None, loggable=False, getter=_random_state_getter, dependencies='seed')
1✔
740
        seed = Parameter(DEFAULT_SEED(), modulable=True, setter=_seed_setter)
1✔
741

742
        def _validate_input_size(self, size):
1✔
743
            if not (isinstance(size, np.ndarray) and isinstance(size.tolist(),int)):
1!
NEW
744
                return 'must be an integer'
×
745

746
        def _validate_hidden_size(self, size):
1✔
747
            if not (isinstance(size, np.ndarray) and isinstance(size.tolist(),int)):
1!
NEW
748
                return 'must be an integer'
×
749

750
        def _validate_bias(self, bias):
1✔
751
            if not isinstance(bias, bool):
1!
NEW
752
                return 'must be a boolean'
×
753

754
        def _validate_input_weights_learning_rate(self, rate):
1✔
755
            if not isinstance(rate, (float, bool)):
1!
NEW
756
                return 'must be a float or a boolean'
×
757

758
        def _validate_hidden_weights_learning_rate(self, rate):
1✔
759
            if not isinstance(rate, (float, bool)):
1!
NEW
760
                return 'must be a float or a boolean'
×
761

762
        def _validate_input_biases_learning_rate(self, rate):
1✔
763
            if not isinstance(rate, (float, bool)):
1!
NEW
764
                return 'must be a float or a boolean'
×
765

766
        def _validate_hidden_biases_learning_rate(self, rate):
1✔
767
            if not isinstance(rate, (float, bool)):
1!
NEW
768
                return 'must be a float or a boolean'
×
769

770
    @check_user_specified
1✔
771
    def __init__(self,
1✔
772
                 input_size=None,
773
                 hidden_size=None,
774
                 bias=None,
775
                 # num_layers:int=1,
776
                 # batch_first:bool=False,
777
                 # dropout:float=0.0,
778
                 # bidirectional:bool=False,
779
                 enable_learning:bool=True,
780
                 learning_rate:float=None,
781
                 optimizer_params:dict=None,
782
                 random_state=None,
783
                 seed=None,
784
                 name="GRU Composition",
785
                 **kwargs):
786

787
        # Instantiate Composition -------------------------------------------------------------------------
788

789
        super().__init__(name=name,
1✔
790
                         input_size=input_size,
791
                         hidden_size=hidden_size,
792
                         bias=bias,
793
                         # num_layers=num_layers,
794
                         # batch_first=batch_first,
795
                         # dropout=dropout,
796
                         # bidirectional=bidirectional,
797
                         enable_learning=enable_learning,
798
                         learning_rate=learning_rate,
799
                         optimizer_params=optimizer_params,
800
                         random_state = random_state,
801
                         seed = seed,
802
                         **kwargs
803
                         )
804

805
        input_size = self.input_size
1✔
806
        hidden_size = self.hidden_size
1✔
807

808
        self._construct_pnl_composition(input_size, hidden_size,
1✔
809
                                    context = Context(source=ContextFlags.COMMAND_LINE, string='FROM GRU'))
810

811
        self._assign_gru_specific_attributes(input_size, hidden_size)
1✔
812

813

814
    # *****************************************************************************************************************
815
    # ******************************  Nodes and Pathway Construction Methods  *****************************************
816
    # *****************************************************************************************************************
817
    #region
818
    # Construct Nodes --------------------------------------------------------------------------------
819

820
    def _construct_pnl_composition(self, input_size, hidden_size, context):
1✔
821
        """Construct Nodes and Projections for GRUComposition"""
822
        hidden_shape = np.ones(hidden_size)
1✔
823

824
        self.input_node = ProcessingMechanism(name=INPUT_NODE,
1✔
825
                                              input_shapes=input_size)
826

827
        # Two input_ports are used to separately gate input its recurrent Projection and from new_node
828
        # LinearCombination function of each InputPort is explicitly specified to allow for gating by a vector
829
        self.hidden_layer_node = ProcessingMechanism(name=HIDDEN_LAYER,
1✔
830
                                                     input_shapes=[hidden_size, hidden_size],
831
                                                     input_ports=[
832
                                                         InputPort(name='NEW INPUT',
833
                                                                   function=LinearCombination(scale=hidden_shape)),
834
                                                         InputPort(name='RECURRENT',
835
                                                                   function=LinearCombination(scale=hidden_shape))],
836
                                                     function=LinearCombination(operation=SUM))
837

838
        # Two input_ports are used to allow the input from the hidden_layer_node to be gated but not the input_node
839
        # The node's LinearCombination function is then used to combine the two inputs
840
        # And then Tanh is assigend as the function of the OutputPort to do the nonlinear transform
841
        self.new_node = ProcessingMechanism(name=NEW_NODE,
1✔
842
                                            input_shapes=[hidden_size, hidden_size],
843
                                            input_ports=['FROM INPUT',
844
                                                         InputPort(name='FROM HIDDEN',
845
                                                                   function=LinearCombination(scale=hidden_shape))],
846
                                            function=LinearCombination,
847
                                            output_ports=[OutputPort(name='TO HIDDEN LAYER INPUT',
848
                                                                     function=Tanh)])
849

850
        # Gates input to hidden_layer_node from its recurrent Projection and from new_node
851
        self.update_node = GatingMechanism(name=UPDATE_NODE,
1✔
852
                                           default_allocation=hidden_shape,
853
                                           function=Logistic,
854
                                           gating_signals=[
855
                                               GatingSignal(name='RECURRENT GATING SIGNAL',
856
                                                            default_allocation=hidden_shape,
857
                                                            gate=self.hidden_layer_node.input_ports['RECURRENT']),
858
                                               GatingSignal(name='NEW GATING SIGNAL',
859
                                                            default_allocation=hidden_shape,
860
                                                            transfer_function=Linear(scale=-1,offset=1),
861
                                                            gate=self.hidden_layer_node.input_ports['NEW INPUT'])])
862
        self.new_gate = self.update_node.gating_signals['NEW GATING SIGNAL'].efferents[0]
1✔
863
        self.new_gate.name = 'NEW GATE'
1✔
864
        self.recurrent_gate = self.update_node.gating_signals['RECURRENT GATING SIGNAL'].efferents[0]
1✔
865
        self.recurrent_gate.name = 'RECURRENT GATE'
1✔
866

867
        self.reset_node = GatingMechanism(name=RESET_NODE,
1✔
868
                                          default_allocation=hidden_shape,
869
                                          function=Logistic,
870
                                          gating_signals=[
871
                                              GatingSignal(name='RESET GATING SIGNAL',
872
                                                           default_allocation=hidden_shape,
873
                                                           gate=self.new_node.input_ports['FROM HIDDEN'])])
874
        self.reset_gate = self.reset_node.gating_signals['RESET GATING SIGNAL'].efferents[0]
1✔
875
        self.reset_gate.name = 'RESET GATE'
1✔
876

877
        self.output_node = ProcessingMechanism(name=OUTPUT_NODE,
1✔
878
                                               input_shapes=hidden_size,
879
                                               function=Linear)
880

881
        self.add_nodes([self.input_node, self.new_node, self.reset_node,
1✔
882
                        self.update_node, self.output_node, self.hidden_layer_node],
883
                       context=context)
884

885
        def init_wts(sender_size, receiver_size):
1✔
886
            """Initialize weights for Projections"""
887
            sqrt_val = np.sqrt(hidden_size)
1✔
888
            return np.random.uniform(-sqrt_val, sqrt_val, (sender_size, receiver_size))
1✔
889

890
        # Learnable: wts_in, wts_iu, wts_ir, wts_hn, wts_hu,, wts_hr
891
        self.wts_in = MappingProjection(name='INPUT TO NEW WEIGHTS',
1✔
892
                                        sender=self.input_node,
893
                                        receiver=self.new_node.input_ports['FROM INPUT'],
894
                                        learnable=True,
895
                                        matrix=init_wts(input_size, hidden_size))
896

897
        self.wts_iu = MappingProjection(name='INPUT TO UPDATE WEIGHTS',
1✔
898
                                        sender=self.input_node,
899
                                        receiver=self.update_node.input_ports[OUTCOME],
900
                                        learnable=True,
901
                                        matrix=init_wts(input_size, hidden_size))
902

903
        self.wts_ir = MappingProjection(name='INPUT TO RESET WEIGHTS',
1✔
904
                                        sender=self.input_node,
905
                                        receiver=self.reset_node.input_ports[OUTCOME],
906
                                        learnable=True,
907
                                        matrix=init_wts(input_size, hidden_size))
908

909
        self.wts_nh = MappingProjection(name='NEW TO HIDDEN WEIGHTS',
1✔
910
                                        sender=self.new_node,
911
                                        receiver=self.hidden_layer_node.input_ports['NEW INPUT'],
912
                                        learnable=False,
913
                                        matrix=IDENTITY_MATRIX)
914

915
        self.wts_hh = MappingProjection(name='HIDDEN RECURRENT WEIGHTS',
1✔
916
                                        sender=self.hidden_layer_node,
917
                                        receiver=self.hidden_layer_node.input_ports['RECURRENT'],
918
                                        learnable=False,
919
                                        matrix=IDENTITY_MATRIX)
920

921
        self.wts_hn = MappingProjection(name='HIDDEN TO NEW WEIGHTS',
1✔
922
                                        sender=self.hidden_layer_node,
923
                                        receiver=self.new_node.input_ports['FROM HIDDEN'],
924
                                        learnable=True,
925
                                        matrix=init_wts(hidden_size, hidden_size))
926

927
        self.wts_hr = MappingProjection(name='HIDDEN TO RESET WEIGHTS',
1✔
928
                                        sender=self.hidden_layer_node,
929
                                        receiver=self.reset_node.input_ports[OUTCOME],
930
                                        learnable=True,
931
                                        matrix=init_wts(hidden_size, hidden_size))
932

933
        self.wts_hu = MappingProjection(name='HIDDEN TO UPDATE WEIGHTS',
1✔
934
                                        sender=self.hidden_layer_node,
935
                                        receiver=self.update_node.input_ports[OUTCOME],
936
                                        learnable=True,
937
                                        matrix=init_wts(hidden_size, hidden_size))
938

939
        self.wts_ho = MappingProjection(name='HIDDEN TO OUTPUT WEIGHTS',
1✔
940
                                        sender=self.hidden_layer_node,
941
                                        receiver=self.output_node,
942
                                        learnable=False,
943
                                        matrix=IDENTITY_MATRIX)
944

945
        self.learnable_projections = [self.wts_in, self.wts_iu, self.wts_ir,
1✔
946
                                      self.wts_hn, self.wts_hr, self.wts_hu]
947

948
        self.add_projections([self.wts_in, self.wts_iu, self.wts_ir, self.wts_nh,
1✔
949
                              self.wts_hh, self.wts_hn, self.wts_hr, self.wts_hu, self.wts_ho],
950
                             context=context)
951

952
        if self.bias:
1✔
953
            self.bias_in_node = ProcessingMechanism(name='BIAS NODE IN', default_variable=[1])
1✔
954
            self.bias_in = MappingProjection(name='BIAS IN',
1✔
955
                                             sender=self.bias_in_node,
956
                                             receiver=self.new_node.input_ports['FROM INPUT'],
957
                                             learnable=True)
958

959
            self.bias_iu_node = ProcessingMechanism(name='BIAS NODE IU', default_variable=[1])
1✔
960
            self.bias_iu = MappingProjection(name='BIAS IU',
1✔
961
                                             sender=self.bias_iu_node,
962
                                             receiver=self.update_node.input_ports[OUTCOME],
963
                                             learnable=True)
964

965
            self.bias_ir_node = ProcessingMechanism(name='BIAS NODE IR', default_variable=[1])
1✔
966
            self.bias_ir = MappingProjection(name='BIAS IR',
1✔
967
                                             sender=self.bias_ir_node,
968
                                             receiver=self.reset_node.input_ports[OUTCOME],
969
                                             learnable=True)
970

971
            self.bias_hn_node = ProcessingMechanism(name='BIAS NODE HN', default_variable=[1])
1✔
972
            self.bias_hn = MappingProjection(name='BIAS HN',
1✔
973
                                             sender=self.bias_hn_node,
974
                                             receiver=self.new_node.input_ports['FROM HIDDEN'],
975
                                             learnable=True)
976

977
            self.bias_hr_node = ProcessingMechanism(name='BIAS NODE HR', default_variable=[1])
1✔
978
            self.bias_hr = MappingProjection(name='BIAS HR',
1✔
979
                                             sender=self.bias_hr_node,
980
                                             receiver=self.reset_node.input_ports[OUTCOME],
981
                                             learnable=True)
982

983
            self.bias_hu_node = ProcessingMechanism(name='BIAS NODE HU', default_variable=[1])
1✔
984
            self.bias_hu = MappingProjection(name='BIAS HU',
1✔
985
                                             sender=self.bias_hu_node,
986
                                             receiver=self.update_node.input_ports[OUTCOME],
987
                                             learnable=True)
988

989
            self.add_nodes([(self.bias_ir_node, NodeRole.BIAS),
1✔
990
                            (self.bias_iu_node, NodeRole.BIAS),
991
                            (self.bias_in_node, NodeRole.BIAS),
992
                            (self.bias_hr_node, NodeRole.BIAS),
993
                            (self.bias_hu_node, NodeRole.BIAS),
994
                            (self.bias_hn_node, NodeRole.BIAS)],
995
                           context=Context(source=ContextFlags.COMMAND_LINE, string='FROM GRU')
996
                           )
997

998
            self.biases = [self.bias_ir, self.bias_iu, self.bias_in,
1✔
999
                                  self.bias_hr, self.bias_hu, self.bias_hn]
1000
            self.add_projections(self.biases, context=context)
1✔
1001

1002
        self.scheduler.add_condition(self.update_node, conditions.AfterNodes(self.reset_node))
1✔
1003
        self.scheduler.add_condition(self.new_node, conditions.AfterNodes(self.update_node))
1✔
1004
        self.scheduler.add_condition(self.hidden_layer_node, conditions.AfterNodes(self.new_node))
1✔
1005

1006
        self._set_learning_attributes()
1✔
1007

1008
        self._analyze_graph()
1✔
1009

1010
    def _assign_gru_specific_attributes(self, input_size, hidden_size):
1✔
1011
        for node in self.nodes:
1✔
1012
            node.exclude_from_show_graph = True
1✔
1013
        self.gru_mech = ProcessingMechanism(name=GRU_NODE,
1✔
1014
                                            input_shapes=input_size,
1015
                                            function=MatrixTransform(
1016
                                                default_variable=np.zeros(input_size),
1017
                                                matrix=get_matrix(FULL_CONNECTIVITY_MATRIX,input_size, hidden_size)))
1018
        self._input_comp_nodes_to_pytorch_nodes_map = {self.input_node: self.gru_mech}
1✔
1019
        self._trained_comp_nodes_to_pytorch_nodes_map = {self.output_node: self.gru_mech}
1✔
1020
        self.target_node = ProcessingMechanism(default_variable = np.zeros_like(self.gru_mech.value),
1✔
1021
                                               name= GRU_TARGET_NODE)
1022

1023
    def _set_learning_attributes(self):
1✔
1024
        """Set learning-related attributes for Node and Projections
1025
        """
1026
        learning_rate = self.enable_learning
1✔
1027

1028
        for projection in self.learnable_projections:
1✔
1029

1030
            if self.enable_learning is False:
1!
NEW
1031
                projection.learnable = False
×
NEW
1032
                continue
×
1033

1034
            if learning_rate is False:
1!
NEW
1035
                projection.learnable = False
×
NEW
1036
                continue
×
1037

1038
            elif learning_rate is True:
1✔
1039
                # Default (GRUComposition's learning_rate) is used for all field_weight Projections:
1040
                learning_rate = self.learning_rate
1✔
1041

1042
            assert isinstance(learning_rate, (int, float)), \
1✔
1043
                (f"PROGRAM ERROR: learning_rate for {projection.sender.owner.name} is not a valid value.")
1044

1045
            projection.learnable = True
1✔
1046
            if projection.learning_mechanism:
1!
NEW
1047
                projection.learning_mechanism.learning_rate = learning_rate
×
1048

1049
    def get_weights(self, context=None):
1✔
NEW
1050
        wts_ir = self.wts_ir.parameters.matrix.get(context)
×
NEW
1051
        wts_iu = self.wts_iu.parameters.matrix.get(context)
×
NEW
1052
        wts_in = self.wts_in.parameters.matrix.get(context)
×
NEW
1053
        wts_hr = self.wts_hr.parameters.matrix.get(context)
×
NEW
1054
        wts_hu = self.wts_hu.parameters.matrix.get(context)
×
NEW
1055
        wts_hn = self.wts_hn.parameters.matrix.get(context)
×
NEW
1056
        return wts_ir, wts_iu, wts_in, wts_hr, wts_hu, wts_hn
×
1057
    #endregion
1058

1059
    @handle_external_context()
1✔
1060
    def set_weights(self, weights:Union[list, np.ndarray], biases:Union[list, np.ndarray], context=None):
1✔
1061
        """Set weights for Projections to input_node and hidden_layer_node."""
1062

1063
        # MODIFIED 2/16/25 NEW:
1064
        # FIX: CHECK IF TORCH GRU EXISTS YET (CHECK FOR pytorch_representation != None; i.e., LEARNING HAS OCCURRED;
1065
        #      IF SO, ADD CALL TO PytorchGRUPRojectionWrapper HELPER METHOD TO SET TORCH GRU PARAMETERS
1066
        for wts, proj in zip(weights,
1✔
1067
                       [self.wts_ir, self.wts_iu, self.wts_in, self.wts_hr, self.wts_hu, self.wts_hn]):
1068
            valid_shape = self._get_valid_weights_shape(proj)
1✔
1069
            assert wts.shape == valid_shape, \
1✔
1070
                (f"PROGRAM ERROR: Shape of weights in 'weights' arg of '{self.name}.set_weights' "
1071
                 f"({wts.shape}) does not match required shape ({valid_shape}).)")
1072
            proj.parameters.matrix._set(wts, context)
1✔
1073
            proj.parameter_ports['matrix'].parameters.value._set(wts, context)
1✔
1074
        # MODIFIED 3/11/25 END
1075

1076
        if biases:
1✔
1077
            for torch_bias, pnl_bias in zip(biases, [self.bias_ir, self.bias_iu, self.bias_in,
1✔
1078
                                                     self.bias_hr, self.bias_hu, self.bias_hn]):
1079
                valid_shape = self._get_valid_weights_shape(pnl_bias)
1✔
1080
                assert torch_bias.shape == valid_shape, \
1✔
1081
                    (f"PROGRAM ERROR: Shape of biases in 'bias' arg of '{self.name}.set_weights' "
1082
                     f"({torch_bias.shape}) does not match required shape ({valid_shape}).")
1083
                pnl_bias.parameters.matrix._set(torch_bias, context)
1✔
1084
                pnl_bias.parameter_ports['matrix'].parameters.value._set(torch_bias, context)
1✔
1085

1086
    @handle_external_context()
1✔
1087
    def infer_backpropagation_learning_pathways(self, execution_mode, context=None)->list:
1✔
1088
        if execution_mode is not pnlvm.ExecutionMode.PyTorch:
1✔
1089
            raise GRUCompositionError(f"Learning in {self.componentCategory} "
1090
                                      f"is not supported for {execution_mode.name}.")
1091

1092
        # Create Mechanism the function fo which will be the Pytorch GRU module
1093
        # Note:  function is a placeholder, to induce proper variable and value dimensions;
1094
        #        will be replaced by PyTorch GRU function in PytorchGRUMechanismWrapper
1095
        target_mech = self.target_node
1✔
1096

1097
        # Add target Node to GRUComposition
1098
        self.add_node(target_mech, required_roles=[NodeRole.TARGET, NodeRole.LEARNING],
1✔
1099
                      context=Context(source=ContextFlags.METHOD, string='FROM GRU'))
1100
        self.exclude_node_roles(target_mech, NodeRole.OUTPUT, context)
1✔
1101

1102
        for output_port in target_mech.output_ports:
1✔
1103
            output_port.parameters.require_projection_in_composition.set(False, override=True)
1✔
1104
        self.targets_from_outputs_map = {target_mech: self.gru_mech}
1✔
1105
        self.outputs_to_targets_map = {self.gru_mech: target_mech}
1✔
1106

1107
        return [target_mech]
1✔
1108

1109
    def _get_pytorch_backprop_pathway(self, input_node, context)->list:
1✔
NEW
1110
        return [[self.gru_mech]]
×
1111

1112
    # *****************************************************************************************************************
1113
    # *********************************** Execution Methods  **********************************************************
1114
    # *****************************************************************************************************************
1115
    #region
1116

1117
    def _get_execution_mode(self, execution_mode):
1✔
1118
        """Parse execution_mode argument and return a valid execution mode for the learn() method"""
1119
        if execution_mode is None:
1!
NEW
1120
            if self.execution_mode_warned_about_default is False:
×
NEW
1121
                warnings.warn(f"The execution_mode argument was not specified in the learn() method of {self.name}; "
×
1122
                              f"ExecutionMode.PyTorch will be used by default.")
NEW
1123
                self.execution_mode_warned_about_default = True
×
NEW
1124
            execution_mode = ExecutionMode.PyTorch
×
1125
        return execution_mode
1✔
1126

1127
    def _add_dependency(self,
1✔
1128
                        sender:ProcessingMechanism,
1129
                        projection:MappingProjection,
1130
                        receiver:ProcessingMechanism,
1131
                        dependency_dict:dict,
1132
                        queue:deque,
1133
                        comp:AutodiffComposition):
1134
        """Override to implement direct pathway through gru_mech for pytorch backprop pathway.
1135
        """
1136
        # FIX: 3/9/25 CLEAN THIS UP: WRT ASSIGNMENT OF _pytorch_projections BELOW:
1137
        if self._pytorch_projections:
1✔
1138
            assert len(self._pytorch_projections) == 2, \
1✔
1139
                (f"PROGRAM ERROR: {self.name}._pytorch_projections should have only two Projections, but has "
1140
                 f"{len(self._pytorch_projections)}: {' ,'.join([proj.name for proj in self._pytorch_projections])}.")
1141
            direct_proj_in = self._pytorch_projections[0]
1✔
1142
            direct_proj_out = self._pytorch_projections[1]
1✔
1143

1144
        else:
1145
            try:
1✔
1146
                direct_proj_in = MappingProjection(name="Projection to GRU COMP",
1✔
1147
                                                   sender=sender,
1148
                                                   receiver=self.gru_mech,
1149
                                                   learnable=projection.learnable)
1150
                self._pytorch_projections.append(direct_proj_in)
1✔
NEW
1151
            except DuplicateProjectionError:
×
1152
                assert False, "PROGRAM ERROR: Duplicate Projection to GRU COMP"
1153

1154
            try:
1✔
1155
                direct_proj_out = MappingProjection(name="Projection from GRU COMP",
1✔
1156
                                                    sender=self.gru_mech,
1157
                                                    receiver=self.output_CIM,
1158
                                                    # receiver=self.output_CIM.input_ports[0],
1159
                                                    learnable=False)
1160
                self._pytorch_projections.append(direct_proj_out)
1✔
NEW
1161
            except DuplicateProjectionError:
×
1162
                assert False, "PROGRAM ERROR: Duplicate Projection to GRU COMP"
1163

1164
        # FIX: GET ALL EFFERENTS OF OUTPUT NODE HERE
1165
        # output_node = self.output_CIM.output_port.efferents[0].receiver.owner
1166
        # output_node = self.output_CIM.output_port
1167
        output_node = self.output_CIM
1✔
1168

1169
        # GRU pathway:
1170
        dependency_dict[direct_proj_in]=sender
1✔
1171
        dependency_dict[self.gru_mech]=direct_proj_in
1✔
1172
        dependency_dict[direct_proj_out]=self.gru_mech
1✔
1173
        dependency_dict[output_node]=direct_proj_out
1✔
1174

1175
        # FIX : ADD ALL EFFERENTS OF OUTPUT NODE HERE:
1176
        queue.append((self.gru_mech, self))
1✔
1177

1178
    def _identify_target_nodes(self, context):
1✔
1179
        return [self.gru_mech]
1✔
1180

1181
    def add_node(self, node, required_roles=None, context=None):
1✔
1182
        """Override if called from command line to disallow modification of GRUComposition"""
1183
        if context is None:
1✔
1184
            raise CompositionError(f"Nodes cannot be added to a {self.componentCategory}: ('{self.name}').")
1185
        super().add_node(node, required_roles, context)
1✔
1186

1187
    def add_projection(self, *args, **kwargs):
1✔
1188
        """Override if called from command line to disallow modification of GRUComposition"""
1189
        if CONTEXT not in kwargs or kwargs[CONTEXT] is None:
1✔
1190
            raise CompositionError(f"Projections cannot be added to a {self.componentCategory}: ('{self.name}'.")
1191
        return super().add_projection(*args, **kwargs)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc