• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.86
/psyneulink/library/compositions/emcomposition/pytorchEMwrappers.py
1
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
2
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
3
#     http://www.apache.org/licenses/LICENSE-2.0
4
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
5
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6
# See the License for the specific language governing permissions and limitations under the License.
7

8
# ********************************************* PytorchComponent *************************************************
9

10
"""PyTorch wrapper for EMComposition"""
11

12
# import torch
13
try:
1✔
14
    import torch
1✔
NEW
15
except (ImportError, ModuleNotFoundError):
×
NEW
16
    torch = None
×
17

18
from typing import Optional
1✔
19

20
from psyneulink.library.compositions.pytorchwrappers import PytorchCompositionWrapper, PytorchMechanismWrapper
1✔
21
from psyneulink.library.components.mechanisms.modulatory.learning.EMstoragemechanism import EMStorageMechanism
1✔
22
from psyneulink.core.globals.keywords import AFTER
1✔
23

24
__all__ = ['PytorchEMCompositionWrapper']
1✔
25

26
class PytorchEMCompositionWrapper(PytorchCompositionWrapper):
1✔
27
    """Wrapper for EMComposition as a Pytorch Module"""
28

29
    def __init__(self, *args, **kwargs):
1✔
30
        super().__init__(*args, **kwargs)
1✔
31

32
        # Assign storage_node (EMComposition's EMStorageMechanism) (assumes there is only one)
33
        self.storage_node = self.nodes_map[self.composition.storage_node]
1✔
34
        # Execute storage_node after gradient calculation,
35
        #     since it assigns weights manually which messes up PyTorch gradient tracking in forward() and backward()
36
        self.storage_node.exclude_from_gradient_calc = AFTER
1✔
37

38
        # Get PytorchProjectionWrappers for Projections to match and retrieve nodes;
39
        #   used by get_memory() to construct memory_matrix and store_memory() to store entry in it
40
        pnl_storage_mech = self.storage_node.mechanism
1✔
41

42
        num_fields = len(pnl_storage_mech.input_ports)
1✔
43
        num_learning_signals = len(pnl_storage_mech.learning_signals)
1✔
44
        num_match_fields = num_learning_signals - num_fields
1✔
45

46
        # ProjectionWrappers for match nodes
47
        learning_signals_for_match_nodes = pnl_storage_mech.learning_signals[:num_match_fields]
1✔
48
        pnl_match_projs = [match_node_learning_signal.efferents[0].receiver.owner
1✔
49
                           for match_node_learning_signal in learning_signals_for_match_nodes]
50
        self.match_projection_wrappers = [self.projections_map[pnl_match_proj]
1✔
51
                                          for pnl_match_proj in pnl_match_projs]
52

53
        # ProjectionWrappers for retrieve nodes
54
        learning_signals_for_retrieve_nodes = pnl_storage_mech.learning_signals[num_match_fields:]
1✔
55
        pnl_retrieve_projs = [retrieve_node_learning_signal.efferents[0].receiver.owner
1✔
56
                              for retrieve_node_learning_signal in learning_signals_for_retrieve_nodes]
57
        self.retrieve_projection_wrappers = [self.projections_map[pnl_retrieve_proj]
1✔
58
                                             for pnl_retrieve_proj in pnl_retrieve_projs]
59

60
        # IMPLEMENTATION NOTE:
61
        #    This is needed for access by subcomponents to the PytorchEMCompositionWrapper when EMComposition is nested,
62
        #    and so _build_pytorch_representation is called on the outer Composition but not EMComposition itelf;
63
        #    access must be provided via EMComposition's pytorch_representation, rather than directly assigning
64
        #    PytorchEMCompositionWrapper as an attribute on the subcomponents, since doing the latter introduces a
65
        #    recursion when torch.nn.module.state_dict() is called on any wrapper in the hiearchay.
66
        if self.composition.pytorch_representation is None:
1✔
67
            self.composition.pytorch_representation = self
1✔
68

69
    @property
1✔
70
    def memory(self)->Optional[torch.Tensor]:
1✔
71
        """Return list of memories in which rows (outer dimension) are memories for each field.
72
        These are derived from the matrix parameters of the afferent Projections to the retrieval_nodes
73
        """
74
        num_fields = len(self.storage_node.afferents)
1✔
75
        memory_matrices = [field.matrix for field in self.retrieve_projection_wrappers]
1✔
76
        memory_capacity = len(memory_matrices[0])
1✔
77
        return (None if not all(val for val in [num_fields, memory_matrices, memory_capacity])
1✔
78
                else torch.stack([torch.stack([memory_matrices[j][i]
79
                                               for j in range(num_fields)])
80
                                  for i in range(memory_capacity)]))
81

82

83
class PytorchEMMechanismWrapper(PytorchMechanismWrapper):
1✔
84
    """Wrapper for EMStorageMechanism as a Pytorch Module"""
85

86
    def execute(self, variable, optimization_num, synch_with_pnl_options, context=None):
1✔
87
        """Override to handle storage of entry to memory_matrix by EMStorage Function"""
88
        if self.mechanism is self.composition.storage_node:
1✔
89
            # Only execute store after last optimization repetition for current mini-batch
90
            # 7/10/24:  FIX: MOVE PASSING OF THESE PARAMETERS TO context
91
            if not (optimization_num + 1) % context.composition.parameters.optimizations_per_minibatch.get(context):
1!
92
                self.store_memory(variable, context)
1✔
93
        else:
94
            super().execute(variable, optimization_num, synch_with_pnl_options, context)
1✔
95

96
    # # MODIFIED 7/29/24 NEW: NEEDED FOR torch MPS SUPPORT
97
    # @torch.jit.script_method
98
    # MODIFIED 7/29/24 END
99
    def store_memory(self, memory_to_store, context):
1✔
100
        """Store variable in memory_matrix (parallel EMStorageMechanism._execute)
101

102
        For each node in query_input_nodes and value_input_nodes,
103
        assign its value to weights of corresponding afferents to corresponding match_node and/or retrieved_node.
104
        - memory = matrix of entries made up vectors for each field in each entry (row)
105
        - entry_to_store = query_input or value_input to store
106
        - field_projections = Projections the matrices of which comprise memory
107

108
        DIVISION OF LABOR between this method and function called by it
109
        store_memory (corresponds to EMStorageMechanism._execute)
110
         - compute norms to find weakest entry in memory
111
         - compute storage_prob to determine whether to store current entry in memory
112
         - call function with memory matrix for each field, to decay existing memory and assign input to weakest entry
113
        storage_node.function (corresponds to EMStorage._function):
114
         - decay existing memories
115
         - assign input to weakest entry (given index for passed from EMStorageMechanism)
116

117
        :return: List[2d tensor] updated memories
118
        """
119
        pytorch_rep = self.composition.pytorch_representation
1✔
120

121
        memory = pytorch_rep.memory
1✔
122
        assert memory is not None, f"PROGRAM ERROR: '{pytorch_rep.name}'.memory is None"
1✔
123

124
        # Get current parameter values from EMComposition's EMStorageMechanism
125
        mech = self.mechanism
1✔
126
        random_state = mech.function.parameters.random_state._get(context)
1✔
127
        decay_rate = mech.parameters.decay_rate._get(context)      # modulable, so use getter
1✔
128
        storage_prob = mech.parameters.storage_prob._get(context)  # modulable, so use getter
1✔
129
        field_weights = mech.parameters.field_weights.get(context) # modulable, so use getter
1✔
130
        concatenation_node = mech.concatenation_node
1✔
131
        # MODIFIED 7/29/24 OLD:
132
        num_match_fields = 1 if concatenation_node else len([i for i in mech.field_types if i==1])
1✔
133
        # # MODIFIED 7/29/24 NEW: NEEDED FOR torch MPS SUPPORT
134
        # if concatenation_node:
135
        #     num_match_fields = 1
136
        # else:
137
        #     num_match_fields = 0
138
        #     for i in mech.field_types:
139
        #         if i==1:
140
        #             num_match_fields += 1
141
        # MODIFIED 7/29/24 END
142

143
        # Find weakest memory (i.e., with lowest norm)
144
        field_norms = torch.linalg.norm(memory, dim=2)
1✔
145
        if field_weights is not None:
1!
146
            field_norms *= field_weights
×
147
        row_norms = torch.sum(field_norms, axis=1)
1✔
148
        idx_of_weakest_memory = torch.argmin(row_norms)
1✔
149

150
        values = []
1✔
151
        for field_projection in pytorch_rep.match_projection_wrappers + pytorch_rep.retrieve_projection_wrappers:
1✔
152
            field_idx = pytorch_rep.composition._field_index_map[field_projection._pnl_proj]
1✔
153
            if field_projection in pytorch_rep.match_projection_wrappers:
1✔
154
                # For match projections:
155
                # - get entry to store from value of sender of Projection matrix (to accommodate concatenation_node)
156
                entry_to_store = field_projection.sender_wrapper.output
1✔
157

158
                # Retrieve the correct field (for each batch, batch is first dimension)
159
                memory_to_store_indexed = memory_to_store[:, field_idx, :]
1✔
160

161
                # - store in row
162
                axis = 0
1✔
163
                if concatenation_node is None:
1!
164
                    # Double check that the memory passed in is the output of the projection for the correct field
165
                    assert (entry_to_store == memory_to_store_indexed).all(), \
1✔
166
                        (f"PROGRAM ERROR: misalignment between memory to be stored (input passed to store_memory) "
167
                         f"and value of projection to corresponding field.")
168
            else:
169
                # For retrieve projections:
170
                # - get entry to store from memory_to_store (which has inputs to all fields)
171
                entry_to_store = memory_to_store[:, field_idx, :]
1✔
172
                # - store in column
173
                axis = 1
1✔
174
            # Get matrix containing memories for the field from the Projection
175
            field_memory_matrix = field_projection.matrix
1✔
176

177
            field_projection.matrix = self.function(entry_to_store,
1✔
178
                                                    memory_matrix=field_memory_matrix,
179
                                                    axis=axis,
180
                                                    storage_location=idx_of_weakest_memory,
181
                                                    storage_prob=storage_prob,
182
                                                    decay_rate=decay_rate,
183
                                                    random_state=random_state)
184
            values.append(field_projection.matrix)
1✔
185

186
        self.value = values
1✔
187
        return values
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc