• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.58
/psyneulink/library/compositions/pytorchwrappers.py
1
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
2
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
3
#     http://www.apache.org/licenses/LICENSE-2.0
4
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
5
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6
# See the License for the specific language governing permissions and limitations under the License.
7

8
# ********************************************* PytorchComponent *************************************************
9

10
"""PyTorch wrappers for Composition, Mechanism, Projection, and Functions for use in AutodiffComposition"""
11
from psyneulink._typing import Optional, Literal, Union
1✔
12

13
import graph_scheduler
1✔
14
import numpy as np
1✔
15

16
# import torch
17
try:
1✔
18
    import torch
1✔
NEW
19
except (ImportError, ModuleNotFoundError):
×
NEW
20
    torch = None
×
21
else:
22
    import torch.nn as nn
1✔
23

24
from enum import Enum, auto
1✔
25

26
from psyneulink.core.components.functions.stateful import StatefulFunction
1✔
27
from psyneulink.core.components.mechanisms.mechanism import Mechanism
1✔
28
from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism
1✔
29
from psyneulink.core.components.mechanisms.processing.transfermechanism import TransferMechanism
1✔
30
from psyneulink.core.components.ports.port import Port
1✔
31
from psyneulink.core.components.projections.projection import Projection, DuplicateProjectionError
1✔
32
from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
1✔
33
from psyneulink.core.compositions.composition import Composition, CompositionInterfaceMechanism, NodeRole
1✔
34
from psyneulink.library.compositions.pytorchllvmhelper import *
1✔
35
from psyneulink.library.compositions.compiledoptimizer import AdamOptimizer, SGDOptimizer
1✔
36
from psyneulink.library.compositions.compiledloss import MSELoss, CROSS_ENTROPYLoss
1✔
37
from psyneulink.core.globals.keywords import (AFTER, ALL, BEFORE, DEFAULT_VARIABLE, EPOCH, INPUTS,
1✔
38
                                              LEARNING, LEARNING_SCALE_LITERALS, Loss, MATRIX_WEIGHTS,
39
                                              NODE, NODE_VALUES, NODE_VARIABLES, OUTPUTS, RESULTS, RUN,
40
                                              SHOW_PYTORCH, SYNCH, TARGET_MECHANISM, )
41
from psyneulink.core.globals.context import Context, ContextFlags, handle_external_context
1✔
42
from psyneulink.core.globals.utilities import convert_to_list, convert_to_np_array, get_deepcopy_with_shared
1✔
43
from psyneulink.core.globals.log import LogCondition
1✔
44
from psyneulink.core import llvm as pnlvm
1✔
45

46
__all__ = ['PytorchCompositionWrapper', 'PytorchMechanismWrapper', 'PytorchProjectionWrapper',
1✔
47
           'ENTER_NESTED', 'EXIT_NESTED', 'SUBCLASS_WRAPPERS']
48

49
SUBCLASS_WRAPPERS = 'subclass_wrappers'
1✔
50
ENTER_NESTED = 0
1✔
51
EXIT_NESTED = 1
1✔
52

53
class DataTypeEnum(Enum):
1✔
54

55
    TRAINED_OUTPUTS = 0
1✔
56
    TARGETS = auto()
1✔
57
    LOSSES = auto()
1✔
58

59

60
def _get_pytorch_function(obj, device, context):
1✔
61
    pytorch_fct = getattr(obj, '_gen_pytorch_fct', None)
1✔
62
    if pytorch_fct is None:
1!
63
        from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError
×
64
        raise AutodiffCompositionError(
65
            f"Function {obj} is not currently supported by AutodiffComposition"
66
        )
67
    else:
68
        return pytorch_fct(device, context)
1✔
69

70

71
class PytorchCompositionWrapper(torch.nn.Module):
1✔
72
# NEEDED FOR torch MPS SUPPORT
73
# class PytorchCompositionWrapper(torch.jit.ScriptModule):
74
# END
75
    """Wrapper for a Composition as a Pytorch Module.
76

77
    Wraps an `AutodiffComposition` as a `PyTorch module
78
    <https://pytorch.org/docs/stable/generated/torch.nn.Module.html>`_, with each `Mechanism <Mechanism>` in
79
    the AutodiffComposition wrapped as a `PytorchMechanismWrapper`, each `Projection <Projection>` wrapped as
80
    a `PytorchProjectionWrapper`, and any nested Compositions wrapped as `PytorchCompositionWrapper`\\s. Each
81
    PytorchMechanismWrapper implements a Pytorch version of the `function(s) <Mechanism_Base.function>` of the wrapped
82
    `Mechanism`, which are executed in the PyTorchCompositionWrapper's `forward <PyTorchCompositionWrapper.forward>`
83
    method in the order specified by the AutodiffComposition's `scheduler <Composition.scheduler>`.  The `matrix
84
    <MappingProjection.matrix>` Parameters of each wrapped `Projection` are assigned as parameters of the
85
    `PytorchMechanismWrapper` Pytorch module and used, together with a Pytorch `matmul
86
    <https://pytorch.org/docs/main/generated/torch.matmul.html>`_ operation, to generate the input to each
87
    PyTorch function as specified by the `PytorchProjectionWrapper`\\'s `graph <Composition.graph>`.  The graph
88
    can be visualized using the AutodiffComposition's `show_graph <ShowGraph.show_graph>` method and setting its
89
    *show_pytorch* argument to True (see `PytorchShowGraph` for additional information).
90

91
    Two main responsibilities:
92

93
    1) Set up functions and parameters of PyTorch module required for its forward computation:
94
       - Handle nested compositions (flattened in infer_backpropagation_learning_pathways):
95
       - Deal with Projections into and/or out of a nested Composition as shown in figure below:
96
            (note: Projections in outer Composition to/from a nested Composition's CIMs are learnable,
97
                   and ones within a nested Composition from/to its CIMs are not)
98

99
         [      OUTER     ][                            NESTED                               ][     OUTER      ]
100
                \\learnable//   \\not learnable//                     \\not learnable//    \\learnable//
101
         ---> [Node] ----> [input_CIM] ~~~> [INPUT Node] ----> [OUTPUT Node] ~~~> [output_CIM] ----> [Node] --->
102
               sndr            rcvr          nested_rcvr         nested_sndr         sndr             rcvr
103
                ^--projection-->^                                                     ^---projection-->^
104
                ^----PytorchProjectionWrapper---->^                  ^----PytorchProjectionWrapper---->^
105
                          ENTER_NESTED                                            EXIT_NESTED
106

107
       .. _Mechanism_and_Projection_Uses:
108

109
       - The uses of Mechanisms and Projections in the pytorch_representation of an AutodiffComposition are
110
         determined, respecticely, by its PytorchMechanismWrapper's `use <PytorchMechanismWrapper.use>` and
111
         PytorchProjectionWrapper's `use <PytorchProjectionWrapper.use>`, as follows:
112

113
         * Mechanisms:
114
           - used in Python execution but not Pytorch execution: *SYNCH*
115
           - used in PyTorch execution but not Python execution: *LEARNING*, *SHOW_PYTORCH*
116
           - used for both Python and Pytorch execution: *LEARNING*, *SYNCH*, *SHOW_PYTORCH*
117

118
         * Projections:
119
           - among (non-CIM) Mechanisms within the same Composition: same as Mechanisms (see above)
120
           - to an input_CIM of a nested Composition:  *LEARNING*, *SYNCH*, *SHOW_PYTORCH*
121
           - from an input_CIM: None
122
           - to an output_CIM: None
123
           - from an output_CIM:  *LEARNING*, *SYNCH*
124
           - directly between (to/from) a nested and outer Composition: *SHOW_PYTORCH*
125

126
    2) Handle coordination of passing data and outcomes back to PsyNeuLink objects, handled by two main methods:
127

128
       - synch_with_psyneulink()
129
            Copies matrix weights, node variables, node values, and/or autoutdiff results
130
            at user-specified intervals (LearningScale:  OPTIMIZATION_STEP, TRIAL, MINIBATCH, EPOCH, RUN);
131
            these are specified by the user in the following arguments to run() or learn():
132
                synch_projection_matrices_with_torch=RUN,
133
                synch_node_variables_with_torch=None,
134
                synch_node_values_with_torch=RUN,
135
                synch_results_with_torch=RUN,
136
            and consolidated in the synch_with_pnl_options dict used by synch_with_psyneulink
137

138
       - retain_for_psyneulink()
139
            Retains learning-specific data used and outcomes generated during execution of PyTorch model
140
            (TRAINED_OUTPUT_VALUES, corresponding TARGETS and LOSSES), that are copied to PsyNeuLink
141
            at the end of a call to learn(); these are specified by the user in the following arguments
142
            to learn():
143
                retain_torch_trained_outputs=MINIBATCH,
144
                retain_torch_targets=MINIBATCH,
145
                retain_torch_losses=MINIBATCH,
146
            and consolidated in the retain_in_pnl_options dict used by retain_for_psyneulink
147

148
        - Note: RESULTS is handled in an idiosyncratic way: it is specified along with the synchronization
149
                parameters, since it is a value ordinarily generated in the execution of a Composition;
150
                however it's helper parallels the retain_for_psyneulink helper methods, and it is called
151
                from _update_results if TRIAL is specified, in order to integrate with the standard execution
152
                of a Composition.
153

154
    Arguments
155
    ---------
156

157

158
    Attributes
159
    ----------
160

161
    composition : AutodiffComposition
162
        The `AutodiffComposition` for which the PytorchCompositionWrapper is the `pytorch_representation
163
        <AutodiffComposition.pytorch_representation>`.
164

165
    node_wrappers : List[PytorchMechanismWrapper]
166
        list of nodes in the PytorchCompositionWrapper corresponding to the PyTorch functions that comprise the
167
        forward method of the Pytorch module implemented by the PytorchCompositionWrapper. Generally these are
168
        `Mechanisms <Mechanism>` wrapped in a `PytorchMechanismWrapper`, however, if the `AutodiffComposition` Node
169
        being wrapped is a nested Composition, then the wrapped node is itself a `PytorchCompositionWrapper` object.
170
        When the PyTorch model is executed, all of these are "flattened" into a single PyTorch module, corresponding
171
        to the outermost AutodiffComposition being wrapped, which can be visualized using that AutodiffComposition's
172
        `show_graph <ShowGraph.show_graph>` method and setting its *show_pytorch* argument to True (see
173
        `PytorchShowGraph` for additional information).
174

175
    nodes_map : Dict[Node: PytorchMechanismWrapper or PytorchCompositionWrapper]
176
        maps PsyNeuLink `Nodes <Composition_Nodes>` to PytorchCompositionWrapper nodes.
177

178
    projection_wrappers = List[PytorchProjectionWrapper]
179
        list of PytorchCompositionWrappers in the PytorchCompositionWrapper, each of which wraps a `Projection`
180
        in the AutodiffComposition being wrapped.
181

182
    projections_map : Dict[Projection: PytorchProjectionWrapper]
183
        maps `Projections <Projection>` in the AutodiffComposition being wrapped to `PytorchProjectionWrappers` in
184
        the PytorchCompositionWrapper.
185

186
    _nodes_to_execute_after_gradient_calc :  Dict[node : torch.Tensor]
187
        contains nodes specified as `exclude_from_gradient_calc` as keys, and their current variable as values
188

189
    optimizer : torch
190
        assigned by AutodffComposition after the wrapper is created, which passes the parameters to the optimizer
191

192
    device : torch.device
193
        device used to process torch Tensors in PyTorch functions
194

195
    params : nn.ParameterList()
196
        list of PyTorch parameters (connection weight matrices) in the PyTorch model.
197

198
    minibatch_loss : torch.Tensor
199
        accumulated loss over all trials (stimuli) within a batch.
200

201
    minibatch_loss_count : int
202
        count of losses (trials) within batch, used to calculate average loss per batch.
203

204
    retained_results : List[ndarray]
205
        list of the `output_values <Composition.output_values>` of the AutodiffComposition for ever trial executed
206
        in a call to `run <AutoDiffComposition.run>` or `learn <AutoDiffComposition.learn>`.
207

208
    retained_trained_outputs : List[ndarray]
209
        values of the trained `OUTPUT <NodeRole.OUTPUT>` Node (i.e., ones associated with `TARGET <NodeRole.TARGET`
210
        Node) for each trial executed in a call to `learn <AutoDiffComposition.learn>`.
211

212
    retained_targets : List[ndarray]
213
        values of the `TARGET <NodeRole.TARGET` Nodes for each trial executed in a call to `learn
214
        <AutoDiffComposition.learn>`.
215

216
    retained_losses : List[ndarray]
217
        losses per batch, epoch or run accumulated over a call to learn()
218
    """
219

220
    torch_dtype = torch.float64
1✔
221

222
    def __init__(self,
1✔
223
                 composition,
224
                 device,
225
                 outer_creator=None,
226
                 dtype=None,
227
                 subclass_components=None,
228
                 context=None,
229
                 base_context=Context(execution_id=None),
230
                 ):
231

232
        super(PytorchCompositionWrapper, self).__init__()
1✔
233

234
        if subclass_components is None:
1✔
235
            self._early_init(composition, device)
1✔
236
            # Instantiate standard PytorchWrappers for Mechanisms and Projections, and execution_sets used in forward()
237
            _node_wrapper_pairs = self._instantiate_pytorch_mechanism_wrappers(composition, device, context)
1✔
238
            self._construct_node_wrapper_maps(_node_wrapper_pairs)
1✔
239
            _projection_wrapper_pairs = self._instantiate_pytorch_projection_wrappers(composition, device, context, base_context)
1✔
240
            self._construct_projection_wrapper_maps(_projection_wrapper_pairs)
1✔
241
            self.execution_sets, execution_context = self._get_execution_sets(composition, context)
1✔
242

243
        else:
244
            # Construct node_wrappers, projection_wrappers, and execution_sets from subclass components passed in
245
            _node_wrapper_pairs, _projection_wrapper_pairs, _execution_sets, execution_context = subclass_components
1✔
246
            self._validate_subclass_components(_node_wrapper_pairs, _projection_wrapper_pairs, _execution_sets)
1✔
247
            self._construct_node_wrapper_maps(_node_wrapper_pairs)
1✔
248
            self._construct_projection_wrapper_maps(_projection_wrapper_pairs)
1✔
249
            self.execution_sets = _execution_sets
1✔
250

251
        # Assign INPUT Nodes for outermost Composition (including any that are nested within it at any level)
252
        # Note: Pytorch representation is "flattened" (i.e., any nested Compositions are replaced by their Nodes)
253
            #   so if any nested Compositions are INPUT Nodes of the outermost Composition,
254
            #   *their* INPUT Nodes are assigned as INPUT Nodes of the outermost Composition
255
        if not composition.is_nested:
1✔
256
            def _assign_input_nodes(nodes):
1✔
257
                for pytorch_node in nodes:
1✔
258
                    if isinstance(pytorch_node, PytorchMechanismWrapper):
1✔
259
                        pytorch_node._is_input = pytorch_node.mechanism in composition._get_input_receivers(type=NODE)
1✔
260
                    else:
261
                        _assign_input_nodes(pytorch_node.node_wrappers)
1✔
262
            _assign_input_nodes(self.node_wrappers)
1✔
263

264
        # Flatten maps
265
        for node_wrapper in self.node_wrappers:
1✔
266
            if isinstance(node_wrapper, PytorchCompositionWrapper):
1✔
267
                # For copying weights back to PNL in AutodiffComposition.do_gradient_optimization
268
                self.projections_map.update(node_wrapper.projections_map)
1✔
269
                for k, v in node_wrapper.nodes_map.items():
1✔
270
                    self._add_node_to_nodes_map(k, v)
1✔
271
        # Purge nodes_map of entries for nested Compositions (their nodes are now in self.nodes_map)
272
        nodes_to_remove = [k for k, v in self.nodes_map.items() if isinstance(v, PytorchCompositionWrapper)]
1✔
273
        for node in nodes_to_remove:
1✔
274
            self._remove_node_from_nodes_map(node)
1✔
275

276
        self.output_nodes = self.composition.get_nested_output_nodes_at_all_levels()
1✔
277

278
        self.composition.parameters.pytorch_representation._set(self, context, skip_history=True, skip_log=True)
1✔
279

280
        # Get projections from flattened set, so that they are all in the outer Composition
281
        #   and visible by _regenerate_torch_parameter_list;
282
        #   needed for call to backward() in AutodiffComposition.do_gradient_optimization
283
        self.projection_wrappers = list(self.projections_map.values())
1✔
284

285
        composition.scheduler._delete_counts(execution_context.execution_id)
1✔
286

287
        self._regenerate_torch_parameter_list()
1✔
288
        assert 'DEBUGGING BREAKPOINT'
1✔
289

290
    def _early_init(self, composition, device):
1✔
291
        """Early initialization of PytorchCompositionWrapper"""
292
                # Assign attributes
293
        self.name = f"PytorchCompositionWrapper[{composition.name}]"
1✔
294
        self.device = device
1✔
295
        self.optimizer = None # This gets assigned by self.composition after the wrapper is created,
1✔
296
                                # as the latter is needed to pass the parameters to the optimizer
297
        self._optimizer_param_groups = []
1✔
298

299
        self.composition = composition
1✔
300
        self.node_wrappers = []  # can be PytorchMechanismWrapper or PytorchCompositionWrapper
1✔
301
        self._nodes_to_execute_after_gradient_calc = {} # Nodes requiring execution after Pytorch forward/backward pass
1✔
302
        self._batch_size = 1 # Store the currently used batch size
1✔
303

304
        self.projection_wrappers = [] # PytorchProjectionWrappers
1✔
305
        self.projections_map = {}  # maps Projections -> PytorchProjectionWrappers
1✔
306
        self._pnl_refs_to_torch_params_map = {} # API for PNL refs to PyTorch params (used by _parse_optimizer_params)
1✔
307

308
        self.minibatch_loss = torch.zeros(1, device=self.device).double() # Accumulated losses within a batch
1✔
309
        self.minibatch_loss_count = 0  # Count of losses within batch
1✔
310

311
        # Data retained by the wrapper during execution and copied to pnl as specified by retain_for_psyneulink
312
        self.retained_results = []          # Values of all output NODES
1✔
313
        self.retained_trained_outputs = []  # Values of trained output NODES (i.e. associated with TARGETS)
1✔
314
        self.retained_targets = []  #       # Values of targets for all trials
1✔
315
        self.retained_losses = []           # Losses per trial or batch accumulated over a run
1✔
316

317
        # The following is a list of methods called in retain_for_psyneulink, indexed by keywords using DataTypeEnum
318
        # (this is constructed as a form of hash table for efficiency since that method can be called alot;
319
        #  it is constructed here to avoid doing so in the retain_for_psyneulink method itself)
320
        self.retain_method = [None] * len(DataTypeEnum)
1✔
321
        self.retain_method[DataTypeEnum.TRAINED_OUTPUTS.value] = self.retain_trained_outputs
1✔
322
        self.retain_method[DataTypeEnum.TARGETS.value] = self.retain_targets
1✔
323
        self.retain_method[DataTypeEnum.LOSSES.value] = self.retain_losses
1✔
324

325
    def _validate_subclass_components(self, _node_wrapper_pairs, _projection_wrapper_pairs, execution_sets):
1✔
326
        """Sublcass instantiated nodes_map, projections_map and execution_sets, so validate these."""
327
        assert all(isinstance(item[0], (Mechanism, Composition)) for item in _node_wrapper_pairs), \
1✔
328
            (f"PROGRAM ERROR: Constructor for {self} passed non-Mechanism or Composition object(s) "
329
             f"as node(s) from subclass.")
330
        assert all(isinstance(item[1], (PytorchMechanismWrapper, PytorchCompositionWrapper))
1✔
331
                   for item in _node_wrapper_pairs), \
332
            (f"PROGRAM ERROR: Constructor for {self} passed non-PytorchMechanismWrapper or PytorchCompositionWrapper "
333
             f"object(s) as node wrapper(s) from subclass.")
334
        assert all(isinstance(item[0], Projection) for item in _projection_wrapper_pairs), \
1✔
335
            (f"PROGRAM ERROR: Constructor for {self} passed non-Projection object(s) as Projection(s) from subclass.")
336
        assert all(isinstance(item[1], PytorchProjectionWrapper) for item in _projection_wrapper_pairs), \
1✔
337
            (f"PROGRAM ERROR: Constructor for {self} passed non-PytorchProjectionWrapper object(s) as "
338
             f"projection wrapper(s) from subclass.")
339
        for exec_set in execution_sets:
1✔
340
            assert isinstance(exec_set, set), \
1✔
341
                f"PROGRAM ERROR: {self}.execution_sets contains non-ExecutionSet object(s)."
342
            for item in exec_set:
1✔
343
                assert isinstance(item, (PytorchMechanismWrapper, PytorchCompositionWrapper)), \
1✔
344
                    (f"PROGRAM ERROR: {self}.execution_sets contains a set with non-PytorchMechanismWrapper "
345
                     f"or PytorchCompositionWrapper object).")
346

347
    def _construct_node_wrapper_maps(self, _node_wrapper_pairs):
1✔
348
        self.nodes_map = {} # maps Node(Mech | nested Comp) -> PytorchMechanismWrapper | PytorchCompositionWrapper
1✔
349
        self.node_wrappers = []
1✔
350
        self._modules_dict = torch.nn.ModuleDict()
1✔
351
        for node, pytorch_node_wrapper in _node_wrapper_pairs:
1✔
352
            self._add_node_to_nodes_map(node, pytorch_node_wrapper)
1✔
353

354
    def _construct_projection_wrapper_maps(self, _projection_wrapper_pairs):
1✔
355
        self.projections_map = {k:v for k,v in _projection_wrapper_pairs}
1✔
356
        self.projection_wrappers = list(self.projections_map.values())
1✔
357

358
    def _add_node_to_nodes_map(self, node, node_wrapper):
1✔
359
        """Keep nodes_map, node_wrappers and modules_dict in synch"""
360
        self.nodes_map[node] = node_wrapper
1✔
361
        if node not in self.node_wrappers:
1!
362
            self.node_wrappers.append(node_wrapper)
1✔
363
        self._modules_dict[node.name] = node_wrapper
1✔
364
        self.state_dict()
1✔
365

366
    def _remove_node_from_nodes_map(self, node):
1✔
367
        """Keep nodes_map, node_wrappers and modules_dict in synch"""
368
        self.nodes_map.pop(node)
1✔
369
        if node in self.node_wrappers:
1!
NEW
370
            self.node_wrappers.remove(node)
×
371
        self._modules_dict.pop(node.name)
1✔
372

373
    def _instantiate_pytorch_mechanism_wrappers(self, composition, device, context)->list:
1✔
374
        """Instantiate PytorchMechanismWrappers for Mechanisms in the Composition being wrapped"""
375
        from psyneulink.library.compositions.autodiffcomposition import AutodiffComposition
1✔
376

377
        # Remove all learning-specific nodes
378
        nodes = list(set(composition.nodes) - set(composition.get_nodes_by_role(NodeRole.LEARNING)))
1✔
379

380
        # Remove nested nodes from nodes list (put there in flattening by infer_backpropagation_learning_pathways)
381
        #   so that they don't interfere with construction of execution_sets by scheduler
382
        # Will re-flatten execution sets below
383
        nodes = [n for n in nodes
1✔
384
                 # Leave nested Compositions
385
                 if (isinstance(n, AutodiffComposition)
386
                     # Needed since composition.nodes is flattened in infer_backpropagation_learning_pathways
387
                     or n not in [n[0] for n in self.composition._get_nested_nodes()])]
388

389
        _node_wrapper_pairs = []
1✔
390
        # Sort to be sure nested Compositions are processed last, as they need outer nodes that project in/out of them
391
        for node in sorted(nodes, key=lambda x: isinstance(x, AutodiffComposition)):
1✔
392
            # Wrap nested Composition
393
            if isinstance(node, AutodiffComposition):
1✔
394
                pytorch_node_wrapper = node.pytorch_composition_wrapper_type(composition=node,
1✔
395
                                                                             device=device,
396
                                                                             outer_creator=self,
397
                                                                             context=context)
398
            # Wrap Mechanism
399
            else:
400
                pytorch_node_wrapper = \
1✔
401
                    self.composition.pytorch_mechanism_wrapper_type(
402
                        mechanism=node,
403
                        composition=composition,
404
                        component_idx=self.composition._get_node_index(node),
405
                        use=[LEARNING, SYNCH, SHOW_PYTORCH],
406
                        dtype=self.torch_dtype,
407
                        device=device,
408
                        context=context)
409
                # pytorch_node._is_bias = all(input_port.default_input == DEFAULT_VARIABLE
410
                #                             for input_port in node.input_ports)
411
                pytorch_node_wrapper._is_bias = node in self.composition.get_nodes_by_role(NodeRole.BIAS)
1✔
412
            _node_wrapper_pairs.append((node, pytorch_node_wrapper))
1✔
413

414
        return _node_wrapper_pairs
1✔
415

416
    def _instantiate_pytorch_projection_wrappers(self, composition, device, context, base_context=Context(execution_id=None)) -> list:
1✔
417
        """Instantiate PytorchProjectionWrappers for Projections in the Composition being wrapped
418
        Assign Projections for outermost Composition (including any that are nested within it at any level)
419
        Note: Pytorch representation is "flattened" (i.e., any nested Compositions are replaced by their Nodes)
420
        so if any nested Compositions have Projections to/from them, they are assigned to the outermost Composition
421
        See figure in module docstring for explanation of how Projections to/from nested Compositions are handled.
422
        """
423

424
        proj_wrappers_pairs = []
1✔
425
        # Instantiate PyTorch ProjectionWrappers (ignoring any from/to CIMs in the same composition)
426
        for projection in composition._inner_projections:
1✔
427
            sndr_mech = projection.sender.owner
1✔
428
            rcvr_mech = projection.receiver.owner
1✔
429

430
            # Rule out that Composition has parameter_CIM,
431
            #    since autodiff does not (yet) support those and they are not (yet) handled by flattening below
432
            assert not hasattr(self, '_parameter_CIM'),\
1✔
433
                (f"PROGRAM ERROR: {self} has a parameter_CIM which is not should not currently be the case "
434
                 f"and is not handled by flatterning in {self.__class__.__name__}.")
435

436
            # Ignore input_CIM and output_CIM within the same Composition (they are not learnable)
437
            if sndr_mech is composition.input_CIM or rcvr_mech is composition.output_CIM:
1✔
438
                continue
1✔
439

440
            # Handle projection to or from a nested Composition
441
            elif (isinstance(sndr_mech, CompositionInterfaceMechanism) or
1✔
442
                  isinstance(rcvr_mech, CompositionInterfaceMechanism)):
443
                pnl_proj, proj_sndr, proj_rcvr, use = self._handle_nested_comp(projection, context, base_context)
1✔
444
                # # use = [LEARNING, SYNCH, SHOW_PYTORCH]
445
                # use = [LEARNING, SYNCH]
446

447
            # Projection within composition
448
            elif all(sndr_and_recvr in self.nodes_map for sndr_and_recvr in {sndr_mech, rcvr_mech}):
1!
449
                proj_sndr = self.nodes_map[sndr_mech]
1✔
450
                proj_rcvr = self.nodes_map[rcvr_mech]
1✔
451
                pnl_proj = projection
1✔
452
                use = [LEARNING, SYNCH, SHOW_PYTORCH]
1✔
453

454
            else:
455
                continue
×
456

457
            component_idx = list(self.composition._inner_projections).index(projection)
1✔
458
            sender_port_idx = projection.sender.owner.output_ports.index(projection.sender)
1✔
459
            pytorch_proj_wrapper = PytorchProjectionWrapper(projection=projection,
1✔
460
                                                            pnl_proj=pnl_proj,
461
                                                            component_idx=component_idx,
462
                                                            sender_port_idx=sender_port_idx,
463
                                                            use=use,
464
                                                            device=device,
465
                                                            sender_wrapper=proj_sndr,
466
                                                            receiver_wrapper=proj_rcvr,
467
                                                            composition=composition,
468
                                                            context=context)
469
            proj_sndr.add_efferent(pytorch_proj_wrapper)
1✔
470
            proj_rcvr.add_afferent(pytorch_proj_wrapper)
1✔
471

472
            proj_wrappers_pairs.append((projection, pytorch_proj_wrapper))
1✔
473

474
        return proj_wrappers_pairs
1✔
475

476
    def _handle_nested_comp(
1✔
477
        self,
478
        projection: MappingProjection,
479
        context: Context,
480
        base_context: Context = Context(execution_id=None),
481
    ) -> tuple:
482
        """Flatten nested Composition and assign Projections to/from it to outermost Composition
483
        This method is called when a Projection is to/from a CIM in a nested Composition that is not in the current
484
        Composition, and is needed for learning.
485
        It may be overridden by a subclass (grucomposition) to handle flattening differently.
486
        See figure in module docstring for explanation of how Projections to/from nested Compositions are handled.
487
        """
488
        sndr_mech = projection.sender.owner
1✔
489
        rcvr_mech = projection.receiver.owner
1✔
490

491
        # ENTER_NESTED:
492
        # input_cim of nested Composition:
493
        #    - projection is to input_CIM that is not in current Composition so must be to a nested one;
494
        #    - needed for learning, so create map for Projection
495
        if (isinstance(rcvr_mech, CompositionInterfaceMechanism)
1✔
496
                and rcvr_mech.composition is not self
497
                and rcvr_mech is rcvr_mech.composition.input_CIM):
498
            # Replace rcvr_mech (input_CIM) with the node in the nested Composition that receives the projection
499
            nested_rcvr_port, nested_rcvr_mech, _ = \
1✔
500
                rcvr_mech._get_destination_info_from_input_CIM(projection.receiver)
501
            nested_pytorch_comp_wrapper = self.nodes_map[rcvr_mech.composition]
1✔
502
            proj, proj_sndr_wrapper, proj_rcvr_wrapper, use = (
1✔
503
                nested_pytorch_comp_wrapper._flatten_for_pytorch(projection,
504
                                                                 sndr_mech, rcvr_mech,
505
                                                                 nested_rcvr_port,
506
                                                                 nested_rcvr_mech,
507
                                                                 self.composition,
508
                                                                 self,
509
                                                                 ENTER_NESTED,
510
                                                                 context,
511
                                                                 base_context,
512
                                                                 )
513
            )
514
            if proj_sndr_wrapper is None:
1✔
515
                proj_sndr_wrapper = self.nodes_map[sndr_mech]
1✔
516

517
        # EXIT_NESTED
518
        # output_cim of nested Composition:
519
        #    - projection is from output_CIM that is not in current Composition so must be from a nested one;
520
        #    - needed for learning, so create map for Projection
521
        elif (isinstance(sndr_mech, CompositionInterfaceMechanism)
1!
522
              and sndr_mech.composition is not self
523
              and sndr_mech is sndr_mech.composition.output_CIM):
524
            # Replace sndr_mech (output_CIM) with the node in the nested Composition that sends the projection
525
            nested_sndr_port, nested_sndr_mech, _ = \
1✔
526
                sndr_mech._get_source_info_from_output_CIM(projection.sender)
527
            nested_pytorch_comp_wrapper = self.nodes_map[sndr_mech.composition]
1✔
528
            proj, proj_sndr_wrapper, proj_rcvr_wrapper, use = (
1✔
529
                nested_pytorch_comp_wrapper._flatten_for_pytorch(projection,
530
                                                                 sndr_mech, rcvr_mech,
531
                                                                 nested_sndr_port,
532
                                                                 nested_sndr_mech,
533
                                                                 self.composition,
534
                                                                 self,
535
                                                                 EXIT_NESTED,
536
                                                                 context))
537
            if proj_rcvr_wrapper is None:
1✔
538
                proj_rcvr_wrapper = self.nodes_map[rcvr_mech]
1✔
539
        return proj, proj_sndr_wrapper, proj_rcvr_wrapper, use
1✔
540

541
    def _flatten_for_pytorch(self,
1✔
542
                             projection,
543
                             sndr_mech,
544
                             rcvr_mech,
545
                             nested_port,
546
                             nested_mech,
547
                             outer_comp,
548
                             outer_comp_pytorch_rep,
549
                             access,
550
                             context,
551
                             base_context=Context(execution_id=None),
552
                             ) -> tuple:
553
        proj_sndr_wrapper = None
1✔
554
        proj_rcvr_wrapper = None
1✔
555
        use = [LEARNING, SYNCH]
1✔
556

557
        if access == ENTER_NESTED:
1✔
558
            proj_rcvr_wrapper = self.nodes_map[nested_mech]
1✔
559
            # Assign Projection from input_CIM to nested_rcvr_port as pnl_proj (for use in forward())
560
            nested_comp = projection.receiver.owner.composition
1✔
561
            incoming_projections = [proj for proj in nested_comp.input_CIM.port_map[nested_port][1].efferents
1✔
562
                                    if proj in nested_comp.projections]
563
            assert len(incoming_projections) == 1, \
1✔
564
                (f"PROGRAM ERROR: There is more than one Projection registered in '{nested_comp.name}' "
565
                 f"from its input_CIM to '{nested_port.owner.name}'.")
566
            nested_port_afferents = [proj for proj in nested_port.path_afferents if proj in nested_comp.projections]
1✔
567
            pnl_proj = incoming_projections[0]
1✔
568
            if pnl_proj != nested_port.path_afferents[0]:
1✔
569
                from psyneulink.library.compositions.autodiffcomposition import AutodiffCompositionError
1✔
570
                raise AutodiffCompositionError(
571
                    f"First afferent Projection to '{nested_port.owner.name}' (which should be from "
572
                    f"'{nested_port.path_afferents[0].sender.owner.name}') is not the same as its "
573
                    f"Projection from the input_CIM of '{projection.receiver.owner.composition.name}'. "
574
                    f"One for this reason may be that these Components belong to different Compositions.")
575

576
            # Construct direct Projection from sender in outer Composition to receiver in nested Composition,
577
            #   and a PytorchCompositionWrapper for it that is assigned use=SHOW_PYTORCH,
578
            #   but don't add to either Composition as it is just used for show_graph(show_pytorch=True)
579
            destination_rcvr_port = rcvr_mech._get_destination_info_from_input_CIM(projection.receiver)[0]
1✔
580
            destination_rcvr_mech = rcvr_mech._get_destination_info_from_input_CIM(projection.receiver)[1]
1✔
581
            try:
1✔
582
                direct_proj = MappingProjection(name=f"Direct Projection from {projection.sender.owner.name} "
1✔
583
                                                     f"to {destination_rcvr_mech.name}",
584
                                                sender=projection.sender,
585
                                                receiver=destination_rcvr_port,
586
                                                learnable=projection.learnable)
587
            except DuplicateProjectionError:
1✔
588
                direct_proj = [proj for proj in projection.sender.efferents
1✔
589
                               if proj.receiver is destination_rcvr_port][0]
590
            else:
591
                direct_proj._initialize_from_context(context, base_context)
1✔
592

593
            if direct_proj not in self.projection_wrappers:
1!
594
                proj_wrapper = PytorchProjectionWrapper(projection=direct_proj,
1✔
595
                                                        pnl_proj=pnl_proj,
596
                                                        component_idx=None,    # These are not needed since the wrapper
597
                                                        sender_port_idx=None,  # is only being used for SHOW_PYTORCH
598
                                                        use=[SHOW_PYTORCH],
599
                                                        device=self.device,
600
                                                        sender_wrapper=proj_sndr_wrapper,
601
                                                        receiver_wrapper=proj_rcvr_wrapper,
602
                                                        composition=self.composition,
603
                                                        context=context)
604
                outer_comp_pytorch_rep.projection_wrappers.append(proj_wrapper)
1✔
605
                outer_comp_pytorch_rep.projections_map[direct_proj] = proj_wrapper
1✔
606
                outer_comp_pytorch_rep.composition._pytorch_projections.append(direct_proj)
1✔
607

608
        elif access == EXIT_NESTED:
1✔
609
            proj_sndr_wrapper = self.nodes_map[nested_mech]
1✔
610

611
            # Assign Projection from nested_sndr_port to output_CIM as pnl_proj
612
            assert nested_port.efferents[0] == projection.sender.owner.port_map[nested_port][0].path_afferents[0], \
1✔
613
                (f"PROGRAM ERROR: First efferent Projection from '{nested_port.owner.name}' "
614
                 f"(to '{nested_port.efferents[0].receiver.owner.name}') is not the same as its "
615
                 f"Projection to '{projection.sender.owner.composition.name}.output_CIM'."
616
                 f"One for this reason may be that these Components belong to different Compositions.")
617
            pnl_proj = projection
1✔
618

619
            # Construct direct Projection from sender in nested Composition to receiver in outer Composition,
620
            #   and a PytorchCompositionWrapper for it that is assigned use=SHOW_PYTORCH,
621
            #   but don't add to either Composition as it is just used for show_graph(show_pytorch=True)
622
            source_sndr_port = sndr_mech._get_source_info_from_output_CIM(projection.sender)[0]
1✔
623
            source_sndr_mech = sndr_mech._get_source_info_from_output_CIM(projection.sender)[1]
1✔
624
            try:
1✔
625
                direct_proj = MappingProjection(name=f"Direct Projection from {source_sndr_mech.name} "
1✔
626
                                                     f"to {rcvr_mech.name}",
627
                                                sender=source_sndr_port,
628
                                                receiver=projection.receiver,
629
                                                learnable=projection.learnable)
630
            except DuplicateProjectionError:
1✔
631
                direct_proj = [proj for proj in projection.receiver.path_afferents
1✔
632
                               if proj.sender is source_sndr_port][0]
633
            else:
634
                direct_proj._initialize_from_context(context, base_context)
1✔
635

636
            if direct_proj not in self.projection_wrappers:
1!
637
                proj_wrapper = PytorchProjectionWrapper(projection=direct_proj,
1✔
638
                                                        pnl_proj=pnl_proj,
639
                                                        component_idx=None,    # These are not needed since the wrapper
640
                                                        sender_port_idx=None,  # is only being used for SHOW_PYTORCH
641
                                                        use=[SHOW_PYTORCH],
642
                                                        device=self.device,
643
                                                        sender_wrapper=proj_sndr_wrapper,
644
                                                        receiver_wrapper=proj_rcvr_wrapper,
645
                                                        composition=self.composition,
646
                                                        context=context)
647
                outer_comp_pytorch_rep.projection_wrappers.append(proj_wrapper)
1✔
648
                outer_comp_pytorch_rep.projections_map[direct_proj] = proj_wrapper
1✔
649
                outer_comp_pytorch_rep.composition._pytorch_projections.append(direct_proj)
1✔
650

651
        else:
652
            assert False, f"PROGRAM ERROR: access must be ENTER_NESTED or EXIT_NESTED, not {access}"
653

654
        return pnl_proj, proj_sndr_wrapper, proj_rcvr_wrapper, use
1✔
655

656
    def _parse_optimizer_params(self, context):
1✔
657
        """Assign parameter-specific optimizer param groups for PyTorch GRU module"""
658
        composition = self.composition
1✔
659

660
        # Replace pnl names with actual torch params as keys in optimizer_params
661
        optimizer_params = self.composition._optimizer_params
1✔
662
        for param_name in optimizer_params.copy():
1!
NEW
663
            param = self._pnl_refs_to_torch_params_map.get(param_name, None)
×
NEW
664
            if param:
×
NEW
665
                optimizer_params[param] = optimizer_params.pop(param_name)
×
666

667
        # FIX: NOT ALL PROJECTIONS FOR WHICH learning_rate COULD BE SET ARE IN
668
        #      _pnl_refs_to_torch_params_map (SEE ABOVE) AND THEREFORE FINDABLE BELOW (INCLUDING IN state_dict())
669
        # Parse learning rate specs in optimizer_params
670
        for param, learning_rate in optimizer_params.items():
1!
NEW
671
            assert any(param is state_param for state_param in self.state_dict().values()), \
×
672
                f"PROGRAM ERROR: {param} not in state_dict for '{self.name}'"
NEW
673
            if composition.enable_learning is False:
×
NEW
674
                param.requires_grad = False
×
675
            else:
NEW
676
                if learning_rate is not False:
×
677
                    # If learning_rate is True, use composition.learning_rate, else specified value
NEW
678
                    lr = composition.learning_rate if isinstance(learning_rate, bool) else learning_rate
×
NEW
679
                    param.requires_grad = True
×
NEW
680
                    self._optimizer_param_groups.append({'params': param, 'lr': lr})
×
681

682
    def _get_execution_sets(self, composition, base_context)->list:
1✔
683
        """Return list of execution sets containing PytorchMechanismWrappers and/or PytorchCompositionWrappers"""
684
        execution_context = Context()
1✔
685
        try:
1✔
686
            composition.scheduler._init_counts(execution_id=execution_context.execution_id,
1✔
687
                                               base_execution_id=base_context.execution_id)
688
        except graph_scheduler.SchedulerError:
1✔
689
            # called from LLVM, no base context is provided
690
            composition.scheduler._init_counts(execution_id=execution_context.execution_id)
1✔
691

692
        # Setup execution sets
693
        # 1) Remove all learning-specific nodes
694
        execution_sets = [x - set(composition.get_nodes_by_role(NodeRole.LEARNING))
1✔
695
                               for x in composition.scheduler.run(context=execution_context)]
696
        # 2) Convert nodes to PytorchMechanismWrappers or PytorchCompositionWrappers
697
        execution_sets = [{self.nodes_map[comp] for comp in s if comp in self.nodes_map}
1✔
698
                               for s in execution_sets]
699
        # 3) Remove empty execution sets
700
        execution_sets = [x for x in execution_sets if len(x) > 0]
1✔
701

702
        # Flattening for forward() and AutodiffComposition.do_gradient_optimization
703

704
        # Flatten nested execution sets:
705
        nested_execution_sets = {}
1✔
706
        for exec_set in execution_sets:
1✔
707
            for node in exec_set:
1✔
708
                if isinstance(node, PytorchCompositionWrapper):
1✔
709
                    nested_execution_sets[node] = node.execution_sets
1✔
710
        for node, exec_sets in nested_execution_sets.items():
1✔
711
            index = execution_sets.index({node})
1✔
712
            # Remove nested Composition from execution sets
713
            execution_sets.remove({node})
1✔
714
            # Insert nested execution sets in place of nested Composition
715
            execution_sets[index:index] = exec_sets
1✔
716

717
        return execution_sets, execution_context
1✔
718

719
    __deepcopy__ = get_deepcopy_with_shared()
1✔
720

721
    def _regenerate_torch_parameter_list(self, base=None):
1✔
722
        """Add Projection matrices to Pytorch Module's parameter list"""
723

724
        # Register pytorch Parameters for ProjectionWrappers (since they are not already torch parameters
725
        for proj_wrapper in [p for p in self.projection_wrappers if not p.projection.exclude_in_autodiff]:
1✔
726
            self.register_parameter(proj_wrapper.name, proj_wrapper.matrix)
1✔
727

728
    # generates llvm function for self.forward
729
    def _gen_llvm_function(self, *, ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
1✔
730
        args = [ctx.get_state_struct_type(self.composition).as_pointer(),
1✔
731
                ctx.get_param_struct_type(self.composition).as_pointer(),
732
                ctx.get_data_struct_type(self.composition).as_pointer()
733
                ]
734
        builder = ctx.create_llvm_function(args, self)
1✔
735

736
        state, params, data = builder.function.args
1✔
737
        if "learning" in tags:
1✔
738
            self._gen_llvm_training_function_body(ctx, builder, state, params, data)
1✔
739
        else:
740
            model_input = builder.gep(data,
1✔
741
                                      [ctx.int32_ty(0),
742
                                       ctx.int32_ty(0),
743
                                       ctx.int32_ty(self.composition._get_node_index(self.composition.input_CIM))])
744
            self._gen_llvm_forward_function_body(ctx, builder, state, params, model_input, data)
1✔
745

746
        builder.ret_void()
1✔
747
        return builder.function
1✔
748

749
    def _gen_llvm_forward_function_body(self, ctx, builder, state, params, arg_in, data):
1✔
750
        z_values = {}  # dict for storing values of terminal (output) nodes
1✔
751
        for current_exec_set in self.execution_sets:
1✔
752
            for component in current_exec_set:
1✔
753
                mech_input_ty = ctx.get_input_struct_type(component.mechanism)
1✔
754
                variable = builder.alloca(mech_input_ty)
1✔
755
                z_values[component] = builder.alloca(mech_input_ty.elements[0].elements[0])
1✔
756
                builder.store(z_values[component].type.pointee(None),z_values[component])
1✔
757

758
                if NodeRole.INPUT in self.composition.get_roles_by_node(component.mechanism):
1✔
759
                    input_ptr = builder.gep(
1✔
760
                        variable, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(0)])
761
                    input_id = component._idx
1✔
762
                    mech_in = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(input_id)])
1✔
763
                    builder.store(builder.load(mech_in), input_ptr)
1✔
764
                for (proj_idx, proj) in enumerate(component.afferents):
1✔
765
                    input_ptr = builder.gep(
1✔
766
                        variable, [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(proj_idx)])
767
                    proj_output = proj._gen_llvm_execute(ctx, builder, state, params, data)
1✔
768
                    # store in input ports struct
769
                    builder.store(builder.load(proj_output), input_ptr)
1✔
770
                    # HACK: Add to z_values struct
771
                    gen_inject_vec_add(ctx, builder, proj_output, z_values[component], z_values[component])
1✔
772
                component._gen_llvm_execute(ctx, builder, state, params, variable, data)
1✔
773

774
        return z_values
1✔
775

776
    # generates a function responsible for a single epoch of the training
777
    def _gen_llvm_training_backprop(self, ctx, optimizer, loss):
1✔
778
        composition = self.composition
1✔
779
        args = [ctx.get_state_struct_type(composition).as_pointer(),
1✔
780
                ctx.get_param_struct_type(composition).as_pointer(),
781
                ctx.get_data_struct_type(composition).as_pointer(),
782
                optimizer._get_optimizer_struct_type(ctx).as_pointer(),
783
                ]
784
        name = self.composition.name + "_training_backprop"
1✔
785
        builder = ctx.create_llvm_function(args, self, name)
1✔
786
        llvm_func = builder.function
1✔
787
        for a in llvm_func.args:
1✔
788
            if isinstance(a.type, pnlvm.ir.PointerType):
1!
789
                a.attributes.add('noalias')
1✔
790

791
        state, params, data, optim_struct = llvm_func.args
1✔
792
        model_input = builder.gep(data, [ctx.int32_ty(0),
1✔
793
                                         ctx.int32_ty(0),
794
                                         ctx.int32_ty(self.composition._get_node_index(self.composition.input_CIM))])
795
        model_output = data
1✔
796
        # setup useful mappings
797
        input_nodes = set(self.composition.get_nodes_by_role(NodeRole.INPUT))
1✔
798

799
        # initialize optimizer params:
800
        delta_w = builder.gep(optim_struct, [ctx.int32_ty(0), ctx.int32_ty(optimizer._DELTA_W_NUM)])
1✔
801

802
        # 2) call forward computation
803
        z_values = self._gen_llvm_forward_function_body(
1✔
804
            ctx, builder, state, params, model_input, data)
805

806
        # 3) compute errors
807
        loss_fn = ctx.import_llvm_function(loss)
1✔
808
        total_loss = builder.alloca(ctx.float_ty)
1✔
809
        builder.store(total_loss.type.pointee(0), total_loss)
1✔
810

811
        error_dict = {}
1✔
812
        for exec_set in reversed(self.execution_sets):
1✔
813
            for node in exec_set:
1✔
814
                if node.mechanism in input_nodes:
1✔
815
                    continue
1✔
816

817
                node_z_value = z_values[node]
1✔
818
                activation_func_derivative = node._gen_llvm_execute_derivative_func(ctx, builder, state, params, node_z_value)
1✔
819
                error_val = builder.alloca(z_values[node].type.pointee)
1✔
820
                error_dict[node] = error_val
1✔
821

822
                if NodeRole.OUTPUT in self.composition.get_roles_by_node(node.mechanism):
1✔
823
                    # We handle output layer here
824
                    # compute  dC/da = a_l - y(x) (TODO: Allow other cost functions! This only applies to MSE)
825

826
                    # 1) Lookup desired target value
827
                    terminal_sequence = self.composition._terminal_backprop_sequences[node.mechanism]
1✔
828
                    target_idx = self.composition.get_nodes_by_role(NodeRole.INPUT).index(terminal_sequence[TARGET_MECHANISM])
1✔
829
                    node_target = builder.gep(model_input, [ctx.int32_ty(0), ctx.int32_ty(target_idx)])
1✔
830

831
                    # 2) Lookup desired output value
832
                    node_output = builder.gep(model_output,
1✔
833
                                              [ctx.int32_ty(0), ctx.int32_ty(0), ctx.int32_ty(node._idx), ctx.int32_ty(0)])
834

835
                    tmp_loss = loss.gen_inject_lossfunc_call(ctx, builder, loss_fn, node_output, node_target)
1✔
836

837
                    pnlvm.helpers.printf_float_array(ctx, builder, node_target, prefix=f"{node}\ttarget:\t", tags={"torch"})
1✔
838
                    pnlvm.helpers.printf_float_array(ctx, builder, node_output, prefix=f"{node}\tvalue:\t", tags={"torch"})
1✔
839

840
                    pnlvm.helpers.printf(ctx, builder, f"{node}\tloss:\t%f\n", tmp_loss, tags={"torch"})
1✔
841
                    builder.store(builder.fadd(builder.load(total_loss), tmp_loss), total_loss)
1✔
842
                    loss_derivative = loss._gen_inject_loss_differential(ctx, builder, node_output, node_target)
1✔
843

844
                    # compute δ_l = dσ/da ⊙ σ'(z)
845
                    gen_inject_vec_hadamard(ctx, builder, activation_func_derivative, loss_derivative, error_val)
1✔
846

847
                else:
848
                    # We propagate error backwards from next layer
849
                    for proj_idx, proj in enumerate(node.efferents):
1✔
850
                        efferent_node = proj.receiver_wrapper
1✔
851
                        efferent_node_error = error_dict[efferent_node]
1✔
852

853
                        weights_llvmlite = proj._extract_llvm_matrix(ctx, builder, state, params)
1✔
854

855
                        if proj_idx == 0:
1✔
856
                            gen_inject_vxm_transposed(ctx, builder, efferent_node_error, weights_llvmlite, error_val)
1✔
857
                        else:
858
                            new_val = gen_inject_vxm_transposed(ctx, builder, efferent_node_error, weights_llvmlite)
1✔
859

860
                            gen_inject_vec_add(ctx, builder, new_val, error_val, error_val)
1✔
861

862
                    gen_inject_vec_hadamard(ctx, builder, activation_func_derivative, error_val, error_val)
1✔
863

864
                pnlvm.helpers.printf_float_array(ctx, builder, activation_func_derivative, prefix=f"{node}\tdSigma:\t", tags={"torch"})
1✔
865
                pnlvm.helpers.printf_float_array(ctx, builder, error_val, prefix=f"{node}\terror:\t", tags={"torch"})
1✔
866

867
        # 4) compute weight gradients
868
        for (node, err_val) in error_dict.items():
1✔
869
            if node in input_nodes:
1!
870
                continue
×
871

872
            for proj in node.afferents:
1✔
873
                # get a_(l-1)
874
                afferent_node_activation = builder.gep(model_output, [ctx.int32_ty(0),
1✔
875
                                                                      ctx.int32_ty(0),
876
                                                                      ctx.int32_ty(proj.sender_wrapper._idx),
877
                                                                      ctx.int32_ty(0)])
878

879
                # get dimensions of weight matrix
880
                weights_llvmlite = proj._extract_llvm_matrix(ctx, builder, state, params)
1✔
881
                pnlvm.helpers.printf_float_matrix(ctx,
1✔
882
                                                  builder,
883
                                                  weights_llvmlite,
884
                                                  prefix= f"{proj.sender_wrapper.mechanism} -> "
885
                                                          f"{proj.receiver_wrapper.mechanism}\n", tags={"torch"})
886
                # update delta_W
887
                node_delta_w = builder.gep(delta_w, [ctx.int32_ty(0), ctx.int32_ty(proj._idx)])
1✔
888

889
                dim_x, dim_y = proj.matrix.shape
1✔
890
                with pnlvm.helpers.for_loop_zero_inc(builder, ctx.int32_ty(dim_x),
1✔
891
                                                     "weight_update_loop_outer") as (b1, weight_row):
892
                    with pnlvm.helpers.for_loop_zero_inc(b1, ctx.int32_ty(dim_y),
1✔
893
                                                         "weight_update_loop_inner") as (b2, weight_column):
894
                        a_val = b2.load(b2.gep(afferent_node_activation,
1✔
895
                                               [ctx.int32_ty(0), weight_row]))
896
                        d_val = b2.load(b2.gep(err_val,
1✔
897
                                               [ctx.int32_ty(0), weight_column]))
898
                        old_val = b2.load(b2.gep(node_delta_w,
1✔
899
                                                 [ctx.int32_ty(0), weight_row, weight_column]))
900
                        new_val = b2.fadd(old_val, b2.fmul(a_val, d_val))
1✔
901
                        b2.store(new_val, b2.gep(node_delta_w,
1✔
902
                                                 [ctx.int32_ty(0), weight_row, weight_column]))
903

904
        pnlvm.helpers.printf(ctx, builder, "TOTAL LOSS:\t%.20f\n", builder.load(total_loss), tags={"torch"})
1✔
905
        builder.ret_void()
1✔
906

907
        return builder.function
1✔
908

909
    def _gen_llvm_training_function_body(self, ctx, builder, state, params, data):
1✔
910
        composition = self.composition
1✔
911

912
        optimizer = self._get_compiled_optimizer()
1✔
913
        # setup loss
914
        loss_type = self.composition.loss_spec
1✔
915
        if loss_type == Loss.MSE:
1✔
916
            loss = MSELoss()
1✔
917
        elif loss_type == Loss.CROSS_ENTROPY:
1!
918
            loss = CROSS_ENTROPYLoss()
1✔
919
        else:
920
            raise Exception("LOSS TYPE", loss_type, "NOT SUPPORTED")
×
921

922
        optimizer_step_f = ctx.import_llvm_function(optimizer)
1✔
923
        optimizer_struct_idx = len(state.type.pointee.elements) - 1
1✔
924
        optimizer_struct = builder.gep(state, [ctx.int32_ty(0), ctx.int32_ty(optimizer_struct_idx)])
1✔
925
        optimizer_zero_grad = ctx.import_llvm_function(optimizer.zero_grad(ctx).name)
1✔
926
        backprop = ctx.import_llvm_function(self._gen_llvm_training_backprop(ctx, optimizer, loss).name)
1✔
927

928
        # # FIXME: converting this call to inlined code results in
929
        # # significant longer compilation times
930
        builder.call(optimizer_zero_grad, [optimizer_struct])
1✔
931
        builder.call(backprop, [state, params, data,
1✔
932
                                optimizer_struct])
933
        builder.call(optimizer_step_f, [optimizer_struct, state, params])
1✔
934

935
    def _get_compiled_optimizer(self):
1✔
936
        # setup optimizer
937
        optimizer_type = self.composition.optimizer_type
1✔
938
        if optimizer_type == 'adam':
1✔
939
            optimizer = AdamOptimizer(self, lr=self.composition.learning_rate)
1✔
940
        elif optimizer_type == 'sgd':
1!
941
            optimizer = SGDOptimizer(self, lr=self.composition.learning_rate)
1✔
942
        else:
943
            raise Exception("OPTIMIZER TYPE", optimizer_type, "NOT SUPPORTED")
×
944
        return optimizer
1✔
945

946
    @handle_external_context()
1✔
947
    def forward(self, inputs, optimization_num, synch_with_pnl_options, context=None)->dict:
1✔
948
    # def forward(self, inputs, optimization_rep, context=None) -> dict:
949
        """Forward method of the model for PyTorch and LLVM modes
950
        Return a dictionary {output_node:value} of output values for the model
951
        """
952

953
        # Store the batch_size we are currently using
954
        inp = inputs[list(inputs.keys())[0]]
1✔
955
        if type(inp) is torch.Tensor:
1!
956
            self._batch_size = inp.shape[0]
1✔
957
        elif type(inp) is list:
×
958
            self._batch_size = len(inp)
×
959
        else:
960
            raise ValueError("Inputs to PytorchCompositionWrapper.forward must be either torch.Tensors or lists of "
961
                             "torch.Tensors")
962

963
        outputs = {}  # dict for storing values of terminal (output) nodes
1✔
964
        for current_exec_set in self.execution_sets:
1✔
965
            for node in current_exec_set:
1✔
966

967
                # If node is nested Composition (wrapped in PytorchCompositionWrapper),
968
                #    call its forward method recursively; no need to manage outputs, as the Composition has been
969
                #    "flattened" (i.e., its nodes have been moved up into the outer Composition of the PyTorch
970
                #    representation) in _build_pytorch_representation), so its outputs will be "consumed" by the
971
                #    MechanismWrappers' `aggregate_afferents()` method to which it projects in the outer Composition.
972
                if isinstance(node, PytorchCompositionWrapper):
1!
NEW
973
                    node.forward(inputs=None, optimization_num=optimization_num, context=context)
×
974
                    continue
×
975

976
                # Get input(s) to node
977
                elif node._is_input or node._is_bias:
1✔
978
                    # node is an INPUT to Composition
979
                    if node.mechanism in inputs:
1✔
980
                        # external input is specified for the Mechanism (i.e., Mechanism is a key in inputs dict)
981
                        if not node._is_bias:
1!
982
                            # all input_ports receive external input, so use that
983
                            variable = inputs[node.mechanism]
1✔
984
                        else:
985
                            # node is also a BIAS node, so get input for each input_port individually
986
                            variable = []
×
NEW
987
                            for i, input_port in enumerate(node.mechanism.input_ports):
×
NEW
988
                                input = inputs[node.mechanism]
×
UNCOV
989
                                if not input_port.internal_only:
×
990
                                    # input_port receives external input, so get from inputs
991
                                    variable.append(input[i])
×
992
                                elif input_port.default_input == DEFAULT_VARIABLE:
×
993
                                    # input_port uses a bias, so get that
994
                                    val = input_port.defaults.variable
×
995

996
                                    # We need to add the batch dimension to default values.
997
                                    val = val[None, ...].expand(self._batch_size, *val.shape)
×
998

999
                                    variable.append(val)
×
1000

1001
                            # We now need to stack these so the batch dimension is first
1002
                            try:
×
1003
                                variable = torch.stack(variable, dim=1)
×
1004
                            except (RuntimeError, TypeError):
×
1005
                                # ragged, we need to reshape so batch dimension is first
1006
                                # is ragged, need to reshape things so batch size is first dimension.
1007
                                batch_size = variable[0].shape[0]
×
1008
                                variable = [[inp[b] for inp in variable] for b in range(batch_size)]
×
1009

1010
                    # Input for the Mechanism is *not* explicitly specified, but its input_port(s) may have been
1011
                    else:
1012
                        # Get input for each input_port of the node
1013
                        variable = []
1✔
1014
                        for i, input_port in enumerate(node.mechanism.input_ports):
1✔
1015
                            if input_port in inputs:
1✔
1016
                                # input to input_port is specified in the inputs dict, so use that
1017
                                variable.append(inputs[input_port])
1✔
1018
                            elif input_port.default_input == DEFAULT_VARIABLE:
1✔
1019
                                # input_port uses a bias, so get that
1020
                                val = torch.from_numpy(input_port.defaults.variable)
1✔
1021

1022
                                # We need to add the batch dimension to default values.
1023
                                val = val[None, ...].expand(self._batch_size, *val.shape)
1✔
1024

1025
                                variable.append(val)
1✔
1026
                            elif not input_port.internal_only:
1!
1027
                                # otherwise, use the node's input_port's afferents
1028
                                variable.append(node.collect_afferents(batch_size=self._batch_size,
1✔
1029
                                                                       port=i,
1030
                                                                       inputs=inputs))
1031

1032
                        # We now need to stack these so the batch dimension is first
1033
                        try:
1✔
1034
                            variable = torch.stack(variable, dim=1)
1✔
1035
                        except (RuntimeError, TypeError):
×
1036
                            # ragged, we need to reshape so batch dimension is first
1037
                            # is ragged, need to reshape things so batch size is first dimension.
1038
                            batch_size = variable[0].shape[0]
×
1039
                            variable = [[inp[b] for inp in variable] for b in range(batch_size)]
×
1040
                else:
1041
                    # Node is not INPUT to Composition or BIAS, so get all input from its afferents
1042
                    variable = node.collect_afferents(batch_size=self._batch_size, inputs=inputs)
1✔
1043
                variable = node.execute_input_ports(variable)
1✔
1044

1045
                # Node is excluded from gradient calculations, so cache for later execution
1046
                if node.exclude_from_gradient_calc:
1✔
1047
                    if node.exclude_from_gradient_calc == AFTER:
1!
1048
                        # Cache variable for later exce execution
1049
                        self._nodes_to_execute_after_gradient_calc[node] = variable
1✔
1050
                        continue
1✔
1051
                    elif node.exclude_from_gradient_calc == BEFORE:
×
1052
                        assert False, 'PROGRAM ERROR: node.exclude_from_gradient_calc == BEFORE not yet implemented'
1053
                    else:
1054
                        assert False, \
1055
                            (f'PROGRAM ERROR: Bad assignment to {node.name}.exclude_from_gradient_calc: '
1056
                             f'{node.exclude_from_gradient_calc}; only {AFTER} is currently supported')
1057

1058
                # Execute the node (i.e., call its forward method) using composition_wrapper for Composition
1059
                # to which it belongs; this is to support override of the execute_node method by subclasses of
1060
                # PytorchCompositionWrapper (such as EMComposition and GRUComposition).
1061
                node.execute(variable, optimization_num, synch_with_pnl_options, context)
1✔
1062

1063
                assert 'DEBUGGING BREAK POINT'
1✔
1064

1065
                # Add entry to outputs dict for OUTPUT Nodes of pytorch representation
1066
                #  note: these may be different than for actual Composition, as they are flattened
1067
                if node._is_output or node.mechanism in self.output_nodes:
1✔
1068
                    outputs[node.mechanism] = node.output
1✔
1069

1070
        # NOTE: Context source needs to be set to COMMAND_LINE to force logs to update independently of timesteps
1071
        # if not self.composition.is_nested:
1072
        old_source = context.source
1✔
1073
        context.source = ContextFlags.COMMAND_LINE
1✔
1074
        self.log_values()
1✔
1075
        self.log_weights()
1✔
1076
        context.source = old_source
1✔
1077

1078
        # Return outputs of the outermost Composition
1079
        return outputs
1✔
1080

1081
    def synch_with_psyneulink(self,
1✔
1082
                              synch_with_pnl_options:dict,
1083
                              current_condition:LEARNING_SCALE_LITERALS,
1084
                              context:Context,
1085
                              params:Optional[list]=None):
1086
        """Copy weights, variables, values, and/or results from Pytorch to PsyNeuLink at specified junctures
1087
        params can be used to restrict copy to a specific (set of) param(s). If params is not specified, all are copied;
1088
        """
1089
        all = [MATRIX_WEIGHTS, NODE_VARIABLES, NODE_VALUES,
1✔
1090
               # 3/15/25 FIX: ADD SUPPORT FOR THESE IN AutodiffComposition AND BELOW
1091
               # NODE_OUTPUT_VALUES, EXECUTE_NODES,
1092
               RESULTS]
1093
        params = convert_to_list(params) or all
1✔
1094
        illegal_params = [param for param in params if param not in all]
1✔
1095
        assert not illegal_params, \
1✔
1096
            f"PROGRAM ERROR: Illegal attributes ({' ,'.join(illegal_params)}) specified in call to synch_with_psyneulink"
1097

1098
        if MATRIX_WEIGHTS in params and synch_with_pnl_options[MATRIX_WEIGHTS] == current_condition:
1✔
1099
            self._copy_weights_to_psyneulink(context)
1✔
1100

1101
        # If either NODE_VARIABLES or NODE_VALUES is specified, and current condition is met, do relevant copies
1102
        if ((NODE_VARIABLES in params and synch_with_pnl_options[NODE_VARIABLES] == current_condition)
1✔
1103
                or (NODE_VALUES in params and synch_with_pnl_options[NODE_VALUES] == current_condition)):
1104
            self.copy_node_variables_and_values_to_psyneulink({k:v for k,v in synch_with_pnl_options.items()
1✔
1105
                                                               if (k in {NODE_VARIABLES, NODE_VALUES} and
1106
                                                                   v == current_condition)},
1107
                                                              context)
1108

1109
        if RESULTS in params and synch_with_pnl_options[RESULTS] == current_condition:
1✔
1110
            self.copy_results_to_psyneulink(current_condition, context)
1✔
1111

1112
    def _copy_weights_to_psyneulink(self, context=None):
1✔
1113
        for proj_wrapper in self.projections_map.values():
1✔
1114
            if SYNCH in proj_wrapper._use:
1!
1115
                proj_wrapper._copy_torch_params_to_pnl_proj(context)
1✔
1116

1117
    def log_weights(self):
1✔
1118
        for proj_wrapper in self.projection_wrappers:
1✔
1119
            proj_wrapper.log_matrix()
1✔
1120

1121
    def copy_node_variables_and_values_to_psyneulink(self, options:dict, context=None):
1✔
1122
        for pytorch_node in self.nodes_map.values():
1✔
1123
            pytorch_node.set_pnl_variable_and_values(set_variable=True if NODE_VARIABLES in options else False,
1✔
1124
                                                     set_value=True if NODE_VALUES in options else False,
1125
                                                     # FIX: 3/15/25 - ADD SUPPORT FOR THESE
1126
                                                     # set_output_values=True if OUTPUT_VALUES in options else False,
1127
                                                     # execute_mech=True if EXECUTE_NODES in options else False,
1128
                                                     context=context)
1129

1130
        # Update output_values of autodiff Composition by executing its output_CIM with pytorch_rep all_output_values
1131
        if self.all_output_values is not None:
1!
1132
            # Execute the output_CIM on the last element of the batch to update the output ports
1133
            self.composition.output_CIM.execute(self.all_output_values[-1, ...], context=context)
1✔
1134

1135
    def log_values(self):
1✔
1136
        for node_wrapper in [n for n in self.node_wrappers if not isinstance(n, PytorchCompositionWrapper)]:
1✔
1137
            node_wrapper.log_value()
1✔
1138

1139
    def copy_results_to_psyneulink(self, current_condition, context=None):
1✔
1140
        """Append outputs of Pytorch forward() to AutodiffComposition.results attribute."""
1141
        # IMPLEMENTATION NOTE: no need to do anything for TRIAL or MINIBATCH,
1142
        #  as Composition's _update_results() method is getting called to do that locally
1143
        if current_condition in {EPOCH, RUN}:
1✔
1144
            results_param = self.composition.parameters.results
1✔
1145
            prev_results = results_param._get(context)
1✔
1146
            curr_results = convert_to_np_array(self.retained_results)
1✔
1147
            if len(prev_results):
1✔
1148
                new_results = np.append(prev_results, curr_results, 0)
1✔
1149
            else:
1150
                new_results = curr_results
1✔
1151
            self.retained_results = []
1✔
1152
            results_param._set(new_results, context)
1✔
1153

1154

1155

1156
    def retain_for_psyneulink(self,
1✔
1157
                              data:dict,
1158
                              retain_in_pnl_options:dict,
1159
                              context):
1160
        """Store outputs, targets, and losses from Pytorch execution for copying to PsyNeuLink at end of learn().
1161
        Arguments
1162
        ---------
1163
        data : dict
1164
            specifies local data available to retain (for copying to pnl at end of run;
1165
            keys must be one or more of the keywords OUTPUTS, TARGETS, or LOSSES; values must be a torch.Tensor
1166
        retain_in_pnl_options : dict
1167
            specifies which data the user has requested be retained (and copied to pnl at end of run)
1168
            keys must be OUTPUTS, TARGETS, or LOSSES; value must be a LearningScale.name or None (which suppresses copy)
1169
        Note:  does not actually copy data to pnl; that is done by _getter methods for the relevant autodiff Parameters
1170
        """
1171
        try:
1✔
1172
            for data_type, data_val in data.items():
1✔
1173
                try:
1✔
1174
                    if retain_in_pnl_options[data_type]:
1!
1175
                        retain_method_idx = DataTypeEnum._member_map_[data_type.upper()].value
1✔
1176
                        self.retain_method[retain_method_idx](data_val)
1✔
1177
                except KeyError:
×
1178
                    assert False, \
1179
                        (f"PROGRAM ERROR: No entry for {data_type} found in retain_in_pnl_options "
1180
                         f"in call to retain_for_psyneulink()")
1181
        except KeyError:
×
1182
            assert False, \
1183
                (f"PROGRAM ERROR: Invalid key(s) specified in call to retain_for_psyneulink: {list(data.keys())}")
1184

1185
    def retain_results(self, results:list):
1✔
1186
        """Track outputs and copy to AutodiffComposition.pytorch_outputs at end of learn()."""
1187
        if len(results):
1!
1188
            self.retained_results.append(results)
1✔
1189

1190
    def retain_trained_outputs(self, trained_outputs:list):
1✔
1191
        """Track outputs and copy to AutodiffComposition.pytorch_outputs at end of learn()."""
1192
        self.retained_trained_outputs.append(trained_outputs)
1✔
1193

1194
    def retain_targets(self, targets:list):
1✔
1195
        """Track targets and copy to AutodiffComposition.pytorch_targets at end of learn()."""
1196
        self.retained_targets.append(targets)
1✔
1197

1198
    def retain_losses(self, loss:torch.Tensor):
1✔
1199
        """Track losses and copy to AutodiffComposition.pytorch_targets at end of learn()."""
1200
        self.retained_losses.append(loss.detach().cpu().numpy().copy().tolist())
1✔
1201

1202
    def detach_all(self):
1✔
1203
        for projection in self.projections_map.values():
×
1204
            projection.matrix.detach()
×
1205

1206

1207
class PytorchMechanismWrapper(torch.nn.Module):
1✔
1208
    """Wrapper for a Mechanism in a PytorchCompositionWrapper
1209
    These comprise nodes of the PytorchCompositionWrapper, and generally correspond to functions in a Pytorch model.
1210

1211
    Attributes
1212
    ----------
1213

1214
    mechanism : Mechanism
1215
        the PsyNeuLink `Mechanism` being wrapped.
1216

1217
    composition : AutodiffComposition
1218
        the `AutodiffComposition` to which the `Mechanism` being wrapped belongs
1219
        (and for which the PytorchCompositionWrapper -- to which the PytorchMechanismWrapper
1220
        belongs -- is the pytorch_representation).
1221

1222
    afferents : List[PytorchProjectionWrapper]
1223
        list of `PytorchProjectionWrapper` objects that project to the PytorchMechanismWrapper.
1224

1225
    input : torch.Tensor
1226
        most recent input to the PytorchMechanismWrapper.
1227

1228
    function : _gen_pytorch_fct
1229
        Pytorch version of the Mechanism's function assigned in its __init__.
1230

1231
    integrator_function : _gen_pytorch_fct
1232
        Pytorch version of the Mechanism's integrator_function assigned in its __init__ if Mechanism
1233
        has an integrator_function;  this assumes the Mechanism also has an integrator_mode attribute
1234
        that is used to determine whether to execute the integrator_function first, and use its result
1235
        as the input to its function.
1236

1237
    output : torch.Tensor
1238
        most recent output of the PytorchMechanismWrapper.
1239

1240
    efferents : List[PytorchProjectionWrapper]
1241
        list of `PytorchProjectionWrapper` objects that project from the PytorchMechanismWrapper.
1242

1243
    exclude_from_gradient_calc : bool or str[BEFORE | AFTER]: False
1244
        used to prevent a node from being included in the Pytorch gradient calculation by excluding it in calls to
1245
        the forward() and backward().  If AFTER is specified, the node is executed after at the end of the
1246
        `update_learning_parameters` method.  BEFORE is not currently supported
1247

1248
    _use : list[LEARNING, SYNCH]
1249
        designates the uses of the Mechanism, specified by the following keywords (see
1250
        PytorchCompositionWrapper `docstring <Mechanism_and_Projection_Uses>` for additional details):
1251

1252
        * *LEARNING*: inputs and `function <Mechanism_Base.function>` Parameters) are used
1253
          for actual execution of the corresponding Pytorch Module;
1254

1255
        * *SYNCH*: used to store results of executing a Pytorch module that are then transferred to
1256
          the `value <Mechanism_Base.value>` Parameter of the PytorchMechanismWrapper\\s `mechanism
1257
          <PytorchMechanismWrapper.mechanism>`;
1258

1259
        * *SHOW_PYTORCH*:  `Mechanism <PytorchProjectionWrapper.projection>` is included when the
1260
          `AutoDiffComposition`\\s `show_graph <AutoDiffComposition.show_graph>` method to used with the
1261
          ``show_pytorch`` option to display its `pytorch_representation <AutodiffComposition.pytorch_representation>`;
1262
          if it is not specified, the `Mechanism <PytorchProjectionWrapper.projection>` is not displayed when the
1263
          `AutoDiffComposition`\\s `show_graph <AutoDiffComposition.show_graph>` method is called, even if the
1264
          ``show_pytorch`` option is specified.
1265
    """
1266

1267
    def __init__(self,
1✔
1268
                 mechanism:ProcessingMechanism,                 # Mechanism to be wrapped
1269
                 composition,                                   # one to which mech belongs (for nested executions)
1270
                 component_idx:Optional[int],                   # index of the Mechanism in the Composition
1271
                 use:Union[list, Literal[LEARNING, SYNCH, SHOW_PYTORCH]], # learning, synching of values and/or display
1272
                 dtype:torch.dtype,                             # needed for Pytorch
1273
                 device:str,                                    # needed for Pytorch
1274
                 subclass_specifies_function:bool=False,        # used to determine whether to assign function here
1275
                 context=None):
1276
        # # MODIFIED 7/10/24 NEW: NEEDED FOR torch MPS SUPPORT
1277
        # super().__init__()
1278
        # MODIFIED 7/10/24 END
1279
        super().__init__()
1✔
1280
        self.name = f"PytorchMechanismWrapper[{mechanism.name}]"
1✔
1281
        self.mechanism = mechanism
1✔
1282
        self._idx = component_idx
1✔
1283
        self._context = context
1✔
1284
        self._is_input = False
1✔
1285
        self._is_bias = False
1✔
1286
        self._is_output = False
1✔
1287
        self._use = use or [LEARNING, SYNCH, SHOW_PYTORCH]
1✔
1288
        self._curr_sender_value = None # Used to assign initializer or default if value == None (i.e., not yet executed)
1✔
1289
        self.exclude_from_gradient_calc = False # Used to execute node before or after forward/backward pass methods
1✔
1290

1291
        from psyneulink.library.compositions.autodiffcomposition import AutodiffComposition
1✔
1292
        assert isinstance(composition, AutodiffComposition), \
1✔
1293
            f"PROGRAM ERROR: {composition} must be an AutodiffComposition."
1294
        self.composition = composition
1✔
1295
        self.torch_dtype = dtype
1✔
1296

1297
        self.input = None
1✔
1298
        self.output = None
1✔
1299

1300
        if mechanism.parameters.has_initializers._get(context) and mechanism.parameters.value.initializer:
1!
1301
            self.default_output = mechanism.parameters.value.initializer.get(context)
×
1302
        else:
1303
            self.default_output = mechanism.defaults.value
1✔
1304
        self.afferents = []
1✔
1305
        self.efferents = []
1✔
1306

1307
        if subclass_specifies_function is False:
1✔
1308
            self._assign_pytorch_function(mechanism, device, context)
1✔
1309

1310
    def _assign_pytorch_function(self, mechanism, device, context):
1✔
1311
        self.function = PytorchFunctionWrapper(mechanism.function, device, context)
1✔
1312

1313
        if hasattr(mechanism, 'integrator_function'):
1✔
1314
            self.integrator_function = PytorchFunctionWrapper(mechanism.integrator_function, device, context)
1✔
1315
            self.integrator_previous_value = mechanism.integrator_function._get_pytorch_fct_param_value('initializer', device, context)
1✔
1316

1317
        self.input_ports = [PytorchFunctionWrapper(input_port.function, device, context)
1✔
1318
                            for input_port in mechanism.input_ports]
1319

1320
    def add_afferent(self, afferent):
1✔
1321
        """Add ProjectionWrapper for afferent to MechanismWrapper.
1322
        For use in call to collect_afferents
1323
        """
1324
        assert afferent not in self.afferents
1✔
1325
        self.afferents.append(afferent)
1✔
1326

1327
    def add_efferent(self, efferent):
1✔
1328
        """Add ProjectionWrapper for efferent from MechanismWrapper.
1329
        Implemented for completeness;  not currently used
1330
        """
1331
        assert efferent not in self.efferents
1✔
1332
        self.efferents.append(efferent)
1✔
1333

1334
    def execute(self, variable, optimization_num, synch_with_pnl_options, context=None)->torch.Tensor:
1✔
1335
        """Execute Mechanism's _gen_pytorch version of function on variable.
1336
        Enforce result to be 2d, and assign to self.output
1337
        """
1338
        def execute_function(function, variable, fct_has_mult_args=False):
1✔
1339
            """Execute _gen_pytorch_fct on variable, enforce result to be 2d, and return it
1340
            If fct_has_mult_args is True, treat each item in variable as an arg to the function
1341
            If False, compute function for each item in variable and return results in a list
1342
            """
1343
            from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction
1✔
1344
            if fct_has_mult_args:
1✔
1345
                res = function(*variable)
1✔
1346
            # variable is ragged
1347
            elif isinstance(variable, list):
1✔
1348
                # res = [function(variable[i]) for i in range(len(variable))]
1349
                res = [function(torch.stack([batch_elem[i] for batch_elem in variable])) for i in range(len(variable[0]))]
1✔
1350

1351
                # Reshape to batch dimension first
1352
                batch_size = res[0].shape[0]
1✔
1353
                res = [[inp[b] for inp in res] for b in range(batch_size)]
1✔
1354

1355
            else:
1356
                # Functions handle batch dimensions, just run the
1357
                # function with the variable and get back a tensor.
1358
                res = function(variable)
1✔
1359
            # TransformFunction can reduce output to single item from
1360
            # multi-item input
1361
            if isinstance(function._pnl_function, TransformFunction):
1✔
1362
                res = res.unsqueeze(1)
1✔
1363
            return res
1✔
1364

1365
        # If mechanism has an integrator_function and integrator_mode is True,
1366
        #   execute it first and use result as input to the main function;
1367
        #   assumes that if PyTorch node has been assigned an integrator_function then mechanism has an integrator_mode
1368
        if hasattr(self, 'integrator_function') and self.mechanism.parameters.integrator_mode._get(context):
1✔
1369
            variable = execute_function(self.integrator_function,
1✔
1370
                                        [self.integrator_previous_value, variable],
1371
                                        fct_has_mult_args=True)
1372
            # Keep track of previous value in Pytorch node for use in next forward pass
1373
            self.integrator_previous_value = variable
1✔
1374

1375
        self.input = variable
1✔
1376

1377
        # Compute main function of mechanism and return result
1378
        self.output = execute_function(self.function, variable)
1✔
1379
        return self.output
1✔
1380

1381
    def collect_afferents(self, batch_size:int, port:Optional[Port]=None, inputs:Optional[dict]=None):
1✔
1382
        """
1383
        Return afferent projections for input_port(s) of the Mechanism
1384
        If there is only one input_port, return the sum of its afferents (for those in Composition)
1385
        If there are multiple input_ports, return a tensor (or list of tensors if input ports are ragged) of shape:
1386

1387
        (batch, input_port, projection, ...)
1388

1389
        Where the ellipsis represent 1 or more dimensions for the values of the projected afferent.
1390

1391
        FIX: AUGMENT THIS TO SUPPORT InputPort's function
1392
        """
1393
        assert self.afferents,\
1✔
1394
            f"PROGRAM ERROR: No afferents found for '{self.mechanism.name}' in AutodiffComposition"
1395

1396
        for proj_wrapper in self.afferents:
1✔
1397
            curr_val = proj_wrapper.sender_wrapper.output
1✔
1398
            if curr_val is not None:
1✔
1399
                if type(curr_val) == torch.Tensor:
1✔
1400
                    proj_wrapper._curr_sender_value = curr_val[:, proj_wrapper._value_idx, ...]
1✔
1401
                else:
1402
                    val = [batch_elem[proj_wrapper._value_idx] for batch_elem in curr_val]
1✔
1403
                    val = torch.stack(val)
1✔
1404
                    proj_wrapper._curr_sender_value = val
1✔
1405

1406
            else:
1407
                val = torch.tensor(proj_wrapper.default_value)
1✔
1408

1409
                # We need to add the batch dimension to default values.
1410
                val = val[None, ...].expand(batch_size, *val.shape)
1✔
1411

1412
                proj_wrapper._curr_sender_value = val
1✔
1413

1414
            proj_wrapper._curr_sender_value = torch.atleast_1d(proj_wrapper._curr_sender_value)
1✔
1415

1416
        # Specific port is specified
1417
        if port is not None:
1✔
1418
            res = [
1✔
1419
                proj_wrapper.execute(proj_wrapper._curr_sender_value)
1420
                for proj_wrapper in self.afferents
1421
                if proj_wrapper._pnl_proj in self.mechanism.input_ports[port].path_afferents
1422
            ]
1423
        else:
1424
            res = []
1✔
1425
            for input_port in self.mechanism.input_ports:
1✔
1426
                ip_res = []
1✔
1427
                for proj_wrapper in self.afferents:
1✔
1428
                    if proj_wrapper._pnl_proj in input_port.path_afferents:
1✔
1429
                        ip_res.append(proj_wrapper.execute(proj_wrapper._curr_sender_value))
1✔
1430

1431
                # Stack the results for this input port on the second dimension, we want to preserve
1432
                # the first dimension as the batch
1433
                ip_res = torch.stack(ip_res, dim=1)
1✔
1434
                res.append(ip_res)
1✔
1435
        try:
1✔
1436
            # Now stack the results for all input ports on the second dimension again, this keeps batch
1437
            # first again. We should now have a 4D tensor; (batch, input_port, projection, values)
1438
            res = torch.stack(res, dim=1)
1✔
1439
        except (RuntimeError, TypeError):
1✔
1440
            # is ragged, will handle ports individually during execute
1441
            # We still need to reshape things so batch size is first dimension.
1442
            batch_size = res[0].shape[0]
1✔
1443
            res = [[inp[b] for inp in res] for b in range(batch_size)]
1✔
1444

1445
        return res
1✔
1446

1447
    def execute_input_ports(self, variable):
1✔
1448
        from psyneulink.core.components.functions.nonstateful.transformfunctions import TransformFunction
1✔
1449

1450
        if not isinstance(variable, torch.Tensor):
1✔
1451
            try:
1✔
1452
                variable = torch.stack(variable)
1✔
1453
            except (RuntimeError, TypeError):
1✔
1454
                # is ragged, need to reshape things so batch size is first dimension.
1455
                pass
1✔
1456

1457
        # must iterate over at least 1d input per port
1458
        if type(variable) == torch.Tensor:
1✔
1459
            variable = torch.atleast_2d(variable)
1✔
1460

1461
        res = []
1✔
1462
        for i in range(len(self.input_ports)):
1✔
1463
            if type(variable) == torch.Tensor:
1✔
1464
                v = variable[:, i, ...] # Get the input for the port for all items in the batch
1✔
1465
            else:
1466
                v = [batch_elem[i] for batch_elem in variable]
1✔
1467

1468
                # We should be able to stack now, since the ragged structure is only on input ports
1469
                v = torch.stack(v)
1✔
1470

1471
            if isinstance(self.input_ports[i]._pnl_function, TransformFunction):
1!
1472
                # Add input port dimension back to account for input port dimension reduction, we should have shape
1473
                # (batch, input_port, ... variable dimensions ) or
1474
                # (batch, input_port, projection, ... variable dimensions ...) if execute_input_ports is invoked
1475
                # after collect_afferents.
1476
                if len(v.shape) == 2:
1✔
1477
                    v = v[:, None, ...]
1✔
1478

1479
            res.append(self.input_ports[i].function(v))
1✔
1480

1481
        try:
1✔
1482
            res = torch.stack(res, dim=1) # Stack along the input port dimension, first dimension is batch
1✔
1483
        except (RuntimeError, TypeError):
1✔
1484
            # is ragged, need to reshape things so batch size is first dimension.
1485
            batch_size = res[0].shape[0]
1✔
1486
            res = [[inp[b] for inp in res] for b in range(batch_size)]
1✔
1487

1488
        return res
1✔
1489

1490
    def set_pnl_variable_and_values(self,
1✔
1491
                                    set_variable:bool=False,
1492
                                    set_value:bool=True,
1493
                                    # FIX: 3/15/25 - ADD SUPPORT FOR THESE
1494
                                    # set_output_values:bool=None,
1495
                                    # execute_mech:bool=True,
1496
                                    context=None):
1497
        """Set the state of the PytorchMechanismWrapper's Mechanism
1498
        Note: if execute_mech=True requires that variable=True
1499
        """
1500
        if SYNCH not in self._use:
1!
NEW
1501
            return
×
1502

1503
        pnl_mech = self.mechanism
1✔
1504

1505
        if set_variable:
1✔
1506
            # First get variable in numpy format
1507
            if isinstance(self.input, list):
1!
NEW
1508
                variable = np.array([val.detach().cpu().numpy() for val in self.input], dtype=object)
×
1509
            else:
1510
                variable = self.input.detach().cpu().numpy()
1✔
1511
            # Set pnl_mech's variable
1512
            pnl_mech.parameters.variable._set(variable, context)
1✔
1513

1514
        if set_value:
1✔
1515
            # self.mechanism.parameters.value._set(value.detach().cpu().numpy().squeeze(1), context)
1516
            if self.output is None:
1✔
1517
                assert self.exclude_from_gradient_calc, \
1✔
1518
                    (f"PROGRAM ERROR: Value of PyTorch wrapper for {self.name} is None during forward pass, "
1519
                     f"but it is not excluded from gradient calculation.")
1520
                return
1✔
1521

1522
            # First get value in numpy format
1523
            if isinstance(self.output, list):
1✔
1524
                batch_size = len(self.output)
1✔
1525
                num_outputs = len(self.output[0])
1✔
1526
                value = np.empty((batch_size, num_outputs), dtype=object)
1✔
1527
                for bi in range(batch_size):
1✔
1528
                    for i in range(num_outputs):
1✔
1529
                        value[bi, i] = self.output[bi][i].detach().cpu().numpy()
1✔
1530

1531
            else:
1532
                value = self.output.detach().cpu().numpy()
1✔
1533

1534
            # Set pnl_mech's value
1535
            pnl_mech.parameters.value._set(value, context)
1✔
1536

1537
            # If pnl_mech's function is Stateful, assign value to its previous_value parameter
1538
            #   so that if Python implementation is run it picks up where PyTorch execution left off
1539
            if isinstance(pnl_mech.function, StatefulFunction):
1!
NEW
1540
                pnl_mech.function.parameters.previous_value._set(value, context)
×
1541
            # Do same for integrator_function of TransferMechanism if it is in integrator_mode
1542
            if isinstance(pnl_mech, TransferMechanism) and pnl_mech.integrator_mode:
1✔
1543
                pnl_mech.integrator_function.parameters.previous_value._set(self.integrator_previous_value,
1✔
1544
                                                                            context)
1545

1546
        # FIX: 3/15/25 - ADD SUPPORT FOR THESE
1547
        # if output_values:
1548
        #     for value, port in zip(output_values, self.mechanism.output_ports):
1549
        #         port.parameters.value._set(value.detach().cpu().numpy().squeeze(), context)
1550
        # if execute:
1551
        #     if variable:
1552
        #         self.execute(variable)
1553
        else:
1554
            assert False, "PROGRAM ERROR: set_state called but neither set_variable nor set_value is specified"
1555

1556
    def _gen_llvm_execute(self, ctx, builder, state, params, mech_input, data):
1✔
1557
        mech_func = ctx.import_llvm_function(self.mechanism)
1✔
1558

1559
        mech_param = builder.gep(params, [ctx.int32_ty(0),
1✔
1560
                                          ctx.int32_ty(0),
1561
                                          ctx.int32_ty(self._idx)])
1562

1563
        mech_state = builder.gep(state, [ctx.int32_ty(0),
1✔
1564
                                         ctx.int32_ty(0),
1565
                                         ctx.int32_ty(self._idx)])
1566

1567
        mech_output = builder.gep(data, [ctx.int32_ty(0),
1✔
1568
                                         ctx.int32_ty(0),
1569
                                         ctx.int32_ty(self._idx)])
1570

1571
        builder.call(mech_func, [mech_param,
1✔
1572
                                 mech_state,
1573
                                 mech_input,
1574
                                 mech_output])
1575

1576
        pnlvm.helpers.printf_float_array(ctx,
1✔
1577
                                         builder,
1578
                                         builder.gep(mech_output, [ctx.int32_ty(0), ctx.int32_ty(0)]),
1579
                                         prefix=f"{self} output:\n",
1580
                                         tags={"torch"})
1581

1582
        return mech_output
1✔
1583

1584
    def log_value(self):
1✔
1585
        if self.mechanism.parameters.value.log_condition != LogCondition.OFF:
1✔
1586
            detached_value = self.output.detach().cpu().numpy()
1✔
1587
            self.mechanism.output_port.parameters.value._set(detached_value, self._context)
1✔
1588
            self.mechanism.parameters.value._set(detached_value, self._context)
1✔
1589

1590
    def _gen_llvm_execute_derivative_func(self, ctx, builder, state, params, arg_in):
1✔
1591
        # psyneulink functions expect a 2d input, where index 0 is the vector
1592
        fun = ctx.import_llvm_function(self.mechanism.function, tags=frozenset({"derivative"}))
1✔
1593
        fun_input_ty = fun.args[2].type.pointee
1✔
1594

1595
        mech_input = builder.alloca(fun_input_ty)
1✔
1596
        mech_input_ptr = builder.gep(mech_input, [ctx.int32_ty(0),
1✔
1597
                                                  ctx.int32_ty(0)])
1598
        builder.store(builder.load(arg_in), mech_input_ptr)
1✔
1599

1600
        mech_params = builder.gep(params, [ctx.int32_ty(0),
1✔
1601
                                           ctx.int32_ty(0),
1602
                                           ctx.int32_ty(self._idx)])
1603

1604
        mech_state = builder.gep(state, [ctx.int32_ty(0),
1✔
1605
                                         ctx.int32_ty(0),
1606
                                         ctx.int32_ty(self._idx)])
1607

1608
        f_params, f_state = ctx.get_param_or_state_ptr(builder,
1✔
1609
                                                       self.mechanism,
1610
                                                       "function",
1611
                                                       param_struct_ptr=mech_params,
1612
                                                       state_struct_ptr=mech_state)
1613

1614
        f_params, builder = self.mechanism._gen_llvm_param_ports_for_obj(
1✔
1615
                self.mechanism.function, f_params, ctx, builder, mech_params, mech_state, mech_input)
1616

1617
        output, _ = self.mechanism._gen_llvm_invoke_function(ctx, builder, self.mechanism.function,
1✔
1618
                                                              f_params, f_state, mech_input, None,
1619
                                                              tags=frozenset({"derivative"}))
1620
        return builder.gep(output, [ctx.int32_ty(0),
1✔
1621
                                    ctx.int32_ty(0)])
1622

1623
    def __repr__(self):
1624
        return "PytorchWrapper for: " +self.mechanism.__repr__()
1625

1626

1627
class PytorchProjectionWrapper():
1✔
1628
    """Wrapper for Projection in a PytorchCompositionWrapper
1629

1630
    The matrix of the wrapped `projection <PytorchProjectionWrapper.projection>` is assigned as a parameter of
1631
    (set of connection weights in ) the PyTorch Module that, coupled with a corresponding input and `torch.matmul
1632
    <https://pytorch.org/docs/main/generated/torch.matmul.html>`_ operation, provide the input to the Pytorch
1633
    function associated with the `Node <Composition_Node>` of the AutdiffComposition that is the `receiver
1634
    <Projection_Base.receiver>` of the wrapped Projection.
1635

1636
    .. note::
1637
       In the case of a nested Composition, the sender and/or receiver attributes may be mapped to different Node(s)
1638
       than the Mechanism(s) of the Projection's actual sender and/or receiver. This is because the sender and/or
1639
       receiver of the Projection may be a nested Composition, in which case the actual sender and/or receiver of the
1640
       Projection will be a `CompositionInterfaceMechanism` (CIM) for the nested Composition.  In that case, the sender
1641
       and/or receiver of the PytorchProjectionWrapper will be assigned to the PytorchMechanismWrapper for the Node in
1642
       the outer Composition that Projects to/from the CIM, and that is the source/destination of the Projection
1643
       actually being learned, and that projection will be referenced in the `PytorchCompositionWrapper.projections_map`
1644
       (see `PytorchCompositionWrapper` for descriptive figure and additional details);  the actual projection is stored
1645
       in pnl_proj.
1646

1647
    Attributes
1648
    ----------
1649

1650
    projection : Projection
1651
        PsyNeuLink `Projection` being wrapped.
1652

1653
    composition : AutodiffComposition
1654
        the `AutodiffComposition` to which the `Projection` being wrapped belongs
1655
        (and for which the PytorchCompositionWrapper -- to which the PytorchProjectionWrapper
1656
        belongs -- is the `pytorch_representation <AutodiffComposition.pytorch_representation>`).
1657

1658
    matrix : torch.nn.Parameter
1659
        Pytorch parameter for the matrix of the Projection.
1660

1661
    sender : PytorchMechanismWrapper
1662
        the PytorchMechanismWrapper node from which the PytorchProjectionWrapper receives its variable.
1663

1664
    receiver : PytorchMechanismWrapper
1665
        the PytorchMechanismWrapper node from which the PytorchProjectionWrapper sends it value.
1666

1667
    function : _gen_pytorch_fct
1668
        Pytorch version of the Projection's function assigned in its __init__.
1669

1670
    .. technical_note::
1671
        _use : list[LEARNING, SYNCH, SHOW_PYTORCH]
1672
            designates the uses of the Projection, specified by the following keywords see PytorchCompositionWrapper
1673
            `docstring <Mechanism_and_Projection_Uses>` for additional details):
1674

1675
            * *LEARNING*: inputs and `function <MappingProjection.function>` Parameters) are used for actual execution
1676
              of the corresponding Pytorch Module;
1677

1678
            * *SYNCH*: store connection weights, for synching them between the `matrix
1679
              <MappingProjection.matrix>` Parameter of its PsyNeuLink `projection <PytorchProjectionWrapper.projection>`
1680
              and the corresponding parameters of a Pytorch module being used for learning;
1681

1682
            * *SHOW_PYTORCH*:  `projection <PytorchProjectionWrapper.projection>` is included when the
1683
              `AutoDiffComposition`\\s `show_graph <AutoDiffComposition.show_graph>` method to used with
1684
              the ``show_pytorch`` option to display its `pytorch_representation
1685
              <AutodiffComposition.pytorch_representation>`; if it is not specified, the `Projection
1686
              <PytorchProjectionWrapper.projection>` is not displayed when the `AutoDiffComposition`\\s
1687
              `show_graph <AutoDiffComposition.show_graph>` method is called, even if the ``show_pytorch``
1688
              option is specified.
1689
    """
1690

1691
    def __init__(self,
1✔
1692
                 projection:Projection,                      # Projection to be wrapped
1693
                 pnl_proj:Projection,                        # one that directly projects to/from sender/receiver
1694
                 component_idx:Optional[int],                   # index of the Projection in the Composition
1695
                 sender_port_idx:Optional[int],                 # index in the sender's Mechanism.output_ports
1696
                 use:Union[list, Literal[LEARNING, SYNCH, SHOW_PYTORCH]],
1697
                 device:str,
1698
                 sender_wrapper:PytorchMechanismWrapper=None,
1699
                 receiver_wrapper:PytorchMechanismWrapper=None,
1700
                 composition:Composition=None,
1701
                 context=None):
1702

1703
        self.projection = projection  # Projection being wrapped (may *not* be the one being learned; see note above)
1✔
1704
        self._pnl_proj = pnl_proj     # Projection to/from CIM that actually projects to/from sender/receiver
1✔
1705
        self._use = convert_to_list(use) or [LEARNING, SYNCH, SHOW_PYTORCH]  # learn, synch, and/or display connection
1✔
1706
        # weights
1707
        self._idx = component_idx     # Index of Projection in Composition's list of projections
1✔
1708
        self._sender_port_idx = sender_port_idx  # Index of sender output_ports for which Projection is an efferent
1✔
1709
        self._value_idx = 0           # Index of value in sender's value (used in collect_afferents)
1✔
1710
        self._curr_sender_value = None
1✔
1711

1712
        self.name = f"PytorchProjectionWrapper[{projection.name}]"
1✔
1713
        self.composition = composition            # Composition to which CompositionWrapper belongs
1✔
1714
        self.sender_wrapper = sender_wrapper      # PytorchMechanismWrapper to which Projection's sender is mapped
1✔
1715
        self.receiver_wrapper = receiver_wrapper  # PytorchMechanismWrapper to which Projection's receiver is mapped
1✔
1716
        self._context = context
1✔
1717

1718
        if (
1!
1719
            projection.parameters.has_initializers._get(context)
1720
            and projection.parameters.value.initializer
1721
        ):
UNCOV
1722
            self.default_value = projection.parameters.value.initializer.get(context)
×
1723
        else:
1724
            self.default_value = projection.defaults.value
1✔
1725

1726
        # Get item of value corresponding to OutputPort that is Projection's sender
1727
        # Note: this may not be the same as _sender_port_idx if the sender Mechanism has OutputPorts for Projections
1728
        #       that are not in the current Composition
1729
        if context.composition and LEARNING in self._use:
1✔
1730
            for i, output_port in enumerate(self.sender_wrapper.mechanism.output_ports):
1✔
1731
                if all(p in context.composition.projections for p in output_port.efferents):
1✔
1732
                    if self._pnl_proj in output_port.efferents:
1✔
1733
                        self._value_idx = i
1✔
1734
                        break
1✔
1735
                    i += 1
1✔
1736

1737
        matrix = projection.parameters.matrix.get(context=context)
1✔
1738
        if matrix is None:
1!
UNCOV
1739
            matrix = projection.parameters.matrix.get(context=None)
×
1740
        # Create a Pytorch Parameter for the matrix
1741
        self.matrix = torch.nn.Parameter(torch.tensor(matrix.copy(),
1✔
1742
                                         device=device,
1743
                                         dtype=torch.double))
1744
        # Use Projection's name as key to align with name of torch Parameter
1745
        self._pnl_refs_to_torch_params_map = {pnl_proj.name: self.matrix}
1✔
1746
        # 2/16/25 - FIX: RECONCILE THIS WITH ANY SPECS FOR PROJECTION IN optimizer_params
1747
        #           cf _parse_optimizer_params():
1748
        if projection.learnable is False:
1✔
1749
            self.matrix.requires_grad = False
1✔
1750

1751
        self.function = projection.function._gen_pytorch_fct(device, context)
1✔
1752

1753
    def execute(self, variable):
1✔
1754
        # return torch.matmul(variable, self.matrix)
1755
        return self.function(variable, self.matrix)
1✔
1756

1757
    def _copy_torch_params_to_pnl_proj(self, context):
1✔
1758
        composition = self.composition
1✔
1759
        composition.copy_torch_param_to_projection_matrix(torch_param=self.matrix.detach().cpu().T,
1✔
1760
                                                          projection=self.projection,
1761
                                                          validate=False,
1762
                                                          context=context)
1763

1764
    def log_matrix(self):
1✔
1765
        if self.projection.parameters.matrix.log_condition != LogCondition.OFF:
1✔
1766
            detached_matrix = self.matrix.detach().cpu().numpy()
1✔
1767
            self.projection.parameters.matrix._set(detached_matrix, context=self._context)
1✔
1768
            self.projection.parameter_ports['matrix'].parameters.value._set(detached_matrix, context=self._context)
1✔
1769

1770
    def _extract_llvm_matrix(self, ctx, builder, state, params):
1✔
1771
        proj_params = builder.gep(params, [ctx.int32_ty(0), ctx.int32_ty(1), ctx.int32_ty(self._idx)])
1✔
1772
        proj_state = builder.gep(state, [ctx.int32_ty(0), ctx.int32_ty(1), ctx.int32_ty(self._idx)])
1✔
1773

1774
        dim_x, dim_y = self.matrix.detach().numpy().shape
1✔
1775

1776
        func_p, func_s = ctx.get_param_or_state_ptr(builder,
1✔
1777
                                                    self.projection,
1778
                                                    self.projection.parameters.function,
1779
                                                    param_struct_ptr=proj_params,
1780
                                                    state_struct_ptr=proj_state)
1781

1782
        proj_matrix = ctx.get_param_or_state_ptr(builder,
1✔
1783
                                                 self.projection.function,
1784
                                                 self.projection.function.parameters.matrix,
1785
                                                 param_struct_ptr=func_p,
1786
                                                 state_struct_ptr=func_s)
1787

1788
        proj_matrix = builder.bitcast(proj_matrix, pnlvm.ir.types.ArrayType(
1✔
1789
            pnlvm.ir.types.ArrayType(ctx.float_ty, dim_y), dim_x).as_pointer())
1790

1791
        return proj_matrix
1✔
1792

1793
    def _gen_llvm_execute(self, ctx, builder, state, params, data):
1✔
1794
        proj_matrix = self._extract_llvm_matrix(ctx, builder, state, params)
1✔
1795

1796
        input_vec = builder.gep(data, [ctx.int32_ty(0),
1✔
1797
                                       ctx.int32_ty(0),
1798
                                       ctx.int32_ty(self.sender_wrapper._idx),
1799
                                       ctx.int32_ty(self._sender_port_idx)])
1800

1801
        output_vec = gen_inject_vxm(ctx, builder, input_vec, proj_matrix)
1✔
1802

1803
        pnlvm.helpers.printf_float_array(ctx,
1✔
1804
                                         builder,
1805
                                         input_vec,
1806
                                         prefix=f"{self.sender_wrapper.mechanism} "
1807
                                                f"-> {self.receiver_wrapper.mechanism} input:\n",
1808
                                         tags={"torch"})
1809
        pnlvm.helpers.printf_float_matrix(ctx,
1✔
1810
                                          builder,
1811
                                          proj_matrix,
1812
                                          prefix=f"{self.sender_wrapper.mechanism} "
1813
                                                 f"-> {self.receiver_wrapper.mechanism} mat:\n",
1814
                                          tags={"torch"})
1815
        pnlvm.helpers.printf_float_array(ctx,
1✔
1816
                                         builder,
1817
                                         output_vec,
1818
                                         prefix=f"{self.sender_wrapper.mechanism} "
1819
                                                f"-> {self.receiver_wrapper.mechanism} output:\n",
1820
                                         tags={"torch"})
1821

1822
        return output_vec
1✔
1823

1824
    def __repr__(self):
1825
        return "PytorchWrapper for: " +self.projection.__repr__()
1826

1827

1828
class PytorchFunctionWrapper(torch.nn.Module):
1✔
1829
    def __init__(self, function, device, context=None):
1✔
1830
        super().__init__()
1✔
1831
        self.name = f"PytorchFunctionWrapper[{function.name}]"
1✔
1832
        self._context = context
1✔
1833
        self._pnl_function = function
1✔
1834
        self.function = _get_pytorch_function(function, device, context)
1✔
1835

1836
    def __repr__(self):
1837
        return "PytorchWrapper for: " + self._pnl_function.__repr__()
1838

1839
    def __call__(self, *args, **kwargs):
1✔
1840
        return self.function(*args, **kwargs)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc