• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

PrincetonUniversity / PsyNeuLink / 15917088825

05 Jun 2025 04:18AM UTC coverage: 84.482% (+0.5%) from 84.017%
15917088825

push

github

web-flow
Merge pull request #3271 from PrincetonUniversity/devel

Devel

9909 of 12966 branches covered (76.42%)

Branch coverage included in aggregate %.

1708 of 1908 new or added lines in 54 files covered. (89.52%)

25 existing lines in 14 files now uncovered.

34484 of 39581 relevant lines covered (87.12%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.01
/psyneulink/library/compositions/pytorchshowgraph.py
1
# Princeton University licenses this file to You under the Apache License, Version 2.0 (the "License");
2
# you may not use this file except in compliance with the License.  You may obtain a copy of the License at:
3
#     http://www.apache.org/licenses/LICENSE-2.0
4
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed
5
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6
# See the License for the specific language governing permissions and limitations under the License.
7

8

9
# **************************************** PyTorch show_graph *********************************************************
10

11
from beartype import beartype
1✔
12

13
from psyneulink._typing import Optional, Union, Literal
1✔
14

15
from psyneulink.core.compositions import NodeRole
1✔
16
from psyneulink.core.compositions.showgraph import ShowGraph, SHOW_JUST_LEARNING_PROJECTIONS, SHOW_LEARNING
1✔
17
from psyneulink.core.components.mechanisms.processing.compositioninterfacemechanism import CompositionInterfaceMechanism
1✔
18
from psyneulink.core.llvm import ExecutionMode
1✔
19
from psyneulink.core.globals.context import Context, ContextFlags, handle_external_context
1✔
20
from psyneulink.core.globals.keywords import SHOW_PYTORCH, PNL
1✔
21

22
EXCLUDE_FROM_GRADIENT_CALC_LINE_STYLE = 'exclude_from_gradient_calc_line_style'
1✔
23
EXCLUDE_FROM_GRADIENT_CALC_COLOR = 'exclude_from_gradient_calc_color'
1✔
24

25
class PytorchShowGraph(ShowGraph):
1✔
26
    """ShowGraph object with `show_graph <ShowGraph.show_graph>` method for displaying `Composition`.
27

28
    This is a subclass of the `ShowGraph` class that is used to display the graph of a `Composition` used for learning
29
    in `PyTorch mode <Composition_Learning_AutodiffComposition>` (also see `AutodiffComposition_PyTorch`).  In this mode,
30
    any `nested Compositions <AutodiffComposition_Nesting>` are "flattened" (i.e., incorporated into the outermost
31
    Composition); also, any `Nodes <Composition_Nodes>`` designated as `exclude_from_gradient_calc
32
    <PytorchMechanismWrapper.exclude_from_gradient_calc>` are moved to the end of the graph (as they are executed
33
    after the gradient calculation), and any Projections designated as `exclude_in_autodiff
34
    <Projection.exclude_in_autodiff>` are not shown as they are not used in the gradient calculations at all.
35

36
    Arguments
37
    ---------
38

39
    show_pytorch : keyword : default 'PYTORCH'
40
        specifies that the PyTorch version of the graph should be shown.
41

42
    """
43

44
    def __init__(self, *args, **kwargs):
1✔
45
        self.show_pytorch = kwargs.pop('show_pytorch', False)
1✔
46
        super().__init__(*args, **kwargs)
1✔
47

48
    @beartype
1✔
49
    @handle_external_context(source=ContextFlags.COMPOSITION)
1✔
50
    def show_graph(self, *args, **kwargs):
1✔
51
        """Override of show_graph to check for autodiff-specific options
52
        If show_pytorch==True, build pytorch rep of autofiffcomposition
53
        If show_learning==PNL, infer backpropagation learning pathways for Python version of graph
54
        """
55
        if SHOW_LEARNING in kwargs and kwargs[SHOW_LEARNING] == PNL:
1!
56
            self.composition.infer_backpropagation_learning_pathways(ExecutionMode.Python)
×
57
            kwargs[SHOW_LEARNING] = True
×
58
            return super().show_graph(*args, **kwargs)
×
59
        self.show_pytorch = kwargs.pop('show_pytorch', False)
1✔
60
        context = kwargs.get('context')
1✔
61
        if self.show_pytorch:
1!
62
            self.pytorch_rep = self.composition._build_pytorch_representation(context, refresh=False)
1✔
63
        self.exclude_from_gradient_calc_line_style = kwargs.pop(EXCLUDE_FROM_GRADIENT_CALC_LINE_STYLE, 'dotted')
1✔
64
        self.exclude_from_gradient_calc_color = kwargs.pop(EXCLUDE_FROM_GRADIENT_CALC_COLOR, 'brown')
1✔
65
        return super().show_graph(*args, **kwargs)
1✔
66

67
    def _get_processing_graph(self, composition, context):
1✔
68
        """Helper method that creates dependencies graph for nodes of autodiffcomposition used in Pytorch mode"""
69
        if self.show_pytorch:
1!
70
            processing_graph = {}
1✔
71
            projections = self._get_projections(composition, context)
1✔
72
            nodes = self._get_nodes(composition, context)
1✔
73
            for node in nodes:
1✔
74
                dependencies = set()
1✔
75
                for projection in projections:
1✔
76
                    sender = projection.sender.owner
1✔
77
                    receiver = projection.receiver.owner
1✔
78
                    if node is receiver:
1✔
79
                        dependencies.add(sender)
1✔
80
                    # FIX: 3/9/25 - HANDLE NODE THAT PROJECTS TO OUTPUT_CIM IN SAME WAY:
81
                    # Add dependency of INPUT node of nested graph on node in outer graph that projects to it
82
                    elif (isinstance(receiver, CompositionInterfaceMechanism) and
1!
83
                          # projection.receiver.owner._get_destination_info_from_input_CIM(projection.receiver)[1]
84
                          # FIX: SUPPOSED TO RETRIEVE GRU NODE HERE,
85
                          #      BUT NEED TO DEAL WITH INTERFERING PROJECTION FROM OUTPUT NODE
86
                          receiver._get_source_info_from_output_CIM(projection.receiver)[1] is node):
NEW
87
                        dependencies.add(sender)
×
88
                    else:
89
                        for proj in [proj for proj in node.afferents if proj.sender.owner in nodes]:
1✔
90
                            dependencies.add(proj.sender.owner)
1✔
91
                processing_graph[node] = dependencies
1✔
92

93
            # Add TARGET nodes
94
            for node in self.composition.learning_components:
1!
95
                processing_graph[node] = set([afferent.sender.owner for afferent in node.path_afferents])
×
96
            return {k: processing_graph[k] for k in sorted(processing_graph.keys())}
1✔
97

98
        else:
99
            return super()._get_processing_graph(composition, context)
×
100

101
    def _get_nodes(self, composition, context):
1✔
102
        """Override to return nodes of PytorchCompositionWrapper rather than autodiffcomposition"""
103
        if self.show_pytorch:
1!
104
            nodes = [node for node in self.pytorch_rep.nodes_map
1✔
105
                           if SHOW_PYTORCH in self.pytorch_rep.nodes_map[node]._use]
106
            return nodes
1✔
107
        else:
108
            return super()._get_nodes(composition, context)
×
109

110
    def _get_projections(self, composition, context):
1✔
111
        """Override to return nodes of Pytorch graph"""
112
        if self.show_pytorch:
1!
113
            # projections = list(self.pytorch_rep.projections_map.keys())
114
            projections = [proj for proj in self.pytorch_rep.projections_map
1✔
115
                           if SHOW_PYTORCH in self.pytorch_rep.projections_map[proj]._use]
116
            # FIX: NEED TO ADD PROJECTIONS TO NESTED COMPS THAT ARE TO CIM
117
            # Add any Projections to TARGET nodes
118
            projections += [afferent
1✔
119
                            for node in self.composition.learning_components
120
                            for afferent in node.path_afferents
121
                            if not isinstance(afferent.sender.owner, CompositionInterfaceMechanism)]
122
            return projections
1✔
123
        else:
124
            return super()._get_projections(composition, context)
×
125

126
    def _proj_in_composition(self, proj, composition_projections, context)->bool:
1✔
127
        """Override to include direct Projections from outer to nested comps in Pytorch mode"""
128
        sndr = proj.sender.owner
1✔
129
        rcvr = proj.receiver.owner
1✔
130
        # # MODIFIED 2/16/25 NEW:
131
        # if isinstance(rcvr, CompositionInterfaceMechanism):
132
        #     # If receiver is an input_CIM, get the node in the inner Composition to which it projects
133
        #     #   as it may be specified as dependent on the sender in the autodiff processing_graph
134
        #     rcvr = rcvr._get_destination_info_from_input_CIM(proj.receiver)[1]
135
        # MODIFIED 2/16/25 END
136
        if self.show_pytorch:
1!
137
            processing_graph = self._get_processing_graph(self.composition, context)
1✔
138
            if proj in composition_projections:
1✔
139
                return True
1✔
140
            # Include if proj is betw. a sender and receiver specified as dependent on it in processing_graph
141
            elif (rcvr in processing_graph and sndr in processing_graph[rcvr]):
1✔
142
                return True
1✔
143
            else:
144
                return False
1✔
145
        else:
146
            return super()._proj_in_composition(proj, composition_projections, context)
×
147

148
    def _get_roles_by_node(self, composition, node, context):
1✔
149
        """Override in Pytorch mode to return NodeRole.INTERNAL for all nodes in nested compositions"""
150
        if self.show_pytorch:
1!
151
            try:
1✔
152
                return composition.get_roles_by_node(node)
1✔
153
            except:
1✔
154
                return [NodeRole.INTERNAL]
1✔
155
        if self.show_pytorch and node not in self.composition.nodes:
×
156
            return [NodeRole.INTERNAL]
×
157
        else:
158
            return super()._get_roles_by_node(composition, node, context)
×
159

160
    def _get_nodes_by_role(self, composition, role, context):
1✔
161
        """Override in Pytorch mode to return all nodes in nested compositions as INTERNAL"""
162
        if self.show_pytorch and composition is not self.composition:
1!
163
            return None
×
164
        else:
165
            return super()._get_nodes_by_role(composition, role, context)
1✔
166

167
    def _implement_graph_node(self, g, rcvr, context, *args, **kwargs):
1✔
168
        """Override to assign EXCLUDE_FROM_GRADIENT_CALC nodes their own style in Pytorch mode"""
169
        if self.show_pytorch:
1!
170
            if hasattr(rcvr, 'exclude_from_show_graph'):
1!
171
                # Exclude PsyNeuLink Nodes in AutodiffComposition marked for exclusion from Pytorch graph
NEW
172
                return
×
173
            if rcvr in self.pytorch_rep.nodes_map and self.pytorch_rep.nodes_map[rcvr].exclude_from_gradient_calc:
1!
174
                kwargs['style'] = self.exclude_from_gradient_calc_line_style
×
175
                kwargs['color'] = self.exclude_from_gradient_calc_color
×
176
            elif rcvr not in self.composition.nodes:
1✔
177
                #  Assign style to nodes of nested Compositions that are INPUT or OUTPUT nodes of Pytorch graph
178
                #  (since they are not in the outermost Composition and are therefore ignored when it is flattened)
179
                dependencies = self._get_processing_graph(self.composition, context)
1✔
180
                receivers = dependencies.keys()
1✔
181
                senders = [sender for sender_list in dependencies.values() for sender in sender_list]
1✔
182
                if rcvr in receivers and rcvr not in senders:
1✔
183
                    kwargs['color'] = self.output_color
1✔
184
                    kwargs['penwidth'] = str(self.bold_width)
1✔
185
                elif rcvr in senders and rcvr not in receivers:
1!
NEW
186
                    kwargs['color'] = self.input_color
×
NEW
187
                    kwargs['penwidth'] = str(self.bold_width)
×
188
            g.node(*args, **kwargs)
1✔
189
        else:
190
            return super()._implement_graph_node( g, rcvr, context, *args, **kwargs)
×
191

192
    def _implement_graph_edge(self, graph, proj, context, *args, **kwargs):
1✔
193
        """Override to assign custom attributes to edges"""
194

195
        if self.show_pytorch:
1!
196
            kwargs['color'] = self.default_node_color
1✔
197

198
            modulatory_node = None
1✔
199
            if proj.parameter_ports[0].mod_afferents:
1!
200
                # MODIFIED 2/22/25 OLD:
UNCOV
201
                modulatory_node = self.pytorch_rep.nodes_map[proj.parameter_ports[0].mod_afferents[0].sender.owner]
×
202
                # # MODIFIED 2/22/25 NEW:
203
                # modulatory_node = self.nodes_map[proj.parameter_ports[0].mod_afferents[0].sender.owner]
204
                # # MODIFIED 2/22/25 END
205

206
            if proj in self.pytorch_rep.projections_map:
1✔
207
                # # MODIFIED 2/25/25 NEW:
208
                # if ((hasattr(proj, 'learnable') and proj.learnable)
209
                #         or (proj in self.pytorch_rep.projections_map and
210
                #             self.pytorch_rep.projections_map[proj].matrix.requires_grad)):
211
                #     proj_is_learnable = True
212
                # # MODIFIED 2/25/25 END
213

214
                # If Projection is a LearningProjection that is active, assign color and arrowhead of a LearningProjection
215
                # # MODIFIED 2/25/25 OLD:
216
                if proj.learnable or self.pytorch_rep.projections_map[proj].matrix.requires_grad:
1!
217
                # # MODIFIED 2/25/25 NEW:
218
                # if proj_is_learnable:
219
                # # MODIFIED 2/25/25 END
220
                    kwargs['color'] = self.learning_color
1✔
221

222
                # If Projection is from a ModulatoryMechanism that is excluded from gradient calculations, assign that style
223
                elif modulatory_node and modulatory_node.exclude_from_gradient_calc:
×
224
                    kwargs['color'] = self.exclude_from_gradient_calc_color
×
225
                    kwargs['style'] = self.exclude_from_gradient_calc_line_style
×
226

227
            elif self._proj_in_composition(proj, self.pytorch_rep.projections_map, context) and proj.learnable:
1!
228
                kwargs['color'] = self.learning_color
1✔
229

230
            graph.edge(*args, **kwargs)
1✔
231

232
        else:
233
            return super()._implement_graph_edge(graph, proj, context, *args, **kwargs)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc