• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.14
/psyneulink/library/compositions/grucomposition/pytorchGRUwrappers.py
1
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
2
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
3
#     http://www.apache.org/licenses/LICENSE-2.0
4
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
5
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6
# See the License for the specific language governing permissions and limitations under the License.
7

8
# ********************************************* PytorchComponent *************************************************
9

10
"""PyTorch wrapper for GRUComposition"""
11

12
import numpy as np
1✔
13
import graph_scheduler
1✔
14
import torch
1✔
15
from typing import Union, Optional, Literal, Tuple
1✔
16

17
from psyneulink.core.compositions.composition import NodeRole
1✔
18
from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
1✔
19
from psyneulink.core.components.projections.projection import DuplicateProjectionError
1✔
20
from psyneulink.library.compositions.autodiffcomposition import AutodiffComposition
1✔
21
from psyneulink.library.compositions.pytorchwrappers import PytorchCompositionWrapper, PytorchMechanismWrapper, \
1✔
22
    PytorchProjectionWrapper, PytorchFunctionWrapper, ENTER_NESTED, EXIT_NESTED, SUBCLASS_WRAPPERS
23
from psyneulink.core.globals.context import Context, handle_external_context
1✔
24
from psyneulink.core.globals.utilities import convert_to_list
1✔
25
from psyneulink.core.globals.keywords import (
1✔
26
    ALL, CONTEXT, INPUT, INPUTS, LEARNING, NODE_VALUES, RUN, SHOW_PYTORCH, SYNCH, SYNCH_WITH_PNL_OPTIONS)
27
from psyneulink.core.globals.log import LogCondition
1✔
28

29
__all__ = ['PytorchGRUCompositionWrapper']
1✔
30

31
class PytorchGRUCompositionWrapper(PytorchCompositionWrapper):
1✔
32
    """Wrapper for GRUComposition as a Pytorch Module
33
    Manage the exchange of the Composition's Projection `Matrices <MappingProjection_Matrix>`
34
    and the Pytorch GRU Module's parameters, and return its output value.
35
    """
36
    def __init__(self,
1✔
37
                 composition,
38
                 device,
39
                 outer_creator=None,
40
                 dtype=None,
41
                 subclass_components=None,
42
                 context=None,
43
                 base_context=Context(execution_id=None),
44
                 ):
45

46
        self._early_init(composition, device)
1✔
47

48
        _node_wrapper_pairs = self._instantiate_GRU_pytorch_mechanism_wrappers(composition, device, context)
1✔
49
        gru_pytorch_node = _node_wrapper_pairs[0][1]
1✔
50
        torch_gru = gru_pytorch_node.function.function
1✔
51
        _projection_wrapper_pairs = self._instantiate_GRU_pytorch_projection_wrappers(torch_gru, device, context)
1✔
52
        execution_sets = [{gru_pytorch_node}]
1✔
53

54
        super().__init__(composition=composition,
1✔
55
                         device=device,
56
                         outer_creator=outer_creator,
57
                         subclass_components=(_node_wrapper_pairs,
58
                                              _projection_wrapper_pairs,
59
                                              execution_sets,
60
                                              Context()),
61
                         context=context,
62
                         base_context=base_context,
63
                         )
64

65
        # The following have to be after super(), so that they can be assigned as attributes of torch.nn.module
66

67
        # IMPLEMENTATION NOTE:
68
        #    This is needed for access by subcomponents to PytorchGRUCompositionWrapper when GRUComposition is nested,
69
        #    and so _build_pytorch_representation is called on the outer Composition but not GRUComposition itelf;
70
        #    access must be provided via GRUComposition's pytorch_representation, rather than directly assigning
71
        #    PytorchGRUCompositionWrapper as an attribute on the subcomponents, since doing the latter introduces a
72
        #    recursion when torch.nn.module.state_dict() is called on any wrapper in the hiearchay.
73
        if self.composition.pytorch_representation is None:
1✔
74
            self.composition.pytorch_representation = self
1✔
75
        self.torch_gru = torch_gru
1✔
76
        self.gru_pytorch_node = gru_pytorch_node
1✔
77

78
        # Note: this has to be done after call to super, so that projections_map has been populated
79
        self.copy_weights_to_torch_gru(context)
1✔
80

81
        self.torch_dtype = dtype or torch.float64
1✔
82
        self.numpy_dtype = torch.tensor([10], dtype=self.torch_dtype).numpy().dtype
1✔
83

84
    def _instantiate_GRU_pytorch_mechanism_wrappers(self, gru_comp, device, context):
1✔
85
        """Instantiate PytorchMechanismWrapper for GRU Node"""
86
        gru_mech = gru_comp.gru_mech
1✔
87
        pytorch_node = PytorchGRUMechanismWrapper(mechanism=gru_mech,
1✔
88
                                                  composition=gru_comp,
89
                                                  component_idx=0,
90
                                                  use=[LEARNING, SHOW_PYTORCH],
91
                                                  dtype=self.torch_dtype,
92
                                                  device=device,
93
                                                  context=context)
94

95
        # Check if there is no source Node for the InputPort of the GRUComposition.input_CIM
96
        source = gru_comp.input_CIM._get_source_node_for_input_CIM(gru_comp.input_node.afferents[0].sender)
1✔
97
        if source is None or not gru_comp.is_nested:
1✔
98
            # If either the GRUComposition is not nested,
99
            # or it does not receive any Projections from the outer Composition,
100
            # then treat it as an INPUT Node (that receives inputs to the outer Composition in collect_afferents()
101
            gru_mech._is_input = True
1✔
102
            pytorch_node._is_input = True
1✔
103
            pytorch_node.afferents = INPUT
1✔
104
        destination = gru_comp.output_CIM._get_destination_info_for_output_CIM(gru_comp.output_node.efferents[
1✔
105
                                                                                   0].receiver)
106
        if destination is None or not gru_comp.is_nested:
1✔
107
            pytorch_node._is_output = True
1✔
108

109
        return [(gru_mech, pytorch_node)]
1✔
110

111
    def _instantiate_GRU_pytorch_projection_wrappers(self, torch_gru, device, context):
1✔
112
        """Create PytorchGRUProjectionWrappers for each learnable Projection of GRUComposition
113
        For each PytorchGRUProjectionWrapper, assign the current weight matrix of the PNL Projection
114
        to the corresponding part of the tensor in the parameter of the Pytorch GRU module.
115
        """
116

117
        pnl = self.composition
1✔
118
        self.torch_gru_parameters = torch_gru.parameters
1✔
119

120
        _projection_wrapper_pairs = []
1✔
121

122
        # Pytorch parameter info
123
        hid_len = pnl.hidden_size
1✔
124
        z_idx = hid_len
1✔
125
        n_idx = 2 * hid_len
1✔
126

127
        w_ih = torch_gru.state_dict()['weight_ih_l0']
1✔
128
        w_hh = torch_gru.state_dict()['weight_hh_l0']
1✔
129
        torch_gru_wts_indices = [(w_ih, slice(None, z_idx)), (w_ih, slice(z_idx, n_idx)),(w_ih, slice(n_idx, None)),
1✔
130
                                 (w_hh, slice(None, z_idx)), (w_hh, slice(z_idx, n_idx)), (w_hh, slice(n_idx, None))]
131
        pnl_proj_wts = [pnl.wts_ir, pnl.wts_iu, pnl.wts_in, pnl.wts_hr, pnl.wts_hu, pnl.wts_hn]
1✔
132
        for pnl_proj, torch_matrix in zip(pnl_proj_wts, torch_gru_wts_indices):
1✔
133
            _projection_wrapper_pairs.append((pnl_proj,
1✔
134
                                             PytorchGRUProjectionWrapper(projection=pnl_proj,
135
                                                                         torch_parameter=torch_matrix,
136
                                                                         use=SYNCH,
137
                                                                         composition=self.composition,
138
                                                                         device=device)))
139
        self._pnl_refs_to_torch_params_map = {'w_ih': w_ih, 'w_hh':  w_hh}
1✔
140

141
        if pnl.bias:
1✔
142
            from psyneulink.library.compositions.grucomposition.grucomposition import GRU_NODE
1✔
143
            assert torch_gru.bias, f"PROGRAM ERROR: '{pnl.name}' has bias=True but {GRU_NODE}.bias=False. "
1✔
144
            b_ih = torch_gru.state_dict()['bias_ih_l0']
1✔
145
            b_hh = torch_gru.state_dict()['bias_hh_l0']
1✔
146
            torch_gru_bias_indices = [(b_ih, slice(None, z_idx)), (b_ih, slice(z_idx, n_idx)),(b_ih, slice(n_idx, None)),
1✔
147
                                      (b_hh, slice(None, z_idx)), (b_hh, slice(z_idx, n_idx)), (b_hh, slice(n_idx, None))]
148
            pnl_biases = [pnl.bias_ir, pnl.bias_iu, pnl.bias_in, pnl.bias_hr, pnl.bias_hu, pnl.bias_hn]
1✔
149
            for pnl_bias_proj, torch_bias in zip(pnl_biases, torch_gru_bias_indices):
1✔
150
                _projection_wrapper_pairs.append((pnl_bias_proj,
1✔
151
                                                  PytorchGRUProjectionWrapper(projection=pnl_bias_proj,
152
                                                                              torch_parameter=torch_bias,
153
                                                                              use=SYNCH,
154
                                                                              composition=pnl,
155
                                                                              device=device)))
156
            self._pnl_refs_to_torch_params_map.update({'b_ih': b_ih, 'b_hh':  b_hh})
1✔
157

158
        return _projection_wrapper_pairs
1✔
159

160
    def _flatten_for_pytorch(self,
1✔
161
                             pnl_proj,
162
                             sndr_mech,
163
                             rcvr_mech,
164
                             nested_port,
165
                             nested_mech,
166
                             outer_comp,
167
                             outer_comp_pytorch_rep,
168
                             access,
169
                             context,
170
                             base_context=Context(execution_id=None),
171
                             ) -> Tuple:
172
        """Return PytorchProjectionWrappers for Projections to/from GRUComposition to nested Composition
173
        Replace GRUComposition's nodes with gru_mech and projections to and from it.
174
        """
175

176
        direct_proj = None
1✔
177
        use = [LEARNING, SYNCH]
1✔
178

179
        if access == ENTER_NESTED:
1✔
180
            sndr_mech_wrapper = outer_comp_pytorch_rep.nodes_map[sndr_mech]
1✔
181
            rcvr_mech_wrapper = self.nodes_map[self.composition.gru_mech]
1✔
182
            try:
1✔
183
                direct_proj = MappingProjection(name="Projection to GRU COMP",
1✔
184
                                             sender=pnl_proj.sender,
185
                                             receiver=self.composition.gru_mech,
186
                                             learnable=pnl_proj.learnable)
187
            except DuplicateProjectionError:
1✔
188
                direct_proj = self.composition.gru_mech.afferents[0]
1✔
189
            else:
190
                direct_proj._initialize_from_context(context, base_context)
1✔
191
            # Index of input_CIM.output_ports for which pnl_proj is an efferent
192
            sender_port_idx = pnl_proj.sender.owner.output_ports.index(pnl_proj.sender)
1✔
193

194
        elif access == EXIT_NESTED:
1✔
195
            sndr_mech_wrapper = self.nodes_map[self.composition.gru_mech]
1✔
196
            rcvr_mech_wrapper = outer_comp_pytorch_rep.nodes_map[rcvr_mech]
1✔
197
            try:
1✔
198
                direct_proj = MappingProjection(name="Projection from GRU COMP",
1✔
199
                                                sender=self.composition.gru_mech,
200
                                                receiver=pnl_proj.receiver,
201
                                                learnable=pnl_proj.learnable)
NEW
202
            except DuplicateProjectionError:
×
NEW
203
                direct_proj = self.composition.gru_mech.efferents[0]
×
204
            else:
205
                direct_proj._initialize_from_context(context, base_context)
1✔
206
            # gru_mech has only one output_port
207
            sender_port_idx = 0
1✔
208

209
        else:
210
            assert False, f"PROGRAM ERROR: access must be ENTER_NESTED or EXIT_NESTED, not {access}"
211

212
        if direct_proj:
1!
213
            component_idx = list(outer_comp._inner_projections).index(pnl_proj)
1✔
214
            proj_wrapper = PytorchProjectionWrapper(projection=direct_proj,
1✔
215
                                                    pnl_proj=pnl_proj,
216
                                                    component_idx=component_idx,
217
                                                    sender_port_idx=sender_port_idx,
218
                                                    use=[SHOW_PYTORCH],
219
                                                    device=self.device,
220
                                                    sender_wrapper=sndr_mech_wrapper,
221
                                                    receiver_wrapper=rcvr_mech_wrapper,
222
                                                    context=context)
223
            outer_comp_pytorch_rep.projection_wrappers.append(proj_wrapper)
1✔
224
            outer_comp_pytorch_rep.projections_map[direct_proj] = proj_wrapper
1✔
225
            outer_comp_pytorch_rep.composition._pytorch_projections.append(direct_proj)
1✔
226

227
        return pnl_proj, sndr_mech_wrapper, rcvr_mech_wrapper, use
1✔
228

229
    @handle_external_context()
1✔
230
    def forward(self, inputs, optimization_num, synch_with_pnl_options, context=None)->dict:
1✔
231
        """Forward method of the model for PyTorch modes
232

233
        This is called only when GRUComposition is run as a standalone Composition.
234
        Otherwise, the node.execute() method is called directly (i.e., it is treated as a single node).
235
        Returns a dictionary {output_node:value} with the output value for the torch GRU module (that is used
236
        by the collect_afferents method(s) of the other node(s) that receive Projections from the GRUComposition.
237

238
        """
239

240
        self._set_synch_with_pnl(synch_with_pnl_options)
1✔
241

242
        # Get input from GRUComposition's INPUT_NODE
243
        inputs = inputs[self.composition.input_node]
1✔
244

245
        # Execute GRU Node
246
        output = self.gru_pytorch_node.execute(inputs, optimization_num, synch_with_pnl_options, context)
1✔
247

248
        # Set GRUComposition's OUTPUT Node of output of GRU Node
249
        self.composition.output_node.parameters.value._set(output.detach().cpu().numpy(), context)
1✔
250
        self.composition.gru_mech.parameters.value._set(output.detach().cpu().numpy(), context)
1✔
251

252
        return {self.composition.gru_mech: output}
1✔
253

254
    def _set_synch_with_pnl(self, synch_with_pnl_options):
1✔
255
        if (NODE_VALUES in synch_with_pnl_options and synch_with_pnl_options[NODE_VALUES] == RUN):
1!
256
            self.gru_pytorch_node.synch_with_pnl = True
1✔
257
        else:
NEW
258
            self.gru_pytorch_node.synch_with_pnl = False
×
259

260
    def copy_weights_to_torch_gru(self, context=None):
1✔
261
        for projection, proj_wrapper in self.projections_map.items():
1✔
262
            if SYNCH in proj_wrapper._use:
1!
263
                proj_wrapper._copy_pnl_proj_to_torch_gru_parameter(context, self.torch_dtype)
1✔
264

265
    def get_parameters_from_torch_gru(torch_gru)->Tuple[torch.Tensor]:
1✔
266
        """Get parameters from PyTorch GRU module corresponding to GRUComposition's Projections.
267
        Format tensors:
268
          - transpose all weight and bias tensors;
269
          - reformat biases as 2d
270
        Return formatted tensors, which are used:
271
         - in set_weights_from_torch_gru(), where they are converted to numpy arrays
272
         - for forward computation in pytorchGRUwrappers._copy_pytorch_node_outputs_to_pnl_values()
273
        """
274
        hid_len = torch_gru.hidden_size
1✔
275
        z_idx = hid_len
1✔
276
        n_idx = 2 * hid_len
1✔
277

278
        wts_ih = torch_gru.state_dict()['weight_ih_l0']
1✔
279
        wts_ir = wts_ih[:z_idx].T.detach().cpu().numpy().copy()
1✔
280
        wts_iu = wts_ih[z_idx:n_idx].T.detach().cpu().numpy().copy()
1✔
281
        wts_in = wts_ih[n_idx:].T.detach().cpu().numpy().copy()
1✔
282
        wts_hh = torch_gru.state_dict()['weight_hh_l0']
1✔
283
        wts_hr = wts_hh[:z_idx].T.detach().cpu().numpy().copy()
1✔
284
        wts_hu = wts_hh[z_idx:n_idx].T.detach().cpu().numpy().copy()
1✔
285
        wts_hn = wts_hh[n_idx:].T.detach().cpu().numpy().copy()
1✔
286
        weights = (wts_ir, wts_iu, wts_in, wts_hr, wts_hu, wts_hn)
1✔
287

288
        biases = None
1✔
289
        if torch_gru.bias:
1✔
290
            # Transpose 1d bias Tensors using permute instead of .T (per PyTorch warning)
291
            b_ih = torch_gru.state_dict()['bias_ih_l0']
1✔
292
            b_ir = torch.atleast_2d(b_ih[:z_idx].permute(*torch.arange(b_ih.ndim - 1, -1, -1))).detach().cpu().numpy().copy()
1✔
293
            b_iu = torch.atleast_2d(b_ih[z_idx:n_idx].permute(*torch.arange(b_ih.ndim - 1, -1, -1))).detach().cpu().numpy().copy()
1✔
294
            b_in = torch.atleast_2d(b_ih[n_idx:].permute(*torch.arange(b_ih.ndim - 1, -1, -1))).detach().cpu().numpy().copy()
1✔
295
            b_hh = torch_gru.state_dict()['bias_hh_l0']
1✔
296
            b_hr = torch.atleast_2d(b_hh[:z_idx].permute(*torch.arange(b_hh.ndim - 1, -1, -1))).detach().cpu().numpy().copy()
1✔
297
            b_hu = torch.atleast_2d(b_hh[z_idx:n_idx].permute(*torch.arange(b_hh.ndim - 1, -1, -1))).detach().cpu().numpy().copy()
1✔
298
            b_hn = torch.atleast_2d(b_hh[n_idx:].permute(*torch.arange(b_hh.ndim - 1, -1, -1))).detach().cpu().numpy().copy()
1✔
299
            biases = (b_ir, b_iu, b_in, b_hr, b_hu, b_hn)
1✔
300
        return weights, biases
1✔
301

302
    def log_weights(self):
1✔
NEW
303
        for proj_wrapper in self.projection_wrappers:
×
NEW
304
            proj_wrapper.log_matrix()
×
305

306
    def log_values(self):
1✔
NEW
307
        for node_wrapper in [n for n in self.node_wrappers if not isinstance(n, PytorchCompositionWrapper)]:
×
NEW
308
            node_wrapper.log_value()
×
309

310

311
class PytorchGRUMechanismWrapper(PytorchMechanismWrapper):
1✔
312
    """Wrapper for Pytorch GRU Node
313
    Handling of hidden_state: uses GRUComposition's HIDDEN_NODE.value to cache state of hidden layer:
314
    - gets input to function for hidden state from GRUComposition's HIDDEN_NODE.value
315
    - sets GRUComposition's HIDDEN_NODE.value to return value for hidden state
316
    """
317

318
    def __init__(self,
1✔
319
                 mechanism,
320
                 composition,
321
                 component_idx,
322
                 use,
323
                 dtype,
324
                 device,
325
                 context):
326

327
        super().__init__(mechanism=mechanism,
1✔
328
                         composition=composition,
329
                         component_idx=component_idx,
330
                         use=use,
331
                         dtype=dtype,
332
                         device=device,
333
                         subclass_specifies_function=True,
334
                         context=context)
335

336
        self._assign_GRU_pytorch_function(mechanism, device, context)
1✔
337

338
        self.synch_with_pnl = False
1✔
339

340
    def _assign_GRU_pytorch_function(self, mechanism, device, context):
1✔
341
        # Assign PytorchGRUFunctionWrapper of Pytorch GRU module as function of GRU Node
342
        input_size = self.composition.parameters.input_size.get(context)
1✔
343
        hidden_size = self.composition.parameters.hidden_size.get(context)
1✔
344
        bias = self.composition.parameters.bias.get(context)
1✔
345
        torch_GRU = torch.nn.GRU(input_size=input_size,
1✔
346
                                 hidden_size=hidden_size,
347
                                 bias=bias).to(dtype=self.torch_dtype)
348
        self.hidden_state = torch.zeros(1, 1, hidden_size, dtype=self.torch_dtype).to(device)
1✔
349

350
        function_wrapper = PytorchGRUFunctionWrapper(torch_GRU, device, context)
1✔
351
        self.function = function_wrapper
1✔
352
        mechanism.function = function_wrapper.function
1✔
353

354
        # Assign input_port functions of GRU Node to PytorchGRUFunctionWrapper
355
        self.input_ports = [PytorchFunctionWrapper(input_port.function, device, context)
1✔
356
                            for input_port in mechanism.input_ports]
357

358
    def execute(self, variable, optimization_num, synch_with_pnl_options, context=None)->torch.Tensor:
1✔
359
        """Execute GRU Node with input variable and return output value.
360
        Override to set GRU Node's synch_with_pnl option if GRUComposition is a nested Composition
361
        This is called directly if GRUComposition is in a nested Composition, rather than its forward method.
362
        Treats GRUComposition as a single node in the PytorchCompositionWrapper's graph, inputs
363
          received from other node(s) that project to the GRUComposition, and its outputs used by the
364
          collect_afferents method(s) of the other node(s) that receive Projections from the  GRUComposition.
365
        """
366
        # Get hidden state from GRUComposition's HIDDEN_NODE.value
367
        from psyneulink.library.compositions.grucomposition.grucomposition import HIDDEN_LAYER
1✔
368

369
        self.composition.pytorch_representation._set_synch_with_pnl(synch_with_pnl_options)
1✔
370

371
        self.input = variable
1✔
372

373
        hidden_state = self.composition.nodes[HIDDEN_LAYER].parameters.value.get(context)
1✔
374
        self.hidden_state = torch.tensor(hidden_state).unsqueeze(1)
1✔
375
        # Save starting hidden_state for re-computing current values in _copy_pytorch_node_outputs_to_pnl_values()
376
        self.previous_hidden_state = self.hidden_state.detach()
1✔
377

378
        if self.synch_with_pnl:
1!
379
            self.torch_gru_internal_state_values = \
1✔
380
                self._calculate_torch_gru_internal_state_values(self.input[0][0], self.hidden_state.detach())
381

382
        # Execute torch GRU module with input (variable) and hidden state
383
        self.output, self.hidden_state = self.function(*[self.input, self.hidden_state])
1✔
384
        # self.output, self.hidden_state = self.function.function(*[input, self.hidden_state])
385

386
        # Set GRUComposition's HIDDEN_NODE.value to GRU Node's hidden state
387
        # Note: this must be done in case the GRUComposition is run after learning,
388
        self.composition.hidden_layer_node.output_port.parameters.value._set(
1✔
389
            self.hidden_state.detach().cpu().numpy().squeeze(), context)
390

391
        return self.output
1✔
392

393
    def collect_afferents(self, batch_size, port=None, inputs:dict=None)->torch.Tensor:
1✔
394
        """
395
        Return afferent projections for input_port(s) of the Mechanism
396
        If there is only one input_port, return the sum of its afferents (for those in Composition)
397
        If there are multiple input_ports, return a tensor (or list of tensors if input ports are ragged) of shape:
398

399
        (batch, input_port, projection, ...)
400

401
        Where the ellipsis represent 1 or more dimensions for the values of the projected afferent.
402

403
        """
404

405
        if self.afferents == INPUT:
1✔
406
            # GRUComposition is nested in an outer Composition, and GRU is INPUT Node of that Composition
407
            #  so get input specified for GRUComposition.input_node from the inputs dict provided in the learn() method
408
            assert self.mechanism._is_input, \
1✔
409
                f"PROGRAM ERROR: No afferents found for '{self.mechanism.name}' in AutodiffComposition"
410
            input_port = self.composition.input_node.input_port
1✔
411
            curr_val = inputs[input_port]
1✔
412
            if type(curr_val) == torch.Tensor:
1!
413
                ip_res = [curr_val[:, 0, ...]]
1✔
414
            else:
NEW
415
                val = [batch_elem[0] for batch_elem in curr_val]
×
NEW
416
                val = torch.stack(val)
×
NEW
417
                ip_res = [val]
×
418
            res = []
1✔
419

420
        else:
421
            proj_wrapper = self.afferents[0]
1✔
422

423
            curr_val = proj_wrapper.sender_wrapper.output
1✔
424
            if curr_val is not None:
1!
425
                if type(curr_val) == torch.Tensor:
1!
426
                    proj_wrapper._curr_sender_value = curr_val[:, proj_wrapper._value_idx, ...]
1✔
427
                else:
NEW
428
                    val = [batch_elem[proj_wrapper._value_idx] for batch_elem in curr_val]
×
NEW
429
                    val = torch.stack(val)
×
NEW
430
                    proj_wrapper._curr_sender_value = val
×
431
            else:
NEW
432
                val = torch.tensor(proj_wrapper.default_value)
×
433

434
                # We need to add the batch dimension to default values.
NEW
435
                val = val[None, ...].expand(batch_size, *val.shape)
×
436

NEW
437
                proj_wrapper._curr_sender_value = val
×
438

439
            proj_wrapper._curr_sender_value = torch.atleast_1d(proj_wrapper._curr_sender_value)
1✔
440

441
            res = []
1✔
442
            input_port = self.mechanism.input_port
1✔
443
            ip_res = [proj_wrapper.execute(proj_wrapper._curr_sender_value)]
1✔
444

445
        # Stack the results for this input port on the second dimension, we want to preserve
446
        # the first dimension as the batch
447
        ip_res = torch.stack(ip_res, dim=1)
1✔
448
        res.append(ip_res)
1✔
449

450
        try:
1✔
451
            # Now stack the results for all input ports on the second dimension again, this keeps batch
452
            # first again. We should now have a 4D tensor; (batch, input_port, projection, values)
453
            res = torch.stack(res, dim=1)
1✔
NEW
454
        except (RuntimeError, TypeError):
×
455
            # is ragged, will handle ports individually during execute
456
            # We still need to reshape things so batch size is first dimension.
NEW
457
            batch_size = res[0].shape[0]
×
NEW
458
            res = [[inp[b] for inp in res] for b in range(batch_size)]
×
459

460
        return res
1✔
461

462
    def execute_input_ports(self, variable)->torch.Tensor:
1✔
463
        from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction
1✔
464
        assert type(variable) == torch.Tensor, (f"PROGRAM ERROR: Input to GRUComposition in ExecutionMode.Pytorch "
1✔
465
                                                f"should be a torch.Tensor, but is {type(variable)}.")
466
        # Return the input for the port for all items in the batch
467
        return variable[:, 0, ...]
1✔
468

469
    def _calculate_torch_gru_internal_state_values(self, input, hidden_state)->dict:
1✔
470
        """Manually calculate and store internal state values for torch GRU prior to backward pass
471
        These are needed for assigning to the corresponding nodes in the GRUComposition.
472
        Returns r_t, z_t, n_t, h_t current reset, update, new, hidden and state values, respectively
473
        """
474
        torch_gru_parameters = PytorchGRUCompositionWrapper.get_parameters_from_torch_gru(self.function.function)
1✔
475

476
        # Get weights
477
        torch_weights = list(torch_gru_parameters[0])
1✔
478
        for i, weight in enumerate(torch_weights):
1✔
479
            torch_weights[i] = torch.tensor(weight, dtype=self.torch_dtype)
1✔
480
        w_ir, w_iz, w_in, w_hr, w_hz, w_hn = torch_weights
1✔
481

482
        # Get biases
483
        pnl_comp = self.composition
1✔
484
        if pnl_comp.bias:
1✔
485
            assert len(torch_gru_parameters) > 1, \
1✔
486
                (f"PROGRAM ERROR: '{pnl_comp.name}' has bias set to True, "
487
                 f"but no bias weights were returned for torch_gru_parameters.")
488
            b_ir, b_iz, b_in, b_hr, b_hz, b_hn = torch_gru_parameters[1]
1✔
489
        else:
490
            b_ir = b_iz = b_in = b_hr = b_hz = b_hn = 0.0
1✔
491

492
        # Do calculations for internal state values
493
        x = input.detach()
1✔
494
        h = hidden_state
1✔
495
        r_t = torch.sigmoid(torch.matmul(x, w_ir) + b_ir + torch.matmul(h, w_hr) + b_hr)
1✔
496
        z_t = torch.sigmoid(torch.matmul(x, w_iz) + b_iz + torch.matmul(h, w_hz) + b_hz)
1✔
497
        n_t = torch.tanh(torch.matmul(x, w_in) + b_in + r_t * (torch.matmul(h, w_hn) + b_hn))
1✔
498
        h_t = (1 - z_t) * n_t + z_t * h
1✔
499

500
        from psyneulink.library.compositions.grucomposition.grucomposition import GRU_INTERNAL_STATE_NAMES
1✔
501
        return {k:v for k,v in zip(GRU_INTERNAL_STATE_NAMES, [n_t, r_t, z_t, h_t])}
1✔
502

503
    def set_pnl_variable_and_values(self,
1✔
504
                                    set_variable:bool=False,
505
                                    set_value:bool=True,
506
                                    # FIX: 3/15/25 - ADD SUPPORT FOR THESE
507
                                    # set_output_values:bool=None,
508
                                    # execute_mech:bool=True,
509
                                    context=None):
510

511
        if set_variable:
1✔
512
            assert False, \
513
                f"PROGRAM ERROR: copying variables to GRUComposition from pytorch execution is not currently supported."
514

515
        if set_value:
1!
516
            n_t, r_t, z_t, h_t = list(self.torch_gru_internal_state_values.values())
1✔
517
            try:
1✔
518
                # Ensure that result of manual-calculated state values matches output of actual call to PyTorch module
519
                np.testing.assert_allclose(h_t.detach().numpy(),
1✔
520
                                           self.output.detach().numpy(),
521
                                           atol=1e-8)
NEW
522
            except ValueError:
×
523
                assert False, "PROGRAM ERROR:  Problem with calculation of internal states of {pnl_comp.name} GRU Node."
524

525
            # Set values of nodes in pnl gru_comp to the result of the corresponding computations in the PyTorch module
526
            pnl_comp = self.composition
1✔
527
            pnl_comp.reset_node.output_port.parameters.value._set(r_t.detach().cpu().numpy().squeeze(), context)
1✔
528
            pnl_comp.update_node.output_ports[0].parameters.value._set(z_t.detach().cpu().numpy().squeeze(), context)
1✔
529
            pnl_comp.update_node.output_ports[1].parameters.value._set(z_t.detach().cpu().numpy().squeeze(), context)
1✔
530
            pnl_comp.new_node.output_port.parameters.value._set(n_t.detach().cpu().numpy().squeeze(), context)
1✔
531
            pnl_comp.output_node.output_port.parameters.value._set(h_t.detach().cpu().numpy().squeeze(), context)
1✔
532
            # Note: no need to set hidden_layer since it was already done when the GRU Node executed
533
            # pnl_comp.hidden_layer_node.output_port.parameters.value._set(h_t.detach().cpu().numpy().squeeze(), context)
534

535
            # # KEEP THIS FOR REFERENCE IN CASE hidden_layer_node IS REPLACED WITH RecurrentTransferMechanism
536
            # # If pnl_node's function is Stateful, assign value to its previous_value parameter
537
            # #   so that if Python implementation is run it picks up where PyTorch execution left off
538
            # if isinstance(pnl_node.function, StatefulFunction):
539
            #     pnl_node.function.parameters.previous_value._set(torch_gru_output, context)
540

541
    def log_value(self):
1✔
542
        # FIX: LOG HIDDEN STATE OF COMPOSITION MECHANISM
543
        if self.mechanism.parameters.value.log_condition != LogCondition.OFF:
1!
NEW
544
            detached_value = self.output.detach().cpu().numpy()
×
NEW
545
            self.mechanism.output_port.parameters.value._set(detached_value, self._context)
×
NEW
546
            self.mechanism.parameters.value._set(detached_value, self._context)
×
547

548
    def log_matrix(self):
1✔
NEW
549
        if self.projection.parameters.matrix.log_condition != LogCondition.OFF:
×
NEW
550
            detached_matrix = self.matrix.detach().cpu().numpy()
×
NEW
551
            self.projection.parameters.matrix._set(detached_matrix, context=self._context)
×
NEW
552
            self.projection.parameter_ports['matrix'].parameters.value._set(detached_matrix, context=self._context)
×
553

554

555
class PytorchGRUProjectionWrapper(PytorchProjectionWrapper):
1✔
556
    """Wrapper for a Projection of the GRUComposition
557

558
    One is created for each Projection of the GRUComposition that is learnable.
559
    Sets of three of these correspond to the Parameters of the torch GRU module:
560

561
    PyTorch GRU parameter:  GRUComposition Projections:
562
         weight_ih_l0       wts_ir, wts_iu, wts_in
563
         weight_hh_l0       wts_hr, wts_hu, wts_hn
564
         bias_ih_l0         bias_ir, bias_iu, bias_in
565
         bias_hh_l0         bias_hr, bias_hu, bias_hn
566

567
    Attributes
568
    ----------
569
    projection:  MappingProjection
570
        the `Projection` of the GRUComposition being wrapped
571

572
    composition : AutodiffComposition
573
        the `AutodiffComposition` to which the `Projection` being wrapped belongs
574
        (and for which the PytorchCompositionWrapper -- to which the PytorchProjectionWrapper
575
        belongs -- is the `pytorch_representation <AutodiffComposition.pytorch_representation>`).
576

577
    torch_parameter: Pytorch parameter
578
        the torch.nn.Parameter corresponding to the matrix of the Projection;
579

580
    matrix_indices: slice
581
        a slice specifying the part of the Pytorch parameter corresponding to the GRUCOmposition Projection's matrix.
582

583
    """
584
    def __init__(self,
1✔
585
                 projection:MappingProjection,
586
                 torch_parameter:Tuple,
587
                 use:Union[list, Literal[LEARNING, SYNCH, SHOW_PYTORCH]],
588
                 composition:AutodiffComposition,
589
                 device:str):
590
        self.name = f"PytorchProjectionWrapper[{projection.name}]"
1✔
591
        # GRUComposition Projection being wrapped:
592
        self.projection = projection # PNL Projection being wrapped
1✔
593
        self._pnl_proj = projection
1✔
594
        # Assign parameter and tensor indices of Pytorch GRU module parameter corresponding to the Projection's matrix:
595
        self.torch_parameter, self.matrix_indices = torch_parameter
1✔
596
        # Projections for GRUComposition are not included in autodiff; matrices are set directly in Pytorch GRU module:
597
        self.projection.exclude_in_autodiff = True
1✔
598
        self._use = convert_to_list(use)
1✔
599
        self.composition = composition
1✔
600
        self.device = device
1✔
601

602
    def _copy_pnl_proj_to_torch_gru_parameter(self, context, dtype):
1✔
603
        """Set relevant part of tensor for parameter of Pytorch GRU module from GRUComposition's Projections."""
604
        matrix = self.projection.parameters.matrix._get(context).T
1✔
605
        torch_tensor = self.torch_parameter[self.matrix_indices]
1✔
606
        self.composition.copy_projection_matrix_to_torch_param(projection=self.projection,
1✔
607
                                                               torch_param=torch_tensor,
608
                                                               validate=False,
609
                                                               context=context)
610

611
    def _copy_torch_params_to_pnl_proj(self, context):
1✔
612
        """Override to deal with indexed tensor of Pytorch GRU module Parameter"""
613
        torch_parameter = self.torch_parameter
1✔
614
        torch_indices = self.matrix_indices
1✔
615
        matrix = torch_parameter[torch_indices].detach().cpu()
1✔
616
        self.composition.copy_torch_param_to_projection_matrix(torch_param=matrix,
1✔
617
                                                               projection=self.projection,
618
                                                               validate=False,
619
                                                               context=context)
620

621
    def log_matrix(self):
1✔
622
        if self.projection.parameters.matrix.log_condition != LogCondition.OFF:
1!
NEW
623
            detached_matrix = self.matrix.detach().cpu().numpy()
×
NEW
624
            self.projection.parameters.matrix._set(detached_matrix, context=self._context)
×
NEW
625
            self.projection.parameter_ports['matrix'].parameters.value._set(detached_matrix, context=self._context)
×
626

627

628
# class PytorchGRUFunctionWrapper(PytorchFunctionWrapper):
629
class PytorchGRUFunctionWrapper(torch.nn.Module):
1✔
630
    def __init__(self, function, device, context=None):
1✔
631
        super().__init__()
1✔
632
        self.name = f"PytorchFunctionWrapper[GRU NODE]"
1✔
633
        self._context = context
1✔
634
        self._pnl_function = function
1✔
635
        self.function = function
1✔
636

637
    def __repr__(self):
638
        return "PytorchWrapper for: " + self._pnl_function.__repr__()
639

640
    def __call__(self, *args, **kwargs):
1✔
641
        return self.function(*args, **kwargs)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc