• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.96
/psyneulink/library/compositions/autodiffcomposition.py
1
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
2
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
3
#     http://www.apache.org/licenses/LICENSE-2.0
4
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
5
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6
# See the License for the specific language governing permissions and limitations under the License.
7

8

9
# ********************************************* AutodiffComposition *************************************************
10

11
"""
12

13
Contents
14
--------
15

16
  * `AutodiffComposition_Overview`
17
  * `AutodiffComposition_Creation`
18
      - `AutodiffComposition`
19
          - `AutodiffComposition_Modulatory_Mechanisms`
20
          - `AutodiffComposition_Bias_Parameters`
21
          - `AutodiffComposition_Nesting`
22
          - `AutodiffComposition_Learning_Rates`
23
          - `AutodiffComposition_Exchange_With_Torch_Parameters`
24
          - `AutodiffComposition_Post_Construction_Modification`
25
      * `AutodiffComposition_Execution`
26
          - `AutodiffComposition_PyTorch`
27
          - `AutodiffComposition_LLVM`
28
          - `AutodiffComposition_Python`
29
          - `AutodiffComposition_Nested_Modulation`
30
          - `AutodiffComposition_Logging`
31
  * `AutodiffComposition_Examples`
32
  * `AutodiffComposition_Class_Reference`
33

34

35
.. _AutodiffComposition_Overview:
36

37
Overview
38
--------
39

40
AutodiffComposition is a subclass of `Composition` for constructing and training feedforward neural network
41
either, using either direct compilation (to LLVM) or automatic conversion to `PyTorch <https://pytorch.org/>`_,
42
both of which considerably accelerate training (by as much as three orders of magnitude) compared to the
43
`standard implementation of learning  <Composition_Learning_Standard>` in a Composition.  Although an
44
AutodiffComposition is constructed and executed in much the same way as a standard Composition, it largely restricted
45
to feedforward neural networks using `supervised learning <Composition_Learning_Supervised>`, and in particular the
46
the `backpropagation learning algorithm <https://en.wikipedia.org/wiki/Backpropagation>`_. although it can be used for
47
some forms of `unsupervised learning <Composition_Learning_Unsupervised>` that are supported in PyTorch (e.g.,
48
`self-organized maps <https://github.com/giannisnik/som>`_).
49

50

51
.. _AutodiffComposition_Creation:
52

53
Creating an AutodiffComposition
54
-------------------------------
55

56
An AutodiffComposition can be created by calling its constructor, and then adding `Components <Component>` using
57
the standard `Composition methods <Composition_Creation>` for doing so (e.g., `add_node <Composition.add_node>`,
58
`add_projection <Composition.add_projections>`,  `add_linear_processing_pathway
59
<Composition.add_linear_processing_pathway>`, etc.). The constructor also includes a number of parameters that are
60
specific to the AutodiffComposition (see `AutodiffComposition_Class_Reference` for a list of these parameters,
61
and `examples <AutodiffComposition_Examples>` below). While an AutodiffComposition can generally be created using the
62
same methods as a standard Composition, there are a few restrictions that apply to its construction, summarized below.
63

64

65
.. _AutodiffComposition_Restrictions:
66

67
*Only one OutputPort per Node*
68
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
69

70
The `Nodes <Composition_Nodes>` of an AutodiffComposition currently can have only *one* `OutputPort`, though that
71
can have more than one `efferent <Port_Base.efferents>` `MappingProjection`.  Nodes can also have more than one
72
`InputPort`, that can receive more than one `afferent `path_afferent <Port_Base.path_afferents>` Projections.
73

74
.. _AutodiffComposition_Modulatory_Mechanisms:
75

76
*No Modulatory Components*
77
~~~~~~~~~~~~~~~~~~~~~~~~~~
78

79
All of the Components in an AutodiffComposition must be able to be subjected to `learning <Composition_Learning>`,
80
which means that no `ModulatoryMechanisms <ModulatoryMechanism>` can be included in an AutodiffComposition.
81
Specifically, this precludes any `learning components <Composition_Learning_Components>`, `ControlMechanisms
82
<ControlMechanism>`, or a `controller <Composition_Controller>`.
83

84
.. _Autodiff_Learning_Components_Warning:
85

86
*Learning Components*.  An AutodiffComposition **cannot include any** `learning components
87
<Composition_Learning_Components>` themselves (i.e., `LearningMechanisms <LearningMechanism>`, `LearningSignals
88
<LearningSignal>`, or `LearningProjections <LearningProjection>`, nor the `ComparatorMechanism`
89
or `ObjectiveMechanism` used to compute the loss for learning). These are constructed
90
automatically when learning is executed in `Python mode <AutodiffComposition_Python>` or `LLVM mode
91
<AutodiffComposition_LLVM>`, and PyTorch-compatible Components are constructed when it is executed in
92
`PyTorch mode <AutodiffComposition_PyTorch>`.
93

94
*Control Components*. An AutodiffComposition also cannot include any `ControlMechanisms <ControlMechanism>` or a
95
`controller <Composition_Controller>`.  However, it *can* include Mechanisms that are subject to modulatory control
96
(see `Figure <ModulatorySignal_Anatomy_Figure>`, and `modulation <ModulatorySignal_Modulation>`) by ControlMechanisms
97
*outside* the Composition, including the controller of a Composition within which the AutodiffComposition is nested.
98
That is, an AutodiffComposition can be `nested in a Composition <Composition_Nested>` that has other such Components
99
(see `AutodiffComposition_Nested_Modulation` below).
100

101
.. _AutodiffComposition_Bias_Parameters:
102

103
*No Bias Parameters*
104
~~~~~~~~~~~~~~~~~~~~
105

106
AutodiffComposition does not (currently) support the *automatic* construction of separate bias parameters.
107
Thus, when constructing the PyTorch version of an AutodiffComposition, the `bias
108
<https://www.pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ parameter of any PyTorch modules are set to False.
109
However, biases can be implemented using `Composition_Bias_Nodes`.
110

111
.. _AutodiffComposition_Nesting:
112

113
*Nesting*
114
~~~~~~~~~
115

116
An AutodiffComposition can be `nested <Composition_Nested>` inside another Composition for learning, and there can
117
be any level of such nestings.  However, all of the nested Compositions must be AutodiffCompositions. Furthermore, all
118
nested Compositions use the `learning_rate <AutodiffComposition.learning_rate>` specified for the outermost Composition,
119
whether this is specified in the call to its `learn <AutodiffComposition.learn>` method, its constructor, or its
120
default value is being used (see `learning_rate <AutodiffComposition.learning_rate>` below for additional details).
121

122
.. technical_note::
123
   Projections from `Nodes <Composition_Nodes>` in an immediately enclosing outer Composition to the `input_CIM
124
   <Composition.input_CIM>` of a nested Composition, and from its `output_CIM <Composition.output_CIM>` to Nodes
125
   in the outer Composition are subject to learning;  however those within the nested Composition itself (i.e.,
126
   from its input_CIM to its INPUT Nodes and from its OUTPUT Nodes to its output_CIM) are *not* subject to learning,
127
   as they serve simply as conduits of information between the outer Composition and the nested one.
128

129
.. warning::
130
   Nested Compositions are supported for learning only in `PyTorch mode <AutodiffComposition_PyTorch>`, and will
131
   cause an error if the `learn <AutodiffComposition.learn>` method of an AutodiffComposition is executed in
132
   `Python mode <AutodiffComposition_Python>` or `LLVM mode <AutodiffComposition_LLVM>`.
133

134
.. _AutodiffComposition_Learning_Rates:
135

136
*Learning Rates and Optimizer Params*
137
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
138

139
The **optimizer_params** argument of the constructor can be used to specify parameters for the optimizer used for
140
learning by the AutodiffComposition. At present, this is restricted to overriding the `learning_rate
141
<AutodiffComposition.learning_rate>` Parameter of the Composition (used as the default by the `optimizer
142
<AutodiffComposition.optimizer>`) to assign individual learning rates to specific Projections. This is done by
143
specifying **optimizer_params** as a dict, in which each key is a reference to a learnable `MappingProjection`
144
in the AutodiffComposition, and the value of which specifies its learning_rate. Sublcasses of AutodiffComposition may
145
involve different forms of specification and/or support other parameters for the optimizer. Projections that are not
146
sepcified in **optimizer_params** use, in order of precedence: the `learning_rate <AutodiffComposition.learning_rate>`
147
specified in the call to the AutodiffComposition's `learn <AutodiffComposition.learn>` method, the **learning_rate**
148
argument of its constructor, or the default value for the AutodiffComposition.
149

150
.. _AutodiffComposition_Exchange_With_Torch_Parameters:
151

152
*Exchanging Parameters with Pytorch Modules*
153
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
154

155
The AutodiffComposition's `copy_torch_param_to_projection_matrix` and `copy_projection_matrix_to_torch_param` methods
156
can be used to exchange weight matrices between the parameters of a PyTorch module and the `matrix
157
<MappingProjection.matrix>` Parameter of a `MappingProjection` in the AutodiffComposition. Pytorch Parameters can
158
be referenced flexibly, either by the Parameter object itself, or by the module and either the name or index of the
159
Parameter in the module's state_dict or parameter list, respectively. Slices of PyTorch Parameters can also be used,
160
for cases in which the matrix of a Project corresponds to only a subpart of the PyTorch Parameter (e.g., for
161
`GRUComposition`). Both methods return the item assigned.
162

163
.. _AutodiffComposition_Post_Construction_Modification:
164

165
*No Post-construction Modification*
166
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
167

168
COMMENT:
169
IS THIS STILL TRUE?  TEST?
170
COMMENT
171
Mechanisms or Projections should not be added to or deleted from an AutodiffComposition after it has
172
been executed. Unlike an ordinary Composition, AutodiffComposition does not support this functionality.
173

174

175
.. _AutodiffComposition_Execution:
176

177
Execution
178
---------
179

180
An AutodiffComposition's `run <Composition.run>`, `execute <Composition.execute>`, and `learn <Composition.learn>`
181
methods are the same as for a `Composition`.  However, the **execution_mode** in the `learn <Composition.learn>`
182
method has different effects than for a standard Composition, that determine whether it uses `LLVM compilation
183
<AutodiffComposition_LLVM>` or `translation to PyTorch <AutodiffComposition_PyTorch>` to execute learning.
184
These are each described in greater detail below, and summarized in this `table <Composition_Compilation_Table>`
185
which provides a comparison of the different modes of execution for an AutodiffComposition and standard `Composition`.
186

187
.. _AutodiffComposition_PyTorch:
188

189
*PyTorch mode*
190
~~~~~~~~~~~~~~
191

192
COMMENT:
193
# 7/10/24 - FIX:
194
.. _AutodiffComposition_PyTorch_LearningScale:
195
   ADD DESCRIPTION OF HOW LearningScale SPECIFICATIONS MAP TO EXECUTION OF pytorch_rep:
196
      OPTIMIZATION STEP:
197
      for AutodiffCompositions, this corresponds to a single call to `foward()` and `backward()`
198
            methods of the Pytorch model
199
COMMENT
200

201
This is the default for an AutodiffComposition, but, can be specified explicitly by setting **execution_mode** =
202
`ExecutionMode.PyTorch` in the `learn <Composition.learn>` method (see `example <BasicsAndPrimer_Rumelhart_Model>`
203
in `BasicsAndPrimer`).  In this mode, the AutodiffComposition is automatically translated to a `PyTorch
204
<https://pytorch.org>`_ model for learning.  This is comparable in speed to `LLVM compilation
205
<_AutodiffComposition_LLVM>`, but provides greater flexiblity, including the ability to include nested
206
AutoDiffCompositions in learning. Although it is best suited for use with `supervised learning
207
<Composition_Learning_Supervised>`, it can also be used for some forms of `unsupervised learning
208
<Composition_Learning_Unsupervised>` that are supported in PyTorch (e.g., `self-organized maps
209
<https://github.com/giannisnik/som>`_).
210

211
    .. _AutodiffComposition_PyTorch_Note:
212

213
    .. note::
214
       While specifying `ExecutionMode.PyTorch` in the `learn <Composition.learn>`  method of an AutodiffComposition
215
       causes it to use PyTorch for training, specifying this in the `run <Compositon.run>` method causes it to be
216
       executed using the *Python* interpreter (and not PyTorch);  this is so that any modulation can take effect
217
       during execution (see `AutodiffComposition_Nested_Modulation` below), which is not supported by PyTorch.
218

219
    .. warning::
220
      * Specifying `ExecutionMode.LLVMRun` or `ExecutionMode.PyTorch` in the learn() method of a standard
221
        `Composition` causes an error.
222

223
COMMENT:
224
FIX: ADD MENTION OF TARGET NODES AND PYTORCH WRAPPERS
225
COMMENT
226

227
.. _AutodiffComposition_LLVM:
228

229
*LLVM mode*
230
~~~~~~~~~~~
231

232
This is specified by setting **execution_mode** = `ExecutionMode.LLVMRun` in the `learn <Composition.learn>` method
233
of an AutodiffCompositon.  This provides the fastest performance, but is limited to `supervised learning
234
<Composition_Learning_Supervised>` using the `BackPropagation` algorithm. This can be run using standard forms of
235
loss, including mean squared error (MSE) and cross entropy, by specifying this in the **loss_spec** argument of
236
the constructor (see `AutodiffComposition <AutodiffComposition_Class_Reference>` for additional details, and
237
`Compilation Modes <Composition_Compiled_Modes>` for more information about executing a Composition in compiled mode.
238

239
    .. note::
240
       Specifying `ExecutionMode.LLVMRun` in either the `learn <Composition.learn>` and `run <Composition.run>`
241
       methods of an AutodiffComposition causes it to (attempt to) use compiled execution in both cases; this is
242
       because LLVM compilation supports the use of modulation in PsyNeuLink models (as compared to `PyTorch mode
243
       <AutodiffComposition_PyTorch>`; see `note <AutodiffComposition_PyTorch_Note>` below).
244

245

246
COMMENT:
247
FIX: 8/13/23 - COMPLETE DOCS HERE
248
COMMENT
249

250
.. _AutodiffComposition_Python:
251

252
*Python mode*
253
~~~~~~~~~~~~~
254
An AutodiffComposition can also be run using the standard PsyNeuLink learning components.  However, this cannot
255
be used if the AutodiffComposition has any nested Compositions, irrespective of whether they are ordinary
256
Compositions or AutodiffCompositions.
257

258

259
.. _AutodiffComposition_Nested_Modulation:
260

261
*Nested Execution and Modulation*
262
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
263

264
# FIX:
265
Like any other `Composition`, an AutodiffComposition may be `nested <Composition_Nested>` inside another
266
(see `example <AutodiffComposition_Nested_Example>` below).  However, during learning, none of the internal
267
Components of the AutodiffComposition (e.g., intermediate layers of a neural network model) are accessible to the
268
other Components of the outer Composition, (e.g., as sources of information, or for `modulation
269
<ModulatorySignal_Modulation>`).  However, when
270
COMMENT:
271
learning turned off,
272
COMMENT
273
it is executed using its `run <Composition.run>` method, then the  AutodiffComposition functions like any other,
274
and all of its internal Components are accessible to other Components of the outer Composition. Thus, as long as access
275
to its internal Components is not needed during learning, an `AutodiffComposition` can be trained, and then used to
276
execute the trained Composition like any other.
277

278

279
.. _AutodiffComposition_Logging:
280

281
*Logging*
282
~~~~~~~~~
283

284
Logging in AutodiffCompositions follows the same procedure as `logging in a Composition <Log>`.
285
However, since an AutodiffComposition internally converts all of its Mechanisms either to LLVM
286
or to an equivalent PyTorch model, then its inner components are not actually executed. This means that there is
287
limited support for logging parameters of components inside an AutodiffComposition; Currently, the only supported
288
parameters are:
289

290
1) the `matrix` parameter of Projections
291

292
2) the `value` parameter of its inner components
293

294

295
.. _AutodiffComposition_Examples:
296

297
Examples
298
--------
299

300
.. _AutodiffComposition_Creation_Example:
301

302
The following is an example showing how to create a simple AutodiffComposition, specify its inputs and targets,
303
and run it with learning enabled and disabled:
304

305
    >>> import psyneulink as pnl
306
    >>> # Set up PsyNeuLink Components
307
    >>> my_mech_1 = pnl.TransferMechanism(function=pnl.Linear, input_shapes = 3)
308
    >>> my_mech_2 = pnl.TransferMechanism(function=pnl.Linear, input_shapes = 2)
309
    >>> my_projection = pnl.MappingProjection(matrix=np.random.randn(3,2),
310
    ...                     sender=my_mech_1,
311
    ...                     receiver=my_mech_2)
312
    >>> # Create AutodiffComposition
313
    >>> my_autodiff = pnl.AutodiffComposition()
314
    >>> my_autodiff.add_node(my_mech_1)
315
    >>> my_autodiff.add_node(my_mech_2)
316
    >>> my_autodiff.add_projection(sender=my_mech_1, projection=my_projection, receiver=my_mech_2)
317
    >>> # Specify inputs and targets
318
    >>> my_inputs = {my_mech_1: [[1, 2, 3]]}
319
    >>> my_targets = {my_mech_2: [[4, 5]]}
320
    >>> input_dict = {"inputs": my_inputs, "targets": my_targets, "epochs": 2}
321
    >>> # Run Composition in learnng mode
322
    >>> my_autodiff.learn(inputs = input_dict)
323
    >>> # Run Composition in test mode
324
    >>> my_autodiff.run(inputs = input_dict['inputs'])
325

326

327
.. _AutodiffComposition_Nested_Example:
328

329
The following shows how the AutodiffComposition created in the previous example can be nested and run inside another
330
Composition::
331

332
    >>> # Create outer composition
333
    >>> my_outer_composition = pnl.Composition()
334
    >>> my_outer_composition.add_node(my_autodiff)
335
    >>> # Specify dict containing inputs and targets for nested Composition
336
    >>> training_input = {my_autodiff: input_dict}
337
    >>> # Run in learning mode
338
    >>> result1 = my_outer_composition.learn(inputs=training_input)
339
    COMMENT:
340
    >>> # Run with learning disabled (and standard input format)
341
    >>> no_training_input = {my_autodiff: my_inputs}
342
    >>> result2 = parentmy_outer_compositionComposition.run(inputs=no_training_input)
343
    COMMENT
344

345
.. _AutodiffComposition_Class_Reference:
346

347
Class Reference
348
---------------
349

350
"""
351
import logging
1✔
352
import os
1✔
353
import warnings
1✔
354
import numpy as np
1✔
355
from packaging import version
1✔
356
from pathlib import Path, PosixPath
1✔
357
from collections import deque
1✔
358
from typing import Union
1✔
359

360
try:
1✔
361
    import torch
1✔
362
    from torch import nn
1✔
363
    import torch.optim as optim
1✔
364
    torch_available = True
1✔
365
except ImportError:
×
366
    torch_available = False
×
367
else:
368
    from psyneulink.library.compositions.pytorchwrappers import PytorchCompositionWrapper
1✔
369
    from psyneulink.library.compositions.pytorchshowgraph import PytorchShowGraph
1✔
370

371
from psyneulink._typing import Iterable, Mapping, Optional
1✔
372
from psyneulink.core.components.component import Component
1✔
373
from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism
1✔
374
from psyneulink.core.components.mechanisms.processing.compositioninterfacemechanism import CompositionInterfaceMechanism
1✔
375
from psyneulink.core.components.mechanisms.modulatory.modulatorymechanism import ModulatoryMechanism_Base
1✔
376
from psyneulink.core.components.projections.pathway.mappingprojection import MappingProjection
1✔
377
from psyneulink.core.components.projections.modulatory.modulatoryprojection import ModulatoryProjection_Base
1✔
378
from psyneulink.core.components.ports.inputport import InputPort
1✔
379
from psyneulink.core.compositions.composition import Composition, NodeRole, CompositionError
1✔
380
from psyneulink.core.compositions.report import (ReportOutput, ReportParams, ReportProgress, ReportSimulations,
1✔
381
                                                 ReportDevices, EXECUTE_REPORT, LEARN_REPORT, PROGRESS_REPORT)
382
from psyneulink.core.globals.context import Context, ContextFlags, handle_external_context
1✔
383
from psyneulink.core.globals.keywords import (
1✔
384
    AUTODIFF_COMPOSITION, EXECUTION_MODE,
385
    LEARNING_SCALE_LITERALS, LEARNING_SCALE_NAMES, LEARNING_SCALE_VALUES,
386
    Loss, LOSSES, MATRIX_WEIGHTS, MINIBATCH, NODE_VALUES, NODE_VARIABLES,
387
    OPTIMIZATION_STEP, RESULTS, RUN, SOFT_CLAMP, SYNCH_WITH_PNL_OPTIONS,
388
    RETAIN_IN_PNL_OPTIONS, TARGETS, TRAINED_OUTPUTS, TRIAL, DEFAULT,
389
)
390
from psyneulink.core.globals.utilities import is_matrix_keyword, is_numeric_scalar, convert_to_np_array
1✔
391
from psyneulink.core.scheduling.scheduler import Scheduler
1✔
392
from psyneulink.core.globals.parameters import Parameter, check_user_specified
1✔
393
from psyneulink.core.scheduling.time import TimeScale
1✔
394
from psyneulink.core import llvm as pnlvm
1✔
395

396

397
logger = logging.getLogger(__name__)
1✔
398

399

400
__all__ = [
1✔
401
    'AutodiffComposition'
402
]
403

404
def _get_torch_trained_outputs(owning_component=None, context=None):
1✔
405
    if not context.execution_id:
1✔
406
        return None
1✔
407
    pytorch_rep = owning_component.parameters.pytorch_representation._get(context)
1✔
408
    if not pytorch_rep:
1!
UNCOV
409
        return None
×
410
    return np.array(pytorch_rep.retained_trained_outputs)
1✔
411

412
def _get_torch_targets(owning_component=None, context=None):
1✔
413
    if not context.execution_id:
1!
414
        return None
1✔
415
    pytorch_rep = owning_component.parameters.pytorch_representation._get(context)
×
416
    if not pytorch_rep:
×
417
        return None
×
418
    return np.array(pytorch_rep.retained_targets)
×
419

420
def _get_torch_losses(owning_component, context):
1✔
421
    if not context.execution_id:
1✔
422
        return None
1✔
423
    pytorch_rep = owning_component.parameters.pytorch_representation._get(context)
1✔
424
    if not pytorch_rep:
1!
425
        return None
×
426
    return np.array(pytorch_rep.retained_losses)
1✔
427

428
class AutodiffCompositionError(CompositionError):
1✔
429

430
    def __init__(self, error_value):
1✔
431
        self.error_value = error_value
1✔
432

433
    def __str__(self):
1✔
434
        return repr(self.error_value)
1✔
435

436

437
class AutodiffComposition(Composition):
1✔
438
    """
439
    AutodiffComposition(                        \
440
        optimizer_type='sgd',
441
        loss_spec=Loss.MSE,
442
        weight_decay=0,
443
        learning_rate=0.001,
444
        optimizer_params=None,
445
        disable_learning=False,
446
        synch_projection_matrices_with_torch=RUN,
447
        synch_node_variables_with_torch=None,
448
        synch_node_values_with_torch=RUN,
449
        synch_results_with_torch=RUN,
450
        retain_torch_trained_outputs=MINIBATCH,
451
        retain_torch_targets=MINIBATCH,
452
        retain_torch_losses=MINIBATCH,
453
        device=CPU
454
        )
455

456
    Subclass of `Composition` that trains models using either LLVM compilation or `PyTorch <https://pytorch.org>`_;
457
    see and `Composition <Composition_Class_Reference>` for additional arguments and attributes.  See `Composition`
458
    for additional arguments to constructor.
459

460
    Arguments
461
    ---------
462

463
    optimizer_type : str : default 'sgd'
464
        the kind of optimizer used in training. The current options are 'sgd' or 'adam'.
465

466
    loss_spec : Loss or PyTorch loss function : default Loss.MSE
467
        specifies the loss function for training; see `Loss` for arguments.
468

469
    weight_decay : float : default 0
470
        specifies the L2 penalty (which discourages large weights) used by the optimizer.
471

472
    learning_rate : float : default 0.001
473
        specifies the learning rate passed to the optimizer if none is specified in the `learn
474
        <AutdodiffComposition.learn>` method of the AutodiffComposition;
475
        see `learning_rate <AutodiffComposition.learning_rate>` for additional details.
476

477
    optimizer_params : Dict[str: value]
478
        specifies parameters for the optimizer used for learning by the GRUComposition
479
        (see `AutodiffComposition_Learning_Rates` for details of specification.
480

481
    disable_learning : bool: default False
482
        specifies whether the AutodiffComposition should disable learning when run in `learning mode
483
        <Composition.learn>`.
484

485
    synch_projection_matrices_with_torch : `LearningScale` : default RUN
486
        specifies the default for the AutodiffComposition for when to copy Pytorch parameters to PsyNeuLink
487
        `Projection matrices <MappingProjection.matrix>` (connection weights), which can be overridden by specifying
488
        the **synch_projection_matrices_with_torch** argument in the `learn <Composition.learn>` method;
489
        see `synch_projection_matrices_with_torch <AutodiffComposition.synch_projection_matrices_with_torch>`
490
        for additional details.
491

492
    synch_node_variables_with_torch : `LearningScale` : default None
493
        specifies the default for the AutodiffComposition for when to copy the current input to Pytorch nodes
494
        to the PsyNeuLink `variable <Mechanism_Base.value>` attribute of the corresponding PsyNeuLink `nodes
495
        <Composition_Node>`, which can be overridden by specifying the **synch_node_variables_with_torch** argument
496
        in the `learn <Composition.learn>` method; see `synch_node_variables_with_torch
497
        <AutodiffComposition.synch_node_variables_with_torch>` for additional details.
498

499
    synch_node_values_with_torch : `LearningScale` : default RUN
500
        specifies the default for the AutodiffComposition for when to copy the current output of Pytorch nodes to the
501
        PsyNeuLink `value <Mechanism_Base.value>` attribute of the corresponding PsyNeuLink `nodes <Composition_Node>`,
502
        which can be overridden by specifying the **synch_node_values_with_torch** argument in the `learn
503
        <Composition.learn>` method; see `synch_node_values_with_torch
504
        <AutodiffComposition.synch_node_values_with_torch>` for additional details.
505

506
    synch_results_with_torch : `LearningScale` : default RUN
507
        specifies the default for the AutodiffComposition for when to copy the outputs of the Pytorch model
508
        to the AutodiffComposition's `results <Composition.results>` attribute, which can be overridden by
509
        specifying the **synch_results_with_torch** argument in the `learn <Composition.learn>` method.
510
        Note that this differs from **retain_torch_trained_outputs**, which specifies the frequency at which
511
        the outputs of the PyTorch model are tracked, all of which are stored in the AutodiffComposition's
512
        `torch_trained_outputs <AutodiffComposition.torch_trained_outputs>` attribute at the end of the run;
513
        see `synch_results_with_torch <AutodiffComposition.synch_results_with_torch>` for
514
        additional details.
515

516
    retain_torch_trained_outputs : `LearningScale` : default MINIBATCH
517
        specifies the default for the AutodiffComposition for scale at which the outputs of the Pytorch
518
        model are tracked, all of which are stored in the AutodiffComposition's `torch_trained_outputs
519
        <AutodiffComposition.torch_trained_outputs>` attribute at the end of the run; this can be overridden
520
        by specifying the **retain_torch_trained_outputs** argument in the `learn <Composition.learn>` method.
521
        Note that this differs from **synch_results_with_torch**, which specifies the frequency with
522
        which values are called to the AutodiffComposition's `results` attribute; see `retain_torch_trained_outputs
523
        <AutodiffComposition.retain_torch_trained_outputs>` for additional details.
524

525
    retain_torch_targets : `LearningScale` : default MINIBATCH
526
        specifies the default for the AutodiffComposition for when to copy the targets used for training the
527
        Pytorch model to the AutodiffComposition's `torch_targets <Composition.torch_targets>` attribute, which can be
528
        overridden by specifying the **retain_torch_targets** argument in the `learn <Composition.learn>` method;
529
        see `retain_torch_targets <AutodiffComposition.retain_torch_targets>` for additional details.
530

531
    retain_torch_losses : `LearningScale` : default MINIBATCH
532
        specifies the default for the AutodiffComposition for the scale at which the losses of the Pytorch model
533
        are tracked, all of which are stored in the AutodiffComposition's `torch_losses <Composition.torch_losses>`
534
        attribute at the end of the run; see `retain_torch_losses <AutodiffComposition.retain_torch_losses>` for
535
        additional details.
536

537
    device : torch.device : default device-dependent
538
        specifies the device on which the model is run. If None, the device is set to 'cuda' if available,
539
        then 'mps`, otherwise 'cpu'.
540

541
    Attributes
542
    ----------
543

544
    pytorch_representation : PytorchCompositionWrapper : default None
545
        represents the PyTorch model of the AutodiffComposition, which is created when the AutodiffComposition is
546
        run in `PyTorch mode <AutodiffComposition_PyTorch>`.
547

548
    optimizer : PyTorch optimizer function
549
        the optimizer used for training. Depends on the **optimizer_type**, **learning_rate**, and **weight_decay**
550
        arguments from initialization.
551

552
    loss : PyTorch loss function
553
        the loss function used for training. Depends on the **loss_spec** argument from initialization.
554

555
    learning_rate : float or bool
556
        determines the default learning_rate passed the optimizer, that is applied to all `Projections <Projection>`
557
        in the AutodiffComposition that are `learnable <MappingProjection.learnable>`, and for which individual rates
558
        have not been specified (for how to do the latter, see `AutodiffComposition_Learning_Rates`).
559

560
        .. note::
561
           At present, an outermost Compositon's learning rate is applied to any `nested Compositions
562
           <AutodiffComposition_Nesting>`, whether this is specified in the call to its `learn
563
           <AutodiffComposition.learn>` method, its constructor, or its default value is being used.
564

565
        .. hint::
566
           To disable updating of a particular `MappingProjection` in an AutodiffComposition, specify either the
567
           **learnable** parameter of its constructor or its learning_rate specification in the **optimizer_params**
568
           argument of the AutodiffComposition's constructor to False  (see `AutodiffComposition_Learning_Rates`);
569
           this applies to MappingProjections at any level of `nesting <AutodiffComposition_Nesting>`
570

571
    synch_projection_matrices_with_torch : OPTIMIZATION_STEP, MINIBATCH, EPOCH or RUN
572
        determines when to copy PyTorch parameters to PsyNeuLink `Projection matrices <MappingProjection.matrix>`
573
        (connection weights) if this is not specified in the call to `learn <AutodiffComposition.learn>`. Copying more
574
        frequently keeps the PsyNeuLink representation more closely synchronized with parameter updates in Pytorch,
575
        but slows performance (see `AutodiffComposition_PyTorch_LearningScale` for information about settings).
576

577
    synch_node_variables_with_torch : OPTIMIZATION_STEP, TRIAL, MINIBATCH, EPOCH, RUN or None
578
        determines when to copy the current input to Pytorch functions to the PsyNeuLink `variable
579
        <Mechanism_Base.value>` attribute of the corresponding PsyNeuLink `nodes <Composition_Node>`,
580
        if this is not specified in the call to `learn <AutodiffComposition.learn>`.
581
        COMMENT:
582
        8/8/24 - FIX: 3/15/25 ADD EXPLANATION OF WHY THIS IS NOT GENERALLY USEFUL ALONG THE LINES OF THE FOLLOWING
583
                 ALSO RELATE TO EXECUTE_NODES OPTION ONCE IMPLEMENTED
584
        This is supported for inspection and debugging, but is not generally useful, as PsyNeuLink uses `Lazy
585
        Evaluation <Component_Lazy_Updating>`, in which the variable of a node is determined by the input it receives
586
        during execution.
587
        COMMENT
588
        Copying more frequently keeps the PsyNeuLink representation more closely copying more frequently
589
        keeps them synchronized with parameter updates in Pytorch, but can slow performance (see
590
        `AutodiffComposition_PyTorch_LearningScale` for information about settings).
591

592
    synch_node_values_with_torch : OPTIMIZATION_STEP, MINIBATCH, EPOCH or RUN
593
        determines when to copy the current output of Pytorch functions to the PsyNeuLink `value
594
        <Mechanism_Base.value>` attribute of the corresponding PsyNeuLink `nodes <Composition_Node>`,
595
        if this is not specified in the call to `learn <AutodiffComposition.learn>`. Copying more
596
        frequently keeps the PsyNeuLink representation more closely synchronized with parameter
597
        updates in Pytorch, but can also slow performance (see `AutodiffComposition_PyTorch_LearningScale`
598
        for information about settings).
599

600
    synch_results_with_torch : OPTIMIZATION_STEP, TRIAL, MINIBATCH, EPOCH or RUN
601
        determines when to copy the current outputs of Pytorch nodes to the PsyNeuLink `results
602
        <Composition.results>` attribute of the AutodiffComposition if this is not specified in
603
        the call to `learn <AutodiffComposition.learn>`. Copying more frequently keeps the PsyNeuLink
604
        representation more closely synchronized with parameter updates in Pytorch, but slows performance
605
        (see `AutodiffComposition_PyTorch_LearningScale` for information about settings).
606

607
    retain_torch_trained_outputs : OPTIMIZATION_STEP, MINIBATCH, EPOCH, RUN or None
608
        determines the scale at which the outputs of the Pytorch model are tracked, all of which are stored in
609
        the AutodiffComposition's `results <Composition.results>` attribute at the end of the run if this is not
610
        specified in the call to `learn <AutodiffComposition.learn>`(see `AutodiffComposition_PyTorch_LearningScale`
611
        for information about settings)
612

613
    retain_torch_targets : OPTIMIZATION_STEP, TRIAL, MINIBATCH, EPOCH, RUN or None
614
        determines the scale at which the targets used for training the Pytorch model are tracked, all of which
615
        are stored in the AutodiffComposition's `targets <Composition.targets>` attribute at the end of the run
616
        if this is not specified in the call to `learn <AutodiffComposition.learn>`
617
        (see `AutodiffComposition_PyTorch_LearningScale` for information about settings).
618

619
    retain_torch_losses : OPTIMIZATION_STEP, MINIBATCH, EPOCH, RUN or None
620
        determines the scale at which the losses of the Pytorch model are tracked, all of which are stored in
621
        the AutodiffComposition's `torch_losses <Composition.torch_losses>` attribute at the end of the run
622
        if this is nota specified in the call to `learn <AutodiffComposition.learn>`
623
        (see `AutodiffComposition_PyTorch_LearningScale` for information about settings).
624

625
    torch_trained_outputs : List[ndarray]
626
        stores the outputs (converted to np arrays) of the Pytorch model trained during learning, at the frequency
627
        specified by `retain_torch_trained_outputs <AutodiffComposition.retain_torch_trained_outputs>` if it is set
628
        to *MINIBATCH*, *EPOCH*, or *RUN*; see `retain_torch_trained_outputs
629
        <AutodiffComposition.retain_torch_trained_outputs>` for additional details.
630

631
    torch_targets : List[ndarray]
632
        stores the targets used for training the Pytorch model during learning at the frequency specified by
633
        `retain_torch_targets <AutodiffComposition.retain_torch_targets>` if it is set to *MINIBATCH*, *EPOCH*,
634
        or *RUN*; see `retain_torch_targets <AutodiffComposition.retain_torch_targets>` for additional details.
635

636
    torch_losses : list of floats
637
        stores the average loss after each weight update (i.e. each minibatch) during learning, at the frequency
638
        specified by `retain_torch_trained_outputs <AutodiffComposition.retain_torch_trained_outputs>` if it is set to *MINIBATCH*,
639
        *EPOCH*, or *RUN*; see `retain_torch_losses <AutodiffComposition.retain_torch_losses>` for additonal details.
640

641
    COMMENT:  FIX: NOT CURRENTLY BEING POPULTED, BUT SEEMS TO BE USED BY _get_total_loss() and early_stopper
642
    trial_losses = Parameter([])
643
    COMMENT
644

645
    last_saved_weights : path
646
        path for file to which weights were last saved.
647

648
    last_loaded_weights : path
649
        path for file from which weights were last loaded.
650

651
    device : torch.device
652
        the device on which the model is run.
653
    """
654

655
    componentCategory = AUTODIFF_COMPOSITION
1✔
656
    if torch_available:
1!
657
        from psyneulink.library.compositions.pytorchwrappers import PytorchCompositionWrapper, PytorchMechanismWrapper
1✔
658
        pytorch_composition_wrapper_type = PytorchCompositionWrapper
1✔
659
        pytorch_mechanism_wrapper_type = PytorchMechanismWrapper
1✔
660

661
    class Parameters(Composition.Parameters):
1✔
662
        pytorch_representation = None
1✔
663
        optimizer = None
1✔
664
        learning_rate = Parameter(.001, fallback_value=DEFAULT)
1✔
665
        synch_projection_matrices_with_torch = Parameter(RUN, fallback_value=DEFAULT)
1✔
666
        synch_node_variables_with_torch = Parameter(None, fallback_value=DEFAULT)
1✔
667
        synch_node_values_with_torch = Parameter(RUN, fallback_value=DEFAULT)
1✔
668
        synch_results_with_torch = Parameter(RUN, fallback_value=DEFAULT)
1✔
669
        retain_torch_trained_outputs = Parameter(MINIBATCH, fallback_value=DEFAULT)
1✔
670
        retain_torch_targets = Parameter(MINIBATCH, fallback_value=DEFAULT)
1✔
671
        retain_torch_losses = Parameter(MINIBATCH, fallback_value=DEFAULT)
1✔
672
        torch_trained_outputs = Parameter([], getter=_get_torch_trained_outputs)
1✔
673
        torch_targets = Parameter([], getter=_get_torch_targets)
1✔
674
        torch_losses = Parameter([], getter=_get_torch_losses)
1✔
675
        trial_losses = Parameter([]) # FIX <- related to early_stopper, but not getting assigned anywhere
1✔
676
        device = None
1✔
677

678
        # def _validate_memory_template(self, device):
679
        #     if isinstance(device, str) and device not in [CPU, CUDA, MPS]:
680
        #         raise AutodiffCompositionError(f"Device must be one of {CPU}, {CUDA}, or {MPS}")
681
        #
682
        def _validate_synch_projection_matrices_with_torch(self, spec):
1✔
683
            if spec is not None and spec not in LEARNING_SCALE_VALUES:
1✔
684
                raise AutodiffCompositionError(f"Value of 'synch_projection_matrices_with_torch' arg "
685
                                               f"must be one of the following keywords: "
686
                                               f"{', '.join(LEARNING_SCALE_NAMES)}")
687

688
        def _validate_synch_node_variables_with_torch(self, spec):
1✔
689
            if spec is not None and spec not in LEARNING_SCALE_VALUES:
1✔
690
                raise AutodiffCompositionError(f"Value of 'synch_node_variables_with_torch' arg "
691
                                               f"must be one of the following keywords: "
692
                                               f"{', '.join(LEARNING_SCALE_NAMES)}")
693

694
        def _validate_synch_node_values_with_torch(self, spec):
1✔
695
            if spec is not None and spec not in LEARNING_SCALE_VALUES:
1✔
696
                raise AutodiffCompositionError(f"Value of 'synch_node_values_with_torch' arg "
697
                                               f"must be one of the following keywords: "
698
                                               f"{', '.join(LEARNING_SCALE_NAMES)}")
699

700
        def _validate_synch_results_with_torch(self, spec):
1✔
701
            if spec is not None and spec not in LEARNING_SCALE_VALUES:
1✔
702
                raise AutodiffCompositionError(f"Value of 'synch_results_with_torch' arg "
703
                                               f"must be one of the following keywords: "
704
                                               f"{', '.join(LEARNING_SCALE_NAMES)}")
705
            if spec is OPTIMIZATION_STEP:
1!
706
                arg_vals = LEARNING_SCALE_NAMES.copy()
×
707
                arg_vals.remove('OPTIMIZATION_STEP')
×
708
                raise AutodiffCompositionError(f"'OPTIMIZATION_STEP can't be used with 'synch_results_with_torch';"
709
                                               f"use another value of {', '.arg_vals}")
710

711

712
        def _validate_retain_torch_trained_outputs(self, spec):
1✔
713
            if spec is not None and spec not in LEARNING_SCALE_VALUES:
1✔
714
                raise AutodiffCompositionError(f"Value of `retain_torch_trained_outputs` arg "
715
                                               f"must be one of the following keywords: "
716
                                               f"{', '.join(LEARNING_SCALE_NAMES)}")
717

718
        def _validate_retain_torch_targets(self, spec):
1✔
719
            if spec is not None and spec not in LEARNING_SCALE_VALUES:
1✔
720
                raise AutodiffCompositionError(f"Value of `retain_torch_targets` arg "
721
                                               f"must be one of the following keywords: "
722
                                               f"{', '.join(LEARNING_SCALE_NAMES)}")
723

724
        def _validate_retain_torch_losses(self, spec):
1✔
725
            if spec is not None and spec not in LEARNING_SCALE_VALUES:
1✔
726
                raise AutodiffCompositionError(f"Value of `retain_torch_losses` arg "
727
                                               f"must be one of the following keywords: "
728
                                               f"{', '.join(LEARNING_SCALE_NAMES)}")
729

730

731
    # TODO (CW 9/28/18): add compositions to registry so default arg for name is no longer needed
732
    @check_user_specified
1✔
733
    def __init__(self,
1✔
734
                 pathways=None,
735
                 optimizer_type='sgd',
736
                 loss_spec=Loss.MSE,
737
                 weight_decay=0,
738
                 learning_rate=None,
739
                 optimizer_params:dict=None,
740
                 disable_learning=False,
741
                 force_no_retain_graph=False,
742
                 refresh_losses=False,
743
                 synch_projection_matrices_with_torch:Optional[str]=RUN,
744
                 synch_node_variables_with_torch:Optional[str]=None,
745
                 synch_node_values_with_torch:Optional[str]=RUN,
746
                 synch_results_with_torch:Optional[str]=RUN,
747
                 retain_torch_trained_outputs:Optional[str]=MINIBATCH,
748
                 retain_torch_targets:Optional[str]=MINIBATCH,
749
                 retain_torch_losses:Optional[str]=MINIBATCH,
750
                 device=None,
751
                 disable_cuda=True,
752
                 cuda_index=None,
753
                 name="autodiff_composition",
754
                 **kwargs):
755

756
        # if not torch_available:
757
        #     raise AutodiffCompositionError('Pytorch python module (torch) is not installed. Please install it with '
758
        #                                    '`pip install torch` or `pip3 install torch`')
759
        #
760
        show_graph_attributes = kwargs.pop('show_graph_attributes', {})
1✔
761

762
        super(AutodiffComposition, self).__init__(
1✔
763
            name = name,
764
            pathways=pathways,
765
            optimizer_type = optimizer_type,
766
            loss_spec = loss_spec,
767
            weight_decay = weight_decay,
768
            learning_rate = learning_rate,
769
            synch_projection_matrices_with_torch = synch_projection_matrices_with_torch,
770
            synch_node_variables_with_torch = synch_node_variables_with_torch,
771
            synch_node_values_with_torch = synch_node_values_with_torch,
772
            synch_results_with_torch = synch_results_with_torch,
773
            retain_torch_trained_outputs = retain_torch_trained_outputs,
774
            retain_torch_targets = retain_torch_targets,
775
            retain_torch_losses = retain_torch_losses,
776
            **kwargs)
777

778
        self._built_pathways = False
1✔
779
        self.targets_from_outputs_map = {} # Map from TARGETS nodes to any OUTPUT nodes from which they receive input
1✔
780
        self.outputs_to_targets_map = {}   # Map from trained OUTPUT nodes to their TARGETS
1✔
781
        self._trained_comp_nodes_to_pytorch_nodes_map = None # Set by subclasses that replace trained OUTPUT Nodes
1✔
782
        self._input_comp_nodes_to_pytorch_nodes_map = None # Set by subclasses that replace INPUT Nodes
1✔
783
        self._pytorch_projections = []
1✔
784
        self.optimizer_type = optimizer_type
1✔
785
        self._optimizer_params = optimizer_params or {}
1✔
786
        self.loss_spec = loss_spec
1✔
787
        self._runtime_learning_rate = None
1✔
788
        self.force_no_retain_graph = force_no_retain_graph
1✔
789
        self.refresh_losses = refresh_losses
1✔
790
        self.weight_decay = weight_decay
1✔
791
        self.disable_learning = disable_learning
1✔
792
        self.loss_function = None
1✔
793
        self.last_saved_weights = None
1✔
794
        self.last_loaded_weights = None
1✔
795

796
        # keeps track of average loss per epoch
797
        self.losses = []
1✔
798

799
        # ordered execution sets for the pytorch model
800
        self.execution_sets = None
1✔
801

802
        # # MODIFIED 7/10/24 OLD:
803
        if not disable_cuda and torch.cuda.is_available():
1!
804
            if cuda_index is None:
×
805
                self.device = torch.device('cuda')
×
806
            else:
807
                self.device = torch.device('cuda:' + str(cuda_index))
×
808
        elif torch_available:
1!
809
            self.device = torch.device('cpu')
1✔
810
            self.torch_dtype = self.pytorch_composition_wrapper_type.torch_dtype
1✔
811
        else:
812
            self.device = device
×
NEW
813
            self.torch_dtype = None
×
814
        # # MODIFIED 7/10/24 NEW: NEEDED FOR torch MPS SUPPORT
815
        #  FIX: ADD AFTER USE OF utilities.get_torch_tensor() AND COMPATIBLITY WITH MPS IS VALIDATED
816
        # if device is None:
817
        #     # Try setting device by default
818
        #     if not disable_cuda and torch.cuda.is_available():
819
        #         if cuda_index is None:
820
        #             self.device = torch.device(CUDA)
821
        #         else:
822
        #             self.device = torch.device('cuda:' + str(cuda_index))
823
        #     elif torch_available:
824
        #         if torch.backends.mps.is_available():
825
        #             from psyneulink.core.components.functions.nonstateful.transferfunctions import Linear
826
        #             try:
827
        #                 self.device = torch.device(MPS)
828
        #                 test_pytorch_fct_with_mps = Linear()._gen_pytorch_fct(self.device, Context())
829
        #             except AssertionError:
830
        #                 self.device = torch.device(CPU)
831
        #         else:
832
        #             self.device = torch.device(CPU)
833
        # else:
834
        #     self.device = device
835
        # # MODIFIED 7/10/24 END
836

837
        # Set to True after first warning about failure to specify execution mode so warning is issued only once
838
        self.execution_mode_warned_about_default = False
1✔
839
        # torch params added when warned in copy_projection_matrix_to_torch_param() to avoid repeats for same param
840
        self.require_grad_warning = []
1✔
841
        # return self.infer_backpropagation_learning_pathways(pnlvm.ExecutionMode.PyTorch)
842

843
        # ShowGraph
844
        self.assign_ShowGraph(show_graph_attributes)
1✔
845
    def assign_ShowGraph(self, show_graph_attributes):
1✔
846
        """Override to replace assignment of ShowGraph class with PytorchShowGraph if torch is available"""
847
        show_graph_attributes = show_graph_attributes or {}
1✔
848
        if torch_available:
1!
849
            self._show_graph = PytorchShowGraph(self, **show_graph_attributes)
1✔
850
        else:
851
            from psyneulink.core.compositions.showgraph import ShowGraph
×
852
            self._show_graph = ShowGraph(self, **show_graph_attributes)
×
853

854
    @handle_external_context()
1✔
855
    def infer_backpropagation_learning_pathways(self, execution_mode, context=None)->list:
1✔
856
        """Create backpropagation learning pathways for every Input Node --> Output Node pathway
857
        Flattens nested compositions:
858
          - only includes the Projections in outer Composition to/from the CIMs of the nested Composition
859
            (i.e., to input_CIMs and from output_CIMs) -- the ones that should be learned;
860
          - excludes Projections from/to CIMs in the nested Composition
861
            (from input_CIMs and to output_CIMs), as those should remain identity Projections;
862
          see `PytorchCompositionWrapper` for table of how Projections are handled and further details.
863
        Returns list of target nodes for each pathway
864
        """
865

866
        # Construct a pathway(s) for each INPUT Node (including BIAS Nodes), except the TARGET Node)
867
        pathways = self._get_pytorch_backprop_pathways(context)
1✔
868

869
        if execution_mode is pnlvm.ExecutionMode.PyTorch:
1✔
870
            # For PyTorch mode, only need to construct dummy TARGET Nodes, to allow targets to be:
871
            #  - specified in the same way as for other execution_modes
872
            #  - trial-by-trial values kept aligned with inputs in batch / minibatch construction
873
            #  - tracked for logging (as mechs of a Composition)
874
            # IMPLEMENTATION NOTE:
875
            #    only add target nodes if not already present
876
            #    (to avoid duplication in multiple calls, including from command line;
877
            #     see test_xor_training_identicalness_standard_composition_vs_PyTorch_and_LLVM for example)
878
            # output_mechs_for_learning = self.get_nested_output_nodes_at_all_levels()
879
            # assert set([mech for mech in [pathway[-1] for pathway in pathways]]) == set(output_mechs_for_learning)
880
            pathway_terminal_nodes = [mech for mech in [pathway[-1] for pathway in pathways]]
1✔
881
            identified_target_nodes = self._identify_target_nodes(context)
1✔
882
            output_mechs_for_learning = [node for node in identified_target_nodes if node in pathway_terminal_nodes]
1✔
883
            target_mechs = [ProcessingMechanism(default_variable = np.array([np.zeros_like(value)
1✔
884
                                                                             for value in mech.value],
885
                                                                            dtype=object),
886
                                                name= 'TARGET for ' + mech.name)
887
                            for mech in output_mechs_for_learning if mech not in self.targets_from_outputs_map.values()]
888
            # Suppress warnings about role assignments
889
            context = Context(source=ContextFlags.METHOD)
1✔
890
            self.add_nodes(target_mechs, required_roles=[NodeRole.TARGET, NodeRole.LEARNING], context=context)
1✔
891
            for target_mech in target_mechs:
1✔
892
                self.exclude_node_roles(target_mech, NodeRole.OUTPUT, context)
1✔
893
                for output_port in target_mech.output_ports:
1✔
894
                    output_port.parameters.require_projection_in_composition.set(False, override=True)
1✔
895
            self.targets_from_outputs_map.update({target: output for target, output
1✔
896
                                           in zip(target_mechs, output_mechs_for_learning)})
897
        else:
898
            # Construct entire PNL backpropagation learning pathways for each INPUT Node
899
            for pathway in pathways:
1✔
900
                self.add_backpropagation_learning_pathway(pathway=pathway,
1✔
901
                                                          loss_spec=self.loss_spec)
902

903
        self.outputs_to_targets_map = {output: target for target, output in self.targets_from_outputs_map.items()}
1✔
904
        self._analyze_graph()
1✔
905
        return self.learning_components
1✔
906

907
    @handle_external_context()
1✔
908
    def _get_pytorch_backprop_pathways(self, context)->list:
1✔
909

910
        self._analyze_graph()
1✔
911
        return [pathway
1✔
912
                    for node in (self.get_nodes_by_role(NodeRole.INPUT) + self.get_nodes_by_role(NodeRole.BIAS))
913
                    if node not in self.get_nodes_by_role(NodeRole.TARGET)
914
                    for pathway in self._get_pytorch_backprop_pathway(node, context)]
915

916
    def _get_pytorch_backprop_pathway(self, input_node, context)->list:
1✔
917
        """Breadth-first search from input_node to find all input -> output pathways
918
        Uses queue(node, composition) to traverse all nodes in the graph
919
        IMPLEMENTATION NOTE:  flattens nested Compositions, removing any CIMs in the nested Compositions
920
        Return a list of all pathways from input_node -> output node
921
        """
922

923
        pathways = []  # List of all feedforward pathways from INPUT Node to OUTPUT Node
1✔
924
        dependency_dict = {}      # Dictionary of previous component for each component in every pathway
1✔
925
        queue = deque([(input_node, self)])  # Queue of nodes to visit in breadth-first search
1✔
926

927
        def create_pathway(current_comp, node)->list:
1✔
928
            """Create pathway starting with node (presumably an output NODE) and working backward via dependency_dict"""
929
            pathway = []
1✔
930
            entry = node
1✔
931
            while entry in dependency_dict:
1✔
932
                # Prevent cycle from recurrent pathway
933
                if entry in pathway:
1!
NEW
934
                    break
×
935
                pathway.insert(0, entry)
1✔
936
                entry = dependency_dict[entry]
1✔
937
            pathway.insert(0, entry)
1✔
938
            # Only consider pathways with 3 or more components (input -> projection -> ... -> output)
939
            #    since can't learn on only one mechanism (len==1)
940
            #    and a pathway can't have just one mechanism and one projection (len==2)
941
            if len(pathway) >= 3:
1!
942
                return pathway
1✔
943
            else:
NEW
944
                return []
×
945

946
        # breadth-first search starting with input node
947
        while len(queue) > 0:
1✔
948
            node, current_comp = queue.popleft()
1✔
949

950
            # node is nested Composition that is an INPUT node of the immediate outer Composition,
951
            #   so put that in queue for procsssing in next pass through while loop
952
            if (isinstance(node, Composition) and node is not self
1✔
953
                    and any(isinstance(proj.sender.owner, CompositionInterfaceMechanism)
954
                            for proj in node.afferents)):
955
                for output_port in node.input_CIM.output_ports:
1✔
956
                    for proj in output_port.efferents:
1✔
957
                        queue.append((proj.receiver.owner, node))
1✔
958
                continue
1✔
959

960
            # node is output_CIM of outer Composition (i.e., end of pathway) which shouldn't happen yet
961
            if isinstance(node, CompositionInterfaceMechanism) and node is self.output_CIM:
1✔
962
                assert False, (f"PROGRAM ERROR: 'Got to output_CIM of outermost Composition '({self.name})' "
963
                               f"without detecting OUTPUT NODE at end of pathway")
964

965
            # End of pathway: OUTPUT Node of outer Composition
966
            if current_comp == self and node in current_comp.get_nodes_by_role(NodeRole.OUTPUT):
1✔
967
                pathways.append(create_pathway(current_comp, node))
1✔
968
                continue
1✔
969

970
            # # Get all efferent Projections of node,
971
            # #   including direct projections out of a nested Composition implemented in PyTorchCompositionWrapper
972
            efferent_projs = [(p, p.receiver.owner) for p in node.efferents if p in current_comp.projections]
1✔
973
            if not efferent_projs:
1✔
974
                efferent_projs = [(p, p.receiver.owner) for p in node.efferents
1✔
975
                                  if p in current_comp._pytorch_projections]
976

977
            # Follow efferent Projection to next Node in pathway
978
            for efferent_proj, rcvr in efferent_projs:
1✔
979
                # Ignore efferent Projections that do not have a learnable attribute
980
                #   or are ModulatoryProjections (i.e., including LearningProjections)
981
                # Note: if learnable==False, it will be passed along to PyTorch in PytorchProjectionWrapper
982
                if not hasattr(efferent_proj,'learnable') or isinstance(efferent_proj,ModulatoryProjection_Base):
1!
NEW
983
                    continue
×
984

985
                # Deal with Projections to/from CIMs since nested comps can be learned in PyTorch mode
986
                if isinstance(rcvr, CompositionInterfaceMechanism):
1✔
987

988
                    # Projection to input_CIM of a nested Composition
989
                    if rcvr == rcvr.composition.input_CIM:
1✔
990
                        assert rcvr.composition is not current_comp
1✔
991
                        rcvr_comp = rcvr.composition
1✔
992
                        # Get Node(s) in inner Composition to which Node projects (via input_CIM)
993
                        receivers = rcvr._get_destination_info_from_input_CIM(efferent_proj.receiver)
1✔
994
                        for _, nested_rcvr, _ in [receivers] if isinstance(receivers, tuple) else receivers:
1✔
995
                            if rcvr_comp._input_comp_nodes_to_pytorch_nodes_map:
1✔
996
                                # If nested comp has _input_comp_nodes_to_pytorch_nodes_map, get nested_rcvr from it
997
                                nested_rcvr = rcvr_comp._input_comp_nodes_to_pytorch_nodes_map[nested_rcvr]
1✔
998
                            else:
999
                                # Otherwise, ensure that nested_rcvr is an INPUT Node of rcvr_comp
1000
                                assert nested_rcvr in rcvr_comp.get_nodes_by_role(NodeRole.INPUT), \
1✔
1001
                                    f"PROGRAM ERROR: '{nested_rcvr.name}' is not an INPUT Node of '{rcvr_comp.name}'"
1002
                                # Assign efferent_proj (Projection to input_CIM) since it should be learned in PyTorch mode
1003
                            rcvr_comp._add_dependency(node, efferent_proj, nested_rcvr,
1✔
1004
                                                      dependency_dict, queue, rcvr_comp)
1005

1006
                    # rcvr is Nested Composition output_CIM:
1007
                    # Projection is to output_CIM exiting from a nested Composition
1008
                    elif rcvr == current_comp.output_CIM and current_comp is not self:
1!
1009

1010
                        # Get output_CIM info for current efferent_proj
1011
                        output_CIM_input_port = efferent_proj.receiver
1✔
1012
                        output_CIM = output_CIM_input_port.owner
1✔
1013
                        # Get port of output_CIM that efferent_proj sends to, for use in findings its receiver(s) below
1014
                        if efferent_proj in current_comp.projections:
1✔
1015
                            output_CIM_output_port = output_CIM.port_map[efferent_proj.sender][1]
1✔
1016
                        elif efferent_proj in current_comp._pytorch_projections:
1!
1017
                            # FIX: 3/8/25 - THERE MUST BE AN EASIER WAY TO GET THIS MORE DIRECTLY
1018
                            output_CIM_output_port = \
1✔
1019
                                (output_CIM.port_map)[efferent_proj.receiver.path_afferents[0].sender][1]
1020

1021
                        # Get all Node(s) in outer Composition to which node projects (via output_CIM)
1022
                        receivers = rcvr._get_destination_info_for_output_CIM(output_CIM_output_port)
1✔
1023
                        # Replace efferent_proj(s) with one(s) from output_CIM to rcvr(s) in outer Composition,
1024
                        #   since that(those) is(are) the one(s) that should be learned in PyTorch mode
1025
                        # Note:  _get_destination_info_for_output_CIM returns list of destinations
1026
                        #        in order of output_CIM.output_port.efferents
1027
                        if receivers:
1✔
1028
                            for efferent_idx, receiver in enumerate(receivers):
1✔
1029
                                if receiver:
1!
1030
                                    _, rcvr, rcvr_comp = receiver
1✔
1031
                                    assert rcvr_comp is not current_comp
1✔
1032
                                efferent_proj = output_CIM_output_port.efferents[efferent_idx]
1✔
1033
                                rcvr_comp._add_dependency(node, efferent_proj, rcvr,
1✔
1034
                                                          dependency_dict, queue, rcvr_comp)
1035
                        else:
1036
                            pathways.append(create_pathway(current_comp, node))
1✔
1037

1038
                    # rcvr is Outermost Composition output_CIM:
1039
                    # End of pathway: Direct projection from output_CIM of nested comp to outer comp's output_CIM
NEW
1040
                    elif rcvr is self.output_CIM:
×
1041
                        # Assign node that projects to current node as OUTPUT Node for pathway
NEW
1042
                        node_output_port = efferent_proj.sender
×
NEW
1043
                        _, sender, _ = node._get_source_info_from_output_CIM(node_output_port)
×
NEW
1044
                        pathway = create_pathway(current_comp, node)
×
NEW
1045
                        if pathway:
×
NEW
1046
                            queue.popleft()
×
NEW
1047
                            pathways.append(pathway)
×
1048

1049
                    else:
1050
                        assert False, f"PROGRAM ERROR:  Unrecognized CompositionInterfaceMechanism: {rcvr}"
1051

1052
                else:
1053
                    if rcvr in current_comp.nodes:
1!
1054
                        # rcvr is still in nested Composition, so keep traversing that
1055
                        current_comp._add_dependency(node, efferent_proj, rcvr, dependency_dict, queue, current_comp)
1✔
1056
                        continue
1✔
NEW
1057
                    elif rcvr in self.nodes:
×
1058
                        # rcvr is in outer Composition (presumably a direct Pytorch Projection out of nested comp)
NEW
1059
                        self._add_dependency(node, efferent_proj, rcvr, dependency_dict, queue, self)
×
NEW
1060
                        continue
×
1061
                    else:
1062
                        assert False, \
1063
                            (f"PROGRAM ERROR:  Unrecognized receiver ('{rcvr.name}') of Projection from '{node.name}'.")
1064

1065
        return pathways
1✔
1066

1067
    def _add_dependency(self,
1✔
1068
                        sender:ProcessingMechanism,
1069
                        projection:MappingProjection,
1070
                        receiver:ProcessingMechanism,
1071
                        dependency_dict:dict,
1072
                        queue:deque,
1073
                        comp:Composition):
1074
        """Append dependencies to dependency list, and next node to queue used in _get_pytorch_backprop_pathway()
1075
        This uses the Projection from node to receiver to implement the relevant dependencies for construcing the
1076
        pathway;  however, this can be overridden by a subclass of Autodiff to implement a custom pathway
1077
        (see example in GRUComposition).
1078
        """
1079
        dependency_dict[receiver] = projection
1✔
1080
        dependency_dict[projection] = sender
1✔
1081
        queue.append((receiver, comp))
1✔
1082

1083
    # CLEANUP: move some of what's done in the methods below to a "validate_params" type of method
1084
    @handle_external_context()
1✔
1085
    def _build_pytorch_representation(self, context=None, refresh=None, base_context=Context(execution_id=None)):
1✔
1086
        """Builds a Pytorch representation of the AutodiffComposition"""
1087
        if self.scheduler is None:
1!
1088
            self.scheduler = Scheduler(graph=self.graph_processing)
×
1089
        if self.parameters.pytorch_representation._get(context=context, fallback_value=None) is None or refresh:
1✔
1090
            model = self.pytorch_composition_wrapper_type(composition=self,
1✔
1091
                                                          device=self.device,
1092
                                                          context=context,
1093
                                                          base_context=base_context,
1094
                                                          )
1095

1096
        # Set up optimizer function
1097
        learning_rate = self._runtime_learning_rate or self.learning_rate
1✔
1098
        old_opt = self.parameters.optimizer._get(context)
1✔
1099
        if (old_opt is None or refresh) and refresh is not False:
1✔
1100
            self._instantiate_optimizer(refresh, learning_rate, context)
1✔
1101
        # Set up loss function
1102
        if self.loss_function is not None:
1✔
1103
            logger.warning("Overwriting 'loss_function' for AutodiffComposition {}! Old loss function: {}".format(
1✔
1104
                self, self.loss_function))
1105
        if callable(self.loss_spec):
1✔
1106
            self.loss_function = self.loss_spec
1✔
1107
        else:
1108
            self.loss_function = self._get_loss(self.loss_spec)
1✔
1109

1110
        return self.parameters.pytorch_representation._get(context)
1✔
1111

1112
    def _instantiate_optimizer(self, refresh, learning_rate, context):
1✔
1113
        if not is_numeric_scalar(learning_rate):
1✔
1114
            raise AutodiffCompositionError("Learning rate must be an integer or float value.")
1115
        if self.optimizer_type not in ['sgd', 'adam']:
1✔
1116
            raise AutodiffCompositionError("Invalid optimizer specified. Optimizer argument must be a string. "
1117
                                           "Currently, Stochastic Gradient Descent and Adam are the only available "
1118
                                           "optimizers (specified as 'sgd' or 'adam').")
1119
        pytorch_rep = self.parameters.pytorch_representation._get(context)
1✔
1120
        params = pytorch_rep.parameters()
1✔
1121
        if self.optimizer_type == 'sgd':
1✔
1122
            opt = optim.SGD(params, lr=learning_rate, weight_decay=self.weight_decay)
1✔
1123
        else:
1124
            opt = optim.Adam(params, lr=learning_rate, weight_decay=self.weight_decay)
1✔
1125

1126
        pytorch_rep._parse_optimizer_params(context)
1✔
1127
        for param_group in pytorch_rep._optimizer_param_groups:
1!
NEW
1128
            opt.add_param_group(param_group)
×
1129

1130
        # Assign optimizer to AutodiffComposition and PytorchCompositionWrapper
1131
        self.parameters.optimizer._set(opt, context, skip_history=True, skip_log=True)
1✔
1132
        pytorch_rep.optimizer = opt
1✔
1133

1134
    def _get_loss(self, loss_spec):
1✔
1135
        if not isinstance(self.loss_spec, (str, Loss)):
1!
1136
            return self.loss_spec
×
1137
        elif loss_spec == Loss.MSE:
1✔
1138
            return nn.MSELoss(reduction='mean')
1✔
1139
        elif loss_spec == Loss.SSE:
1!
1140
            return nn.MSELoss(reduction='sum')
×
1141
        elif loss_spec == Loss.CROSS_ENTROPY:
1✔
1142
            if version.parse(torch.version.__version__) >= version.parse('1.12.0'):
1!
1143
                return nn.CrossEntropyLoss()
1✔
1144
            # Cross entropy loss is used for multiclass categorization and needs inputs in shape
1145
            # ((# minibatch_size, C), targets) where C is a 1-d vector of probabilities for each potential category
1146
            # and where target is a 1d vector of type long specifying the index to the target category. This
1147
            # formatting is different from most other loss functions available to autodiff compositions,
1148
            # and therefore requires a wrapper function to properly package inputs.
1149
            return lambda x, y: nn.CrossEntropyLoss()(torch.atleast_2d(x), torch.atleast_2d(y.type(x.type())))
×
1150
        elif loss_spec == Loss.BINARY_CROSS_ENTROPY:
1✔
1151
            return nn.BCELoss()
1✔
1152
        elif loss_spec == Loss.L1:
1✔
1153
            return nn.L1Loss(reduction='sum')
1✔
1154
        elif loss_spec == Loss.NLL:
1!
1155
            return nn.NLLLoss(reduction='sum')
×
1156
        elif loss_spec == Loss.POISSON_NLL:
1!
1157
            return nn.PoissonNLLLoss(reduction='sum')
1✔
1158
        elif loss_spec == Loss.KL_DIV:
×
1159
            return nn.KLDivLoss(reduction='sum')
×
1160
        else:
1161
            raise AutodiffCompositionError(f"Loss type {loss_spec} not recognized. 'loss_function' argument must be a "
1162
                                           f"Loss enum or function. Currently, the recognized loss types are: "
1163
                                           f"L1 (Mean), SSE (sum squared error), CROSS_ENTROPY, NLL (negative log "
1164
                                           f"likelihood), POISSONNLL (Poisson negative log likelihood, "
1165
                                           f"and KL_DIV (KL divergence.")
1166

1167
    def get_target_nodes(self, execution_mode=pnlvm.ExecutionMode.PyTorch):
1✔
1168
        """Return `TARGET` `Nodes <Composition_Nodes>` of the AutodiffComposition."""
1169
        self.infer_backpropagation_learning_pathways(execution_mode=execution_mode)
1✔
1170
        return super(AutodiffComposition, self).get_target_nodes()
1✔
1171

1172
    def autodiff_forward(self, inputs, targets,
1✔
1173
                         synch_with_pnl_options, retain_in_pnl_options,
1174
                         execution_mode, scheduler, context):
1175
        """
1176
        Perform forward pass of model and compute loss for a batch of trials in Pytorch mode.
1177
        Losses are then accumulated, error is backpropagated by compositionrunner.run_learning()
1178
          before the next time it calls run(), in a call to backward() by do_gradient_optimization()
1179
          in _batch_inputs() or _batch_function_inputs(),
1180
        """
1181
        assert execution_mode is pnlvm.ExecutionMode.PyTorch
1✔
1182
        pytorch_rep = self.parameters.pytorch_representation._get(context)
1✔
1183

1184
        # --------- Get current values of nodes  -------------------------------------------------
1185

1186
        #   should return 2d values for each component
1187

1188
        # Get value of INPUT nodes for current trial
1189
        curr_tensors_for_inputs = {}
1✔
1190
        for component in inputs.keys():
1✔
1191
            if not isinstance(inputs[component], torch.Tensor):
1!
1192
                curr_tensors_for_inputs[component] = torch.tensor(inputs[component], device=self.device).double()
×
1193
            else:
1194
                curr_tensors_for_inputs[component] = inputs[component]
1✔
1195

1196
        # Execute PytorchCompositionWrapper to get value of all OUTPUT nodes for current trial
1197
        curr_tensors_for_outputs = pytorch_rep.forward(curr_tensors_for_inputs, None, synch_with_pnl_options, context)
1✔
1198

1199
        # Get value of OUTPUT nodes that are being trained (i.e., for which there are TARGET nodes)
1200
        curr_tensors_for_trained_outputs = {k:v for k,v in curr_tensors_for_outputs.items()
1✔
1201
                                            if k in self.outputs_to_targets_map}
1202

1203
        # Get value of TARGET nodes for current trial
1204
        curr_tensors_for_targets = {}
1✔
1205
        for component, target in targets.items():
1✔
1206
            if isinstance(target, torch.Tensor) or isinstance(target, np.ndarray):
1✔
1207
                curr_tensors_for_targets[component] = [target[:, i, :] for i in range(target.shape[1])]
1✔
1208
            else:
1209
                # It's  a list, of lists, of torch tensors because it is ragged
1210
                num_outputs = len(targets[component][0])
1✔
1211
                curr_tensors_for_targets[component] = [torch.stack([batch_elem[i]
1✔
1212
                                                                    for batch_elem in target])
1213
                                                       for i in range(num_outputs)]
1214

1215
        # Map value of TARGET nodes to trained OUTPUT nodes
1216
        curr_target_tensors_for_trained_outputs = {}
1✔
1217
        for trained_output, target in self.outputs_to_targets_map.items():
1✔
1218
            curr_target_tensors_for_trained_outputs[trained_output] = curr_tensors_for_targets[target]
1✔
1219

1220
        # --------- Compute the loss (TARGET-OUTPUT) for each trained OUTPUT node  ---------------------------
1221

1222
        # Calculate and track the loss over the trained OUTPUT nodes:
1223
        #   curr_target_tensors_for_trained_outputs compared against curr_tensors_for_trained_outputs
1224
        for component, outputs in curr_tensors_for_trained_outputs.items():
1✔
1225
            trial_loss = 0
1✔
1226
            targets = curr_target_tensors_for_trained_outputs[component]
1✔
1227
            num_outputs = outputs.shape[1] if type(outputs) is torch.Tensor else len(outputs[0])
1✔
1228
            for i in range(num_outputs):
1✔
1229
                # loss only accepts 0 or 1d target. reshape assuming pytorch_rep.minibatch_loss dim is correct
1230

1231
                # Get the output, if it's a torch tensor we can slice, if it's a list of list (its ragged) and we
1232
                # need to index
1233
                output = outputs[:, i, :] if type(outputs) is torch.Tensor else torch.stack([batch_elem[i] for batch_elem in outputs])
1✔
1234

1235
                comp_loss = self.loss_function(
1✔
1236
                    output,
1237
                    torch.atleast_1d(targets[i])
1238
                )
1239
                comp_loss = comp_loss.reshape_as(pytorch_rep.minibatch_loss)
1✔
1240
                trial_loss += comp_loss
1✔
1241
            pytorch_rep.minibatch_loss += trial_loss
1✔
1242
        pytorch_rep.minibatch_loss_count += 1
1✔
1243

1244
        # --------- Return the values of output of trained nodes and all nodes  ---------------------------------------
1245

1246
        # IMPLEMENTATION NOTE: Need values in order corresponding to output_CIM Ports.
1247

1248
        # Get output Nodes, their out_ports and corresponding indices
1249
        #     in order of outermost AutodiffComposition's output_CIM Ports
1250
        outputs_idx_port_node_comp = []
1✔
1251
        for port in self.output_CIM.input_ports:
1✔
1252
            source_info = self.output_CIM._get_source_info_from_output_CIM(port)
1✔
1253
            source_ouput_port_idx = source_info[1].output_ports.index(source_info[0])
1✔
1254
            outputs_idx_port_node_comp.append(tuple((source_ouput_port_idx, *source_info)))
1✔
1255

1256
        # Assign values to trained_output_values and all_output_values
1257
        trained_output_values = []
1✔
1258
        all_output_values = []
1✔
1259
        for item in outputs_idx_port_node_comp:
1✔
1260
            idx, port, node, comp = item
1✔
1261
            if comp._trained_comp_nodes_to_pytorch_nodes_map:
1✔
1262
                node = comp._trained_comp_nodes_to_pytorch_nodes_map[node]
1✔
1263
            outputs = curr_tensors_for_outputs[node]
1✔
1264
            if type(outputs) is torch.Tensor:
1✔
1265
                output = outputs[:, idx, ...]
1✔
1266
            else:
1267
                output = torch.stack([batch_elem[idx] for batch_elem in outputs])
1✔
1268
            output = output.detach().cpu().numpy().copy().tolist()
1✔
1269
            if self.targets_from_outputs_map.values():
1!
1270
                trained_output_values += [output]
1✔
1271
            all_output_values += [output]
1✔
1272

1273
        # Turn into a numpy array, possibly ragged
1274
        all_output_values = convert_to_np_array(all_output_values)
1✔
1275

1276
        # Swap the first two dimensions (output_port, batch) to (batch, output_port)
1277
        all_output_values = all_output_values.swapaxes(0, 1)
1✔
1278

1279
        pytorch_rep.all_output_values = all_output_values
1✔
1280

1281
        # Get values of TARGET nodes
1282
        target_values = [value[0].detach().cpu().numpy().copy().tolist()
1✔
1283
                         for value in list(curr_tensors_for_targets.values())]
1284
        pytorch_rep.target_values = target_values
1✔
1285

1286
        # Synchronize outcomes after every trial if specified
1287
        # IMPLEMENTATION NOTE: RESULTS is not included here as it is handled in call to autodiff._update_results()
1288
        pytorch_rep.synch_with_psyneulink(synch_with_pnl_options,
1✔
1289
                                          [OPTIMIZATION_STEP, TRIAL],
1290
                                          context,
1291
                                          [NODE_VARIABLES, NODE_VALUES])
1292
        pytorch_rep.retain_for_psyneulink({TRAINED_OUTPUTS: trained_output_values,
1✔
1293
                                           TARGETS: target_values},
1294
                                          retain_in_pnl_options,
1295
                                          context)
1296

1297
        return trained_output_values, all_output_values
1✔
1298

1299
    def clear_losses(self, context=None):
1✔
1300
        self.losses = []
1✔
1301
        if self.pytorch_representation:
1!
1302
            self.pytorch_representation.retained_losses = []
1✔
1303

1304
    def do_gradient_optimization(self, retain_in_pnl_options, context, optimization_num=None):
1✔
1305
        """Compute loss and use in call to autodiff_backward() to compute gradients and update PyTorch parameters.
1306
        Update parameters (weights) based on trial(s) executed since last optimization,
1307
        Reinitizalize minibatch_loss and minibatch_loss_count
1308
        """
1309
        pytorch_rep = self.parameters.pytorch_representation._get(context=context)
1✔
1310
        minibatch_loss = pytorch_rep.minibatch_loss / pytorch_rep.minibatch_loss_count
1✔
1311

1312
        self.autodiff_backward(minibatch_loss, context)
1✔
1313

1314
        # # Save loss for current round of optimization
1315
        pytorch_rep.retain_for_psyneulink({LOSSES: minibatch_loss}, retain_in_pnl_options, context)
1✔
1316

1317
        # Reset minibatch_loss for next round of optimization
1318
        pytorch_rep.minibatch_loss = torch.zeros(1, device=self.device).double()
1✔
1319
        pytorch_rep.minibatch_loss_count = 0
1✔
1320

1321
    def autodiff_backward(self, minibatch_loss, context):
1✔
1322
        """Calculate gradients and apply to PyTorch model parameters (weights)"""
1323
        pytorch_rep = self.parameters.pytorch_representation._get(context=context)
1✔
1324
        optimizer = pytorch_rep.optimizer
1✔
1325

1326
        # Gradient updates
1327
        optimizer.zero_grad()
1✔
1328
        # Compute and log average loss over all trials since last update
1329
        minibatch_loss.backward(retain_graph=not self.force_no_retain_graph)
1✔
1330
        # Update weights and copy to PNL
1331
        optimizer.step()
1✔
1332
        assert True
1✔
1333

1334
    def _gen_llvm_function(self, *, ctx:pnlvm.LLVMBuilderContext, tags:frozenset):
1✔
1335
        if "run" in tags:
1✔
1336
            return pnlvm.codegen.gen_composition_run(ctx, self, tags=tags)
1✔
1337
        else:
1338
            return pnlvm.codegen.gen_autodiffcomp_exec(ctx, self, tags=tags)
1✔
1339

1340
    def _get_total_loss(self, num_trials: int=1, context:Context=None):
1✔
1341
        return sum(self.parameters.trial_losses._get(context)[-num_trials:]) /num_trials
1✔
1342

1343
    def _get_autodiff_inputs_values(self, input_dict: dict):
1✔
1344
        """Remove TARGET Nodes, and return dict with values of INPUT Nodes for single trial
1345
        For nested Compositions, replace input to nested Composition with inputs to its INPUT Nodes
1346
        For InuptPorts, replace with owner
1347

1348
        Returns
1349
        ---------
1350
        A dict mapping INPUT Nodes -> input values for a single trial
1351
        """
1352
        autodiff_input_dict = {}
1✔
1353
        for node, values in input_dict.items():
1✔
1354
            mech = node.owner if isinstance(node, InputPort) else node
1✔
1355
            if (mech in self.get_nested_input_nodes_at_all_levels()
1✔
1356
                    and mech not in self.get_nodes_by_role(NodeRole.TARGET)):
1357
                # Pass along inputs to all INPUT Nodes except TARGETS
1358
                # (those are handled separately in _get_autodiff_targets_values)
1359
                if torch_available:
1!
1360
                    # Convert to torch tensor of type expected by PytorchCompositionWrapper
1361
                    # values = torch.tensor(values, dtype=self.torch_dtype, device=self.device)
1362
                    values = values.type(self.torch_dtype)
1✔
1363
                autodiff_input_dict[node] = values
1✔
1364
        return autodiff_input_dict
1✔
1365

1366
    def _get_autodiff_targets_values(self, input_dict):
1✔
1367
        """Return dict with input values for TARGET Nodes
1368
        Get inputs to TARGET Nodes used for computation of loss in autodiff_forward().
1369
        Uses input_dict to get input values for TARGET Nodes that are INPUT Nodes of the AutodiffComposition,
1370
        If a TARGET Node is not an INPUT Node, it is assumed to be the target of a projection from an INPUT Node
1371
        and the value is determined by searching recursively for the input Node that projects to the TARGET Node.
1372

1373
        Returns
1374
        ---------
1375
        A dict mapping TARGET Nodes -> target values
1376
        """
1377
        target_values = {}
1✔
1378
        def get_target_value(target):
1✔
1379
            if target in self.get_nodes_by_role(NodeRole.INPUT):
1✔
1380
                return input_dict[target]
1✔
1381
            if len(target.path_afferents) > 1:
1✔
1382
                raise AutodiffCompositionError(f"TARGET Node '{target.name}' (for '{self.name}')"
1383
                                               f"cannot have more than one afferent projection.")
1384
            target = target.path_afferents[0].sender.owner
1✔
1385
            return get_target_value(target)
1✔
1386

1387
        for target in self.targets_from_outputs_map:
1✔
1388
            target_values[target] = get_target_value(target)
1✔
1389
        return target_values
1✔
1390

1391
    def _parse_learning_spec(self, inputs, targets, execution_mode, context):
1✔
1392
        stim_input, num_input_trials = super()._parse_learning_spec(inputs, targets, execution_mode, context)
1✔
1393

1394
        if not callable(inputs):
1✔
1395
            input_ports_for_INPUT_Nodes = self._get_input_receivers()
1✔
1396
            nested_inputs = {}
1✔
1397
            stim_input_copy = stim_input.copy()
1✔
1398
            # Replace input to nested Composition with inputs to its INPUT Nodes (to accommodate flattened version)
1399
            for node in stim_input_copy:
1✔
1400
                # If node is a nested Composition
1401
                if isinstance(node, Composition):
1✔
1402
                    # If owner of input_port is a Node in the nested Composition, replace entry for nested Composition
1403
                    #   in stim_input with entries for the input_ports of its INPUT Nodes
1404
                    for elem, input_port in enumerate([p for p in input_ports_for_INPUT_Nodes if p.owner in node.nodes]):
1✔
1405
                        nested_inputs[input_port] = [entry[elem] for entry in stim_input_copy[node]]
1✔
1406
                    stim_input.pop(node)
1✔
1407
                    stim_input.update(nested_inputs)
1✔
1408

1409
        return stim_input, num_input_trials
1✔
1410

1411
    def _check_nested_target_mechs(self):
1✔
1412
        pass
1✔
1413

1414
    def _identify_target_nodes(self, context)->list:
1✔
1415
        """Recursively call all nested AutodiffCompositions to assign TARGET nodes for learning"""
1416
        # Default is to use OUTPUT
1417
        target_nodes = [node for node in self.get_nodes_by_role(NodeRole.OUTPUT)
1✔
1418
                        if not isinstance(node, Composition)]
1419
        for node in self.nodes:
1✔
1420
            if isinstance(node, AutodiffComposition):
1✔
1421
                target_nodes.extend(node._identify_target_nodes(context))
1✔
1422
        return target_nodes
1✔
1423

1424
    def _get_valid_weights_shape(self, projection):
1✔
1425
        pnl_wt_matrix = projection.defaults.matrix
1✔
1426
        if not isinstance(pnl_wt_matrix, np.ndarray):
1✔
1427
            assert is_matrix_keyword(pnl_wt_matrix)
1✔
1428
            pnl_wt_matrix = projection._get_matrix_from_keyword(pnl_wt_matrix)
1✔
1429
        return pnl_wt_matrix.shape
1✔
1430

1431
    @handle_external_context()
1✔
1432
    def set_weights(self, pnl_proj, weights:Union[list, np.ndarray], context=None):
1✔
1433
        """Set weights for specified Projection."""
1434
        valid_shape = self._get_valid_weights_shape(pnl_proj)
1✔
1435
        assert weights.shape == valid_shape, \
1✔
1436
            (f"PROGRAM ERROR: Shape of weights in 'weights' arg of '{self.name}.set_weights' "
1437
             f"Specified weights do not match required shape ({valid_shape}).)")
1438
        pnl_proj.parameters.matrix._set(weights, context)
1✔
1439
        pnl_proj.parameter_ports['matrix'].parameters.value._set(weights, context)
1✔
1440

1441
    @handle_external_context(fallback_default=True)
1✔
1442
    def learn(self,
1✔
1443
              *args,
1444
              synch_projection_matrices_with_torch:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1445
              synch_node_variables_with_torch:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1446
              synch_node_values_with_torch:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1447
              synch_results_with_torch:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1448
              retain_torch_trained_outputs:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1449
              retain_torch_targets:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1450
              retain_torch_losses:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1451
              context: Context = None,
1452
              base_context: Context = Context(execution_id=None),
1453
              skip_initialization: bool = False,
1454
              **kwargs
1455
              ) -> list:
1456
        """Override to handle synch and retain args
1457
        Note: defaults for synch and retain args are set to NotImplemented, so that the user can specify None if
1458
              they want to locally override the default values for the AutodiffComposition (see docstrings for run()
1459
              and _parse_synch_and_retain_args() for additonal details).
1460
        """
1461
        execution_phase_at_entry = context.execution_phase
1✔
1462
        context.execution_phase = ContextFlags.PREPARING
1✔
1463

1464
        execution_mode = self._get_execution_mode(kwargs.pop('execution_mode', None))
1✔
1465
        context.execution_phase = execution_phase_at_entry
1✔
1466

1467
        any_nested_comps = [node for node in self.nodes if isinstance(node, Composition)]
1✔
1468
        if any_nested_comps:
1✔
1469
            # Can't learn in Python mode if any nested Compositions
1470
            if execution_mode is not pnlvm.ExecutionMode.PyTorch:
1✔
1471
                nested_comp_names = [f"'{comp.name}'" for comp in any_nested_comps]
1✔
1472
                raise AutodiffCompositionError(f"Unable to execute learning in {pnlvm.ExecutionMode.Python.name} mode "
1473
                                               f"for '{self.name}' because it contains one or more nested "
1474
                                               f"Compositions: {' ,'.join(nested_comp_names)}.")
1475

1476
            # Can't learn if any nested comps that are not AutodiffCompositions
1477
            nested_comps = [f"'{comp.name}'" for comp in any_nested_comps if not isinstance(comp, AutodiffComposition)]
1✔
1478
            if nested_comps:
1✔
1479
                raise AutodiffCompositionError(f"Unable execute learning for '{self.name}' "
1480
                                               f"because it contains nested Composition(s) "
1481
                                               f"that are not AutodiffCompositions: {' ,'.join(nested_comps)}.")
1482

1483
        if self._built_pathways is False:
1✔
1484
            self.infer_backpropagation_learning_pathways(execution_mode, context=context)
1✔
1485
            self._built_pathways = True
1✔
1486

1487
        synch_with_pnl_options, retain_in_pnl_options = (
1✔
1488
            self._parse_synch_and_retain_args(synch_projection_matrices_with_torch,
1489
                                              synch_node_variables_with_torch,
1490
                                              synch_node_values_with_torch,
1491
                                              synch_results_with_torch,
1492
                                              retain_torch_trained_outputs,
1493
                                              retain_torch_targets,
1494
                                              retain_torch_losses,
1495
                                              context=context,
1496
                                              **kwargs))
1497

1498
        if execution_mode == pnlvm.ExecutionMode.PyTorch and not torch_available:
1✔
1499
            raise AutodiffCompositionError(f"'{self.name}.learn()' has been called with ExecutionMode.Pytorch, "
1500
                                           f"but Pytorch module ('torch') is not installed. "
1501
                                           f"Please install it with `pip install torch` or `pip3 install torch`")
1502

1503
        return super().learn(*args,
1✔
1504
                             synch_with_pnl_options=synch_with_pnl_options,
1505
                             retain_in_pnl_options=retain_in_pnl_options,
1506
                             execution_mode=execution_mode,
1507
                             context=context,
1508
                             base_context=base_context,
1509
                             skip_initialization=skip_initialization,
1510
                             **kwargs)
1511

1512
    def _parse_synch_and_retain_args(self,
1✔
1513
                                     synch_projection_matrices_with_torch:Optional[LEARNING_SCALE_LITERALS],
1514
                                     synch_node_variables_with_torch:Optional[LEARNING_SCALE_LITERALS],
1515
                                     synch_node_values_with_torch:Optional[LEARNING_SCALE_LITERALS],
1516
                                     synch_results_with_torch:Optional[LEARNING_SCALE_LITERALS],
1517
                                     retain_torch_trained_outputs:Optional[LEARNING_SCALE_LITERALS],
1518
                                     retain_torch_targets:Optional[LEARNING_SCALE_LITERALS],
1519
                                     retain_torch_losses:Optional[LEARNING_SCALE_LITERALS],
1520
                                     context: Context = None,
1521
                                     **kwargs
1522
                                     )->tuple:
1523
        # Remove args from kwargs in case called from run() (won't be there if called from learn()
1524
        if synch_projection_matrices_with_torch == NotImplemented:
1✔
1525
            synch_projection_matrices_with_torch = kwargs.pop('synch_projection_matrices_with_torch', NotImplemented)
1✔
1526
            if synch_projection_matrices_with_torch == NotImplemented:
1!
1527
                synch_projection_matrices_with_torch = self.parameters.synch_projection_matrices_with_torch.default_value
1✔
1528
        if synch_node_variables_with_torch == NotImplemented:
1✔
1529
            synch_node_variables_with_torch = kwargs.pop('synch_node_variables_with_torch', NotImplemented)
1✔
1530
            if synch_node_variables_with_torch == NotImplemented:
1!
1531
                synch_node_variables_with_torch = self.parameters.synch_node_variables_with_torch.default_value
1✔
1532
        if synch_node_values_with_torch == NotImplemented:
1✔
1533
            synch_node_values_with_torch = kwargs.pop('synch_node_values_with_torch', NotImplemented)
1✔
1534
            if synch_node_values_with_torch == NotImplemented:
1!
1535
                synch_node_values_with_torch = self.parameters.synch_node_values_with_torch.default_value
1✔
1536
        if synch_results_with_torch == NotImplemented:
1✔
1537
            synch_results_with_torch = kwargs.pop('synch_results_with_torch', NotImplemented)
1✔
1538
            if synch_results_with_torch == NotImplemented:
1!
1539
                synch_results_with_torch = self.parameters.synch_results_with_torch.default_value
1✔
1540
        if retain_torch_trained_outputs == NotImplemented:
1!
1541
            retain_torch_trained_outputs = kwargs.pop('retain_torch_trained_outputs', NotImplemented)
1✔
1542
            if retain_torch_trained_outputs == NotImplemented:
1!
1543
                retain_torch_trained_outputs = self.parameters.retain_torch_trained_outputs.default_value
1✔
1544
        if retain_torch_targets == NotImplemented:
1!
1545
            retain_torch_targets = kwargs.pop('retain_torch_targets', NotImplemented)
1✔
1546
            if retain_torch_targets == NotImplemented:
1!
1547
                retain_torch_targets = self.parameters.retain_torch_targets.default_value
1✔
1548
        if retain_torch_losses == NotImplemented:
1!
1549
            retain_torch_losses = kwargs.pop('retain_torch_losses', NotImplemented)
1✔
1550
            if retain_torch_losses == NotImplemented:
1!
1551
                retain_torch_losses = self.parameters.retain_torch_losses.default_value
1✔
1552

1553
        if self.minibatch_size > 1:
1!
1554
            args_str = []
×
1555
            if retain_torch_trained_outputs in {OPTIMIZATION_STEP, TRIAL}:
×
1556
                args_str.append('retain_torch_trained_outputs')
×
1557
            if retain_torch_losses in {OPTIMIZATION_STEP,TRIAL}:
×
1558
                args_str.append('retain_torch_losses')
×
1559
            if retain_torch_targets in {OPTIMIZATION_STEP,TRIAL}:
×
1560
                args_str.append('retain_torch_targets')
×
1561
            if args_str:
×
1562
                arg_args = 'args' if len(args_str) == 1 else 'arg'
×
1563
                is_are = 'is' if len(args_str) == 1 else 'are'
×
1564
                raise AutodiffCompositionError(f"The {' ,'.join(args_str)} {arg_args} in the learn() method for "
1565
                                               f"'{self.name}' {is_are} specifed as 'OPTIMIZATION' or 'TRIAL', but "
1566
                                               f"'minibatch_size` ({self.minibatch_size}) != 1, so "
1567
                                               f"{', '.join([arg.split('_')[-1] for arg in args_str])} "
1568
                                               f"will be updated only at the end of a minibatch; "
1569
                                               f"use 'MINIBATCH' for the {arg_args} to avoid this warning.")
1570

1571
        # Package options for synching and tracking into dictionaries as arguments to learning and exec methods
1572
        synch_with_pnl_options = {MATRIX_WEIGHTS: synch_projection_matrices_with_torch
1✔
1573
                                                  or self.parameters.synch_projection_matrices_with_torch._get(context),
1574
                                  NODE_VARIABLES: synch_node_variables_with_torch
1575
                                               or self.parameters.synch_node_variables_with_torch._get(context),
1576
                                  NODE_VALUES: synch_node_values_with_torch
1577
                                               or self.parameters.synch_node_values_with_torch._get(context),
1578
                                  RESULTS: synch_results_with_torch
1579
                                                    or self.parameters.synch_results_with_torch._get(context)}
1580

1581
        retain_in_pnl_options = {TRAINED_OUTPUTS: retain_torch_trained_outputs
1✔
1582
                                                   or self.parameters.retain_torch_trained_outputs._get(context),
1583
                                 TARGETS: retain_torch_targets or self.parameters.retain_torch_targets._get(context),
1584
                                 LOSSES: retain_torch_losses or self.parameters.retain_torch_losses._get(context)}
1585

1586
        return synch_with_pnl_options, retain_in_pnl_options
1✔
1587

1588
    def _get_execution_mode(self, execution_mode):
1✔
1589
        """Parse execution_mode argument and return a valid execution mode for the learn() method
1590
        Can be overridden by subclasses to change the permitted and/or default execution mode for learning
1591
        """
1592
        if execution_mode is None:
1✔
1593
            if self.execution_mode_warned_about_default is False:
1✔
1594
                warnings.warn(f"The execution_mode argument was not specified in the learn() method of '{self.name}'; "
1✔
1595
                              f"ExecutionMode.PyTorch will be used by default.")
1596
                self.execution_mode_warned_about_default = True
1✔
1597
            execution_mode = pnlvm.ExecutionMode.PyTorch
1✔
1598

1599
        return execution_mode
1✔
1600

1601
    @handle_external_context(fallback_default=True)
1✔
1602
    def execute(self,
1✔
1603
                inputs=None,
1604
                num_trials=None,
1605
                minibatch_size=1,
1606
                optimizations_per_minibatch=1,
1607
                do_logging=False,
1608
                scheduler=None,
1609
                termination_processing=None,
1610
                call_before_minibatch=None,
1611
                call_after_minibatch=None,
1612
                call_before_time_step=None,
1613
                call_before_pass=None,
1614
                call_after_time_step=None,
1615
                call_after_pass=None,
1616
                reset_stateful_functions_to=None,
1617
                context=None,
1618
                base_context=Context(execution_id=None),
1619
                clamp_input=SOFT_CLAMP,
1620
                targets=None,
1621
                runtime_params=None,
1622
                execution_mode:pnlvm.ExecutionMode = pnlvm.ExecutionMode.PyTorch,
1623
                skip_initialization=False,
1624
                synch_with_pnl_options:Optional[Mapping]=None,
1625
                retain_in_pnl_options:Optional[Mapping]=None,
1626
                report_output:ReportOutput=ReportOutput.OFF,
1627
                report_params:ReportOutput=ReportParams.OFF,
1628
                report_progress:ReportProgress=ReportProgress.OFF,
1629
                report_simulations:ReportSimulations=ReportSimulations.OFF,
1630
                report_to_devices:ReportDevices=None,
1631
                report=None,
1632
                report_num=None,
1633
                )->np.ndarray:
1634
        """Override to execute autodiff_forward() in learning mode if execute_mode is not Python"""
1635

1636
        if (self._is_learning(context) and execution_mode is not pnlvm.ExecutionMode.PyTorch and
1✔
1637
                any([isinstance(node, Composition) for node in self.nodes])):
1638
            raise CompositionError(f"Must use execution_mode=ExecutionMode.PyTorch for learning "
1639
                                   f"that includes nested AutodiffComposition(s).")
1640

1641
        if execution_mode is not pnlvm.ExecutionMode.Python:
1✔
1642
            self._assign_execution_ids(context)
1✔
1643
            context.composition = self
1✔
1644
            context.source = ContextFlags.COMPOSITION
1✔
1645

1646
            if execution_mode is pnlvm.ExecutionMode.PyTorch and not torch_available:
1✔
1647
                raise AutodiffCompositionError(f"'{self.name}.learn()' has been called with ExecutionMode.Pytorch, "
1648
                                               f"but Pytorch module ('torch') is not installed. "
1649
                                               f"Please install it with `pip install torch` or `pip3 install torch`")
1650

1651
            if scheduler is None:
1✔
1652
                scheduler = self.scheduler
1✔
1653

1654
            if self._is_learning(context):
1✔
1655
                # TBI: How are we supposed to use base_context and statefulness here?
1656
                # TBI: can we call _build_pytorch_representation in _analyze_graph so that pytorch
1657
                # model may be modified between runs?
1658

1659
                autodiff_inputs = self._get_autodiff_inputs_values(inputs)
1✔
1660
                autodiff_targets = self._get_autodiff_targets_values(inputs)
1✔
1661

1662
                # Begin reporting of learning TRIAL:
1663
                report(self,
1✔
1664
                       LEARN_REPORT,
1665
                       # EXECUTE_REPORT,
1666
                       report_num=report_num,
1667
                       scheduler=scheduler,
1668
                       content='trial_start',
1669
                       context=context)
1670

1671
                self._build_pytorch_representation(context, base_context=base_context)
1✔
1672
                trained_output_values, all_output_values = \
1✔
1673
                                                self.autodiff_forward(inputs=autodiff_inputs,
1674
                                                                      targets=autodiff_targets,
1675
                                                                      synch_with_pnl_options=synch_with_pnl_options,
1676
                                                                      retain_in_pnl_options=retain_in_pnl_options,
1677
                                                                      execution_mode=execution_mode,
1678
                                                                      scheduler=scheduler,
1679
                                                                      context=context)
1680
                execution_phase = context.execution_phase
1✔
1681
                context.execution_phase = ContextFlags.PROCESSING
1✔
1682
                context.execution_phase = execution_phase
1✔
1683

1684
                # Complete TRIAL Panel for output report, and report progress
1685
                report(self,
1✔
1686
                       # [LEARN_REPORT],
1687
                       [EXECUTE_REPORT, PROGRESS_REPORT],
1688
                       report_num=report_num,
1689
                       scheduler=scheduler,
1690
                       content='trial_end',
1691
                       context=context)
1692

1693
                scheduler.get_clock(context)._increment_time(TimeScale.TRIAL)
1✔
1694

1695
                self.most_recent_context = context
1✔
1696
                return all_output_values
1✔
1697

1698
        # Call Composition execute in Python mode
1699
        return super(AutodiffComposition, self).execute(inputs=inputs,
1✔
1700
                                                        scheduler=scheduler,
1701
                                                        termination_processing=termination_processing,
1702
                                                        call_before_time_step=call_before_time_step,
1703
                                                        call_before_pass=call_before_pass,
1704
                                                        call_after_time_step=call_after_time_step,
1705
                                                        call_after_pass=call_after_pass,
1706
                                                        reset_stateful_functions_to=reset_stateful_functions_to,
1707
                                                        context=context,
1708
                                                        base_context=base_context,
1709
                                                        clamp_input=clamp_input,
1710
                                                        runtime_params=runtime_params,
1711
                                                        execution_mode=execution_mode,
1712
                                                        report=report,
1713
                                                        report_num=report_num
1714
                                                        )
1715

1716
    @handle_external_context(fallback_default=True)
1✔
1717
    def run(self, *args,
1✔
1718
            synch_projection_matrices_with_torch:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1719
            synch_node_variables_with_torch:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1720
            synch_node_values_with_torch:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1721
            synch_results_with_torch:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1722
            retain_torch_trained_outputs:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1723
            retain_torch_targets:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1724
            retain_torch_losses:Optional[LEARNING_SCALE_LITERALS]=NotImplemented,
1725
            batched_results:bool=False,
1726
            context: Context = None,
1727
            **kwargs):
1728
        """Override to handle synch and retain args if run called directly from run() rather than learn()
1729
        Note: defaults for synch and retain args are NotImplemented, so that the user can specify None if they want
1730
              to locally override the default values for the AutodiffComposition (see _parse_synch_and_retain_args()
1731
              for details). This is distinct from the user assigning the Parameter default_values(s), which is done
1732
              in the AutodiffComposition constructor and handled by the Parameter._specify_none attribute.
1733
        """
1734

1735
        # Store whether we need to return results list with a batch dimension, or flatten it
1736
        self.batched_results = batched_results
1✔
1737

1738
        if not (SYNCH_WITH_PNL_OPTIONS in kwargs and RETAIN_IN_PNL_OPTIONS in kwargs):
1✔
1739
            # No synch_with_pnl_options and retain_in_pnl_options dicts:
1740
            # - so must have been called from run directly rather than learn
1741
            # - therefore, must validate, parse and package options into those dicts
1742
            if synch_results_with_torch is NotImplemented:
1!
1743
                # IMPLEMENTATION NOTE:
1744
                #     If synch_results_with_torch is not specified by the user in call from run(), set it to
1745
                #     MINIBATCH (rather than RUN, which is the default_value for calls from AutodiffComposition);
1746
                #     this is required for calling _update_results() from Composition.run(), which does not itself
1747
                #     know about synch and retain options, and the expected default behavior of which is to update
1748
                #     results on every try in a call to run().
1749
                synch_results_with_torch = MINIBATCH
1✔
1750
            synch_with_pnl_options, retain_in_pnl_options = (
1✔
1751
                self._parse_synch_and_retain_args(synch_projection_matrices_with_torch,
1752
                                                   synch_node_variables_with_torch,
1753
                                                   synch_node_values_with_torch,
1754
                                                   synch_results_with_torch,
1755
                                                   retain_torch_trained_outputs,
1756
                                                   retain_torch_targets,
1757
                                                   retain_torch_losses,
1758
                                                  context=context,
1759
                                                   **kwargs))
1760
            kwargs[SYNCH_WITH_PNL_OPTIONS] = synch_with_pnl_options
1✔
1761
            kwargs[RETAIN_IN_PNL_OPTIONS] = retain_in_pnl_options
1✔
1762

1763
        results = super(AutodiffComposition, self).run(*args, context=context, **kwargs)
1✔
1764
        if EXECUTION_MODE in kwargs and kwargs[EXECUTION_MODE] is pnlvm.ExecutionMode.PyTorch:
1✔
1765
            # Synchronize specified outcomes at end of run
1766
            pytorch_rep = self.parameters.pytorch_representation.get(context)
1✔
1767
            if pytorch_rep:
1✔
1768
                pytorch_rep.synch_with_psyneulink(kwargs[SYNCH_WITH_PNL_OPTIONS], RUN, context)
1✔
1769

1770
        return results
1✔
1771

1772
    def _update_results(self, results, trial_output, execution_mode, synch_with_pnl_options, context):
1✔
1773
        """Track results at specified frequency during learning"""
1774
        if execution_mode is pnlvm.ExecutionMode.PyTorch:
1✔
1775

1776
            # Check if the trial_output is atleast 3D
1777
            is_output_3d = trial_output.ndim >= 3 or (trial_output.ndim == 2 and len(trial_output) > 0 and
1✔
1778
                                                      isinstance(trial_output[0, 0], (np.ndarray, list)))
1779

1780
            if (RESULTS in synch_with_pnl_options
1✔
1781
                    and synch_with_pnl_options[RESULTS] in {TRIAL, MINIBATCH}):
1782
                # Use Composition's own _update_results method since no savings when done trial-by-trial
1783
                if not self.batched_results and is_output_3d:
1✔
1784
                    for out in trial_output:
1✔
1785
                        super()._update_results(results, out, execution_mode, synch_with_pnl_options, context)
1✔
1786
                else:
1787
                    super()._update_results(results, trial_output, execution_mode, synch_with_pnl_options, context)
1✔
1788

1789
            elif (RESULTS in synch_with_pnl_options
1!
1790
                  and synch_with_pnl_options[RESULTS] == RUN):
1791
                # Use pytorch_reps method to keep a local list of results that are copied to autodiff.results after run
1792
                pytorch_rep = self.parameters.pytorch_representation._get(context)
1✔
1793
                if not self.batched_results and is_output_3d:
1✔
1794
                    for out in trial_output:
1✔
1795
                        pytorch_rep.retain_results(out)
1✔
1796
                else:
1797
                    pytorch_rep.retain_results(trial_output)
1✔
1798
        else:
1799
            super()._update_results(results, trial_output, execution_mode, synch_with_pnl_options, context)
1✔
1800

1801
    @handle_external_context(fallback_most_recent=True)
1✔
1802
    def save(self, path:PosixPath=None, directory:str=None, filename:str=None, context=None):
1✔
1803
        """Saves all weight matrices for all MappingProjections in the AutodiffComposition
1804

1805
        Arguments
1806
        ---------
1807
        path: Path, PosixPath or str : default None
1808
            path specification; must be a legal path specification in the filesystem.
1809
        directory: str : default ``current working directory``
1810
            directory where `matrices <MappingProjection.matrix>` for all MappingProjections
1811
            in the AutodiffComposition are saved.
1812
        filename: str : default ``<name of AutodiffComposition>_matrix_wts.pnl``
1813
            filename in which `matrices <MappingProjection.matrix>` for all MappingProjections
1814
            in the AutodiffComposition are saved.
1815
        .. note::
1816
           Matrices are saved in
1817
           `PyTorch state_dict <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`_ format.
1818

1819
        Return
1820
        ------
1821
        Path
1822

1823
        """
1824
        error_msg = f" (for saving weight matrices for '{self.name}') is not a legal path."
1✔
1825

1826
        if path:
1!
1827
            try:
1✔
1828
                path = Path(path)
1✔
1829
            except:
×
1830
                raise AutodiffCompositionError(f"'{path}'{error_msg}")
1831
        else:
1832
            try:
×
1833
                if directory:
×
1834
                    path = Path(directory)
×
1835
                else:
1836
                    path = Path(os.getcwd())
×
1837
                if filename:
×
1838
                    path = Path(os.path.join(path, filename))
×
1839
                else:
1840
                    path = Path(os.path.join(path, f'{self.name}_matrix_wts.pnl'))
×
1841
            except IsADirectoryError:
×
1842
                raise AutodiffCompositionError(f"'{path}'{error_msg}")
1843
        proj_state = {
1✔
1844
            p.name: p.parameters.matrix.get(context=context)
1845
            # p.name: p.matrix.base
1846
            for p in self.projections
1847
            if not (isinstance(p, ModulatoryProjection_Base)
1848
                    or isinstance(p.sender.owner, CompositionInterfaceMechanism)
1849
                    or isinstance(p.receiver.owner, CompositionInterfaceMechanism)
1850
                    or isinstance(p.sender.owner, ModulatoryMechanism_Base)
1851
                    or isinstance(p.receiver.owner, ModulatoryMechanism_Base)
1852
                    or p.sender.owner in self.get_nodes_by_role(NodeRole.LEARNING)
1853
                    or p.receiver.owner in self.get_nodes_by_role(NodeRole.LEARNING)
1854
                )}
1855
        try:
1✔
1856
            torch.save(proj_state, path)
1✔
1857
        except IsADirectoryError:
×
1858
            raise AutodiffCompositionError(f"'{path}'{error_msg}")
1859

1860
        self.last_saved_weights = path
1✔
1861

1862
        return path
1✔
1863

1864
    @handle_external_context(fallback_most_recent=True)
1✔
1865
    def load(self, path:PosixPath=None, directory:str=None, filename:str=None, context=None, weights_only:bool=False):
1✔
1866
        """Loads all weight matrices for all MappingProjections in the AutodiffComposition from file
1867
        Arguments
1868
        ---------
1869
        path: Path : default None
1870
            Path for file in which `MappingProjection` `matrices <MappingProjection.matrix>` are stored.
1871
            This must be a legal PosixPath object; if it is specified **directory** and **filename** are ignored.
1872
        directory: str : default ``current working directory``
1873
            directory where `MappingProjection` `matrices <MappingProjection.matrix>` are stored.
1874
        filename: str : default ``<name of AutodiffComposition>_matrix_wts.pnl``
1875
            name of file in which `MappingProjection` `matrices <MappingProjection.matrix>` are stored.
1876
        .. note::
1877
           Matrices must be stored in
1878
           `PyTorch state_dict <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`_ format.
1879
        """
1880
        error_msg = f" (for loading weight matrices for '{self.name}') is not a legal path."
1✔
1881
        if path:
1!
1882
            if not isinstance(path,Path):
1✔
1883
                raise AutodiffCompositionError(f"'{path}'{error_msg}")
1884
        else:
1885
            try:
×
1886
                if directory:
×
1887
                    path = Path(directory)
×
1888
                else:
1889
                    path = Path(os.getcwd())
×
1890
                if filename:
×
1891
                    path = Path(os.path.join(path, filename))
×
1892
                else:
1893
                    path = Path(os.path.join(path , f'{self.name}_matrix_wts.pnl'))
×
1894
            except IsADirectoryError:
×
1895
                raise AutodiffCompositionError(f"'{path}'{error_msg}")
1896
        try:
1✔
1897
            state = torch.load(path, weights_only=weights_only)
1✔
1898
        except FileNotFoundError:
×
1899
            raise AutodiffCompositionError(f"'{path}'{error_msg}")
1900

1901
        self.last_loaded_weights = path
1✔
1902

1903
        for projection in [p for p in self.projections
1✔
1904
                           if not (isinstance(p, ModulatoryProjection_Base)
1905
                                   or isinstance(p.sender.owner, CompositionInterfaceMechanism)
1906
                                   or isinstance(p.receiver.owner, CompositionInterfaceMechanism)
1907
                                   or isinstance(p.sender.owner, ModulatoryMechanism_Base)
1908
                                   or isinstance(p.receiver.owner, ModulatoryMechanism_Base)
1909
                                   or p.sender.owner in self.get_nodes_by_role(NodeRole.LEARNING)
1910
                                   or p.receiver.owner in self.get_nodes_by_role(NodeRole.LEARNING)
1911
            )]:
1912
            matrix = state[projection.name]
1✔
1913
            if np.array(matrix).shape != projection.matrix.base.shape:
1✔
1914
                raise AutodiffCompositionError(f"Shape of matrix loaded for '{projection.name}' "
1915
                                               f"({np.array(matrix).shape}) "
1916
                                               f"does not match its shape ({projection.matrix.base.shape})")
1917
            projection.matrix.base = matrix
1✔
1918
            projection.parameters.matrix.set(matrix, context=context, override=True)
1✔
1919
            projection.parameter_ports['matrix'].parameters.value.set(matrix, context=context, override=True)
1✔
1920

1921
        self._build_pytorch_representation(context=context, refresh=True)
1✔
1922

1923
    def _get_state_ids(self):
1✔
1924
        return super()._get_state_ids() + ["optimizer"]
1✔
1925

1926
    def _get_state_struct_type(self, ctx):
1✔
1927
        comp_state_type_list = ctx.get_state_struct_type(super())
1✔
1928
        pytorch_representation = self._build_pytorch_representation()
1✔
1929
        optimizer_state_type = pytorch_representation._get_compiled_optimizer()._get_optimizer_struct_type(ctx)
1✔
1930

1931
        return pnlvm.ir.LiteralStructType((
1✔
1932
            *comp_state_type_list,
1933
            optimizer_state_type))
1934

1935
    def _get_state_initializer(self, context):
1✔
1936
        comp_states = super()._get_state_initializer(context)
1✔
1937
        optimizer_states = tuple()
1✔
1938

1939
        return (*comp_states, optimizer_states)
1✔
1940

1941
    if torch_available:
1!
1942
        @handle_external_context(fallback_most_recent=True)
1✔
1943
        def copy_torch_param_to_projection_matrix(self,
1✔
1944
                                                  projection:Union[str, MappingProjection],
1945
                                                  torch_param:Union[torch.nn.Parameter, torch.Tensor, str, int],
1946
                                                  torch_module:torch.nn.Module=None,
1947
                                                  torch_slice:slice=None,
1948
                                                  validate:bool=True,
1949
                                                  context:Optional[Union[Context, str]]=None)->np.ndarray:
1950
            """Assign torch Parameter to `matrix <MappingProjection.matrix>` Parameter of specified `MappingProjection`.
1951
            Return torch_param as the np.ndarray assigned to `matrix <MappingProjection.matrix>` Parameter of
1952
            **projection**.
1953

1954
            Arguments
1955
            ---------
1956

1957
            projection : str or MappingProjection
1958
               specifies `MappingProjection` to which the torch_param is assigned as its `matrix
1959
               <MappingProjection.matrix>` Parameter;  if specified as a str, it must be the name of a
1960
               MappingProjection in the AutodiffComposition.
1961

1962
            torch_param : torch.nn.Parameter, str or int
1963
               specifies torch_param to assign to the `matrix <MappingProjection.matrix>` Parameter of **projection**;
1964
               if it is a torch.nn.Parameter or torch.Tensor, then the **torch_module** argument does not need to be
1965
               specified; if specified as a str or int, it must be the name of a torch Parameter (used to access it in
1966
               the state_dict) or its index (used to access it in the parameterlist) of the **torch_module** argument,
1967
               which must be also specified.
1968

1969
            torch_module : torch.nn.Module : default None
1970
               specifies a torch.nn.Module containing **torch_param** assigned to the`matrix<MappingProjection.matrix>`
1971
               Parameter of **projection**; this does not need to be specified if **torch_param** is a
1972
               torch.nn.Parameter or torch.Tensor, but must be specified if **torch_param** is a str or int.
1973

1974
            torch_slice : slice : default None
1975
               specifies a slice of **torch_param** to assign to the `matrix <MappingProjection.matrix>` Parameter
1976
               of **projection**; if it is not specified, the entire tensor of **torch_param** is used.
1977

1978
              .. warning::
1979
                 **torch_slice** should not be specified if the specification of **torch_param** already takes this
1980
                 into account.
1981

1982
            validate : bool : default True
1983
               specifies whether to validate the **projection** and **torch_param** arguments; setting it to False
1984
               results in more efficient processing if this method is called frequently; however, invalid arguments will
1985
               raise standard Python exceptions rather than more informative AutodiffComposition errors, and unexpected
1986
               results may go unnoticed.
1987

1988
               .. warning::
1989
                  if validate is False, for efficiency: **projection** *must* be a `MappingProjection`, **torch_param**
1990
                  *must* be a torch.Tensor, and both **torch_module** and **torch_slice** are ignored.
1991

1992
            context : Context or None : default most recent Context
1993
               specifies context to use for the value of Projection.matrix;  if it is not provided, then a default
1994
               `Context` is constructed using the `name <Composition.name>` of the AutodiffComposition as the
1995
               `execution_id <Context.execution_id>`, commensurate with the one used bydefault for its `execution
1996
               <AutodiffComposition_Execution>`.
1997
            """
1998
            if validate:
1✔
1999
                torch_tensor, projection = self._validate_torch_param_and_projection(torch_param,
1✔
2000
                                                                                     torch_module,
2001
                                                                                     torch_slice,
2002
                                                                                     projection)
2003
            else:
2004
                # Assume **torch_param** is passed in as Tensor and **projection** as Projection if validate is False
2005
                torch_tensor = torch_param[torch_slice] if torch_slice else torch_param
1✔
2006

2007
            torch_param_as_pnl_matrix = torch_tensor.detach().cpu().numpy().T
1✔
2008
            projection.parameters.matrix._set(torch_param_as_pnl_matrix, context)
1✔
2009
            projection.parameter_ports['matrix'].parameters.value._set(torch_param_as_pnl_matrix, context)
1✔
2010
            return torch_param_as_pnl_matrix
1✔
2011

2012
        def copy_projection_matrix_to_torch_param(self,
1✔
2013
                                                  projection:Union[str, MappingProjection],
2014
                                                  torch_param:Union[torch.nn.Parameter, torch.Tensor, str, int],
2015
                                                  torch_module:torch.nn.Module=None,
2016
                                                  torch_slice:slice=None,
2017
                                                  validate:bool=True,
2018
                                                  context:Optional[Union[Context, str]]=None)->torch.Tensor:
2019
            """Assign the `matrix <MappingProjection.matrix>` Parameter of a `MappingProjection` to a Pytorch Parameter.
2020

2021
            .. warning:
2022
               If the PyTorch Parameter has requires_grad=True, this will impact its updating in PyTorch.
2023

2024
            Return torch.Tensor assigned to **torch_param**
2025

2026
            Arguments
2027
            ---------
2028

2029
            projection : str or MappingProjection
2030
               specifies `MappingProjection`, the `matrix <MappingProjection.matrix>` of which is assigned torch_param;
2031
               if specified as a str, it must be the name of a MappingProjection in the AutodiffComposition.
2032

2033
            torch_param : torch.nn.Parameter, str or int
2034
               specifies torch Parameter to which the `matrix <MappingProjection.matrix>` of the Projection is assigned;
2035
               if it is a torch.nn.Parameter or torch.Tensor, then the **torch_module** argument does not need to be
2036
               specified; if specified as a str or int, it must be the name of a torch Parameter (used to access it in
2037
               the state_dict) or its index (used to access it in the parameterlist) of the **torch_module** argument,
2038
               which must be also specified.
2039

2040
            torch_module : torch.nn.Module : default None
2041
               specifies a torch.nn.Module containing **torch_param** to which the **projection**'s `matrix
2042
               <MappingProjection.matrix>` Parameter is assigned; this does not need to be specified if **torch_param**
2043
               is a torch.nn.Parameter or torch.Tensor, but must be specified if **torch_param** is a str or int.
2044

2045
            torch_slice : slice : default None
2046
               specifies a slice of **torch_param** to assign to the `matrix <MappingProjection.matrix>` Parameter
2047
               of **projection**; if it is not specified, the entire tensor of **torch_param** is used.
2048

2049
              .. warning::
2050
                 **torch_slice** should not be specified if the specification of **torch_param** already takes this
2051
                 into account.
2052

2053
            validate : bool : default True
2054
               specifies whether to validate the **projection** and **torch_param** arguments; setting it to False
2055
               results in more efficient processing if this method is called frequently; however, invalid arguments
2056
               then raise standard Python exceptions rather than more informative AutodiffComposition errors,
2057
               and unexpected results may go unnoticed.
2058

2059
               .. warning::
2060
                  if validate is False, for efficiency: **projection** *must* be a `MappingProjection`, **torch_param**
2061
                  *must* be a torch.Tensor, and both **torch_module** and **torch_slice** are ignored.
2062

2063
            context : Context or None : default most recent Context
2064
               specifies context to use for the value of Projection.matrix;  if it is not provided, then a default
2065
               `Context` is constructed using the `name <Composition.name>` of the AutodiffComposition as the
2066
               `execution_id <Context.execution_id>`, commensurate with the one used bydefault for its `execution
2067
               <AutodiffComposition_Execution>`.
2068
            """
2069
            if validate:
1✔
2070
                torch_tensor, projection = self._validate_torch_param_and_projection(torch_param,
1✔
2071
                                                                                     torch_module,
2072
                                                                                     torch_slice,
2073
                                                                                     projection)
2074
            # Assume **torch_param** is passed in as a Tensor and **projection** as a Projection if validate is False
2075
            else:
2076
                torch_tensor = torch_param
1✔
2077
            if slice is not None:
1!
2078
                torch_tensor = torch_tensor[torch_slice]
1✔
2079
            matrix = projection.parameters.matrix.get(context).T.squeeze()
1✔
2080
            matrix_as_tensor = torch.tensor(matrix, dtype=torch_tensor.dtype)
1✔
2081
            torch_tensor.data.copy_(matrix_as_tensor)
1✔
2082
            return matrix_as_tensor
1✔
2083

2084
        def _validate_torch_param_and_projection(self, torch_param, torch_module, torch_slice, projection_spec)->tuple:
1✔
2085
            """Validate torch and projection arguments for copying between PyTorch and AutodiffComposition.
2086
            Return tuple of torch.Tensor and MappingProjection.
2087
            """
2088
            method_name = 'copy_torch_param_to_projection_matrix'
1✔
2089

2090
            # Torch Parameter specification is a Tensor or a torch.nn.Parameter
2091
            if isinstance(torch_param, torch.Tensor):
1✔
2092
                torch_tensor = torch_param
1✔
2093

2094
            # Torch Parameter specification is a Tensor or a torch.nn.Parameter
2095
            elif isinstance(torch_param, type(None)):
1✔
2096
                if isinstance(torch_module, (torch.nn.Parameter, torch.Tensor)):
1✔
2097
                    raise AutodiffCompositionError(f"Specification of 'torch_module' arg in {method_name}() is a "
2098
                                                   f"torch Parameter or Tensor; this should be specified using the "
2099
                                                   f"'torch_para' arg.")
2100
                raise AutodiffCompositionError(f"The 'torch_param' arg in {method_name}() ({torch_param}) must be "
2101
                                               f"specified, using either a torch.nn.Parameter or torch.Tensor, or a "
2102
                                               f"str or int paired with specification of a torch.nn.Module in the "
2103
                                               f"'torch_module' arg.")
2104
            # Torch Parameter specification is a torch.nn.Module
2105
            elif isinstance(torch_param, torch.nn.Module):
1✔
2106
                raise AutodiffCompositionError(f"Specification of 'torch_param' arg in {method_name}() ({torch_param}) "
2107
                                               f"is a Module, but must be a torch.nn.Parameter, torch.Tensor, str or "
2108
                                               f"int; if a Module is intended, use the 'torch_module' arg, and specify "
2109
                                               f"the Parameter name or index in the 'torch_param' arg.")
2110

2111
            elif isinstance(torch_param, (str, int)):
1✔
2112
                if torch_module is None:
1✔
2113
                    raise AutodiffCompositionError(f"Specifying of the 'torch_param' arg in {method_name}() with a "
2114
                                                   f"string or int ({torch_param}) requires the 'torch_module' "
2115
                                                   f"arg to be specified as well.")
2116
                if not isinstance(torch_module, torch.nn.Module):
1✔
2117
                    raise AutodiffCompositionError(f"Specification of 'torch_module' arg in {method_name}() "
2118
                                                   f"({torch_module}) must be a torch.nn.Module.")
2119
                if isinstance(torch_param, str):
1✔
2120
                    # Name of Parameter was specified, so get it from Module's state_dict,
2121
                    if torch_param not in torch_module.state_dict():
1✔
2122
                        raise AutodiffCompositionError(f"'{torch_param}' specified in 'torch_param' arg of "
2123
                                                       f"{method_name}() is not the name of a Parameter in the "
2124
                                                       f"state_dict() for '{torch_module}'.")
2125
                    torch_tensor = torch_module.state_dict()[torch_param]
1✔
2126
                else:
2127
                    # Index of Parameter was specified, so get it from Module's parameters() list
2128
                    try:
1✔
2129
                        torch_tensor = list(torch_module.parameters())[torch_param]
1✔
2130
                    except IndexError:
1✔
2131
                        raise AutodiffCompositionError(f"The value ({torch_param}) specified in the 'torch_param' arg "
2132
                                                       f"of {method_name}() is not an index within the range of the "
2133
                                                       f"ParameterList specified for the Module ('{torch_module}').")
2134
            else:
2135
                # Unrecognized specification for torch_param arg.
2136
                raise AutodiffCompositionError(f"Specification of 'torch_param' arg in {method_name}() ({torch_param}) "
2137
                                               f"must be a torch.nn.Parameter, torch.Tensor, str or int.")
2138

2139
            if torch_slice is not None:
1✔
2140
                if not isinstance(torch_slice, slice):
1✔
2141
                    if isinstance(torch_param, (str, int)):
1✔
2142
                        param_ref = f"'{torch_param}'" if isinstance(torch_param, str) else f"{torch_param}"
1✔
2143
                        raise AutodiffCompositionError(f"Specification of 'torch_slice' arg in {method_name}() "
2144
                                                       f"('{torch_slice}') for Parameter {param_ref} of {torch_module} "
2145
                                                       f"must be a slice.")
2146
                    else:
2147
                        raise AutodiffCompositionError(f"Specification of 'torch_slice' arg in {method_name}() "
2148
                                                       f"({torch_slice}) must be a slice.")
2149
                torch_tensor = torch_tensor[torch_slice]
1✔
2150

2151
            # Parse and validate projection spec
2152
            if projection_spec not in self.projections:
1✔
2153
                if isinstance(projection_spec, str):
1✔
2154
                    raise AutodiffCompositionError(f"'{projection_spec}' in {method_name}() "
2155
                                                   f"is not the name of a Projection in '{self.name}'.")
2156
                elif isinstance(projection_spec, MappingProjection):
1✔
2157
                    raise AutodiffCompositionError(f"'{projection_spec.name}' in {method_name}() "
2158
                                                   f"is not a Projection in '{self.name}'.")
2159
                else:
2160
                    assert False, f"PROGRAM ERROR: Illegal type for 'projection' ({projection_spec}) in {method_name}."
2161
            projection = self.projections[projection_spec]
1✔
2162

2163
            torch_param_as_pnl_matrix = torch_tensor.detach().cpu().numpy().T
1✔
2164
            bias_note = ""
1✔
2165
            if torch_param_as_pnl_matrix.ndim == 1:
1✔
2166
                # Note: torch biases are 1d, but PNL requires matrices to be 2d
2167
                torch_param_as_pnl_matrix = np.atleast_2d(torch_param_as_pnl_matrix)
1✔
2168
                bias_note = (f" [Note: torch biases, usually 1d, have already been converted to 2d "
1✔
2169
                             f"to match PsyNeuLink BIAS Nodes Projections.]")
2170
            if torch_param_as_pnl_matrix.shape != projection.parameters.matrix.get().shape:
1✔
2171
                raise AutodiffCompositionError(
2172
                    f"Shape of torch parameter {torch_param_as_pnl_matrix.shape} in {method_name}() does not match "
2173
                    f"shape of matrix for '{projection.name}' {projection.parameters.matrix.get().shape}.{bias_note}")
2174
            return torch_tensor, projection
1✔
2175

2176
    def show_graph(self, *args, **kwargs):
1✔
2177
        """Override to use PytorchShowGraph if show_pytorch is True"""
2178
        return self._show_graph.show_graph(*args, **kwargs)
1✔
2179

2180
    @property
1✔
2181
    def _dependent_components(self) -> Iterable[Component]:
1✔
2182
        res = super()._dependent_components
1✔
2183

2184
        # NOTE: _dependent_components should possibly be reworked to be
2185
        # a context-dependent method
2186
        for pytorch_repr in self.parameters.pytorch_representation.values.values():
1✔
2187
            if pytorch_repr is not None:
1✔
2188
                res.extend([w.projection for w in pytorch_repr.projection_wrappers])
1✔
2189

2190
        return res
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc